vtrv.vls commited on
Commit
8c11891
·
1 Parent(s): df22c76

Clear inputs on model change

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -39,7 +39,7 @@ def qwen_gen(content, chat_history):
39
  send_to_s3(res, f'protobench/tiny_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
40
  return '', chat_history
41
 
42
- def model_gen(content, model_type: str, chat_history):
43
  if content is None:
44
  return None, None
45
  gen = MODEL_LIB[model_type]
@@ -49,8 +49,12 @@ def model_gen(content, model_type: str, chat_history):
49
  def clear_chat():
50
  return '', []
51
 
52
- def rerun(msg, chatbot):
53
- return msg, chatbot[:-1]
 
 
 
 
54
 
55
  MODEL_LIB = {'RUBASE': giga_gen, 'TINYLLAMA': tiny_gen, 'QWEN2INS1B': qwen_gen}
56
 
@@ -72,9 +76,9 @@ def tab_arena():
72
  with gradio.Row():
73
  with gradio.Accordion("Parameters", open=False):
74
  context = gradio.Checkbox(label="No context", value=False)
75
- top_p = gradio.Slider(label='Top P', minimum=0, maximum=1, value=1, step=0.5)
76
- temp = gradio.Slider(label='Temperature', minimum=0, maximum=1, value=0.7, step=0.5)
77
- max_tokens = gradio.Slider(label='Max ouput tokens', minimum=1, maximum=2048, value=512, step=1)
78
 
79
  with gradio.Row():
80
  msg = gradio.Textbox()
@@ -83,14 +87,14 @@ def tab_arena():
83
  clear = gradio.ClearButton([msg, chatbot_left, chatbot_right], value='Clear history')
84
  regen_left = gradio.Button(value='Regenerate left answer')
85
  regen_right = gradio.Button(value='Regenerate right answer')
86
- regen_left.click(rerun, [msg, chatbot_left], [msg, chatbot_left])
87
- regen_right.click(rerun, [msg, chatbot_right], [msg, chatbot_left])
88
 
89
  with gradio.Blocks():
90
  model_left.change(clear_chat, [], [msg, chatbot_left])
91
  model_right.change(clear_chat, [], [msg, chatbot_right])
92
- msg.submit(model_gen, [msg, model_left, chatbot_left], [msg, chatbot_left])
93
- msg.submit(model_gen, [msg, model_right, chatbot_right], [msg, chatbot_right])
94
 
95
  # with gradio.Column():
96
  # gradio.ChatInterface(
 
39
  send_to_s3(res, f'protobench/tiny_{str(datetime.now()).replace(" ", "_")}.json', S3_SESSION)
40
  return '', chat_history
41
 
42
+ def model_gen(content, chat_history, model_type: str):
43
  if content is None:
44
  return None, None
45
  gen = MODEL_LIB[model_type]
 
49
  def clear_chat():
50
  return '', []
51
 
52
+ def model_regen(content, chat_history, model_type: str):
53
+ if content is None:
54
+ return None, None
55
+ gen = MODEL_LIB[model_type]
56
+ print(MODEL_LIB[model_type])
57
+ return gen(content, chat_history[:-1])
58
 
59
  MODEL_LIB = {'RUBASE': giga_gen, 'TINYLLAMA': tiny_gen, 'QWEN2INS1B': qwen_gen}
60
 
 
76
  with gradio.Row():
77
  with gradio.Accordion("Parameters", open=False):
78
  context = gradio.Checkbox(label="No context", value=False)
79
+ top_p = gradio.Slider(label='Top P', minimum=0, maximum=1, value=1, step=0.5, interactive=True)
80
+ temp = gradio.Slider(label='Temperature', minimum=0, maximum=1, value=0.7, step=0.5, interactive=True)
81
+ max_tokens = gradio.Slider(label='Max ouput tokens', minimum=1, maximum=2048, value=512, step=1, interactive=True)
82
 
83
  with gradio.Row():
84
  msg = gradio.Textbox()
 
87
  clear = gradio.ClearButton([msg, chatbot_left, chatbot_right], value='Clear history')
88
  regen_left = gradio.Button(value='Regenerate left answer')
89
  regen_right = gradio.Button(value='Regenerate right answer')
90
+ regen_left.click(model_regen, [msg, chatbot_left, model_left], [msg, chatbot_left])
91
+ regen_right.click(model_regen, [msg, chatbot_right, model_right], [msg, chatbot_left])
92
 
93
  with gradio.Blocks():
94
  model_left.change(clear_chat, [], [msg, chatbot_left])
95
  model_right.change(clear_chat, [], [msg, chatbot_right])
96
+ msg.submit(model_gen, [msg, chatbot_left, model_left], [msg, chatbot_left])
97
+ msg.submit(model_gen, [msg, chatbot_right, model_right], [msg, chatbot_right])
98
 
99
  # with gradio.Column():
100
  # gradio.ChatInterface(