Fine-tuning only part of an embedding matrix, given a vocabulary


I have a rather open-ended problem that I’m not entirely sure how to approach. Say I wanted to train BiDAF, and I wanted to freeze the pretrained glove embedding layer, except for the 1000 most common question words. Do you have a sense of how I’d go about making this work?

Pretty sure this would work:

Have three TokenIndexers (with two corresponding Embedders):

  • A frozen SingleIdTokenIndexer (and corresponding Embedder) that embeds all words (possibly minus the most frequent, but it doesn’t matter too much)
  • A tunable SingleIdTokenIndexer (and corresponding Embedder) that embeds the most frequent words (you could also have it embed all words, but that just wastes memory, and might also affect certain gradient normalizations / initializations)
  • A mask that says whether a word is frequent or not

Write your own custom TextFieldEmbedder, that takes the two Embedders mentioned above as constructor arguments, and in the forward() method looks for those three inputs. It calls both embedders, and does a masked add on them (frequent_embedding * mask + infrequent_embedding * (1 - mask)).