import gradio as gr import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import matplotlib.pyplot as plt import seaborn as sns import os # Load model and tokenizer model_name = "gpt2" model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) def get_token_probabilities(text, top_k=10): # Tokenize the input text input_ids = tokenizer.encode(text, return_tensors="pt") # Get the last token's position last_token_position = input_ids.shape[1] - 1 # Get model predictions with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits # Get probabilities for the next token after the last token next_token_logits = logits[0, last_token_position, :] next_token_probs = torch.softmax(next_token_logits, dim=0) # Get top k most likely tokens topk_probs, topk_indices = torch.topk(next_token_probs, top_k) # Convert to numpy for easier handling topk_probs = topk_probs.numpy() topk_indices = topk_indices.numpy() # Decode tokens topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices] # Create a plot plt.figure(figsize=(10, 6)) sns.barplot(x=topk_probs, y=topk_tokens) plt.title(f"Top {top_k} token probabilities after: '{text}'") plt.xlabel("Probability") plt.ylabel("Tokens") plt.tight_layout() # Ensure temp directory exists os.makedirs("tmp", exist_ok=True) # Save the plot to a file in the temp directory plot_path = os.path.join("tmp", "token_probabilities.png") plt.savefig(plot_path) plt.close() return plot_path, dict(zip(topk_tokens, topk_probs.tolist())) with gr.Blocks() as demo: gr.Markdown("# GPT-2 Next Token Probability Visualizer") gr.Markdown("Enter text and see the probabilities of possible next tokens.") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Input Text", placeholder="Type some text here...", value="Hello, my name is" ) top_k = gr.Slider( minimum=5, maximum=20, value=10, step=1, label="Number of top tokens to show" ) btn = gr.Button("Generate Probabilities") with gr.Column(): output_image = gr.Image(label="Probability Distribution") output_table = gr.JSON(label="Token Probabilities") btn.click( fn=get_token_probabilities, inputs=[input_text, top_k], outputs=[output_image, output_table] ) gr.Examples( examples=[ ["Hello, my name is", 10], ["The capital of France is", 10], ["Once upon a time", 10], ["The best way to learn is to", 10] ], inputs=[input_text, top_k], ) # Launch the app demo.launch()