gnumanth commited on
Commit
f2da4dd
·
verified ·
1 Parent(s): 128311c

chore: cleanup @gr .funcs

Browse files
Files changed (1) hide show
  1. app.py +0 -5
app.py CHANGED
@@ -6,21 +6,17 @@ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
6
  device = 'cuda'
7
  torch_dtype = torch.bfloat16
8
 
9
- @gr.funcs
10
  def load_model() -> AutoModelForCausalLM:
11
  return AutoModelForCausalLM.from_pretrained(model_name, device=device, torch_dtype=torch_dtype)
12
 
13
- @gr.funcs
14
  def load_tokenizer() -> AutoTokenizer:
15
  return AutoTokenizer.from_pretrained(model_name)
16
 
17
- @gr.funcs
18
  def preprocess_messages(message: str, history: list, system_prompt: str) -> dict:
19
  messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': message}]
20
  prompt = load_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
21
  return prompt
22
 
23
- @gr.funcs
24
  def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
25
  model = load_model()
26
  terminators = [load_tokenizer().eos_token_id, load_tokenizer().convert_tokens_to_ids(['\n'])]
@@ -35,7 +31,6 @@ def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
35
  )
36
  return load_tokenizer().decode(outputs[0], skip_special_tokens=True)
37
 
38
- @gr.funcs
39
  def chat_function(
40
  message: str,
41
  history: list,
 
6
  device = 'cuda'
7
  torch_dtype = torch.bfloat16
8
 
 
9
  def load_model() -> AutoModelForCausalLM:
10
  return AutoModelForCausalLM.from_pretrained(model_name, device=device, torch_dtype=torch_dtype)
11
 
 
12
  def load_tokenizer() -> AutoTokenizer:
13
  return AutoTokenizer.from_pretrained(model_name)
14
 
 
15
  def preprocess_messages(message: str, history: list, system_prompt: str) -> dict:
16
  messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': message}]
17
  prompt = load_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
18
  return prompt
19
 
 
20
  def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
21
  model = load_model()
22
  terminators = [load_tokenizer().eos_token_id, load_tokenizer().convert_tokens_to_ids(['\n'])]
 
31
  )
32
  return load_tokenizer().decode(outputs[0], skip_special_tokens=True)
33
 
 
34
  def chat_function(
35
  message: str,
36
  history: list,