File size: 3,002 Bytes
823cebb
 
 
 
 
2faad0e
823cebb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2faad0e
 
 
 
 
 
823cebb
 
2faad0e
823cebb
2faad0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823cebb
2faad0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823cebb
2faad0e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Load model and tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

def get_token_probabilities(text, top_k=10):
    # Tokenize the input text
    input_ids = tokenizer.encode(text, return_tensors="pt")
    
    # Get the last token's position
    last_token_position = input_ids.shape[1] - 1
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    
    # Get probabilities for the next token after the last token
    next_token_logits = logits[0, last_token_position, :]
    next_token_probs = torch.softmax(next_token_logits, dim=0)
    
    # Get top k most likely tokens
    topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
    
    # Convert to numpy for easier handling
    topk_probs = topk_probs.numpy()
    topk_indices = topk_indices.numpy()
    
    # Decode tokens
    topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
    
    # Create a plot
    plt.figure(figsize=(10, 6))
    sns.barplot(x=topk_probs, y=topk_tokens)
    plt.title(f"Top {top_k} token probabilities after: '{text}'")
    plt.xlabel("Probability")
    plt.ylabel("Tokens")
    plt.tight_layout()
    
    # Ensure temp directory exists
    os.makedirs("tmp", exist_ok=True)
    
    # Save the plot to a file in the temp directory
    plot_path = os.path.join("tmp", "token_probabilities.png")
    plt.savefig(plot_path)
    plt.close()
    
    return plot_path, dict(zip(topk_tokens, topk_probs.tolist()))

with gr.Blocks() as demo:
    gr.Markdown("# GPT-2 Next Token Probability Visualizer")
    gr.Markdown("Enter text and see the probabilities of possible next tokens.")
    
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                label="Input Text",
                placeholder="Type some text here...",
                value="Hello, my name is"
            )
            top_k = gr.Slider(
                minimum=5, 
                maximum=20, 
                value=10, 
                step=1, 
                label="Number of top tokens to show"
            )
            btn = gr.Button("Generate Probabilities")
        
        with gr.Column():
            output_image = gr.Image(label="Probability Distribution")
            output_table = gr.JSON(label="Token Probabilities")
    
    btn.click(
        fn=get_token_probabilities,
        inputs=[input_text, top_k],
        outputs=[output_image, output_table]
    )
    
    gr.Examples(
        examples=[
            ["Hello, my name is", 10],
            ["The capital of France is", 10],
            ["Once upon a time", 10],
            ["The best way to learn is to", 10]
        ],
        inputs=[input_text, top_k],
    )

# Launch the app
demo.launch()