Tonic commited on
Commit
e4a1a3c
·
1 Parent(s): 26d0891

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -33,14 +33,13 @@ repetition_penalty=1.7
33
 
34
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
35
  model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
36
- device_map="auto"
37
- # torch_dtype=torch.bfloat16,
38
- # load_in_4bit=True
39
  )
40
- # model.eval()
41
 
42
  class StarlingBot:
43
- def __init__(self, system_prompt="The following dialogue is a conversation"):
44
  self.system_prompt = system_prompt
45
 
46
  def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
@@ -75,7 +74,6 @@ iface = gr.Interface(
75
  fn=starling_bot.predict,
76
  title=title,
77
  description=description,
78
- # examples=examples,
79
  inputs=[
80
  gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
81
  gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
@@ -87,5 +85,5 @@ iface = gr.Interface(
87
  gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
88
  ],
89
  outputs="text",
90
- # theme="ParityError/Anime"
91
  )
 
33
 
34
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
35
  model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
36
+ device_map=device,
37
+ torch_dtype=torch.bfloat16
 
38
  )
39
+ model.eval()
40
 
41
  class StarlingBot:
42
+ def __init__(self, system_prompt="I am Starling-7B by Tonic-AI, I ready to do anything to help my user."):
43
  self.system_prompt = system_prompt
44
 
45
  def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
 
74
  fn=starling_bot.predict,
75
  title=title,
76
  description=description,
 
77
  inputs=[
78
  gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
79
  gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
 
85
  gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
86
  ],
87
  outputs="text",
88
+ theme="ParityError/Anime"
89
  )