RNN / app.py
kevin1911's picture
Create app.py
9f0b7a5 verified
raw
history blame
3.75 kB
import torch
import torch.nn as nn
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import io
from PIL import Image
# -------- 1) Define a Tiny RNN Model (LSTM) and Vocab --------
# For demonstration, we keep the model untrained with small dimensions.
# A small toy vocab:
vocab_list = ["<PAD>", "<UNK>", "the", "cat", "dog", "was", "chasing", "and", "it", "fell", "over", "hello", "world"]
vocab_dict = {word: i for i, word in enumerate(vocab_list)}
vocab_size = len(vocab_list) # e.g., 13
embedding_dim = 8
hidden_dim = 8
# Simple LSTM model
class TinyRNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(TinyRNN, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
def forward(self, input_ids):
# input_ids: (batch_size, seq_len)
embeds = self.embedding(input_ids) # -> (batch_size, seq_len, embedding_dim)
outputs, (h_n, c_n) = self.lstm(embeds)
# outputs: (batch_size, seq_len, hidden_dim) -> the hidden state at *each* time step
# h_n: (1, batch_size, hidden_dim) -> final hidden state
return outputs, (h_n, c_n)
# Initialize the model (untrained, random weights)
tiny_rnn = TinyRNN(vocab_size, embedding_dim, hidden_dim)
tiny_rnn.eval() # Not training, just forward pass for visualization
# -------- 2) Tokenizer / Indexing Functions --------
def simple_tokenize(text):
# Very naive whitespace tokenizer
tokens = text.lower().split()
return tokens
def numericalize(tokens):
# Convert tokens to vocab indices, use <UNK> for OOV
indices = []
for t in tokens:
if t in vocab_dict:
indices.append(vocab_dict[t])
else:
indices.append(vocab_dict["<UNK>"])
return indices
# -------- 3) Visualization Function --------
def visualize_rnn_states(input_text):
"""
1) Tokenize input_text
2) Convert to vocab indices
3) Forward pass through LSTM
4) Plot heatmap of hidden states across timesteps
"""
# Tokenize & numericalize
tokens = simple_tokenize(input_text)
if len(tokens) == 0:
tokens = ["<UNK>"]
indices = numericalize(tokens)
# Convert to Tensor, shape (batch_size=1, seq_len)
input_tensor = torch.tensor(indices).unsqueeze(0) # shape (1, seq_len)
# LSTM forward
with torch.no_grad():
outputs, (h_n, c_n) = tiny_rnn(input_tensor)
# outputs shape: (1, seq_len, hidden_dim)
outputs = outputs.squeeze(0).cpu().numpy() # shape: (seq_len, hidden_dim)
# Create heatmap
seq_len, hidden_dim_ = outputs.shape
plt.figure(figsize=(6, max(3, seq_len * 0.4))) # dynamic height if many tokens
sns.heatmap(
outputs,
yticklabels=tokens,
xticklabels=[f"h{i}" for i in range(hidden_dim_)],
cmap="coolwarm",
center=0
)
plt.title("RNN Hidden States Heatmap")
plt.ylabel("Tokens")
plt.xlabel("Hidden State Dimensions (size=8)")
plt.tight_layout()
# Convert plot to an image for Gradio
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return Image.open(buf)
# -------- 4) Gradio Interface --------
demo = gr.Interface(
fn=visualize_rnn_states,
inputs=gr.Textbox(lines=2, label="Input Text", value="The cat was chasing the dog"),
outputs="image",
title="RNN (LSTM) Hidden States Visualizer",
description=(
"Visualize how an untrained LSTM's hidden state (dim=8) changes "
"for each token in your input text. Rows=timesteps, Columns=hidden dim."
),
)
demo.launch(debug=True, share=True)