Regenerating token representations within model

I am trying to implement a model described here - in allennlp.
The basic structure of the model contains two submodels - a rationale extractor and classifier. Basically the first model takes token level embeddings and return a 0/1 output that tells us if the token is important or not (think of it as sequence tagging problem). We use this decision to subset the tokens in the document (keeping only tokens with output 1) and then reencode them to get token representations and perform classification. The models are trained jointly with reinforce algorithm.

Currently I use a textfield to pass in the document to rationale extractor but I have no direct way to subset and reencode documents for passing to the classifier .

My current setup is to pass the original tokens in a metadatafield also (and a reference to dataset reader object). The re-encoding is done by first subsetting these tokens, converting into instance, batching, padding and convert to tensor dict. Basically the whole interface between the model and dataset reader is broken and they interact back and forth. I was hoping to find a better way to perform this functionality.

    # This is method in my dataset reader
    def convert_tokens_to_instance(self, tokens: List[Token]):
        fields = {}
        fields["document"] = TextField(tokens, self._token_indexers)
        return Instance(fields)
    # This is a method in my Model
    def regenerate_tokens(self, metadata, token_selection):
        # Contain 0/1 decision about should we keep a token or not
        token_selection = token_selection.cpu().data.numpy()

        #Tokens for each document
        tokens = [m["tokens"] for m in metadata]
        instances = []
        for words, mask in zip(tokens, token_selection):
            mask = mask[: len(words)]
            # For each document select tokens with mask = 1
            new_words = [w for i, (w, m) in enumerate(zip(words, mask)) m == 1]

            # Convert them into text field instance
            instance = metadata[0]["convert_tokens_to_instance"](new_words)

        batch = Batch(instances)
        padding_lengths = batch.get_padding_lengths()

        batch = batch.as_tensor_dict(padding_lengths)
        return {k: for k, v in batch["document"].items()}

Yeah, this kind of thing is going to be hard. I don’t have specific advice for you, other than, don’t feel like you have to use any particular abstraction if it’s not helpful to you. If you know what kind of encoding you’re going to do, you could just put your tokens into an ArrayField, and manipulate that directly in your model.