import torch import gradio as gr import plotly.express as px import numpy as np from transformers import AutoModel, AutoTokenizer ######################################## # 1) Load DistilBERT with attention ######################################## model_name = "distilbert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name, output_attentions=True) model.eval() ######################################## # 2) Generate attention analysis ######################################## def analyze_attention(text, layer=5, top_k=3, show_heatmap=True): """ 1. Tokenize 'text'. 2. Forward pass DistilBERT (output_attentions=True). 3. Extract attention from chosen layer (0..5). 4. Average across heads => (seq_len, seq_len). 5. Optionally create Plotly heatmap (fig_dict). 6. Create text summary of top-k focuses for each token. 7. Generate an "interpretation" to highlight interesting patterns. """ with torch.no_grad(): inputs = tokenizer.encode_plus(text, return_tensors="pt") outputs = model(**inputs) all_attentions = outputs.attentions # tuple: (#layers) each => (1, #heads, seq_len, seq_len) # DistilBERT has 6 layers => valid range: 0..5 att = all_attentions[layer].mean(dim=1) # average across heads => shape: (1, seq_len, seq_len) att_matrix = att[0].cpu().numpy() # (seq_len, seq_len) input_ids = inputs["input_ids"][0] tokens = tokenizer.convert_ids_to_tokens(input_ids) seq_len = len(tokens) # (Optional) Heatmap fig_dict = None if show_heatmap: fig = px.imshow( att_matrix, x=tokens, y=tokens, labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"}, color_continuous_scale="Blues", title=f"DistilBERT Self-Attention (Layer {layer})" ) fig.update_xaxes(side="top") fig.update_traces( hovertemplate="Row token: %{y}
Column token: %{x}
Focus Weight: %{z:.3f}" ) fig_dict = fig.to_dict() # Top-K Summary for each row summary_md = "## Top-K Focus for Each Token\n" summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n" for i in range(seq_len): row_token = tokens[i] row_weights = att_matrix[i] sorted_idx = row_weights.argsort()[::-1] top_indices = sorted_idx[:top_k] summary_md += f"**Token '{row_token}'** focuses on:\n" for j in top_indices: col_token = tokens[j] weight = row_weights[j] summary_md += f" - `{col_token}` (weight={weight:.3f})\n" summary_md += "\n" # Generate an additional "interpretation" to highlight patterns interpretation_md = interpret_attention(att_matrix, tokens) # Combine summaries combined_md = summary_md + "\n" + interpretation_md return fig_dict, combined_md ######################################## # 3) Interpretation function ######################################## def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str: """ Provide a short bullet-list interpretation of the attention matrix: - Count how many tokens mostly attend to themselves (diagonal). - Find the global max attention weight (which row->col?), mention tokens involved. - Possibly mention if we see something interesting about distribution. """ seq_len = len(tokens) diagonal_focus_count = 0 # We'll track the max weight overall max_val = -1.0 max_rc = (0, 0) # For each row, check if diagonal is the top focus for i in range(seq_len): row = att_matrix[i] best_j = row.argmax() if best_j == i: diagonal_focus_count += 1 # Check global max if row[best_j] > max_val: max_val = row[best_j] max_rc = (i, best_j) # Summaries # 1) Diagonal focus stat diag_msg = f"- **{diagonal_focus_count}/{seq_len} tokens** focus most on themselves (the diagonal)." # 2) Global max i, j = max_rc token_i = tokens[i] token_j = tokens[j] global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'." # 3) Possibly some quick ratio # For each row, sum of row vs. sum of diagonal # We'll keep it simpler for now interpretation = "## Additional Interpretation\n\n" interpretation += ( "Here are some overall patterns in the attention matrix that might help you:\n\n" ) interpretation += f"{diag_msg}\n" interpretation += f"{global_msg}\n" interpretation += "\n- A strong diagonal means tokens often reference themselves.\n" interpretation += ( "- If a token's top focus is another token, that suggests it's referencing or depending on that other token.\n" ) return interpretation ######################################## # 4) Gradio UI ######################################## description_md = """ # DistilBERT Self-Attention with Extra Interpretation **Instructions:** 1. Type your text into the box. 2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5). 3. Decide how many top tokens you want listed for each token (Top-K). 4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary. **Reading the Heatmap**: - **Rows** = tokens doing the looking (focus). - **Columns** = tokens being looked at. - **Color intensity** = how strongly the row token focuses on the column token. Below the heatmap, you'll see: - A **Top-K focus** summary for each token. - An **interpretation** bullet list that highlights interesting overall patterns. """ def run_demo(text, layer, top_k, show_heatmap): fig_dict, summary_md = analyze_attention(text, layer, top_k, show_heatmap) return fig_dict, summary_md with gr.Blocks() as demo: gr.Markdown(description_md) with gr.Row(): text_in = gr.Textbox( label="Enter text", value="Transformers handle long-range context in parallel." ) layer_in = gr.Slider( minimum=0, maximum=5, step=1, value=5, label="DistilBERT Layer" ) topk_in = gr.Slider( minimum=1, maximum=6, step=1, value=3, label="Top-K Focus" ) show_heatmap_check = gr.Checkbox( label="Show Heatmap?", value=True ) run_btn = gr.Button("Analyze Attention") out_plot = gr.Plot(label="Attention Heatmap") out_summary = gr.Markdown(label="Attention Summaries & Interpretation") run_btn.click( fn=run_demo, inputs=[text_in, layer_in, topk_in, show_heatmap_check], outputs=[out_plot, out_summary] ) demo.launch()