Tonic commited on
Commit
03c59e6
·
1 Parent(s): 90ffb86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import optimum
2
  import transformers
3
  from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
- # from optimum.bettertransformer import BetterTransformer
5
  import torch
6
  import gradio as gr
7
  import json
@@ -24,8 +24,6 @@ examples = [
24
  ]
25
 
26
  model_name = "berkeley-nest/Starling-LM-7B-alpha"
27
- # base_model = "meta-llama/Llama-2-7b-chat-hf"
28
-
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  temperature=0.4
@@ -40,14 +38,14 @@ model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
40
  torch_dtype=torch.bfloat16,
41
  load_in_4bit=True
42
  )
43
- # model = BetterTransformer.transform(model)
44
  model.eval()
45
 
46
  class StarlingBot:
47
- def __init__(self, system_prompt="The following dialogue is a conversation"):
48
- self.system_prompt = system_prompt
49
 
50
- def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
 
51
  conversation = f" <s> [INST] {self.system_prompt} [INST] {assistant_message if assistant_message else ''} </s> [/INST] {user_message} </s> "
52
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
53
  input_ids = input_ids.to(device)
@@ -56,7 +54,7 @@ class StarlingBot:
56
  use_cache=False,
57
  early_stopping=False,
58
  bos_token_id=model.config.bos_token_id,
59
- eos_token_id=model.config.eos_token_id,
60
  pad_token_id=model.config.eos_token_id,
61
  temperature=temperature,
62
  do_sample=True,
@@ -65,13 +63,12 @@ class StarlingBot:
65
  repetition_penalty=repetition_penalty
66
  )
67
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
68
- response_text = response.strip()
69
  # response_text = response.split("<|assistant|>\n")[-1]
70
  return response_text
71
- finally:
72
- del input_ids, attention_mask, output_ids
73
- gc.collect()
74
- torch.cuda.empty_cache()
75
 
76
  starling_bot = StarlingBot()
77
 
@@ -79,7 +76,7 @@ iface = gr.Interface(
79
  fn=starling_bot.predict,
80
  title=title,
81
  description=description,
82
- # examples=examples,
83
  inputs=[
84
  gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
85
  gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),
 
1
  import optimum
2
  import transformers
3
  from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
+
5
  import torch
6
  import gradio as gr
7
  import json
 
24
  ]
25
 
26
  model_name = "berkeley-nest/Starling-LM-7B-alpha"
 
 
27
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  temperature=0.4
 
38
  torch_dtype=torch.bfloat16,
39
  load_in_4bit=True
40
  )
 
41
  model.eval()
42
 
43
  class StarlingBot:
44
+ def __init__(self, system_prompt="The following dialogue is a conversation"):
45
+ self.system_prompt = system_prompt
46
 
47
+ def predict(self, user_message, assistant_message, system_prompt, do_sample, temperature=0.4, max_new_tokens=700, top_p=0.99, repetition_penalty=1.9):
48
+ try:
49
  conversation = f" <s> [INST] {self.system_prompt} [INST] {assistant_message if assistant_message else ''} </s> [/INST] {user_message} </s> "
50
  input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
51
  input_ids = input_ids.to(device)
 
54
  use_cache=False,
55
  early_stopping=False,
56
  bos_token_id=model.config.bos_token_id,
57
+ eos_token_id=model.config.eos_token_id,
58
  pad_token_id=model.config.eos_token_id,
59
  temperature=temperature,
60
  do_sample=True,
 
63
  repetition_penalty=repetition_penalty
64
  )
65
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
 
66
  # response_text = response.split("<|assistant|>\n")[-1]
67
  return response_text
68
+ finally:
69
+ del input_ids, attention_mask, output_ids
70
+ gc.collect()
71
+ torch.cuda.empty_cache()
72
 
73
  starling_bot = StarlingBot()
74
 
 
76
  fn=starling_bot.predict,
77
  title=title,
78
  description=description,
79
+ examples=examples,
80
  inputs=[
81
  gr.Textbox(label="🌟🤩User Message", type="text", lines=5),
82
  gr.Textbox(label="💫🌠Starling Assistant Message or Instructions ", lines=2),