davanstrien HF Staff commited on
Commit
35991d0
·
verified ·
1 Parent(s): e644ba4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
@@ -7,7 +6,6 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
  model = GPT2LMHeadModel.from_pretrained("gpt2")
8
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
9
 
10
- @spaces.GPU
11
  def get_next_token_probs(text):
12
  # Handle empty input
13
  if not text.strip():
@@ -48,10 +46,13 @@ with gr.Blocks(css="footer {display: none}") as demo:
48
  # Input textbox
49
  input_text = gr.Textbox(
50
  label="Text Input",
51
- placeholder="Type here and watch predictions update...",
52
  value="The weather tomorrow will be"
53
  )
54
 
 
 
 
55
  # Simple header for results
56
  gr.Markdown("##### Most likely next tokens:")
57
 
@@ -64,8 +65,8 @@ with gr.Blocks(css="footer {display: none}") as demo:
64
 
65
  token_outputs = [token1, token2, token3, token4, token5]
66
 
67
- # Set up the live update
68
- input_text.change(
69
  fn=get_next_token_probs,
70
  inputs=input_text,
71
  outputs=token_outputs
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
6
  model = GPT2LMHeadModel.from_pretrained("gpt2")
7
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
8
 
 
9
  def get_next_token_probs(text):
10
  # Handle empty input
11
  if not text.strip():
 
46
  # Input textbox
47
  input_text = gr.Textbox(
48
  label="Text Input",
49
+ placeholder="Type text here...",
50
  value="The weather tomorrow will be"
51
  )
52
 
53
+ # Predict button
54
+ predict_btn = gr.Button("Predict Next Tokens")
55
+
56
  # Simple header for results
57
  gr.Markdown("##### Most likely next tokens:")
58
 
 
65
 
66
  token_outputs = [token1, token2, token3, token4, token5]
67
 
68
+ # Set up button click event
69
+ predict_btn.click(
70
  fn=get_next_token_probs,
71
  inputs=input_text,
72
  outputs=token_outputs