File size: 3,745 Bytes
9f0b7a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
import torch.nn as nn
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import io
from PIL import Image
# -------- 1) Define a Tiny RNN Model (LSTM) and Vocab --------
# For demonstration, we keep the model untrained with small dimensions.
# A small toy vocab:
vocab_list = ["<PAD>", "<UNK>", "the", "cat", "dog", "was", "chasing", "and", "it", "fell", "over", "hello", "world"]
vocab_dict = {word: i for i, word in enumerate(vocab_list)}
vocab_size = len(vocab_list) # e.g., 13
embedding_dim = 8
hidden_dim = 8
# Simple LSTM model
class TinyRNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(TinyRNN, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
def forward(self, input_ids):
# input_ids: (batch_size, seq_len)
embeds = self.embedding(input_ids) # -> (batch_size, seq_len, embedding_dim)
outputs, (h_n, c_n) = self.lstm(embeds)
# outputs: (batch_size, seq_len, hidden_dim) -> the hidden state at *each* time step
# h_n: (1, batch_size, hidden_dim) -> final hidden state
return outputs, (h_n, c_n)
# Initialize the model (untrained, random weights)
tiny_rnn = TinyRNN(vocab_size, embedding_dim, hidden_dim)
tiny_rnn.eval() # Not training, just forward pass for visualization
# -------- 2) Tokenizer / Indexing Functions --------
def simple_tokenize(text):
# Very naive whitespace tokenizer
tokens = text.lower().split()
return tokens
def numericalize(tokens):
# Convert tokens to vocab indices, use <UNK> for OOV
indices = []
for t in tokens:
if t in vocab_dict:
indices.append(vocab_dict[t])
else:
indices.append(vocab_dict["<UNK>"])
return indices
# -------- 3) Visualization Function --------
def visualize_rnn_states(input_text):
"""
1) Tokenize input_text
2) Convert to vocab indices
3) Forward pass through LSTM
4) Plot heatmap of hidden states across timesteps
"""
# Tokenize & numericalize
tokens = simple_tokenize(input_text)
if len(tokens) == 0:
tokens = ["<UNK>"]
indices = numericalize(tokens)
# Convert to Tensor, shape (batch_size=1, seq_len)
input_tensor = torch.tensor(indices).unsqueeze(0) # shape (1, seq_len)
# LSTM forward
with torch.no_grad():
outputs, (h_n, c_n) = tiny_rnn(input_tensor)
# outputs shape: (1, seq_len, hidden_dim)
outputs = outputs.squeeze(0).cpu().numpy() # shape: (seq_len, hidden_dim)
# Create heatmap
seq_len, hidden_dim_ = outputs.shape
plt.figure(figsize=(6, max(3, seq_len * 0.4))) # dynamic height if many tokens
sns.heatmap(
outputs,
yticklabels=tokens,
xticklabels=[f"h{i}" for i in range(hidden_dim_)],
cmap="coolwarm",
center=0
)
plt.title("RNN Hidden States Heatmap")
plt.ylabel("Tokens")
plt.xlabel("Hidden State Dimensions (size=8)")
plt.tight_layout()
# Convert plot to an image for Gradio
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return Image.open(buf)
# -------- 4) Gradio Interface --------
demo = gr.Interface(
fn=visualize_rnn_states,
inputs=gr.Textbox(lines=2, label="Input Text", value="The cat was chasing the dog"),
outputs="image",
title="RNN (LSTM) Hidden States Visualizer",
description=(
"Visualize how an untrained LSTM's hidden state (dim=8) changes "
"for each token in your input text. Rows=timesteps, Columns=hidden dim."
),
)
demo.launch(debug=True, share=True)
|