File size: 5,770 Bytes
0b23237
1f47b32
0b23237
 
1f47b32
0b23237
e93333c
0b23237
 
 
 
 
1f47b32
e93333c
1f47b32
e93333c
 
 
 
 
1f47b32
0b23237
 
 
e93333c
 
 
 
 
 
0b23237
e93333c
0b23237
 
da6e24a
e93333c
 
 
 
 
 
 
 
 
 
da6e24a
e93333c
 
 
 
 
da6e24a
e93333c
da6e24a
e93333c
da6e24a
e93333c
 
 
da6e24a
e93333c
 
 
 
 
 
 
 
 
da6e24a
e93333c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f47b32
0b23237
e93333c
 
 
 
 
 
 
da6e24a
0b23237
e93333c
0b23237
e93333c
 
 
0b23237
e93333c
 
 
da6e24a
e93333c
 
 
da6e24a
e93333c
 
 
c0fdcd2
e93333c
 
1f47b32
 
e93333c
 
0b23237
 
e93333c
0b23237
e93333c
da6e24a
e93333c
 
 
 
 
 
 
 
da6e24a
1f47b32
e93333c
 
 
 
 
 
 
1f47b32
e93333c
 
 
 
 
 
1f47b32
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import gradio as gr
import plotly.express as px
from transformers import AutoModel, AutoTokenizer

########################################
# Load Transformer (DistilBERT) with attention
########################################
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
model.eval()

def visualize_attention(text, layer=5):
    """
    1. Tokenize input text.
    2. Run DistilBERT forward pass to get attention matrices.
    3. Pick a layer (0..5) and average across attention heads.
    4. Generate a heatmap (Plotly) of shape (seq_len x seq_len).
    5. Label axes with tokens (Query vs. Key).
    """
    with torch.no_grad():
        inputs = tokenizer.encode_plus(text, return_tensors="pt")
        outputs = model(**inputs)
        all_attentions = outputs.attentions
        # DistilBERT has 6 layers => valid layer indices: 0..5
        attn_layer = all_attentions[layer].mean(dim=1)  # shape: (1, seq_len, seq_len)

    # Convert to numpy for plotting
    attn_matrix = attn_layer[0].cpu().numpy()

    # Get tokens (including special tokens)
    input_ids = inputs["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # Build a Plotly heatmap
    fig = px.imshow(
        attn_matrix,
        x=tokens,
        y=tokens,
        labels={"x": "Key (Being Attended to)", "y": "Query (Focusing)"},
        color_continuous_scale="Blues",
        title=f"DistilBERT Attention (Layer {layer})"
    )
    fig.update_xaxes(side="top")

    # Add tooltip
    fig.update_traces(
        hovertemplate="Query: %{y}<br>Key: %{x}<br>Attention Weight: %{z:.3f}"
    )
    fig.update_layout(coloraxis_colorbar=dict(title="Attention Weight"))

    return fig

def interpret_token_attention(text, token_index=0, layer=5):
    """
    Provides a textual explanation for why a particular token (Query) attends
    to other tokens in the input, highlighting the top 2 or 3 tokens
    it focuses on.
    """
    with torch.no_grad():
        inputs = tokenizer.encode_plus(text, return_tensors="pt")
        outputs = model(**inputs)
        all_attentions = outputs.attentions
        attn_layer = all_attentions[layer].mean(dim=1)  # shape: (1, seq_len, seq_len)

    # Get tokens
    input_ids = inputs["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # Safety check for token_index
    if token_index < 0 or token_index >= len(tokens):
        return "Invalid token index. Please choose a valid token index."

    # Extract the row corresponding to our Query token
    query_attn = attn_layer[0, token_index, :].cpu().numpy()  # shape: (seq_len,)

    # Sort tokens by attention weight (descending)
    sorted_indices = query_attn.argsort()[::-1]
    top_indices = sorted_indices[:3]  # Grab top 3
    top_tokens = [tokens[i] for i in top_indices]
    top_weights = [query_attn[i] for i in top_indices]

    # Build an explanation
    query_token_str = tokens[token_index]
    explanation = (
        f"**You chose token index {token_index}, which is '{query_token_str}'.**\n\n"
        "In Transformers, each token is converted into Query, Key, and Value vectors:\n"
        "- **Query** = What this token is looking for\n"
        "- **Key**   = What another token has to offer\n"
        "- **Value** = The actual information from that token\n\n"
        f"As a Query, '{query_token_str}' attends most strongly to:\n"
    )

    for t, w in zip(top_tokens, top_weights):
        explanation += f"- **{t}** with attention weight ~ {w:.3f}\n"

    explanation += (
        "\nA higher attention weight indicates that this Query token is 'looking at' or "
        "focusing on that Key token more strongly, likely because it finds the Key token "
        "relevant to its meaning or context."
    )

    return explanation

# Short explanation text for the UI
description_text = """
## Understanding Transformer Self-Attention

- **Rows = Query token** (the token doing the 'looking').  
- **Columns = Key token** (the token being 'looked at').  
- Darker color = stronger attention weight.

**Transformers** process all tokens in **parallel**, allowing any token to attend to any other token in the sentence.
This makes it easier for the model to capture long-distance relationships.
"""

########################################
# Gradio Interface
########################################
with gr.Blocks(css="footer{display:none !important}") as demo:
    gr.Markdown("# Transformer Self-Attention Visualization (DistilBERT)")
    gr.Markdown(description_text)

    with gr.Row():
        text_input = gr.Textbox(
            label="Enter a sentence",
            value="Transformers handle long-range context in parallel."
        )
        layer_slider = gr.Slider(
            minimum=0, maximum=5, step=1, value=5,
            label="DistilBERT Layer (0=lowest, 5=highest)"
        )
    output_plot = gr.Plot(label="Attention Heatmap")

    # Visualization Button
    visualize_button = gr.Button("Visualize Attention")
    visualize_button.click(
        fn=visualize_attention,
        inputs=[text_input, layer_slider],
        outputs=output_plot
    )

    # Dropdown (or Slider) to choose a token index for interpretation
    token_index = gr.Number(
        label="Choose a token index to interpret (0-based)",
        value=0
    )

    interpretation_output = gr.Markdown(label="Interpretation")

    # Interpretation Button
    interpret_button = gr.Button("Explain This Token's Attention")
    interpret_button.click(
        fn=interpret_token_attention,
        inputs=[text_input, token_index, layer_slider],
        outputs=interpretation_output
    )

demo.launch()