Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -21,13 +21,28 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
|
21 |
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
|
22 |
|
23 |
class OrcaChatBot:
|
|
|
24 |
def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
|
25 |
self.model = model
|
26 |
self.tokenizer = tokenizer
|
27 |
self.system_message = system_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
|
30 |
-
|
|
|
31 |
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
|
32 |
input_ids = inputs["input_ids"].to(self.model.device)
|
33 |
|
@@ -42,7 +57,7 @@ class OrcaChatBot:
|
|
42 |
)
|
43 |
|
44 |
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
45 |
-
|
46 |
return response
|
47 |
|
48 |
Orca_bot = OrcaChatBot(model, tokenizer)
|
|
|
21 |
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
|
22 |
|
23 |
class OrcaChatBot:
|
24 |
+
# Code below from [microsoft/ari9dam](https://huggingface.co/spaces/ari9dam/Orca-2-13B)
|
25 |
def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
|
26 |
self.model = model
|
27 |
self.tokenizer = tokenizer
|
28 |
self.system_message = system_message
|
29 |
+
self.conversation_history = []
|
30 |
+
|
31 |
+
def update_conversation_history(self, user_message, assistant_message):
|
32 |
+
self.conversation_history.append(("user", user_message))
|
33 |
+
self.conversation_history.append(("assistant", assistant_message))
|
34 |
+
|
35 |
+
|
36 |
+
def format_prompt(self):
|
37 |
+
prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n"
|
38 |
+
for role, message in self.conversation_history:
|
39 |
+
prompt += f"<|im_start|>{role}\n{message}<|im_end|>\n"
|
40 |
+
prompt += "<|im_start|> assistant\n"
|
41 |
+
return prompt
|
42 |
|
43 |
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
|
44 |
+
self.update_conversation_history(user_message, "")
|
45 |
+
prompt = self.format_prompt()
|
46 |
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
|
47 |
input_ids = inputs["input_ids"].to(self.model.device)
|
48 |
|
|
|
57 |
)
|
58 |
|
59 |
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
60 |
+
self.update_conversation_history("", response)
|
61 |
return response
|
62 |
|
63 |
Orca_bot = OrcaChatBot(model, tokenizer)
|