rdave88 commited on
Commit
ddf6886
·
1 Parent(s): 1b07353

Switch to Gradio version

Browse files
Files changed (1) hide show
  1. app.py +40 -13
app.py CHANGED
@@ -2,8 +2,8 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
 
5
- def plot_3d(max_t, num_points):
6
- # Generate sequence
7
  t = np.linspace(0, max_t, num_points)
8
  x = np.sin(t)
9
  y = np.cos(t)
@@ -15,19 +15,46 @@ def plot_3d(max_t, num_points):
15
  ax.plot(x, y, z, label="3D Spiral")
16
  ax.legend()
17
 
 
 
 
 
 
 
 
18
  return fig
19
 
20
- # Build Gradio interface with sliders
21
- demo = gr.Interface(
22
- fn=plot_3d,
23
- inputs=[
24
- gr.Slider(5, 50, value=20, step=1, label="Spiral Length"),
25
- gr.Slider(100, 2000, value=500, step=50, label="Number of Points")
26
- ],
27
- outputs="plot",
28
- title="3D Hidden States Visualization",
29
- description="Adjust the spiral length and number of points to explore the 3D curve."
30
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  if __name__ == "__main__":
33
  demo.launch()
 
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
 
5
+ def plot_3d(max_t, num_points, sentence, show_words):
6
+ # Generate spiral
7
  t = np.linspace(0, max_t, num_points)
8
  x = np.sin(t)
9
  y = np.cos(t)
 
15
  ax.plot(x, y, z, label="3D Spiral")
16
  ax.legend()
17
 
18
+ # Add token labels if requested
19
+ if show_words and sentence.strip():
20
+ tokens = sentence.strip().split()
21
+ idxs = np.linspace(0, len(t) - 1, len(tokens), dtype=int)
22
+ for i, token in zip(idxs, tokens):
23
+ ax.text(x[i], y[i], z[i], token, fontsize=9, color="red")
24
+
25
  return fig
26
 
27
+ with gr.Blocks() as demo:
28
+ gr.Markdown("# 3D Hidden States Visualization")
29
+ gr.Markdown(
30
+ """
31
+ This plot shows a **3D spiral**, generated using sine and cosine for the x/y axes
32
+ and a linear sequence for the z-axis:
33
+
34
+ - $x = \\sin(t)$
35
+ - $y = \\cos(t)$
36
+ - $z = t$
37
+
38
+ Together, these equations trace a spiral around the z-axis.
39
+ Think of it as an analogy for **hidden states in a neural network**,
40
+ which evolve over time (the z-axis), while oscillating in complex patterns (x & y axes).
41
+
42
+ ✨ Try typing your own sentence — each word will be placed along the spiral,
43
+ showing how tokens could be mapped into hidden state space.
44
+ """
45
+ )
46
+
47
+ with gr.Row():
48
+ max_t = gr.Slider(5, 50, value=20, step=1, label="Spiral Length")
49
+ num_points = gr.Slider(100, 2000, value=500, step=50, label="Number of Points")
50
+
51
+ sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence")
52
+ show_words = gr.Checkbox(label="Show Tokens", value=True)
53
+
54
+ plot = gr.Plot()
55
+
56
+ btn = gr.Button("Generate Plot")
57
+ btn.click(plot_3d, inputs=[max_t, num_points, sentence, show_words], outputs=plot)
58
 
59
  if __name__ == "__main__":
60
  demo.launch()