File size: 3,745 Bytes
9f0b7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import io
from PIL import Image

# -------- 1) Define a Tiny RNN Model (LSTM) and Vocab --------
# For demonstration, we keep the model untrained with small dimensions.

# A small toy vocab:
vocab_list = ["<PAD>", "<UNK>", "the", "cat", "dog", "was", "chasing", "and", "it", "fell", "over", "hello", "world"]
vocab_dict = {word: i for i, word in enumerate(vocab_list)}

vocab_size = len(vocab_list)     # e.g., 13
embedding_dim = 8
hidden_dim = 8

# Simple LSTM model
class TinyRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(TinyRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, input_ids):
        # input_ids: (batch_size, seq_len)
        embeds = self.embedding(input_ids)    # -> (batch_size, seq_len, embedding_dim)
        outputs, (h_n, c_n) = self.lstm(embeds)
        # outputs: (batch_size, seq_len, hidden_dim) -> the hidden state at *each* time step
        # h_n:     (1, batch_size, hidden_dim)       -> final hidden state
        return outputs, (h_n, c_n)

# Initialize the model (untrained, random weights)
tiny_rnn = TinyRNN(vocab_size, embedding_dim, hidden_dim)
tiny_rnn.eval()  # Not training, just forward pass for visualization

# -------- 2) Tokenizer / Indexing Functions --------

def simple_tokenize(text):
    # Very naive whitespace tokenizer
    tokens = text.lower().split()
    return tokens

def numericalize(tokens):
    # Convert tokens to vocab indices, use <UNK> for OOV
    indices = []
    for t in tokens:
        if t in vocab_dict:
            indices.append(vocab_dict[t])
        else:
            indices.append(vocab_dict["<UNK>"])
    return indices

# -------- 3) Visualization Function --------

def visualize_rnn_states(input_text):
    """
    1) Tokenize input_text
    2) Convert to vocab indices
    3) Forward pass through LSTM
    4) Plot heatmap of hidden states across timesteps
    """
    # Tokenize & numericalize
    tokens = simple_tokenize(input_text)
    if len(tokens) == 0:
        tokens = ["<UNK>"]
    indices = numericalize(tokens)

    # Convert to Tensor, shape (batch_size=1, seq_len)
    input_tensor = torch.tensor(indices).unsqueeze(0)  # shape (1, seq_len)

    # LSTM forward
    with torch.no_grad():
        outputs, (h_n, c_n) = tiny_rnn(input_tensor)
        # outputs shape: (1, seq_len, hidden_dim)
        outputs = outputs.squeeze(0).cpu().numpy()  # shape: (seq_len, hidden_dim)

    # Create heatmap
    seq_len, hidden_dim_ = outputs.shape
    plt.figure(figsize=(6, max(3, seq_len * 0.4)))  # dynamic height if many tokens

    sns.heatmap(
        outputs,
        yticklabels=tokens,
        xticklabels=[f"h{i}" for i in range(hidden_dim_)],
        cmap="coolwarm",
        center=0
    )
    plt.title("RNN Hidden States Heatmap")
    plt.ylabel("Tokens")
    plt.xlabel("Hidden State Dimensions (size=8)")
    plt.tight_layout()

    # Convert plot to an image for Gradio
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    plt.close()
    return Image.open(buf)

# -------- 4) Gradio Interface --------

demo = gr.Interface(
    fn=visualize_rnn_states,
    inputs=gr.Textbox(lines=2, label="Input Text", value="The cat was chasing the dog"),
    outputs="image",
    title="RNN (LSTM) Hidden States Visualizer",
    description=(
        "Visualize how an untrained LSTM's hidden state (dim=8) changes "
        "for each token in your input text. Rows=timesteps, Columns=hidden dim."
    ),
)

demo.launch(debug=True, share=True)