|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
def plot_3d(max_t, num_points, sentence, show_words): |
|
|
|
t = np.linspace(0, max_t, num_points) |
|
x = np.sin(t) |
|
y = np.cos(t) |
|
z = t |
|
|
|
|
|
fig = plt.figure() |
|
ax = fig.add_subplot(111, projection="3d") |
|
ax.plot(x, y, z, label="3D Spiral") |
|
ax.legend() |
|
|
|
|
|
if show_words and sentence.strip(): |
|
tokens = sentence.strip().split() |
|
idxs = np.linspace(0, len(t) - 1, len(tokens), dtype=int) |
|
for i, token in zip(idxs, tokens): |
|
ax.text(x[i], y[i], z[i], token, fontsize=9, color="red") |
|
|
|
return fig |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 3D Hidden States Visualization") |
|
gr.Markdown( |
|
""" |
|
This plot shows a **3D spiral**, generated using sine and cosine for the x/y axes |
|
and a linear sequence for the z-axis: |
|
|
|
- $x = \\sin(t)$ |
|
- $y = \\cos(t)$ |
|
- $z = t$ |
|
|
|
Together, these equations trace a spiral around the z-axis. |
|
Think of it as an analogy for **hidden states in a neural network**, |
|
which evolve over time (the z-axis), while oscillating in complex patterns (x & y axes). |
|
|
|
✨ Try typing your own sentence — each word will be placed along the spiral, |
|
showing how tokens could be mapped into hidden state space. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
max_t = gr.Slider(5, 50, value=20, step=1, label="Spiral Length") |
|
num_points = gr.Slider(100, 2000, value=500, step=50, label="Number of Points") |
|
|
|
sentence = gr.Textbox(value="I love hidden states in transformers", label="Sentence") |
|
show_words = gr.Checkbox(label="Show Tokens", value=True) |
|
|
|
plot = gr.Plot() |
|
|
|
btn = gr.Button("Generate Plot") |
|
btn.click(plot_3d, inputs=[max_t, num_points, sentence, show_words], outputs=plot) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|