How to Visualize Model Internals and Attention in Hugging Face Transformers

Learn how to visualize the Hugging Face Transformers model and attention internally.



How to Visualize Model Internals and Attention in Hugging Face Transformers
Image by Editor | Midjourney

 

Let’s learn to visualize the Hugging Face Transformers model internally and with attention.

 

Preparation

 
For our tutorial would require several visualization packages and the Transformers. We can install them using the following code:

pip install transformers matplotlib seaborn bertviz

 

Additionally, you should install the PyTorch package.

With the package installed, we will get into the next part.
 

Model Internals and Attention Visualization

 
It's sometimes hard to understand when we talk about the Transformers model internally and attention as we need a deep understanding of the model architecture. By visualizing them, we can understand the intuition behind the model easily to create certain predictions.

In general, Transformers is a deep-learning model architecture developed by a Google researcher in 2017 and based on a multi-head attention mechanism. The architecture is a breakthrough in the NLP field that allows text processing simultaneously rather than sequentially. Using the self-attention mechanism, the model can weigh the word importance in the sentence and understand the context.

We would visualize the model internals and the attention mechanism with BERT. First, we would perform gradient-based visualization to understand which word the BERT model deems important when processing the sentence.

import torch
import matplotlib.pyplot as plt
from transformers import AutoModel, AutoTokenizer

model_name = "bert-base-uncased"
model = AutoModel.from_pretrained(model_name, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs = tokenizer("The quick brown fox jumps over the lazy dog.", return_tensors="pt")

embeddings = model.embeddings.word_embeddings(inputs['input_ids'])
embeddings.retain_grad()

outputs = model(inputs_embeds=embeddings)

loss = outputs.last_hidden_state.sum()
loss.backward()  

gradients = embeddings.grad
average_gradients = gradients[0].mean(dim=1).detach().numpy()

plt.plot(average_gradients, marker='o')
plt.title("Averaged Gradients for Input Tokens")
plt.xlabel("Token Index")
plt.ylabel("Average Gradient Value")
plt.xticks(ticks=range(len(average_gradients)), labels=tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]), rotation=45)
plt.grid(True)
plt.show()

 

How to Visualize Model Internals and Attention in Hugging Face Transformers
 

We can see that the word jumps is the highest, followed by brown. It means that these tokens are the most sensitive, and changing them would cause changes in the model output.

Now, we would visualize the attention weights from one of the heads to understand further how the BERT model works and which words are contextually considered related in the sentences.

import seaborn as sns

attention = outputs.attentions
attention_matrix = attention[0][0][0].detach().numpy()

sns.heatmap(attention_matrix, xticklabels=tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
            yticklabels=tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), cmap="viridis")
plt.title("Attention Weights")
plt.show()

 

How to Visualize Model Internals and Attention in Hugging Face Transformers
 

The attention weights show that the words fox and dog have higher weights than the others, which shows the BERT model paying attention to these words and considering them as the main subject in the sentences.

The Transformers are based on a multi-head mechanism, so we will visualize all 12 heads from the first layer. As a reminder, the BERT structure contains 12 layers with 12 heads in each layer for their multi-head mechanism.

fig, axes = plt.subplots(3, 4, figsize=(15, 10))
for i, ax in enumerate(axes.flat):
    sns.heatmap(attention[0][0][i].detach().numpy(), ax=ax, cmap="viridis",
                xticklabels=tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
                yticklabels=tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
    ax.set_title(f"Head {i+1}")
plt.tight_layout()
plt.show()

 

How to Visualize Model Internals and Attention in Hugging Face Transformers
 

By visualizing all the heads, we can understand where the model specifically gives attention to each part of the sentences and which aspect they focus on.

Next, we can visualize the hidden states of the token in each layer. This analysis is great for understanding the word representation evolving through the model.

outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states

cls_hidden_states = [state[:, 4, :].detach().numpy() for state in hidden_states]

plt.plot([state.mean() for state in cls_hidden_states])
plt.title("Mean Hidden State of fox Token Across Layers")
plt.xlabel("Layer")
plt.ylabel("Mean Activation")
plt.show()

 

How to Visualize Model Internals and Attention in Hugging Face Transformers
 

We can see that the token Fox significantly decreases in the first layer as the model adjusts them in the early phase but increases in the last layer to finalize where Fox interacts with the rest of the sentences.

Lastly, we can use the bertviz package to visualize multi-head attention throughout the layer easily.

from bertviz import head_view
attention = outputs.attentions
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
head_view(attention, tokens)

 

How to Visualize Model Internals and Attention in Hugging Face Transformers
 

Master the Hugging Face Transformers model internal and attention visualization to understand how your model works.

 

Additional Resouces

 

 
 

Cornellius Yudha Wijaya is a data science assistant manager and data writer. While working full-time at Allianz Indonesia, he loves to share Python and data tips via social media and writing media. Cornellius writes on a variety of AI and machine learning topics.





Our Top 3 Partner Recommendations



1. Best VPN for Engineers - Stay secure & private online with a free trial

2. Best Project Management Tool for Tech Teams - Boost team efficiency today

4. Best Network Management Tool - Best for Medium to Large Companies