Pretrained Transformer Mismatched Embedder output embeddings with nan after a few iterations

Hi!

I’v tried to replace the Stacked-BiLSTM encoder of the graph-parser model (actually a slightly different model that is based on it) with a bert encoder.
Im using the pretrained-transformer-mismatched indexer and embedder, and the “pass_through” encoder. After a few iterations, I get “nan” while computing the loss (BCEWithLogits).
It seems like PretrainedTransformerMismatchedEmbedder() returns some nan values in the embeddings of the given tokens (and so nan logits and nan CE losses). Any idea of how to debug/fix it? Im using grad_norm…

Here is the relevant changes in the configuration (the BASE is very similar to this one):

local BASE = import 'biaffine-graph-parser.jsonnet';

local bert_model = "bert-base-uncased";
local max_length = 128;
local bert_dim = 768;

BASE+{
  "dataset_reader"+: {
     "token_indexers": {
      "tokens": {
        "type": "pretrained_transformer_mismatched",
        "model_name": bert_model,
        "max_length": max_length
      },
    },
  },
  "model"+: {
     "text_field_embedder": {
        "token_embedders": {
            "tokens": {
                "type": "pretrained_transformer_mismatched",
                "model_name": bert_model,
                "max_length": max_length
            }
        }
     },
     "pos_tag_embedding"+:{
       "sparse": false  # huggingface_adamw cannot work with sparse
     },
//    "pos_tag_embedding": null,
     "encoder": {
      "type": "pass_through",
      "input_dim": bert_dim + $.model.pos_tag_embedding.embedding_dim
     },
   },
   "trainer"+: {
       "grad_norm": 1.0,
       "optimizer": {
          "type": "huggingface_adamw",
          "lr": 1e-3,
          "weight_decay": 0.01,
          "parameter_groups": [
            [[".*transformer.*"], {"lr": 1e-5}]
          ]
        }
   }
}

Thanks!

1e-3 seems like a really high learning rate for a transformer. nans in a loss that works with other configurations likely points to an optimization problem, and the first places to look for those kinds of problems are in initializations and learning rates. As your initializations come from a pretrained model, that’s likely not the issue, and so it’s probably the learning rate.

Thanks for your answer!

Actually the learning rate for the transformer is 1e-5 (see “parameter_groups” part), the 1e-3 rate is for the rest of the parser network parts (FF, Bilinear,… ) - I took the optimizer part from the coref-bert-lstm example. I’v tried other learning rates (all:1e-5, trans: 1e-5), (all: 1e-5, trans: 1e-7), but got the same behavior…
The nan values arise in the first epoch (so its not a duration issue).
As for initialization, I initialize the rest parts of the network like here.

Ah, sorry, I looked at the parameters too quickly. I’m afraid I don’t have other ideas. I think pytorch has some built-in features for detecting nans, but I haven’t used them. What I would do is try to figure out which instance causes the nan, then walk through that instance by itself, and see if there’s anything up with data processing, or anything. Maybe there’s an impossible label somewhere, or something?

Mmm I’m not sure it’s a data issue since its the same data as for the original BiLSTM version, that run successfully.
I’ll try to figure out, and update if something interesting come up.

Thanks :slight_smile:

It’s possible that there’s something in the conversion to and from wordpieces that is messing things up. That’s where I would look for a data issue.

Ok just figured it out - I used an old version of the mismatched embedder, where 0-length spans where not filtered out while averaging (yield division by 0). It was fixed on v1.0.

1 Like