Hierarchical Networks and possible bug in get_final_encoder_states

Hi all,

I have built a lightweight wrapper of type Seq2VecEncoder that combines Seq2SeqEncoder and Attention. One of the ways in which I can aggregate the output of the Seq2SeqEncoder is by taking the final hidden state, using allennlp.nn.util.get_final_encoder_states(). This works fine in a regular setting, but doesn’t seem to work when I am implementing it in a hierarchical attention network setting.
Consider a multi-sentence document that I call a summary. Here is my implementation of a forward pass for the encoder:

@overrides
    def forward(self, inputs: Dict[str, torch.LongTensor],
                text_embedder: TextFieldEmbedder) -> torch.Tensor:
        """
        Forward pass implementation for the summary encoder.
        Note that `inputs['tokens']` is a tensor of shape batch_size x max_num_sentences_in_batch
        x max_num_words_in_sent_in_batch for a hierarchical model, and
        batch_size x num_words_in_batch for a flattened (non-hierarchical) model.
        """
        embedded_tokens = text_embedder(inputs)
        if self._flattened_text:
            token_mask = util.get_text_field_mask(inputs)            
            encoded_summary = self._word_encoder(embedded_tokens, token_mask)            
        else:
            token_mask = util.get_text_field_mask(inputs, num_wrapping_dims=1)
            batch_size, max_sent_count, max_sent_length = token_mask.size()
            embedded_tokens = embedded_tokens.view(batch_size * max_sent_count, max_sent_length, -1)
            word_mask = token_mask.view(batch_size * max_sent_count, max_sent_length)
            encoded_sentences = self._word_encoder(embedded_tokens, word_mask)
            sentence_mask = token_mask.sum(dim=-1).gt(0)
            encoded_summary = self._sentence_encoder(
                encoded_sentences.reshape(batch_size, max_sent_count, -1), sentence_mask)
        return encoded_summary

The forward pass for self._word_encoder is essentially get_final_encoder_states() on a GRU (bidirectional or not doesn’t matter). The stack trace with CUDA_LAUNCH_BLOCKING=1 is:

THCudaCheck FAIL file=/pytorch/aten/src/THC/generic/THCTensorScatterGather.cu line=71 error=710 : device-side assert triggered
Traceback (most recent call last):
  File "/apps/python3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/apps/python3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/apps/python3/lib/python3.7/site-packages/allennlp/run.py", line 21, in <module>
    run()
  File "/apps/python3/lib/python3.7/site-packages/allennlp/run.py", line 18, in run
    main(prog="allennlp")
  File "/apps/python3/lib/python3.7/site-packages/allennlp/commands/__init__.py", line 102, in main
    args.func(args)
  File "/apps/python3/lib/python3.7/site-packages/allennlp/commands/train.py", line 124, in train_model_from_args
    args.cache_prefix)
  File "/apps/python3/lib/python3.7/site-packages/allennlp/commands/train.py", line 168, in train_model_from_file
    cache_directory, cache_prefix)
  File "/apps/python3/lib/python3.7/site-packages/allennlp/commands/train.py", line 252, in train_model
    metrics = trainer.train()
  File "/apps/python3/lib/python3.7/site-packages/allennlp/training/trainer.py", line 478, in train
    train_metrics = self._train_epoch(epoch)
  File "/apps/python3/lib/python3.7/site-packages/allennlp/training/trainer.py", line 320, in _train_epoch
    loss = self.batch_loss(batch_group, for_training=True)
  File "/apps/python3/lib/python3.7/site-packages/allennlp/training/trainer.py", line 261, in batch_loss
    output_dict = self.model(**batch)
  File "/apps/python3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/source_to_tags/code/mux/models/model.py", line 303, in forward
    encoded_summary = self.summary_encoder(summary, self.text_field_embedder)
  File "/apps/python3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/source_to_tags/code/mux/modules/summary_encoder.py", line 65, in forward
    encoded_sentences = self._word_encoder(embedded_tokens, word_mask)
  File "/apps/python3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/source_to_tags/code/mux/modules/attention_encoder.py", line 80, in forward
    encoded_tokens, mask, bidirectional=self._encoder.is_bidirectional)
  File "/apps/python3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/source_to_tags/code/mux/modules/boe_encoder.py", line 55, in forward
    aggregated = get_final_encoder_states(tokens, mask, bidirectional)  # fails with HAN
  File "/apps/python3/lib/python3.7/site-packages/allennlp/nn/util.py", line 202, in get_final_encoder_states
    final_encoder_output = encoder_outputs.gather(1, expanded_indices)
RuntimeError: cuda runtime error (710) : device-side assert triggered at /pytorch/aten/src/THC/generic/THCTensorScatterGather.cu:71

I suspect something is going on when I’m calling embedded_tokens.view or mask.view. Any ideas?

Avneesh