import spaces import gradio as gr import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer # Load model and tokenizer model = GPT2LMHeadModel.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") @spaces.GPU def get_next_token_probs(text, top_k=5): # Handle empty input if not text.strip(): return [""] * top_k # Tokenize input input_ids = tokenizer.encode(text, return_tensors="pt") # Get predictions with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits # Get probabilities for next token next_token_logits = logits[0, -1, :] next_token_probs = torch.softmax(next_token_logits, dim=0) # Get top-k tokens and their probabilities topk_probs, topk_indices = torch.topk(next_token_probs, top_k) topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices] # Format the results as strings formatted_results = [] for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)): # Format probability as percentage with 1 decimal place prob_percent = f"{prob.item()*100:.1f}%" # Clean up token display (remove leading space if present) display_token = token.replace(" ", "␣") # Replace space with visible space symbol # Format the output string formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})") return formatted_results # Create custom CSS custom_css = """ .token-box { margin-top: 10px; padding: 15px; border-radius: 8px; background-color: #f7f7f7; font-family: monospace; font-size: 16px; } .token-item { margin: 8px 0; padding: 8px; background-color: white; border-left: 4px solid #2c8ecb; border-radius: 4px; } footer {display: none} """ # Create minimal interface with gr.Blocks(css=custom_css) as demo: gr.Markdown("### GPT-2 Next Token Predictor") # Input textbox input_text = gr.Textbox( label="Text Input", placeholder="Type here and watch predictions update...", value="The weather tomorrow will be" ) # Container for token displays with gr.Box(elem_classes=["token-box"]): gr.Markdown("##### Most likely next tokens:") token_outputs = [gr.Markdown(elem_classes=["token-item"]) for _ in range(5)] # Function to update tokens in real-time def update_tokens(text): return get_next_token_probs(text) # Set up the live update input_text.change( fn=update_tokens, inputs=input_text, outputs=token_outputs ) # Initialize with default text demo.load( fn=update_tokens, inputs=input_text, outputs=token_outputs ) # Launch the app demo.launch()