File size: 6,871 Bytes
0b23237
1f47b32
0b23237
da6e24a
0b23237
1f47b32
0b23237
da6e24a
0b23237
 
 
 
 
1f47b32
da6e24a
 
 
 
1f47b32
da6e24a
 
 
 
 
 
 
1f47b32
da6e24a
0b23237
 
 
da6e24a
 
 
0b23237
da6e24a
0b23237
 
da6e24a
0b23237
da6e24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f47b32
da6e24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f47b32
da6e24a
 
0b23237
da6e24a
 
 
 
0b23237
da6e24a
0b23237
 
 
da6e24a
0b23237
da6e24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b23237
da6e24a
1f47b32
 
da6e24a
 
0b23237
 
da6e24a
0b23237
da6e24a
0b23237
da6e24a
 
 
 
 
 
 
 
 
1f47b32
da6e24a
 
1f47b32
da6e24a
 
 
 
1f47b32
 
 
0b23237
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import torch
import gradio as gr
import plotly.express as px
import numpy as np
from transformers import AutoModel, AutoTokenizer

########################################
# 1) Load DistilBERT with attention
########################################
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
model.eval()

########################################
# 2) Generate attention analysis
########################################
def analyze_attention(text, layer=5, top_k=3, show_heatmap=True):
    """
    1. Tokenize 'text'.
    2. Forward pass DistilBERT (output_attentions=True).
    3. Extract attention from chosen layer (0..5).
    4. Average across heads => (seq_len, seq_len).
    5. Optionally create Plotly heatmap (fig_dict).
    6. Create text summary of top-k focuses for each token.
    7. Generate an "interpretation" to highlight interesting patterns.
    """

    with torch.no_grad():
        inputs = tokenizer.encode_plus(text, return_tensors="pt")
        outputs = model(**inputs)
        all_attentions = outputs.attentions  # tuple: (#layers) each => (1, #heads, seq_len, seq_len)
        # DistilBERT has 6 layers => valid range: 0..5
        att = all_attentions[layer].mean(dim=1)  # average across heads => shape: (1, seq_len, seq_len)

    att_matrix = att[0].cpu().numpy()  # (seq_len, seq_len)
    input_ids = inputs["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    seq_len = len(tokens)

    # (Optional) Heatmap
    fig_dict = None
    if show_heatmap:
        fig = px.imshow(
            att_matrix,
            x=tokens,
            y=tokens,
            labels={"x": "Token Being Looked At", "y": "Token Doing the Looking"},
            color_continuous_scale="Blues",
            title=f"DistilBERT Self-Attention (Layer {layer})"
        )
        fig.update_xaxes(side="top")
        fig.update_traces(
            hovertemplate="Row token: %{y}<br>Column token: %{x}<br>Focus Weight: %{z:.3f}"
        )
        fig_dict = fig.to_dict()

    # Top-K Summary for each row
    summary_md = "## Top-K Focus for Each Token\n"
    summary_md += f"Showing the **top {top_k}** tokens each token focuses on.\n\n"
    for i in range(seq_len):
        row_token = tokens[i]
        row_weights = att_matrix[i]
        sorted_idx = row_weights.argsort()[::-1]
        top_indices = sorted_idx[:top_k]

        summary_md += f"**Token '{row_token}'** focuses on:\n"
        for j in top_indices:
            col_token = tokens[j]
            weight = row_weights[j]
            summary_md += f" - `{col_token}` (weight={weight:.3f})\n"
        summary_md += "\n"

    # Generate an additional "interpretation" to highlight patterns
    interpretation_md = interpret_attention(att_matrix, tokens)

    # Combine summaries
    combined_md = summary_md + "\n" + interpretation_md

    return fig_dict, combined_md


########################################
# 3) Interpretation function
########################################
def interpret_attention(att_matrix: np.ndarray, tokens: list) -> str:
    """
    Provide a short bullet-list interpretation of the attention matrix:
    - Count how many tokens mostly attend to themselves (diagonal).
    - Find the global max attention weight (which row->col?), mention tokens involved.
    - Possibly mention if we see something interesting about distribution.
    """

    seq_len = len(tokens)
    diagonal_focus_count = 0
    # We'll track the max weight overall
    max_val = -1.0
    max_rc = (0, 0)

    # For each row, check if diagonal is the top focus
    for i in range(seq_len):
        row = att_matrix[i]
        best_j = row.argmax()
        if best_j == i:
            diagonal_focus_count += 1
        # Check global max
        if row[best_j] > max_val:
            max_val = row[best_j]
            max_rc = (i, best_j)

    # Summaries
    # 1) Diagonal focus stat
    diag_msg = f"- **{diagonal_focus_count}/{seq_len} tokens** focus most on themselves (the diagonal)."
    
    # 2) Global max
    i, j = max_rc
    token_i = tokens[i]
    token_j = tokens[j]
    global_msg = f"- The **highest single focus** in the matrix is **{max_val:.3f}**, from token '{token_i}' onto '{token_j}'."

    # 3) Possibly some quick ratio
    # For each row, sum of row vs. sum of diagonal
    # We'll keep it simpler for now

    interpretation = "## Additional Interpretation\n\n"
    interpretation += (
        "Here are some overall patterns in the attention matrix that might help you:\n\n"
    )
    interpretation += f"{diag_msg}\n"
    interpretation += f"{global_msg}\n"

    interpretation += "\n- A strong diagonal means tokens often reference themselves.\n"
    interpretation += (
        "- If a token's top focus is another token, that suggests it's referencing or depending on that other token.\n"
    )

    return interpretation


########################################
# 4) Gradio UI
########################################
description_md = """
# DistilBERT Self-Attention with Extra Interpretation

**Instructions:**
1. Type your text into the box.
2. Choose which **layer** of DistilBERT to visualize. (Layers range 0..5).
3. Decide how many top tokens you want listed for each token (Top-K).
4. (Optional) Check "Show Heatmap" to see the matrix. If it's too overwhelming, uncheck and just see the summary.

**Reading the Heatmap**:
- **Rows** = tokens doing the looking (focus).
- **Columns** = tokens being looked at.
- **Color intensity** = how strongly the row token focuses on the column token.

Below the heatmap, you'll see:
- A **Top-K focus** summary for each token.
- An **interpretation** bullet list that highlights interesting overall patterns.
"""

def run_demo(text, layer, top_k, show_heatmap):
    fig_dict, summary_md = analyze_attention(text, layer, top_k, show_heatmap)
    return fig_dict, summary_md

with gr.Blocks() as demo:
    gr.Markdown(description_md)

    with gr.Row():
        text_in = gr.Textbox(
            label="Enter text",
            value="Transformers handle long-range context in parallel."
        )
        layer_in = gr.Slider(
            minimum=0, maximum=5, step=1, value=5,
            label="DistilBERT Layer"
        )
        topk_in = gr.Slider(
            minimum=1, maximum=6, step=1, value=3,
            label="Top-K Focus"
        )
    show_heatmap_check = gr.Checkbox(
        label="Show Heatmap?",
        value=True
    )
    run_btn = gr.Button("Analyze Attention")

    out_plot = gr.Plot(label="Attention Heatmap")
    out_summary = gr.Markdown(label="Attention Summaries & Interpretation")

    run_btn.click(
        fn=run_demo,
        inputs=[text_in, layer_in, topk_in, show_heatmap_check],
        outputs=[out_plot, out_summary]
    )

demo.launch()