kevin1911 commited on
Commit
9f0b7a5
·
verified ·
1 Parent(s): 2d0eee9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import io
7
+ from PIL import Image
8
+
9
+ # -------- 1) Define a Tiny RNN Model (LSTM) and Vocab --------
10
+ # For demonstration, we keep the model untrained with small dimensions.
11
+
12
+ # A small toy vocab:
13
+ vocab_list = ["<PAD>", "<UNK>", "the", "cat", "dog", "was", "chasing", "and", "it", "fell", "over", "hello", "world"]
14
+ vocab_dict = {word: i for i, word in enumerate(vocab_list)}
15
+
16
+ vocab_size = len(vocab_list) # e.g., 13
17
+ embedding_dim = 8
18
+ hidden_dim = 8
19
+
20
+ # Simple LSTM model
21
+ class TinyRNN(nn.Module):
22
+ def __init__(self, vocab_size, embedding_dim, hidden_dim):
23
+ super(TinyRNN, self).__init__()
24
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
25
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
26
+
27
+ def forward(self, input_ids):
28
+ # input_ids: (batch_size, seq_len)
29
+ embeds = self.embedding(input_ids) # -> (batch_size, seq_len, embedding_dim)
30
+ outputs, (h_n, c_n) = self.lstm(embeds)
31
+ # outputs: (batch_size, seq_len, hidden_dim) -> the hidden state at *each* time step
32
+ # h_n: (1, batch_size, hidden_dim) -> final hidden state
33
+ return outputs, (h_n, c_n)
34
+
35
+ # Initialize the model (untrained, random weights)
36
+ tiny_rnn = TinyRNN(vocab_size, embedding_dim, hidden_dim)
37
+ tiny_rnn.eval() # Not training, just forward pass for visualization
38
+
39
+ # -------- 2) Tokenizer / Indexing Functions --------
40
+
41
+ def simple_tokenize(text):
42
+ # Very naive whitespace tokenizer
43
+ tokens = text.lower().split()
44
+ return tokens
45
+
46
+ def numericalize(tokens):
47
+ # Convert tokens to vocab indices, use <UNK> for OOV
48
+ indices = []
49
+ for t in tokens:
50
+ if t in vocab_dict:
51
+ indices.append(vocab_dict[t])
52
+ else:
53
+ indices.append(vocab_dict["<UNK>"])
54
+ return indices
55
+
56
+ # -------- 3) Visualization Function --------
57
+
58
+ def visualize_rnn_states(input_text):
59
+ """
60
+ 1) Tokenize input_text
61
+ 2) Convert to vocab indices
62
+ 3) Forward pass through LSTM
63
+ 4) Plot heatmap of hidden states across timesteps
64
+ """
65
+ # Tokenize & numericalize
66
+ tokens = simple_tokenize(input_text)
67
+ if len(tokens) == 0:
68
+ tokens = ["<UNK>"]
69
+ indices = numericalize(tokens)
70
+
71
+ # Convert to Tensor, shape (batch_size=1, seq_len)
72
+ input_tensor = torch.tensor(indices).unsqueeze(0) # shape (1, seq_len)
73
+
74
+ # LSTM forward
75
+ with torch.no_grad():
76
+ outputs, (h_n, c_n) = tiny_rnn(input_tensor)
77
+ # outputs shape: (1, seq_len, hidden_dim)
78
+ outputs = outputs.squeeze(0).cpu().numpy() # shape: (seq_len, hidden_dim)
79
+
80
+ # Create heatmap
81
+ seq_len, hidden_dim_ = outputs.shape
82
+ plt.figure(figsize=(6, max(3, seq_len * 0.4))) # dynamic height if many tokens
83
+
84
+ sns.heatmap(
85
+ outputs,
86
+ yticklabels=tokens,
87
+ xticklabels=[f"h{i}" for i in range(hidden_dim_)],
88
+ cmap="coolwarm",
89
+ center=0
90
+ )
91
+ plt.title("RNN Hidden States Heatmap")
92
+ plt.ylabel("Tokens")
93
+ plt.xlabel("Hidden State Dimensions (size=8)")
94
+ plt.tight_layout()
95
+
96
+ # Convert plot to an image for Gradio
97
+ buf = io.BytesIO()
98
+ plt.savefig(buf, format='png')
99
+ buf.seek(0)
100
+ plt.close()
101
+ return Image.open(buf)
102
+
103
+ # -------- 4) Gradio Interface --------
104
+
105
+ demo = gr.Interface(
106
+ fn=visualize_rnn_states,
107
+ inputs=gr.Textbox(lines=2, label="Input Text", value="The cat was chasing the dog"),
108
+ outputs="image",
109
+ title="RNN (LSTM) Hidden States Visualizer",
110
+ description=(
111
+ "Visualize how an untrained LSTM's hidden state (dim=8) changes "
112
+ "for each token in your input text. Rows=timesteps, Columns=hidden dim."
113
+ ),
114
+ )
115
+
116
+ demo.launch(debug=True, share=True)