ceymox commited on
Commit
b1678d4
·
verified ·
1 Parent(s): 1355df9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py CHANGED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import json
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+ # Load model and tokenizer
8
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_id,
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto"
14
+ )
15
+
16
+ # Define a simple addition function schema
17
+ function_schema = {
18
+ "name": "add_numbers",
19
+ "description": "Add two numbers together",
20
+ "parameters": {
21
+ "type": "object",
22
+ "properties": {
23
+ "number1": {
24
+ "type": "number",
25
+ "description": "The first number"
26
+ },
27
+ "number2": {
28
+ "type": "number",
29
+ "description": "The second number"
30
+ }
31
+ },
32
+ "required": ["number1", "number2"]
33
+ }
34
+ }
35
+
36
+ # Create prompt with function definition
37
+ def create_prompt(user_input, function):
38
+ prompt = f"<|system|>\nYou are a helpful assistant that can use functions. Please call the add_numbers function for any addition requests.\n\nAvailable function:\n{json.dumps(function)}\n<|user|>\n{user_input}\n<|assistant|>\n"
39
+ return prompt
40
+
41
+ # Extract function call from response
42
+ def extract_function_call(response_text):
43
+ try:
44
+ if "<functioncall>" in response_text and "</functioncall>" in response_text:
45
+ func_text = response_text.split("<functioncall>")[1].split("</functioncall>")[0].strip()
46
+ return json.loads(func_text)
47
+ return None
48
+ except Exception as e:
49
+ print(f"Error extracting function call: {e}")
50
+ return None
51
+
52
+ # Actually perform the addition
53
+ def execute_add_numbers(params):
54
+ try:
55
+ num1 = float(params.get("number1", 0))
56
+ num2 = float(params.get("number2", 0))
57
+ return {"result": num1 + num2}
58
+ except Exception as e:
59
+ return {"error": str(e)}
60
+
61
+ def process_query(query, debug=False):
62
+ # Create the initial prompt
63
+ prompt = create_prompt(query, function_schema)
64
+
65
+ # Generate the initial response
66
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
67
+ outputs = model.generate(
68
+ **inputs,
69
+ max_new_tokens=256,
70
+ temperature=0.1
71
+ )
72
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
73
+
74
+ # Process the response
75
+ try:
76
+ assistant_response = response.split("<|assistant|>")[1].strip()
77
+ except:
78
+ return "Error parsing model response."
79
+
80
+ debug_info = f"Initial response:\n{assistant_response}\n\n" if debug else ""
81
+
82
+ # Check for function call
83
+ function_call = extract_function_call(assistant_response)
84
+ if not function_call:
85
+ return debug_info + "No function call detected in the response."
86
+
87
+ debug_info += f"Function call detected:\n{json.dumps(function_call, indent=2)}\n\n" if debug else ""
88
+
89
+ # Execute the function
90
+ result = execute_add_numbers(function_call)
91
+
92
+ debug_info += f"Function result:\n{json.dumps(result, indent=2)}\n\n" if debug else ""
93
+
94
+ # Create follow-up prompt with function result
95
+ follow_up_prompt = f"{prompt}\n<functioncall>\n{json.dumps(function_call)}\n</functioncall>\n\n<functionresponse>\n{json.dumps(result)}\n</functionresponse>\n"
96
+
97
+ # Generate final response
98
+ follow_up_inputs = tokenizer(follow_up_prompt, return_tensors="pt").to(model.device)
99
+ follow_up_outputs = model.generate(
100
+ **follow_up_inputs,
101
+ max_new_tokens=256,
102
+ temperature=0.1
103
+ )
104
+ follow_up_response = tokenizer.decode(follow_up_outputs[0], skip_special_tokens=False)
105
+
106
+ try:
107
+ if "<functionresponse>" in follow_up_response and "</functionresponse>" in follow_up_response:
108
+ final_response = follow_up_response.split("</functionresponse>")[1].strip()
109
+ else:
110
+ final_response = follow_up_response.split("<|assistant|>")[1].strip()
111
+ except:
112
+ return debug_info + "Error extracting final response."
113
+
114
+ if debug:
115
+ return debug_info + f"Final response:\n{final_response}"
116
+ else:
117
+ return final_response
118
+
119
+ # Create Gradio interface
120
+ with gr.Blocks() as demo:
121
+ gr.Markdown("# Llama 3.1 Function Calling: Addition Calculator")
122
+ gr.Markdown("Ask the model to add numbers, and it will use the `add_numbers` function")
123
+
124
+ with gr.Row():
125
+ query_input = gr.Textbox(
126
+ label="Your Question",
127
+ placeholder="Example: What is 24 plus 18?",
128
+ lines=2
129
+ )
130
+ debug_checkbox = gr.Checkbox(label="Show Debug Info", value=False)
131
+
132
+ submit_btn = gr.Button("Submit")
133
+
134
+ output = gr.Textbox(label="Response", lines=10)
135
+
136
+ submit_btn.click(
137
+ fn=process_query,
138
+ inputs=[query_input, debug_checkbox],
139
+ outputs=output
140
+ )
141
+
142
+ gr.Examples(
143
+ [
144
+ ["What is 25 plus 17?"],
145
+ ["Can you add 123 and 456?"],
146
+ ["Calculate 3.14 + 2.71"]
147
+ ],
148
+ inputs=query_input
149
+ )
150
+
151
+ demo.launch()