import gradio as gr import torch from transformers import AutoTokenizer, AutoModel import numpy as np import matplotlib.pyplot as plt import seaborn as sns import tempfile # Helper function to plot attention heatmap def plot_attention(attn, tokens, layer=0, head=0): plt.figure(figsize=(10, 8)) sns.heatmap(attn[layer][head], xticklabels=tokens, yticklabels=tokens, cmap="viridis") plt.title(f"Attention Map - Layer {layer}, Head {head}") plt.xlabel("Keys") plt.ylabel("Queries") tmp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) plt.savefig(tmp_file.name, bbox_inches='tight') plt.close() return tmp_file.name # Main logic def process_input(text, model_name, layer, head): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) model.eval() inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) attentions = [a.squeeze(0).cpu().numpy() for a in outputs.attentions] # layers, heads, seq_len, seq_len hidden_states = [h.squeeze(0).cpu().numpy().tolist() for h in outputs.hidden_states] # layers, seq_len, dim attn_img_path = plot_attention(attentions, tokens, layer=layer, head=head) return tokens, hidden_states[layer], attn_img_path # Gradio interface function def gradio_interface(text, model_name, layer, head): tokens, hidden, attn_img = process_input(text, model_name, layer, head) return tokens, hidden, attn_img # Available transformer models model_choices = [ "bert-base-uncased", "distilbert-base-uncased", "roberta-base", "gpt2" ] # Launch the Gradio app interface = gr.Interface( fn=gradio_interface, inputs=[ gr.Textbox(label="Input Text", placeholder="Type a sentence here..."), gr.Dropdown(label="Model", choices=model_choices, value="bert-base-uncased"), gr.Slider(label="Attention Layer", minimum=0, maximum=11, step=1, value=0), gr.Slider(label="Attention Head", minimum=0, maximum=11, step=1, value=0), ], outputs=[ gr.JSON(label="Tokens"), gr.Dataframe(label="Hidden States (Selected Layer)"), gr.Image(label="Attention Map") ], title="🔍 Transformer Visualizer", description="Visualize tokenization, attention maps, and hidden states of popular Hugging Face Transformer models.", ) if __name__ == "__main__": interface.launch()