File size: 6,871 Bytes
0b23237 1f47b32 0b23237 da6e24a 0b23237 1f47b32 0b23237 da6e24a 0b23237 1f47b32 da6e24a 1f47b32 da6e24a 1f47b32 da6e24a 0b23237 da6e24a 0b23237 da6e24a 0b23237 da6e24a 0b23237 da6e24a 1f47b32 da6e24a 1f47b32 da6e24a 0b23237 da6e24a 0b23237 da6e24a 0b23237 da6e24a 0b23237 da6e24a 0b23237 da6e24a 1f47b32 da6e24a 0b23237 da6e24a 0b23237 da6e24a 0b23237 da6e24a 1f47b32 da6e24a 1f47b32 da6e24a 1f47b32 0b23237 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import torch
import gradio as gr
import plotly.express as px
import numpy as np
from transformers import AutoModel, AutoTokenizer
########################################
# 1) Load DistilBERT with attention
########################################
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
model.eval()
########################################
# 2) Generate attention analysis
########################################
def analyze_attention(text, layer=5, top_k=3, show_heatmap=True):
"""
1. Tokenize 'text'.
2. Forward pass DistilBERT (output_attentions=True).
3. Extract attention from chosen layer (0..5).
4. Average across heads => (seq_len, seq_len).
5. Optionally create Plotly heatmap (fig_dict).
6. Create text summary of top-k focuses for each token.
7. Generate an "interpretation" to highlight interesting patterns.
"""
with torch.no_grad():
inputs = tokenizer.encode_plus(text, return_tensors="pt")
outputs = model(**inputs)
all_attentions = outputs.attentions # tuple: (#layers) each => (1, #heads, seq_len, seq_len)
# DistilBERT has 6 layers => valid range: 0..5
att = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
att_matrix = att[0].cpu().numpy() # (seq_len, seq_len)
input_ids = inputs["input_ids"][0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
seq_len = len(tokens)
# (Optional) Heatmap
fig_dict = None
if show_heatmap:
fig = px.imshow(
att_matrix,
x=tokens,
y=tokens,
labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"},
color_continuous_scale="Blues",
title=f"DistilBERT Self-Attention (Layer {layer})"
)
fig.update_xaxes(side="top")
fig.update_traces(
hovertemplate="Row token: %{y}<br>Column token: %{x}<br>Focus Weight: %{z:.3f}"
)
fig_dict = fig.to_dict()
# Top-K Summary for each row
summary_md = "## Top-K Focus for Each Token\n"
summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n"
for i in range(seq_len):
row_token = tokens[i]
row_weights = att_matrix[i]
sorted_idx = row_weights.argsort()[::-1]
top_indices = sorted_idx[:top_k]
summary_md += f"**Token '{row_token}'** focuses on:\n"
for j in top_indices:
col_token = tokens[j]
weight = row_weights[j]
summary_md += f" - `{col_token}` (weight={weight:.3f})\n"
summary_md += "\n"
# Generate an additional "interpretation" to highlight patterns
interpretation_md = interpret_attention(att_matrix, tokens)
# Combine summaries
combined_md = summary_md + "\n" + interpretation_md
return fig_dict, combined_md
########################################
# 3) Interpretation function
########################################
def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str:
"""
Provide a short bullet-list interpretation of the attention matrix:
- Count how many tokens mostly attend to themselves (diagonal).
- Find the global max attention weight (which row->col?), mention tokens involved.
- Possibly mention if we see something interesting about distribution.
"""
seq_len = len(tokens)
diagonal_focus_count = 0
# We'll track the max weight overall
max_val = -1.0
max_rc = (0, 0)
# For each row, check if diagonal is the top focus
for i in range(seq_len):
row = att_matrix[i]
best_j = row.argmax()
if best_j == i:
diagonal_focus_count += 1
# Check global max
if row[best_j] > max_val:
max_val = row[best_j]
max_rc = (i, best_j)
# Summaries
# 1) Diagonal focus stat
diag_msg = f"- **{diagonal_focus_count}/{seq_len} tokens** focus most on themselves (the diagonal)."
# 2) Global max
i, j = max_rc
token_i = tokens[i]
token_j = tokens[j]
global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'."
# 3) Possibly some quick ratio
# For each row, sum of row vs. sum of diagonal
# We'll keep it simpler for now
interpretation = "## Additional Interpretation\n\n"
interpretation += (
"Here are some overall patterns in the attention matrix that might help you:\n\n"
)
interpretation += f"{diag_msg}\n"
interpretation += f"{global_msg}\n"
interpretation += "\n- A strong diagonal means tokens often reference themselves.\n"
interpretation += (
"- If a token's top focus is another token, that suggests it's referencing or depending on that other token.\n"
)
return interpretation
########################################
# 4) Gradio UI
########################################
description_md = """
# DistilBERT Self-Attention with Extra Interpretation
**Instructions:**
1. Type your text into the box.
2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5).
3. Decide how many top tokens you want listed for each token (Top-K).
4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary.
**Reading the Heatmap**:
- **Rows** = tokens doing the looking (focus).
- **Columns** = tokens being looked at.
- **Color intensity** = how strongly the row token focuses on the column token.
Below the heatmap, you'll see:
- A **Top-K focus** summary for each token.
- An **interpretation** bullet list that highlights interesting overall patterns.
"""
def run_demo(text, layer, top_k, show_heatmap):
fig_dict, summary_md = analyze_attention(text, layer, top_k, show_heatmap)
return fig_dict, summary_md
with gr.Blocks() as demo:
gr.Markdown(description_md)
with gr.Row():
text_in = gr.Textbox(
label="Enter text",
value="Transformers handle long-range context in parallel."
)
layer_in = gr.Slider(
minimum=0, maximum=5, step=1, value=5,
label="DistilBERT Layer"
)
topk_in = gr.Slider(
minimum=1, maximum=6, step=1, value=3,
label="Top-K Focus"
)
show_heatmap_check = gr.Checkbox(
label="Show Heatmap?",
value=True
)
run_btn = gr.Button("Analyze Attention")
out_plot = gr.Plot(label="Attention Heatmap")
out_summary = gr.Markdown(label="Attention Summaries & Interpretation")
run_btn.click(
fn=run_demo,
inputs=[text_in, layer_in, topk_in, show_heatmap_check],
outputs=[out_plot, out_summary]
)
demo.launch()
|