LLMdemo / app.py
kevin1911's picture
Update app.py
da6e24a verified
raw
history blame
6.87 kB
import torch
import gradio as gr
import plotly.express as px
import numpy as np
from transformers import AutoModel, AutoTokenizer
########################################
# 1) Load DistilBERT with attention
########################################
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
model.eval()
########################################
# 2) Generate attention analysis
########################################
def analyze_attention(text, layer=5, top_k=3, show_heatmap=True):
"""
1. Tokenize 'text'.
2. Forward pass DistilBERT (output_attentions=True).
3. Extract attention from chosen layer (0..5).
4. Average across heads => (seq_len, seq_len).
5. Optionally create Plotly heatmap (fig_dict).
6. Create text summary of top-k focuses for each token.
7. Generate an "interpretation" to highlight interesting patterns.
"""
with torch.no_grad():
inputs = tokenizer.encode_plus(text, return_tensors="pt")
outputs = model(**inputs)
all_attentions = outputs.attentions # tuple: (#layers) each => (1, #heads, seq_len, seq_len)
# DistilBERT has 6 layers => valid range: 0..5
att = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len)
att_matrix = att[0].cpu().numpy() # (seq_len, seq_len)
input_ids = inputs["input_ids"][0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
seq_len = len(tokens)
# (Optional) Heatmap
fig_dict = None
if show_heatmap:
fig = px.imshow(
att_matrix,
x=tokens,
y=tokens,
labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"},
color_continuous_scale="Blues",
title=f"DistilBERT Self-Attention (Layer {layer})"
)
fig.update_xaxes(side="top")
fig.update_traces(
hovertemplate="Row token: %{y}<br>Column token: %{x}<br>Focus Weight: %{z:.3f}"
)
fig_dict = fig.to_dict()
# Top-K Summary for each row
summary_md = "## Top-K Focus for Each Token\n"
summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n"
for i in range(seq_len):
row_token = tokens[i]
row_weights = att_matrix[i]
sorted_idx = row_weights.argsort()[::-1]
top_indices = sorted_idx[:top_k]
summary_md += f"**Token '{row_token}'** focuses on:\n"
for j in top_indices:
col_token = tokens[j]
weight = row_weights[j]
summary_md += f" - `{col_token}` (weight={weight:.3f})\n"
summary_md += "\n"
# Generate an additional "interpretation" to highlight patterns
interpretation_md = interpret_attention(att_matrix, tokens)
# Combine summaries
combined_md = summary_md + "\n" + interpretation_md
return fig_dict, combined_md
########################################
# 3) Interpretation function
########################################
def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str:
"""
Provide a short bullet-list interpretation of the attention matrix:
- Count how many tokens mostly attend to themselves (diagonal).
- Find the global max attention weight (which row->col?), mention tokens involved.
- Possibly mention if we see something interesting about distribution.
"""
seq_len = len(tokens)
diagonal_focus_count = 0
# We'll track the max weight overall
max_val = -1.0
max_rc = (0, 0)
# For each row, check if diagonal is the top focus
for i in range(seq_len):
row = att_matrix[i]
best_j = row.argmax()
if best_j == i:
diagonal_focus_count += 1
# Check global max
if row[best_j] > max_val:
max_val = row[best_j]
max_rc = (i, best_j)
# Summaries
# 1) Diagonal focus stat
diag_msg = f"- **{diagonal_focus_count}/{seq_len} tokens** focus most on themselves (the diagonal)."
# 2) Global max
i, j = max_rc
token_i = tokens[i]
token_j = tokens[j]
global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'."
# 3) Possibly some quick ratio
# For each row, sum of row vs. sum of diagonal
# We'll keep it simpler for now
interpretation = "## Additional Interpretation\n\n"
interpretation += (
"Here are some overall patterns in the attention matrix that might help you:\n\n"
)
interpretation += f"{diag_msg}\n"
interpretation += f"{global_msg}\n"
interpretation += "\n- A strong diagonal means tokens often reference themselves.\n"
interpretation += (
"- If a token's top focus is another token, that suggests it's referencing or depending on that other token.\n"
)
return interpretation
########################################
# 4) Gradio UI
########################################
description_md = """
# DistilBERT Self-Attention with Extra Interpretation
**Instructions:**
1. Type your text into the box.
2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5).
3. Decide how many top tokens you want listed for each token (Top-K).
4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary.
**Reading the Heatmap**:
- **Rows** = tokens doing the looking (focus).
- **Columns** = tokens being looked at.
- **Color intensity** = how strongly the row token focuses on the column token.
Below the heatmap, you'll see:
- A **Top-K focus** summary for each token.
- An **interpretation** bullet list that highlights interesting overall patterns.
"""
def run_demo(text, layer, top_k, show_heatmap):
fig_dict, summary_md = analyze_attention(text, layer, top_k, show_heatmap)
return fig_dict, summary_md
with gr.Blocks() as demo:
gr.Markdown(description_md)
with gr.Row():
text_in = gr.Textbox(
label="Enter text",
value="Transformers handle long-range context in parallel."
)
layer_in = gr.Slider(
minimum=0, maximum=5, step=1, value=5,
label="DistilBERT Layer"
)
topk_in = gr.Slider(
minimum=1, maximum=6, step=1, value=3,
label="Top-K Focus"
)
show_heatmap_check = gr.Checkbox(
label="Show Heatmap?",
value=True
)
run_btn = gr.Button("Analyze Attention")
out_plot = gr.Plot(label="Attention Heatmap")
out_summary = gr.Markdown(label="Attention Summaries & Interpretation")
run_btn.click(
fn=run_demo,
inputs=[text_in, layer_in, topk_in, show_heatmap_check],
outputs=[out_plot, out_summary]
)
demo.launch()