Use --overrides to "remove" a model component?

I have a model that contains a FeedForward module at its ouput. During training this module is trained along with the rest of the model. At test time, I would like to drop it (the model is being used to embed text and there’s some evidence that the representations before this layer perform better in downstream tasks).

To do this, I try using the --overrides command of allennlp predict. This looks like:

--overrides '{"model.feedforward": null}'

Which appears to work. In my log, I can see:

2020-02-20 12:36:07,412 - INFO - allennlp.common.params - model.feedforward = None

But, I get the following error:

RuntimeError: Error(s) in loading state_dict for ContrastiveTextEncoder:
	Unexpected key(s) in state_dict: "_feedforward._linear_layers.0.weight", "_feedforward._linear_layers.0.bias", "_feedforward._linear_layers.1.weight", 

This makes sense. The models state_dict has an entry for feedforward and I have just removed that component from the model with --overrides.

My question is, am I misusing --overrides? Is there a better way to use drop a component from a model at test time when using allennlp predict, such that I don’t get a RuntimeError when loading the state_dict?


I noticed that load_state_dict has a parameter strict, which if False, would allow arbitary mismatches between the model and the state_dict. Is there anyway for me to switch strict=False in AllenNLP (without modifying the source code)? I see this related issue but can’t tell if anything came out of it?

Yes, this is not an intended use of --overrides, and I would not expect it to work. A couple of options:

  1. Write your own prediction code (if you look at what predict actually does, it’s quite simple), that grabs a different output than the top layer. It’s pretty easy to have your model return multiple outputs, and you can have a Predictor or whatever script you’re using just use a different one, before the FeedFoward layer that you want to remove. This seems easiest.
  2. Following the issue you quoted, override _load() in your model, copying everything from the base class, but changing the call to load_state_dict.

I see! Okay thanks a lot for the guidance.