iqramukhtiar commited on
Commit
e1243ba
·
verified ·
1 Parent(s): b0cf356

Upload 2 files

Browse files
Files changed (2) hide show
  1. App.py +71 -0
  2. requirements.txt +5 -0
App.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import tempfile
8
+
9
+ # Helper function to plot attention heatmap
10
+ def plot_attention(attn, tokens, layer=0, head=0):
11
+ plt.figure(figsize=(10, 8))
12
+ sns.heatmap(attn[layer][head], xticklabels=tokens, yticklabels=tokens, cmap="viridis")
13
+ plt.title(f"Attention Map - Layer {layer}, Head {head}")
14
+ plt.xlabel("Keys")
15
+ plt.ylabel("Queries")
16
+ tmp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
17
+ plt.savefig(tmp_file.name, bbox_inches='tight')
18
+ plt.close()
19
+ return tmp_file.name
20
+
21
+ # Main logic
22
+ def process_input(text, model_name, layer, head):
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ model = AutoModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
25
+ model.eval()
26
+
27
+ inputs = tokenizer(text, return_tensors="pt")
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+
31
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
32
+ attentions = [a.squeeze(0).cpu().numpy() for a in outputs.attentions] # layers, heads, seq_len, seq_len
33
+ hidden_states = [h.squeeze(0).cpu().numpy().tolist() for h in outputs.hidden_states] # layers, seq_len, dim
34
+
35
+ attn_img_path = plot_attention(attentions, tokens, layer=layer, head=head)
36
+
37
+ return tokens, hidden_states[layer], attn_img_path
38
+
39
+ # Gradio interface function
40
+ def gradio_interface(text, model_name, layer, head):
41
+ tokens, hidden, attn_img = process_input(text, model_name, layer, head)
42
+ return tokens, hidden, attn_img
43
+
44
+ # Available transformer models
45
+ model_choices = [
46
+ "bert-base-uncased",
47
+ "distilbert-base-uncased",
48
+ "roberta-base",
49
+ "gpt2"
50
+ ]
51
+
52
+ # Launch the Gradio app
53
+ interface = gr.Interface(
54
+ fn=gradio_interface,
55
+ inputs=[
56
+ gr.Textbox(label="Input Text", placeholder="Type a sentence here..."),
57
+ gr.Dropdown(label="Model", choices=model_choices, value="bert-base-uncased"),
58
+ gr.Slider(label="Attention Layer", minimum=0, maximum=11, step=1, value=0),
59
+ gr.Slider(label="Attention Head", minimum=0, maximum=11, step=1, value=0),
60
+ ],
61
+ outputs=[
62
+ gr.JSON(label="Tokens"),
63
+ gr.Dataframe(label="Hidden States (Selected Layer)"),
64
+ gr.Image(label="Attention Map")
65
+ ],
66
+ title="🔍 Transformer Visualizer",
67
+ description="Visualize tokenization, attention maps, and hidden states of popular Hugging Face Transformer models.",
68
+ )
69
+
70
+ if __name__ == "__main__":
71
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ gradio
4
+ matplotlib
5
+ seaborn