rdave88 commited on
Commit
3c13c5f
·
verified ·
1 Parent(s): ab500dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -33
app.py CHANGED
@@ -1,60 +1,88 @@
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import numpy as np
 
 
 
4
 
5
- def plot_3d(max_t, num_points, sentence, show_words):
6
- # Generate spiral
7
- t = np.linspace(0, max_t, num_points)
8
- x = np.sin(t)
9
- y = np.cos(t)
10
- z = t
 
 
 
 
 
 
 
 
11
 
12
- # Plot
13
  fig = plt.figure()
14
  ax = fig.add_subplot(111, projection="3d")
15
- ax.plot(x, y, z, label="3D Spiral")
16
- ax.legend()
17
 
18
- # Add token labels if requested
19
- if show_words and sentence.strip():
20
- tokens = sentence.strip().split()
21
- idxs = np.linspace(0, len(t) - 1, len(tokens), dtype=int)
22
- for i, token in zip(idxs, tokens):
23
- ax.text(x[i], y[i], z[i], token, fontsize=9, color="red")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  return fig
26
 
27
  with gr.Blocks() as demo:
28
- gr.Markdown("# 3D Hidden States Visualization")
29
  gr.Markdown(
30
  """
31
- This plot shows a **3D spiral**, generated using sine and cosine for the x/y axes
32
- and a linear sequence for the z-axis:
33
-
34
- - $x = \\sin(t)$
35
- - $y = \\cos(t)$
36
- - $z = t$
37
-
38
- Together, these equations trace a spiral around the z-axis.
39
- Think of it as an analogy for **hidden states in a neural network**,
40
- which evolve over time (the z-axis), while oscillating in complex patterns (x & y axes).
41
-
42
- ✨ Try typing your own sentence — each word will be placed along the spiral,
43
- showing how tokens could be mapped into hidden state space.
44
  """
45
  )
46
 
47
  with gr.Row():
48
- max_t = gr.Slider(5, 50, value=20, step=1, label="Spiral Length")
49
- num_points = gr.Slider(100, 2000, value=500, step=50, label="Number of Points")
50
 
51
  sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence")
52
- show_words = gr.Checkbox(label="Show Tokens", value=True)
 
53
 
54
  plot = gr.Plot()
55
 
56
  btn = gr.Button("Generate Plot")
57
- btn.click(plot_3d, inputs=[max_t, num_points, sentence, show_words], outputs=plot)
58
 
59
  if __name__ == "__main__":
60
  demo.launch()
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from sklearn.decomposition import PCA
7
 
8
+ # Load model & tokenizer once (tiny DistilBERT for speed on Spaces)
9
+ MODEL_NAME = "distilbert-base-uncased"
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+ model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True)
12
+ model.eval()
13
+
14
+ def plot_hidden_states(mode, max_tokens, sentence, show_words, focus_token):
15
+ # Tokenize
16
+ inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=max_tokens)
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+
20
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
21
+ all_layers = torch.stack(outputs.hidden_states).squeeze(1).numpy() # [num_layers+1, seq_len, hidden_dim]
22
 
 
23
  fig = plt.figure()
24
  ax = fig.add_subplot(111, projection="3d")
 
 
25
 
26
+ if mode == "Per-token trajectory":
27
+ hs = outputs.last_hidden_state.squeeze(0).numpy()
28
+ xy = PCA(n_components=2).fit_transform(hs)
29
+ x, y = xy[:, 0], xy[:, 1]
30
+ z = np.arange(len(x))
31
+
32
+ ax.plot(x, y, z, label="Hidden state trajectory")
33
+ ax.legend()
34
+
35
+ if show_words:
36
+ for i, tok in enumerate(tokens):
37
+ ax.text(x[i], y[i], z[i], tok, fontsize=9, color="red")
38
+
39
+ elif mode == "Per-layer trajectory":
40
+ if focus_token.strip() in tokens:
41
+ idx = tokens.index(focus_token.strip())
42
+ else:
43
+ idx = 0
44
+
45
+ path_layers = all_layers[:, idx, :] # [num_layers+1, hidden_dim]
46
+ xy = PCA(n_components=2).fit_transform(path_layers)
47
+ x, y = xy[:, 0], xy[:, 1]
48
+ z = np.arange(len(x))
49
 
50
+ ax.plot(x, y, z, label=f"Layer evolution for '{tokens[idx]}'")
51
+ ax.legend()
52
+
53
+ for i in range(len(z)):
54
+ ax.text(x[i], y[i], z[i], f"L{i}", fontsize=8, color="blue")
55
+
56
+ ax.set_xlabel("PC1")
57
+ ax.set_ylabel("PC2")
58
+ ax.set_zlabel("Index")
59
+ ax.set_title(mode)
60
+ plt.tight_layout()
61
  return fig
62
 
63
  with gr.Blocks() as demo:
64
+ gr.Markdown("# 🌀 3D Hidden States Explorer")
65
  gr.Markdown(
66
  """
67
+ Visualize **transformer hidden states** in 3D.
68
+ Choose between two modes:
69
+ - **Per-token trajectory:** how tokens in a sentence evolve in the final layer.
70
+ - **Per-layer trajectory:** how one token moves across all layers.
 
 
 
 
 
 
 
 
 
71
  """
72
  )
73
 
74
  with gr.Row():
75
+ mode = gr.Radio(["Per-token trajectory", "Per-layer trajectory"], value="Per-token trajectory", label="Mode")
76
+ max_tokens = gr.Slider(10, 64, value=32, step=1, label="Max Tokens")
77
 
78
  sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence")
79
+ show_words = gr.Checkbox(label="Show Tokens (per-token mode)", value=True)
80
+ focus_token = gr.Textbox(value="hidden", label="Focus Token (per-layer mode)")
81
 
82
  plot = gr.Plot()
83
 
84
  btn = gr.Button("Generate Plot")
85
+ btn.click(plot_hidden_states, inputs=[mode, max_tokens, sentence, show_words, focus_token], outputs=plot)
86
 
87
  if __name__ == "__main__":
88
  demo.launch()