Tonic commited on
Commit
eb1851a
·
1 Parent(s): 1ac0fa5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -30,15 +30,15 @@ model.eval()
30
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
31
 
32
  class StarlingBot:
33
- def __init__(self, system_prompt="I am Starling-7B by Tonic-AI, I ready to do anything to help my user."):
34
  self.system_prompt = system_prompt
35
 
36
- def predict(self, user_message, assistant_message, system_prompt, mode, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
37
  try:
38
  if mode == "Assistant":
39
- conversation = f" GPT4 Correct Assistant: {system_prompt if system_prompt else self.system_prompt} GPT4 Correct Assistant: {assistant_message if assistant_message else ''} GPT4 Correct User: {user_message} GPT4 Correct Assistant:"
40
  else: # mode == "Coder"
41
- conversation = f" Code Assistant: {system_prompt if system_prompt else self.system_prompt} Code Assistant: {assistant_message if assistant_message else ''} Code User:: {user_message} Code Assistant:"
42
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True)
43
  input_ids = input_ids.to(device)
44
  response = model.generate(
@@ -72,10 +72,10 @@ examples = [
72
  1.9, # repetition_penalty
73
  ]
74
  ]
75
- # Initialize StarlingBot
76
  starling_bot = StarlingBot()
77
 
78
- def gradio_starling(user_message, assistant_message, system_prompt, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty):
79
  response = starling_bot.predict(user_message, assistant_message, system_prompt, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty)
80
  return response
81
 
@@ -83,8 +83,7 @@ with gr.Blocks(theme="ParityError/Anime") as demo:
83
  gr.Markdown(title)
84
  gr.Markdown(description)
85
  with gr.Row():
86
- system_prompt = gr.Textbox(label="Optional💫🌠Starling System Prompt", lines=2)
87
- assistant_message = gr.Textbox(label="💫🌠Starling Assistant Message", lines=2)
88
  user_message = gr.Textbox(label="Your Message", lines=3)
89
  with gr.Row():
90
  mode = gr.Radio(choices=["Assistant", "Coder"], value="Assistant", label="Mode")
@@ -101,7 +100,7 @@ with gr.Blocks(theme="ParityError/Anime") as demo:
101
 
102
  submit_button.click(
103
  gradio_starling,
104
- inputs=[user_message, assistant_message, system_prompt, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty],
105
  outputs=output_text
106
  )
107
 
 
30
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
31
 
32
  class StarlingBot:
33
+ def __init__(self, assistant_message="I am Starling-7B by Tonic-AI, I am ready to do anything to help my user."):
34
  self.system_prompt = system_prompt
35
 
36
+ def predict(self, user_message, assistant_message, mode, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
37
  try:
38
  if mode == "Assistant":
39
+ conversation = f"GPT4 Correct Assistant: {assistant_message if assistant_message else ''} GPT4 Correct User: {user_message} GPT4 Correct Assistant:"
40
  else: # mode == "Coder"
41
+ conversation = f"Code Assistant: {assistant_message if assistant_message else ''} Code User:: {user_message} Code Assistant:"
42
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=True)
43
  input_ids = input_ids.to(device)
44
  response = model.generate(
 
72
  1.9, # repetition_penalty
73
  ]
74
  ]
75
+
76
  starling_bot = StarlingBot()
77
 
78
+ def gradio_starling(user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty):
79
  response = starling_bot.predict(user_message, assistant_message, system_prompt, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty)
80
  return response
81
 
 
83
  gr.Markdown(title)
84
  gr.Markdown(description)
85
  with gr.Row():
86
+ assistant_message = gr.Textbox(label="Optional💫🌠Starling Assistant Message", lines=2)
 
87
  user_message = gr.Textbox(label="Your Message", lines=3)
88
  with gr.Row():
89
  mode = gr.Radio(choices=["Assistant", "Coder"], value="Assistant", label="Mode")
 
100
 
101
  submit_button.click(
102
  gradio_starling,
103
+ inputs=[user_message, assistant_message, mode, do_sample, temperature, max_new_tokens, top_p, repetition_penalty],
104
  outputs=output_text
105
  )
106