NaN loss when training on large sentences with dependency parser with xlm-mlm-100-1280

I tried to train a dependency parser with XLM embeddings on a smaller GPU, hence I used a maximum length of 96. However, this leads to a NaN loss and crashes. However, when using multilingual BERT it runs fine (also with a max_len of 256 and xlm it runs). My configuration file is as follows:

//local transformer_model = "bert-base-multilingual-cased";
//local transformer_dim = 768;
local transformer_model = "xlm-mlm-100-1280";
local transformer_dim = 1280;

local max_len = 96;
//local max_len = 256;

        "token_indexers": {
            "tokens": {
                "type": "pretrained_transformer_mismatched",
                "max_length": max_len,
                "model_name": transformer_model

    "train_data_path": "../corpora/parsing/ud-treebanks-v2.5/UD_English-EWT/en_ewt-ud-train.conllu",
    "validation_data_path": "../corpora/parsing/ud-treebanks-v2.5/UD_English-EWT/en_ewt-ud-dev.conllu",
    "model": {
      "type": "biaffine_parser",
      "text_field_embedder": {
            "type": "basic",
            "token_embedders": {
                "tokens": {
                    "type": "pretrained_transformer_mismatched",
                    "last_layer_only": true,
                    "max_length": max_len,
                    "model_name": transformer_model,
                    "train_parameters": true
        "embedding_dim": 0,
        "vocab_namespace": "pos",
      "encoder": {
        "type": "pass_through",
        "input_dim": transformer_dim
      "use_mst_decoding_for_validation": true,
      "arc_representation_dim": 500,
      "tag_representation_dim": 100,
      "dropout": 0.3,
      "input_dropout": 0.3,
      "initializer": {
        "regexes": [
          [".*projection.*weight", {"type": "xavier_uniform"}],
          [".*projection.*bias", {"type": "zero"}],
          [".*tag_bilinear.*weight", {"type": "xavier_uniform"}],
          [".*tag_bilinear.*bias", {"type": "zero"}],
          [".*weight_ih.*", {"type": "xavier_uniform"}],
          [".*weight_hh.*", {"type": "orthogonal"}],
          [".*bias_ih.*", {"type": "zero"}],
          [".*bias_hh.*", {"type": "lstm_hidden_bias"}]
    "data_loader": {
      "batch_sampler": {
        "type": "bucket",
        "batch_size" : 32
    "trainer": {
        "learning_rate_scheduler": {
            "type": "slanted_triangular",
            "cut_frac": 0.2,
            "decay_factor": 0.38,
            "discriminative_fine_tuning": false,
            "gradual_unfreezing": false
        "optimizer": {
            "type": "huggingface_adamw",
            "betas": [0.9, 0.99],
            "correct_bias": false,
            "lr": 0.00001,
                        "parameter_groups": [

            "weight_decay": 0.01
      "num_epochs": 20,
      "grad_norm": 5.0,
      "patience": 5,
        "cuda_device": -1,
      "validation_metric": "+LAS",

I run this configuration on AllenNLP 1.1, with:
allennlp train dependency_parser.xlm.jsonnet -s outDir.xlm --include-package allennlp_models.structured_prediction

Just to clarify:
mbert 256: runs
mbert 96: runs
xlm: 256: runs
xlm: 96: fails

Edit: this most probably has to do with the sentences which are longer than 96 wordpieces (it also breaks on other EWT versions, even when using the simple tagger instead of the dependency parser). However, training on a single long sentence runs fine, and I inspected the wordpiece ID’s for long sentences for mbert and xlm, and both seem to be fine (xlm has a </s> and a <s> around position 96)

In my experience NaN is almost always caused by something being zero that shouldn’t be zero, like the number of non-masked tokens, the number of non-masked sequences, the number of distinct token type ids, etc.

AllenNLP has an assert for the loss not being NaN, but when that assert triggers, you’ll have to dig into the model to find out exactly where it came from. I am personally a big fan of debuggers to help with that kind of thing, but you can also sprinkle NaN asserts all over the model code to find out where it comes from.

This is a pretty generic answer for chasing down NaN issues. I don’t know enough about the specific transformer models you’re using to have a more precise answer.

Thanks for your suggestions!, I have inspected all items in words[‘tokens’] in the forward pass, and couldnt find anything suspicious. However, I did notice that it runs with batch size 1, and after that found that it only crashes if a batch has both a sentence with less then 96 wordpieces and one with more than 96. If a batch consists of only >96 sentences it runs fine. Also, it crashes only in the forward pass in the batch after the problematic one (getting all NaNs from the embedder), it seems like the `wrong’ batch leaves the transformer model in a bad state somehow.

Furthermore, I have found that this error also occurs with at least the following embeddings: distilbert-base-uncased, flaubert_flaubert_large_cased, xlm, xlm-mlm-100-1280, xlm-mlm-17-1280, xlm-mlm-tlm-xnli15-1024, xlm-mlm-xnli15-1024.

Ok, I have narrowed down the problem, and made a hack for it, but would like to hear about cleaner solutions.

The problem is when a smaller as well as a longer sentence are in the same batch. In that case both of them are padded to be the same length (indexer). When they are passed into the embedder, they are both split, leading to one row containing a mix of only ‘s and 0’s (which could be the same). And the mask for that row is simply [False] * max_len. When using mBERT, this simply returns a list of empty embeddings with the expected dimensions. However, for the previously mentioned embeddings this breaks the transformer model, as they are not expecting a `full’ mask. I have hacked my solution for xlm-mlm-tlm-xnli15-1024, which you can find here:

To summarize: I remove the empty rows before calling forward on the transformer model, and add them back afterwards, because somewhere down the line the alignment breaks if I don’t.

However, this still leaves me with 2 questions:

  1. Is the splitting supposed to be done like this?, (that even a short sentence can lead to two rows?)
  2. Would anyone know of a cleaner solution to this problem? (my solution is slow, not elegant, and only works when the padding index = 2)

What’s happening here is that it’s splitting a sentence with more than 96 tokens into two sequences, one with the first 96 tokens, and another with the rest. Correct?

Slightly more generally, it creates n sequences of length 96 for each source sequence. It does this in a new dimension (IIRC), so other sequences that are shorter than 96 still get a second sequence added, except it is all padding.

So the first question is, why do you use such a short sequence length anyways? Two sequences of length 96*2=192 will take more memory than four sequences of 96 each, but the difference won’t be large.

Second question is, do you have to use the “folding” it does, where it splits a long sequence that doesn’t fit into multiple smaller ones? Can you just truncate the sequence instead? I guess for dependency parsing that’s maybe not a solution.

If the answer to these questions is that you really need both of these things, then we’ll need to fix the code so that it can handle the case where a sequence is 100% padding. That should work anyways, but clearly it does not. If you have a workable patch for it, I’d love to add it to the library.

Yes, correct. So this is indeed the intended behavior.

Well, for the mbert model, this small difference makes it fit exactly on a 1080 gpu (12gb), without max_len it goes OOM. I also agree that 96 is a rather small, but in fact most sentences from UD fit in this without a problem.

We are using it for sequence labeling/dependency parsing/NER, so that would indeed be undesirable.

I don’t have a proper patch at the moment (besides the hack above), will submit a pull request if I can think of/implement something nicer. I can actually run with max_len=512 on other machines I have access to (with more gpu RAM). But I guess it will still break if there is one sentence with a larger length, which actually occurs in UD.

Just use a batch size of 1 then, and a longer sequence length. If you need the mathematical batch size to be larger than 1, use gradient accumulation.

Edit: With a batch size of 1, this problem will never happen, because it will never batch together a short and long sentence. It’s not a proper solution, but it will solve the problem with (most likely) minimum performance loss. And the quality of your results might go up!