Spaces:
No application file
No application file
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import tempfile | |
# Helper function to plot attention heatmap | |
def plot_attention(attn, tokens, layer=0, head=0): | |
plt.figure(figsize=(10, 8)) | |
sns.heatmap(attn[layer][head], xticklabels=tokens, yticklabels=tokens, cmap="viridis") | |
plt.title(f"Attention Map - Layer {layer}, Head {head}") | |
plt.xlabel("Keys") | |
plt.ylabel("Queries") | |
tmp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
plt.savefig(tmp_file.name, bbox_inches='tight') | |
plt.close() | |
return tmp_file.name | |
# Main logic | |
def process_input(text, model_name, layer, head): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) | |
model.eval() | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
attentions = [a.squeeze(0).cpu().numpy() for a in outputs.attentions] # layers, heads, seq_len, seq_len | |
hidden_states = [h.squeeze(0).cpu().numpy().tolist() for h in outputs.hidden_states] # layers, seq_len, dim | |
attn_img_path = plot_attention(attentions, tokens, layer=layer, head=head) | |
return tokens, hidden_states[layer], attn_img_path | |
# Gradio interface function | |
def gradio_interface(text, model_name, layer, head): | |
tokens, hidden, attn_img = process_input(text, model_name, layer, head) | |
return tokens, hidden, attn_img | |
# Available transformer models | |
model_choices = [ | |
"bert-base-uncased", | |
"distilbert-base-uncased", | |
"roberta-base", | |
"gpt2" | |
] | |
# Launch the Gradio app | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Input Text", placeholder="Type a sentence here..."), | |
gr.Dropdown(label="Model", choices=model_choices, value="bert-base-uncased"), | |
gr.Slider(label="Attention Layer", minimum=0, maximum=11, step=1, value=0), | |
gr.Slider(label="Attention Head", minimum=0, maximum=11, step=1, value=0), | |
], | |
outputs=[ | |
gr.JSON(label="Tokens"), | |
gr.Dataframe(label="Hidden States (Selected Layer)"), | |
gr.Image(label="Attention Map") | |
], | |
title="π Transformer Visualizer", | |
description="Visualize tokenization, attention maps, and hidden states of popular Hugging Face Transformer models.", | |
) | |
if __name__ == "__main__": | |
interface.launch() | |