How to use PretrainedTransformer as both the token embedder and the mode

Hi,

Apologies if this has been asked before. My goal is to use a PretrainedTransformer as the encoder of an encoder-decoder model. For now, this is a SimpleSeq2Seq model.

The problem is that the encoder-decoder models in AllenNLP expect both a source embedder and an encoder, but the PretrainedTransformer model is essentially both (it excepts input ids, maps these to embeddings, and then feeds the embeddings forward through the layers of the transformer).

For now, I have just subclassed SimpleSeq2Seq, removing all mentions of the source_embedder and using the encoder in its place. My config is below:

// This should be either a registered name in the Transformers library, or a path on disk to a
// serialized transformer model. Note, to avoid issues, please name the serialized model folder
// [bert|roberta|gpt2|distillbert|etc]-[base|large|etc]-[uncased|cased|etc]
local pretrained_transformer_model_name = "roberta-base";
// This will be used to set the max # of tokens and the max # of decoding steps
local max_sequence_length = 512;

{
    "dataset_reader": {
        "lazy": false,
        // TODO (John): Because our source and target text is identical, we should subclass this
        // dataloader to one which only expects one document per line.
        "type": "seq2seq",
        "source_tokenizer": {
            "type": "pretrained_transformer",
            "model_name": pretrained_transformer_model_name,
            "max_length": max_sequence_length,
        },
        "target_tokenizer": {
            "type": "spacy"
        },
        // TODO (John): For now, use different namespaces for source and target indexers. It may
        // make more sense to use the same namespaces in the future.
        "source_token_indexers": {
            "tokens": {
                "type": "pretrained_transformer",
                "model_name": pretrained_transformer_model_name,
            },
        },
        "target_token_indexers": {
            "tokens": {
                "type": "single_id",
                "namespace": "target_tokens"
            },
        },
        // This will break the pretrained_transformer token indexer. So remove for now.
        // In the future, we should assigned the special start and end sequence tokens to one of
        // BERTs unused vocab ids.
        // See: https://github.com/allenai/allennlp/issues/3435#issuecomment-558668277
        "source_add_start_token": false,
        "source_add_end_token": false,
    },
    "train_data_path": "datasets/pubmed/train.tsv",
    "validation_data_path": "datasets/pubmed/valid.tsv",
    "model": {
        "type": "pretrained_transformer",
        "encoder": {
            "type": "pretrained_transformer",
            "model_name": pretrained_transformer_model_name,
        },
        "target_namespace": "target_tokens",
        "max_decoding_steps": max_sequence_length,
        "beam_size": 8,
        "target_embedding_dim": 256,
        "use_bleu": true
    },
    "iterator": {
        "type": "bucket",
        "sorting_keys": [["source_tokens", "num_tokens"]],
        "batch_size": 2
    },
    "trainer": {
        "optimizer": {
            "type": "adam",
            // TODO (John): Because our decoder is trained from scratch, we will likely need a larger
            // learning rate. Idea, diff learning rates for encoder / decoder?
            "lr": 5e-5,
        },
        // "validation_metric": "-loss",
        "num_serialized_models_to_keep": 1,
        "num_epochs": 5,
        "cuda_device": 0,
        "grad_norm": 1.0,
    }
}

My subclassed SimpledSeq2Seq is:

from typing import Dict

import torch
from overrides import overrides
from torch.nn.modules.linear import Linear
from torch.nn.modules.rnn import LSTMCell

from allennlp.common.checks import ConfigurationError
from allennlp.common.util import END_SYMBOL
from allennlp.common.util import START_SYMBOL
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
from allennlp.models.model import Model
from allennlp.modules import Attention
from allennlp.modules import Seq2SeqEncoder
from allennlp.modules.attention import LegacyAttention
from allennlp.modules.similarity_functions import SimilarityFunction
from allennlp.modules.token_embedders import Embedding
from allennlp.nn import util
from allennlp.nn.beam_search import BeamSearch
from allennlp.training.metrics import BLEU


@Model.register("pretrained_transformer")
class PretrainedTransformerSeq2Seq(SimpleSeq2Seq):
    def __init__(
        self,
        vocab: Vocabulary,
        encoder: Seq2SeqEncoder,
        max_decoding_steps: int,
        attention: Attention = None,
        attention_function: SimilarityFunction = None,
        beam_size: int = None,
        target_namespace: str = "tokens",
        target_embedding_dim: int = None,
        scheduled_sampling_ratio: float = 0.0,
        use_bleu: bool = True,
    ) -> None:
        # TODO (John): This is dumb. I needed to avoid calling the SimpleSeq2Seq constructor
        Model.__init__(self, vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(
                self.vocab._padding_token, self._target_namespace
            )
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(
            self._end_index, max_steps=max_decoding_steps, beam_size=beam_size
        )

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError(
                    "You can only specify an attention module or an "
                    "attention function, but not both."
                )
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or encoder.get_output_dim()
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)

    @overrides
    def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(source_tokens['tokens'], source_mask)
        return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}

And my encoder

import torch
from overrides import overrides

from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from transformers.modeling_auto import AutoModel


@Seq2SeqEncoder.register("pretrained_transformer")
class PretrainedTransformerEncoder(Seq2SeqEncoder):
    """
    Implements an encoder from a pretrained transformer model (e.g. BERT).

    Parameters
    ----------
    vocab : ``Vocabulary``
    bert_model : ``Union[str, BertModel]``
        The BERT model to be wrapped. If a string is provided, we will call
        ``BertModel.from_pretrained(bert_model)`` and use the result.
    num_labels : ``int``, optional (default: None)
        How many output classes to predict. If not provided, we'll use the
        vocab_size for the ``label_namespace``.
    index : ``str``, optional (default: "bert")
        The index of the token indexer that generates the BERT indices.
    label_namespace : ``str``, optional (default : "labels")
        Used to determine the number of classes if ``num_labels`` is not supplied.
    trainable : ``bool``, optional (default : True)
        If True, the weights of the pretrained BERT model will be updated during training.
        Otherwise, they will be frozen and only the final linear layer will be trained.
    initializer : ``InitializerApplicator``, optional
        If provided, will be used to initialize the final linear layer *only*.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self, model_name: str) -> None:
        super().__init__(model_name)
        self.transformer_model = AutoModel.from_pretrained(model_name)
        # I'm not sure if this works for all models; open an issue on github if you find a case
        # where it doesn't work.
        self.output_dim = self.transformer_model.config.hidden_size

    @overrides
    def get_output_dim(self):
        return self.output_dim

    @overrides
    def is_bidirectional(self) -> bool:
        """
        Returns ``True`` if this encoder is bidirectional.  If so, we assume the forward direction
        of the encoder is the first half of the final dimension, and the backward direction is the
        second half.
        """
        return False

    @overrides
    def forward(self, inputs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:  # type: ignore
        return self.transformer_model(inputs, attention_mask=mask)[0]

This feels wrong, however. Is there a “right” way to use PretrainedTranformers with the encoder-decoders of AllenNLP?

I think I answered my own question. Just found the PassThroughTokenEmbedder which looks like exactly what I want.