Tonic commited on
Commit
ea3b3e9
·
1 Parent(s): bfb620f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -24,7 +24,7 @@ examples = [
24
  ]
25
 
26
  model_name = "berkeley-nest/Starling-RM-7B-alpha"
27
- base_model = "meta-llama/Llama-2-7b-chat-hf"
28
 
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -46,9 +46,8 @@ class StarlingBot:
46
  def __init__(self, system_prompt="The following dialogue is a conversation"):
47
  self.system_prompt = system_prompt
48
 
49
- def predict(self, user_message, assistant_message, system_prompt, advanced, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
50
  conversation = f" <s> [INST] {self.system_prompt} [INST] {assistant_message if assistant_message else ''} </s> [/INST] {user_message} </s> "
51
- # Encode the conversation using the tokenizer
52
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
53
  input_ids = input_ids.to(device)
54
  response = model.generate(
@@ -67,21 +66,18 @@ class StarlingBot:
67
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
68
  return response_text
69
 
70
- # Create the Falcon chatbot instance
71
- StarlingBot_bot = StarlingBot()
72
-
73
- starling_bot = StarlingBot() # Renamed for consistency
74
 
75
  iface = gr.Interface(
76
- fn=starling_bot.predict, # Corrected to match the instance name
77
  title=title,
78
  description=description,
79
- examples=examples,
80
  inputs=[
81
- gr.Textbox(label="User Message", type="text", lines=5),
82
  gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
83
  gr.Textbox(label="💫🌠Starling System Prompt or Instruction", lines=2),
84
- gr.Checkbox(label="Advanced", value=False), # Ensure this is connected to functionality
85
  gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
86
  gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1),
87
  gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
@@ -89,5 +85,4 @@ iface = gr.Interface(
89
  ],
90
  outputs="text",
91
  theme="ParityError/Anime"
92
- )
93
-
 
24
  ]
25
 
26
  model_name = "berkeley-nest/Starling-RM-7B-alpha"
27
+ base_model = "michaelfeil/ct2fast-Llama-2-7b-chat-hf"
28
 
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
46
  def __init__(self, system_prompt="The following dialogue is a conversation"):
47
  self.system_prompt = system_prompt
48
 
49
+ def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
50
  conversation = f" <s> [INST] {self.system_prompt} [INST] {assistant_message if assistant_message else ''} </s> [/INST] {user_message} </s> "
 
51
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
52
  input_ids = input_ids.to(device)
53
  response = model.generate(
 
66
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
67
  return response_text
68
 
69
+ starling_bot = StarlingBot()
 
 
 
70
 
71
  iface = gr.Interface(
72
+ fn=starling_bot.predict,
73
  title=title,
74
  description=description,
75
+ # examples=examples,
76
  inputs=[
77
+ gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
78
  gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
79
  gr.Textbox(label="💫🌠Starling System Prompt or Instruction", lines=2),
80
+ gr.Checkbox(label="Advanced", value=False),
81
  gr.Slider(label="Temperature", value=0.7, minimum=0.05, maximum=1.0, step=0.05),
82
  gr.Slider(label="Max new tokens", value=100, minimum=25, maximum=256, step=1),
83
  gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05),
 
85
  ],
86
  outputs="text",
87
  theme="ParityError/Anime"
88
+ )