Spaces:
Running
Running
chore: cleanup @gr .funcs
Browse files
app.py
CHANGED
@@ -6,21 +6,17 @@ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
6 |
device = 'cuda'
|
7 |
torch_dtype = torch.bfloat16
|
8 |
|
9 |
-
@gr.funcs
|
10 |
def load_model() -> AutoModelForCausalLM:
|
11 |
return AutoModelForCausalLM.from_pretrained(model_name, device=device, torch_dtype=torch_dtype)
|
12 |
|
13 |
-
@gr.funcs
|
14 |
def load_tokenizer() -> AutoTokenizer:
|
15 |
return AutoTokenizer.from_pretrained(model_name)
|
16 |
|
17 |
-
@gr.funcs
|
18 |
def preprocess_messages(message: str, history: list, system_prompt: str) -> dict:
|
19 |
messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': message}]
|
20 |
prompt = load_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
21 |
return prompt
|
22 |
|
23 |
-
@gr.funcs
|
24 |
def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
|
25 |
model = load_model()
|
26 |
terminators = [load_tokenizer().eos_token_id, load_tokenizer().convert_tokens_to_ids(['\n'])]
|
@@ -35,7 +31,6 @@ def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
|
|
35 |
)
|
36 |
return load_tokenizer().decode(outputs[0], skip_special_tokens=True)
|
37 |
|
38 |
-
@gr.funcs
|
39 |
def chat_function(
|
40 |
message: str,
|
41 |
history: list,
|
|
|
6 |
device = 'cuda'
|
7 |
torch_dtype = torch.bfloat16
|
8 |
|
|
|
9 |
def load_model() -> AutoModelForCausalLM:
|
10 |
return AutoModelForCausalLM.from_pretrained(model_name, device=device, torch_dtype=torch_dtype)
|
11 |
|
|
|
12 |
def load_tokenizer() -> AutoTokenizer:
|
13 |
return AutoTokenizer.from_pretrained(model_name)
|
14 |
|
|
|
15 |
def preprocess_messages(message: str, history: list, system_prompt: str) -> dict:
|
16 |
messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': message}]
|
17 |
prompt = load_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
18 |
return prompt
|
19 |
|
|
|
20 |
def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
|
21 |
model = load_model()
|
22 |
terminators = [load_tokenizer().eos_token_id, load_tokenizer().convert_tokens_to_ids(['\n'])]
|
|
|
31 |
)
|
32 |
return load_tokenizer().decode(outputs[0], skip_special_tokens=True)
|
33 |
|
|
|
34 |
def chat_function(
|
35 |
message: str,
|
36 |
history: list,
|