File size: 5,770 Bytes
0b23237 1f47b32 0b23237 1f47b32 0b23237 e93333c 0b23237 1f47b32 e93333c 1f47b32 e93333c 1f47b32 0b23237 e93333c 0b23237 e93333c 0b23237 da6e24a e93333c da6e24a e93333c da6e24a e93333c da6e24a e93333c da6e24a e93333c da6e24a e93333c da6e24a e93333c 1f47b32 0b23237 e93333c da6e24a 0b23237 e93333c 0b23237 e93333c 0b23237 e93333c da6e24a e93333c da6e24a e93333c c0fdcd2 e93333c 1f47b32 e93333c 0b23237 e93333c 0b23237 e93333c da6e24a e93333c da6e24a 1f47b32 e93333c 1f47b32 e93333c 1f47b32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
import gradio as gr
import plotly.express as px
from transformers import AutoModel, AutoTokenizer
########################################
# Load Transformer (DistilBERT) with attention
########################################
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
model.eval()
def visualize_attention(text, layer=5):
"""
1. Tokenize input text.
2. Run DistilBERT forward pass to get attention matrices.
3. Pick a layer (0..5) and average across attention heads.
4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
5. Label axes with tokens (Query vs. Key).
"""
with torch.no_grad():
inputs = tokenizer.encode_plus(text, return_tensors="pt")
outputs = model(**inputs)
all_attentions = outputs.attentions
# DistilBERT has 6 layers => valid layer indices: 0..5
attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len)
# Convert to numpy for plotting
attn_matrix = attn_layer[0].cpu().numpy()
# Get tokens (including special tokens)
input_ids = inputs["input_ids"][0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Build a Plotly heatmap
fig = px.imshow(
attn_matrix,
x=tokens,
y=tokens,
labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
color_continuous_scale="Blues",
title=f"DistilBERT Attention (Layer {layer})"
)
fig.update_xaxes(side="top")
# Add tooltip
fig.update_traces(
hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}"
)
fig.update_layout(coloraxis_colorbar=dict(title="Attention Weight"))
return fig
def interpret_token_attention(text, token_index=0, layer=5):
"""
Provides a textual explanation for why a particular token (Query) attends
to other tokens in the input, highlighting the top 2 or 3 tokens
it focuses on.
"""
with torch.no_grad():
inputs = tokenizer.encode_plus(text, return_tensors="pt")
outputs = model(**inputs)
all_attentions = outputs.attentions
attn_layer = all_attentions[layer].mean(dim=1) # shape: (1, seq_len, seq_len)
# Get tokens
input_ids = inputs["input_ids"][0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Safety check for token_index
if token_index < 0 or token_index >= len(tokens):
return "Invalid token index. Please choose a valid token index."
# Extract the row corresponding to our Query token
query_attn = attn_layer[0, token_index, :].cpu().numpy() # shape: (seq_len,)
# Sort tokens by attention weight (descending)
sorted_indices = query_attn.argsort()[::-1]
top_indices = sorted_indices[:3] # Grab top 3
top_tokens = [tokens[i] for i in top_indices]
top_weights = [query_attn[i] for i in top_indices]
# Build an explanation
query_token_str = tokens[token_index]
explanation = (
f"**You chose token index {token_index}, which is '{query_token_str}'.**\n\n"
"In Transformers, each token is converted into Query, Key, and Value vectors:\n"
"- **Query** = What this token is looking for\n"
"- **Key** = What another token has to offer\n"
"- **Value** = The actual information from that token\n\n"
f"As a Query, '{query_token_str}' attends most strongly to:\n"
)
for t, w in zip(top_tokens, top_weights):
explanation += f"- **{t}** with attention weight ~ {w:.3f}\n"
explanation += (
"\nA higher attention weight indicates that this Query token is 'looking at' or "
"focusing on that Key token more strongly, likely because it finds the Key token "
"relevant to its meaning or context."
)
return explanation
# Short explanation text for the UI
description_text = """
## Understanding Transformer Self-Attention
- **Rows = Query token** (the token doing the 'looking').
- **Columns = Key token** (the token being 'looked at').
- Darker color = stronger attention weight.
**Transformers** process all tokens in **parallel**, allowing any token to attend to any other token in the sentence.
This makes it easier for the model to capture long-distance relationships.
"""
########################################
# Gradio Interface
########################################
with gr.Blocks(css="footer{display:none !important}") as demo:
gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
gr.Markdown(description_text)
with gr.Row():
text_input = gr.Textbox(
label="Enter a sentence",
value="Transformers handle long-range context in parallel."
)
layer_slider = gr.Slider(
minimum=0, maximum=5, step=1, value=5,
label="DistilBERT Layer (0=lowest, 5=highest)"
)
output_plot = gr.Plot(label="Attention Heatmap")
# Visualization Button
visualize_button = gr.Button("Visualize Attention")
visualize_button.click(
fn=visualize_attention,
inputs=[text_input, layer_slider],
outputs=output_plot
)
# Dropdown (or Slider) to choose a token index for interpretation
token_index = gr.Number(
label="Choose a token index to interpret (0-based)",
value=0
)
interpretation_output = gr.Markdown(label="Interpretation")
# Interpretation Button
interpret_button = gr.Button("Explain This Token's Attention")
interpret_button.click(
fn=interpret_token_attention,
inputs=[text_input, token_index, layer_slider],
outputs=interpretation_output
)
demo.launch()
|