Suggested change to dataset readers in distributed case doesn't work?

In the pre-release notes for v1.0, there is a fix mentioned for a DatasetReader used in distributed training:

rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
for idx, inputs in enumerate(data_file):
	if idx % world_size == rank:
		yield self.text_to_instance(inputs)

This code change would be made to the _read method of our DatasetReader, right? Assuming that is true, I run into an error when using this fix:

0it [00:00, ?it/s]Traceback (most recent call last):
  File "/home/johnmg/t2t/bin/allennlp", line 11, in <module>
    load_entry_point('allennlp', 'console_scripts', 'allennlp')()
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/", line 18, in run
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/commands/", line 93, in main
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/commands/", line 145, in train_model_from_args
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/commands/", line 204, in train_model_from_file
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/commands/", line 306, in train_model
    params.duplicate(), serialization_dir, print_statistics=dry_run
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/training/", line 461, in make_vocab_from_params
    all_datasets = datasets_from_params(params)
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/training/", line 205, in datasets_from_params
    train_data =
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/data/dataset_readers/", line 201, in read
    instances = [instance for instance in Tqdm.tqdm(instances)]
  File "/lustre04/scratch/johnmg/t2t/allennlp/allennlp/data/dataset_readers/", line 201, in <listcomp>
    instances = [instance for instance in Tqdm.tqdm(instances)]
  File "/home/johnmg/t2t/lib/python3.7/site-packages/tqdm/", line 1087, in __iter__
    for obj in iterable:
  File "/scratch/johnmg/t2t/t2t/data/dataset_readers/", line 62, in _read
    rank = dist.get_rank()
  File "/home/johnmg/t2t/lib/python3.7/site-packages/torch/distributed/", line 564, in get_rank
  File "/home/johnmg/t2t/lib/python3.7/site-packages/torch/distributed/", line 193, in _check_default_pg
    "Default process group is not initialized"
AssertionError: Default process group is not initialized
0it [00:00, ?it/s]

I believe this is because init_process_group is not called until this line, which is not reached until after the data loading process (and thus _read) is called.

Am I correct or just misunderstanding something? Any help is appreciated!

@JohnG You’re correct with your observations. To create the vocabulary first, the dataset reader gets called once before the distributed process group is setup and hence the error. The solution would be to always check if the distributed group is initialized in the first place before using the get_rank and other methods.

So your checks could be re-written as follows:

distributed = torch.distributed.is_initialized()
if distributed: 
    # get_rank
    # get_world_size
for idx, input_size in enumerate(data_file):
    if distributed and idx % world_size != rank:
    yield stuff
1 Like

@Ananda_Seelan Thanks a lot! That is exactly the solution I was looking for.