Ways to obtain multiple tensors from BasicTextFieldEmbedder

I use four token embedders in my program: embedding, token_character, elmo_token_embedder, and dependency label embedding.
If I directly use the text_field_embedder. I will get a complete tensor.

## tokens: Dict[str, torch.LongTensor]
all_embeddings = self.text_field_embedder(tokens)

However, in my use case, I want to take the embedding and token_character embeddings to do something first. For example, in my case, I use additional dependency head/parent tensor to obtain each embedding’s parent/head embedding. Then, I concatenate them together.

Here is what I did:

elmo_embedding = self.text_field_embeder({"elmo": tokens["elmo"]})
token_embeddings = self.text_field_embeder({"tokens": tokens["tokens"], "token_characters": tokens["token_characters"]})
## heads with size: (batch_size x sent_len) is the dependency parent index for each word/position.
dep_head_emb = torch.gather(token_embeddings, 1, heads.view(batch_size, sent_len, 1).expand(batch_size, sent_len, token_emb_size)) ## this step to select
dep_label_emb = self.text_field_embeder({"dependency_label": tokens["dependency_label"]})

all_embs = torch.cat((elmo_embedding, token_embeddings, dep_head_emb, dep_label_emb), dim=2)

But doing something like this doesn’t work, it would trigger the Mismatched token keys in https://github.com/allenai/allennlp/blob/master/allennlp/modules/text_field_embedders/basic_text_field_embedder.py#L96

Even though I can set the allow_unmatched_keys to True, if I called elmo_embedding = self.text_field_embeder({"elmo": tokens["elmo"]}), it will still enumerate all the original four keys instead of this single key I specify.

Some solutions in my mind:

  1. Should I just create another BasicTextFieldEmbedder and then modify the code to just enumerate the keys that I specify? If that should work, one thing I’m not sure is these few lines (https://github.com/allenai/allennlp/blob/master/allennlp/modules/text_field_embedders/basic_text_field_embedder.py#L108-L112). Can I just keep it what it looks like? And then just enumerate the input keys?
  2. Because the keys will be sorted, maybe I can just take a sub-tensor from the complete embedded tensor? Like embedded_sequence[:, :, start:end]

Instead of using the TextFieldEmbedder, use a collection of TokenEmbedders directly. If you need more detail than this, I can give it.

Okay, I think I sort of understand the basic idea. But not sure what I did is correct:

## First modify the constructor as below, remove text_field embedder, but use a bunch of token embedders
def __init__(
            self,
            vocab: Vocabulary,
            # text_field_embedder: TextFieldEmbedder,
            token_embedders: Dict[str, TokenEmbedder])
     self._token_embedders = token_embedders

def forward(tokens: Dict[str, torch.LongTensor]):
    elmo_embedding = self._token_embedders['elmo'](tokens["elmo"])
    token_embeddings = self._token_embedders['tokens'](tokens["tokens"])
    character_embeddings = self._token_embedders['token_characters'](tokens["token_characters"])
    word_emb = torch.cat((token_embeddings,character_embeddings), dim = 2)
    dep_head_emb = torch.gather(word_emb, 1, heads.view(batch_size, sent_len, 1).expand(batch_size, sent_len, token_emb_size))
    dep_label_emb = self._token_embedders["dependency_label"](tokens["dependency_label"])
   all_embs = torch.cat((elmo_embedding, token_embeddings, dep_head_emb, dep_label_emb), dim=2)

Yes, that’s basically what I was suggesting. If it’s working for you, then it’s “correct” enough - the abstractions that we provide are only there to make things easier. If they don’t work for what you want to do, you don’t need to use them, or worry about what is the “correct” way to do things.

Thanks. It works well

Hi @mattg I run into a device issue when running using GPU
Using the above code, I got this error with GPU

  File "/home/allanj/anaconda3/envs/allennlp/bin/allennlp", line 10, in <module>
    sys.exit(run())
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/run.py", line 18, in run
    main(prog="allennlp")
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/commands/__init__.py", line 102, in main
    args.func(args)
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/commands/train.py", line 124, in train_model_from_args
    args.cache_prefix)
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/commands/train.py", line 168, in train_model_from_file
    cache_directory, cache_prefix)
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/commands/train.py", line 252, in train_model
    metrics = trainer.train()
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/training/trainer.py", line 478, in train
    train_metrics = self._train_epoch(epoch)
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/training/trainer.py", line 320, in _train_epoch
    loss = self.batch_loss(batch_group, for_training=True)
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/training/trainer.py", line 261, in batch_loss
    output_dict = self.model(**batch)
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "./src/models/ner_with_dep.py", line 120, in forward
    token_embeddings = self._token_embedders['tokens'](tokens["tokens"].to(device))
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/modules/token_embedders/embedding.py", line 144, in forward
    sparse=self.sparse)
  File "/home/allanj/anaconda3/envs/allennlp/lib/python3.6/site-packages/torch/nn/functional.py", line 1484, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_index_select

I made a workaround below and it works. I’m still not sure why this error happens. But want a bit more clean solution to this.

def forward(tokens: Dict[str, torch.LongTensor]):
    device = torch.device("cuda:"+str(util.get_device_of(SOME_OTHER_TENSOR)))
        for key in self._token_embedders:
            self._token_embedders[key] = self._token_embedders[key].to(device)
    token_embeddings = self._token_embedders['tokens'](tokens["tokens"])

Okay I think I solve this by adding something (add_module) to the constructor:

def __init__(
            self,
            vocab: Vocabulary,
            # text_field_embedder: TextFieldEmbedder,
            token_embedders: Dict[str, TokenEmbedder])
     self._token_embedders = token_embedders
     for key, embedder in token_embedders.items():
          name = "token_embdder_%s" % key
          self.add_module(name, embedder)

@Allan_Jie, you can also use a ModuleDict for self._token_embedders. See https://pytorch.org/docs/stable/nn.html#torch.nn.ModuleDict. This handles calling add_module for its elements automatically.

1 Like