|
import torch |
|
import torch.nn as nn |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import io |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
vocab_list = ["<PAD>", "<UNK>", "the", "cat", "dog", "was", "chasing", "and", "it", "fell", "over", "hello", "world"] |
|
vocab_dict = {word: i for i, word in enumerate(vocab_list)} |
|
|
|
vocab_size = len(vocab_list) |
|
embedding_dim = 8 |
|
hidden_dim = 8 |
|
|
|
|
|
class TinyRNN(nn.Module): |
|
def __init__(self, vocab_size, embedding_dim, hidden_dim): |
|
super(TinyRNN, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
|
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) |
|
|
|
def forward(self, input_ids): |
|
|
|
embeds = self.embedding(input_ids) |
|
outputs, (h_n, c_n) = self.lstm(embeds) |
|
|
|
|
|
return outputs, (h_n, c_n) |
|
|
|
|
|
tiny_rnn = TinyRNN(vocab_size, embedding_dim, hidden_dim) |
|
tiny_rnn.eval() |
|
|
|
|
|
|
|
def simple_tokenize(text): |
|
|
|
tokens = text.lower().split() |
|
return tokens |
|
|
|
def numericalize(tokens): |
|
|
|
indices = [] |
|
for t in tokens: |
|
if t in vocab_dict: |
|
indices.append(vocab_dict[t]) |
|
else: |
|
indices.append(vocab_dict["<UNK>"]) |
|
return indices |
|
|
|
|
|
|
|
def visualize_rnn_states(input_text): |
|
""" |
|
1) Tokenize input_text |
|
2) Convert to vocab indices |
|
3) Forward pass through LSTM |
|
4) Plot heatmap of hidden states across timesteps |
|
""" |
|
|
|
tokens = simple_tokenize(input_text) |
|
if len(tokens) == 0: |
|
tokens = ["<UNK>"] |
|
indices = numericalize(tokens) |
|
|
|
|
|
input_tensor = torch.tensor(indices).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs, (h_n, c_n) = tiny_rnn(input_tensor) |
|
|
|
outputs = outputs.squeeze(0).cpu().numpy() |
|
|
|
|
|
seq_len, hidden_dim_ = outputs.shape |
|
plt.figure(figsize=(6, max(3, seq_len * 0.4))) |
|
|
|
sns.heatmap( |
|
outputs, |
|
yticklabels=tokens, |
|
xticklabels=[f"h{i}" for i in range(hidden_dim_)], |
|
cmap="coolwarm", |
|
center=0 |
|
) |
|
plt.title("RNN Hidden States Heatmap") |
|
plt.ylabel("Tokens") |
|
plt.xlabel("Hidden State Dimensions (size=8)") |
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
plt.close() |
|
return Image.open(buf) |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=visualize_rnn_states, |
|
inputs=gr.Textbox(lines=2, label="Input Text", value="The cat was chasing the dog"), |
|
outputs="image", |
|
title="RNN (LSTM) Hidden States Visualizer", |
|
description=( |
|
"Visualize how an untrained LSTM's hidden state (dim=8) changes " |
|
"for each token in your input text. Rows=timesteps, Columns=hidden dim." |
|
), |
|
) |
|
|
|
demo.launch(debug=True, share=True) |
|
|