Get attention from PretrainedTransformerEmbedder


What would be the best way to access attention from the pretrained transformer embedders? It seems that the API for token_embedders doesn’t make it feasible.

I don’t need attention during training, so if there’s a simpler way to get it working for inference only I’m open to that (e.g. saving the fine-tuned embedder and loading it through huggingface transformers).

Thank you,

We have a way of using a hook to capture each module’s outputs. This is implemented as a context manager inside of a Predictor. If that works for you, then it is definitely the easiest way to get what you want. If the attention is returned as a module output anywhere in huggingface code, then this should work. If it doesn’t work, then I think the next best solution is something like what you’re saying, where you pull out the underlying transformer model and return its output somehow.

I ended up just not using an AllenNLP embedder and handling the embedding directly in the model. Not really in spirit with the library, but it was a quick and dirty fix that worked for me since I was only ever going to use BERT embeddings in my experiments. Sharing code in case anyone else has a similar issue.

from typing import Dict

import torch
from torch import nn
from torch.nn import functional as F
from transformers import BertModel

from import Vocabulary
from allennlp.models import Model
from allennlp.modules import TextFieldEmbedder, Seq2VecEncoder
from allennlp.nn import util
from import CategoricalAccuracy, Average, F1Measure

    'bert-base-uncased': BertModel,
    'bert-base-cased': BertModel,

class SimpleAttentionClassifier(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: str,
                 encoder: Seq2VecEncoder,
                 label_namespace: str = "labels",
                 attention_layer: str = "first",
                 finetune_bert: bool = True,

        assert embedder in EMBEDDERS.keys(), f"embedder must be in {list(EMBEDDERS.keys())}"
        self.finetune_bert = finetune_bert
        self.attention_layer = attention_layer

        # encoder and embedder layers
        self.embedder = EMBEDDERS[embedder].from_pretrained(embedder, output_attentions=True)
        self.encoder = encoder
        self.labels = vocab.get_index_to_token_vocabulary(namespace=label_namespace)
        self.classifier_out = len(self.labels)
        self.embed_dim = encoder.get_output_dim()

        self.classifier = nn.Linear(self.embed_dim, self.classifier_out)
        self.accuracy = CategoricalAccuracy()
    def forward(self,
                text: Dict[str, torch.Tensor],
                label: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        if label is None or not self.finetune_bert:
        # Shape: (batch_size, num_tokens, embedding_dim)
        token_ids = text["bert"]["token_ids"]
        outputs = self.embedder(token_ids)
        embedded_text = outputs[0]
        # Shape: (batch_size, num_tokens)
        mask = util.get_text_field_mask(text)
        # Shape: (batch_size, encoding_dim)
        encoded_text = self.encoder(embedded_text, mask)
        # Shape: (batch_size, classifier_out)
        logits = self.classifier(encoded_text)
        # Shape: (batch_size, classifier_out)
        probs = F.softmax(logits, dim=-1)
        output = {'probs': probs.detach().tolist()}

        if label is None:
            # Shape: (num_layers, bsz, num_attn_heads, seq_len)
            attention = torch.stack(outputs[-1])[:,:,:,0,:]
            # Shape: (bsz, num_attn_heads, seq_len)
            if self.attention_layer == "first":
                attention = attention[0]
            elif self.attention_layer == "last":
                attention = attention[-1]
            # Shape: (bsz, seq_len)
            attention = attention.mean(dim=1)
            attention /= attention.sum(dim=1).unsqueeze(-1)
            tokens = []
            for seq in token_ids:
                tokens.append([self.vocab.get_token_from_index(i.item(), "tags") for i in seq])
            output.update({'encoded_text': encoded_text.detach().tolist()})
            output.update({'attention': attention.detach().tolist()})
            output.update({'tokens': tokens})
            self.accuracy(logits, label)
            output['loss'] = F.cross_entropy(logits, label)
        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {"accuracy": self.accuracy.get_metric(reset)}
        return metrics