gnumanth commited on
Commit
78011f3
·
verified ·
1 Parent(s): dc16a81
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
6
+ device = 'cuda'
7
+ torch_dtype = torch.bfloat16
8
+
9
+ @gr.funcs
10
+ def load_model() -> AutoModelForCausalLM:
11
+ return AutoModelForCausalLM.from_pretrained(model_name, device=device, torch_dtype=torch_dtype)
12
+
13
+ @gr.funcs
14
+ def load_tokenizer() -> AutoTokenizer:
15
+ return AutoTokenizer.from_pretrained(model_name)
16
+
17
+ @gr.funcs
18
+ def preprocess_messages(message: str, history: list, system_prompt: str) -> dict:
19
+ messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': message}]
20
+ prompt = load_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
21
+ return prompt
22
+
23
+ @gr.funcs
24
+ def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str:
25
+ model = load_model()
26
+ terminators = [load_tokenizer().eos_token_id, load_tokenizer().convert_tokens_to_ids(['\n'])]
27
+ temp = temperature + 0.1
28
+ outputs = model.generate(
29
+ prompt,
30
+ max_new_tokens=max_new_tokens,
31
+ eos_token_id=terminators[0],
32
+ do_sample=True,
33
+ temperature=temp,
34
+ top_p=0.9
35
+ )
36
+ return load_tokenizer().decode(outputs[0], skip_special_tokens=True)
37
+
38
+ @gr.funcs
39
+ def chat_function(
40
+ message: str,
41
+ history: list,
42
+ system_prompt: str,
43
+ max_new_tokens: int,
44
+ temperature: float
45
+ ) -> str:
46
+ prompt = preprocess_messages(message, history, system_prompt)
47
+ return generate_text(prompt, max_new_tokens, temperature)
48
+
49
+ gr.ChatInterface(
50
+ chat_function,
51
+ chatbot=gr.Chatbot(height=400),
52
+ textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
53
+ title="LLAMA3 Chat",
54
+ description="""Chat with llama3""",
55
+ theme="soft",
56
+ additional_inputs=[
57
+ gr.Textbox("You shall answer to all the questions as very smart AI", label="System Prompt"),
58
+ gr.Slider(512, 4096, label="Max New Tokens"),
59
+ gr.Slider(0, 1, label="Temperature")
60
+ ]
61
+ ).launch()