Tonic commited on
Commit
8f4fc52
·
1 Parent(s): 588f72e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -33,9 +33,12 @@ class StarlingBot:
33
  def __init__(self, system_prompt="I am Starling-7B by Tonic-AI, I ready to do anything to help my user."):
34
  self.system_prompt = system_prompt
35
 
36
- def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
37
  try:
38
- conversation = f" GPT4 Correct Assistant: {system_prompt if system_prompt else self.system_prompt} <|end_of_turn|> GPT4 Correct Assistant: {assistant_message if assistant_message else ''} <|end_of_turn|> GPT4 Correct User: {user_message} <|end_of_turn|> GPT4 Correct Assistant:"
 
 
 
39
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True)
40
  input_ids = input_ids.to(device)
41
  response = model.generate(
@@ -84,8 +87,8 @@ with gr.Blocks(theme="ParityError/Anime") as demo:
84
  assistant_message = gr.Textbox(label="💫🌠Starling Assistant Message", lines=2)
85
  user_message = gr.Textbox(label="Your Message", lines=3)
86
  with gr.Row():
87
- do_sample = gr.Checkbox(label="Advanced", value=True)
88
-
89
  with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample):
90
  with gr.Row():
91
  temperature = gr.Slider(label="Temperature", value=0.4, minimum=0.05, maximum=1.0, step=0.05)
@@ -98,7 +101,7 @@ with gr.Blocks(theme="ParityError/Anime") as demo:
98
 
99
  submit_button.click(
100
  gradio_starling,
101
- inputs=[user_message, assistant_message, system_prompt, do_sample, temperature, max_new_tokens, top_p, repetition_penalty],
102
  outputs=output_text
103
  )
104
 
 
33
  def __init__(self, system_prompt="I am Starling-7B by Tonic-AI, I ready to do anything to help my user."):
34
  self.system_prompt = system_prompt
35
 
36
+ def predict(self, user_message, assistant_message, system_prompt, mode, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
37
  try:
38
+ if mode == "Assistant":
39
+ conversation = f" GPT4 Correct Assistant: {system_prompt if system_prompt else self.system_prompt} GPT4 Correct Assistant: {assistant_message if assistant_message else ''} GPT4 Correct User: {user_message} GPT4 Correct Assistant:"
40
+ else: # mode == "Coder"
41
+ conversation = f" Code Assistant: {system_prompt if system_prompt else self.system_prompt} Code Assistant: {assistant_message if assistant_message else ''} GPT4 Correct User: {user_message} Code Assistant:"
42
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True)
43
  input_ids = input_ids.to(device)
44
  response = model.generate(
 
87
  assistant_message = gr.Textbox(label="💫🌠Starling Assistant Message", lines=2)
88
  user_message = gr.Textbox(label="Your Message", lines=3)
89
  with gr.Row():
90
+ mode = gr.Radio(choices=["Assistant", "Coder"], value="Assistant", label="Mode")
91
+ do_sample = gr.Checkbox(label="Advanced", value=True)
92
  with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample):
93
  with gr.Row():
94
  temperature = gr.Slider(label="Temperature", value=0.4, minimum=0.05, maximum=1.0, step=0.05)
 
101
 
102
  submit_button.click(
103
  gradio_starling,
104
+ inputs=[user_message, assistant_message, system_prompt, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty],
105
  outputs=output_text
106
  )
107