|
import torch |
|
import gradio as gr |
|
import plotly.express as px |
|
import numpy as np |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
|
|
|
|
model_name = "distilbert-base-uncased" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name, output_attentions=True) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
def analyze_attention(text, layer=5, top_k=3, show_heatmap=True): |
|
""" |
|
1. Tokenize 'text'. |
|
2. Forward pass DistilBERT (output_attentions=True). |
|
3. Extract attention from chosen layer (0..5). |
|
4. Average across heads => (seq_len, seq_len). |
|
5. Optionally create Plotly heatmap (fig_dict). |
|
6. Create text summary of top-k focuses for each token. |
|
7. Generate an "interpretation" to highlight interesting patterns. |
|
""" |
|
|
|
with torch.no_grad(): |
|
inputs = tokenizer.encode_plus(text, return_tensors="pt") |
|
outputs = model(**inputs) |
|
all_attentions = outputs.attentions |
|
|
|
att = all_attentions[layer].mean(dim=1) |
|
|
|
att_matrix = att[0].cpu().numpy() |
|
input_ids = inputs["input_ids"][0] |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
seq_len = len(tokens) |
|
|
|
|
|
fig_dict = None |
|
if show_heatmap: |
|
fig = px.imshow( |
|
att_matrix, |
|
x=tokens, |
|
y=tokens, |
|
labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"}, |
|
color_continuous_scale="Blues", |
|
title=f"DistilBERT Self-Attention (Layer {layer})" |
|
) |
|
fig.update_xaxes(side="top") |
|
fig.update_traces( |
|
hovertemplate="Row token: %{y}<br>Column token: %{x}<br>Focus Weight: %{z:.3f}" |
|
) |
|
fig_dict = fig.to_dict() |
|
|
|
|
|
summary_md = "## Top-K Focus for Each Token\n" |
|
summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n" |
|
for i in range(seq_len): |
|
row_token = tokens[i] |
|
row_weights = att_matrix[i] |
|
sorted_idx = row_weights.argsort()[::-1] |
|
top_indices = sorted_idx[:top_k] |
|
|
|
summary_md += f"**Token '{row_token}'** focuses on:\n" |
|
for j in top_indices: |
|
col_token = tokens[j] |
|
weight = row_weights[j] |
|
summary_md += f" - `{col_token}` (weight={weight:.3f})\n" |
|
summary_md += "\n" |
|
|
|
|
|
interpretation_md = interpret_attention(att_matrix, tokens) |
|
|
|
|
|
combined_md = summary_md + "\n" + interpretation_md |
|
|
|
return fig_dict, combined_md |
|
|
|
|
|
|
|
|
|
|
|
def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str: |
|
""" |
|
Provide a short bullet-list interpretation of the attention matrix: |
|
- Count how many tokens mostly attend to themselves (diagonal). |
|
- Find the global max attention weight (which row->col?), mention tokens involved. |
|
- Possibly mention if we see something interesting about distribution. |
|
""" |
|
|
|
seq_len = len(tokens) |
|
diagonal_focus_count = 0 |
|
|
|
max_val = -1.0 |
|
max_rc = (0, 0) |
|
|
|
|
|
for i in range(seq_len): |
|
row = att_matrix[i] |
|
best_j = row.argmax() |
|
if best_j == i: |
|
diagonal_focus_count += 1 |
|
|
|
if row[best_j] > max_val: |
|
max_val = row[best_j] |
|
max_rc = (i, best_j) |
|
|
|
|
|
|
|
diag_msg = f"- **{diagonal_focus_count}/{seq_len} tokens** focus most on themselves (the diagonal)." |
|
|
|
|
|
i, j = max_rc |
|
token_i = tokens[i] |
|
token_j = tokens[j] |
|
global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'." |
|
|
|
|
|
|
|
|
|
|
|
interpretation = "## Additional Interpretation\n\n" |
|
interpretation += ( |
|
"Here are some overall patterns in the attention matrix that might help you:\n\n" |
|
) |
|
interpretation += f"{diag_msg}\n" |
|
interpretation += f"{global_msg}\n" |
|
|
|
interpretation += "\n- A strong diagonal means tokens often reference themselves.\n" |
|
interpretation += ( |
|
"- If a token's top focus is another token, that suggests it's referencing or depending on that other token.\n" |
|
) |
|
|
|
return interpretation |
|
|
|
|
|
|
|
|
|
|
|
description_md = """ |
|
# DistilBERT Self-Attention with Extra Interpretation |
|
|
|
**Instructions:** |
|
1. Type your text into the box. |
|
2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5). |
|
3. Decide how many top tokens you want listed for each token (Top-K). |
|
4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary. |
|
|
|
**Reading the Heatmap**: |
|
- **Rows** = tokens doing the looking (focus). |
|
- **Columns** = tokens being looked at. |
|
- **Color intensity** = how strongly the row token focuses on the column token. |
|
|
|
Below the heatmap, you'll see: |
|
- A **Top-K focus** summary for each token. |
|
- An **interpretation** bullet list that highlights interesting overall patterns. |
|
""" |
|
|
|
def run_demo(text, layer, top_k, show_heatmap): |
|
fig_dict, summary_md = analyze_attention(text, layer, top_k, show_heatmap) |
|
return fig_dict, summary_md |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(description_md) |
|
|
|
with gr.Row(): |
|
text_in = gr.Textbox( |
|
label="Enter text", |
|
value="Transformers handle long-range context in parallel." |
|
) |
|
layer_in = gr.Slider( |
|
minimum=0, maximum=5, step=1, value=5, |
|
label="DistilBERT Layer" |
|
) |
|
topk_in = gr.Slider( |
|
minimum=1, maximum=6, step=1, value=3, |
|
label="Top-K Focus" |
|
) |
|
show_heatmap_check = gr.Checkbox( |
|
label="Show Heatmap?", |
|
value=True |
|
) |
|
run_btn = gr.Button("Analyze Attention") |
|
|
|
out_plot = gr.Plot(label="Attention Heatmap") |
|
out_summary = gr.Markdown(label="Attention Summaries & Interpretation") |
|
|
|
run_btn.click( |
|
fn=run_demo, |
|
inputs=[text_in, layer_in, topk_in, show_heatmap_check], |
|
outputs=[out_plot, out_summary] |
|
) |
|
|
|
demo.launch() |
|
|
|
|