davanstrien HF Staff commited on
Commit
2faad0e
·
verified ·
1 Parent(s): 823cebb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -48
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
- import numpy as np
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
 
7
 
8
  # Load model and tokenizer
9
  model_name = "gpt2"
@@ -44,55 +44,55 @@ def get_token_probabilities(text, top_k=10):
44
  plt.ylabel("Tokens")
45
  plt.tight_layout()
46
 
47
- # Save the plot to a file
48
- plt.savefig("token_probabilities.png")
 
 
 
 
49
  plt.close()
50
 
51
- return "token_probabilities.png", dict(zip(topk_tokens, topk_probs.tolist()))
52
 
53
- def interface():
54
- with gr.Blocks() as demo:
55
- gr.Markdown("# GPT-2 Next Token Probability Visualizer")
56
- gr.Markdown("Enter text and see the probabilities of possible next tokens.")
57
-
58
- with gr.Row():
59
- with gr.Column():
60
- input_text = gr.Textbox(
61
- label="Input Text",
62
- placeholder="Type some text here...",
63
- value="Hello, my name is"
64
- )
65
- top_k = gr.Slider(
66
- minimum=5,
67
- maximum=20,
68
- value=10,
69
- step=1,
70
- label="Number of top tokens to show"
71
- )
72
- btn = gr.Button("Generate Probabilities")
73
-
74
- with gr.Column():
75
- output_image = gr.Image(label="Probability Distribution")
76
- output_table = gr.JSON(label="Token Probabilities")
77
-
78
- btn.click(
79
- fn=get_token_probabilities,
80
- inputs=[input_text, top_k],
81
- outputs=[output_image, output_table]
82
- )
83
-
84
- gr.Examples(
85
- examples=[
86
- ["Hello, my name is", 10],
87
- ["The capital of France is", 10],
88
- ["Once upon a time", 10],
89
- ["The best way to learn is to", 10]
90
- ],
91
- inputs=[input_text, top_k],
92
- )
93
 
94
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- if __name__ == "__main__":
97
- demo = interface()
98
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
+ import os
7
 
8
  # Load model and tokenizer
9
  model_name = "gpt2"
 
44
  plt.ylabel("Tokens")
45
  plt.tight_layout()
46
 
47
+ # Ensure temp directory exists
48
+ os.makedirs("tmp", exist_ok=True)
49
+
50
+ # Save the plot to a file in the temp directory
51
+ plot_path = os.path.join("tmp", "token_probabilities.png")
52
+ plt.savefig(plot_path)
53
  plt.close()
54
 
55
+ return plot_path, dict(zip(topk_tokens, topk_probs.tolist()))
56
 
57
+ with gr.Blocks() as demo:
58
+ gr.Markdown("# GPT-2 Next Token Probability Visualizer")
59
+ gr.Markdown("Enter text and see the probabilities of possible next tokens.")
60
+
61
+ with gr.Row():
62
+ with gr.Column():
63
+ input_text = gr.Textbox(
64
+ label="Input Text",
65
+ placeholder="Type some text here...",
66
+ value="Hello, my name is"
67
+ )
68
+ top_k = gr.Slider(
69
+ minimum=5,
70
+ maximum=20,
71
+ value=10,
72
+ step=1,
73
+ label="Number of top tokens to show"
74
+ )
75
+ btn = gr.Button("Generate Probabilities")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ with gr.Column():
78
+ output_image = gr.Image(label="Probability Distribution")
79
+ output_table = gr.JSON(label="Token Probabilities")
80
+
81
+ btn.click(
82
+ fn=get_token_probabilities,
83
+ inputs=[input_text, top_k],
84
+ outputs=[output_image, output_table]
85
+ )
86
+
87
+ gr.Examples(
88
+ examples=[
89
+ ["Hello, my name is", 10],
90
+ ["The capital of France is", 10],
91
+ ["Once upon a time", 10],
92
+ ["The best way to learn is to", 10]
93
+ ],
94
+ inputs=[input_text, top_k],
95
+ )
96
 
97
+ # Launch the app
98
+ demo.launch()