ceymox commited on
Commit
3c9cad1
·
verified ·
1 Parent(s): 7111481

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -181
app.py CHANGED
@@ -1,181 +0,0 @@
1
- # app.py
2
- import json
3
- import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import torch
6
-
7
- # Load model and tokenizer
8
- model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
9
- tokenizer = AutoTokenizer.from_pretrained(model_id)
10
-
11
- # Add this workaround for the RoPE scaling issue
12
- from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
13
- import os
14
- import json
15
-
16
- # Fix the rope_scaling configuration before loading the model
17
- config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.json")
18
- if not os.path.exists(config_path):
19
- # Download the config file if it doesn't exist
20
- from huggingface_hub import hf_hub_download
21
- config_path = hf_hub_download(repo_id=model_id, filename=CONFIG_NAME)
22
-
23
- # Load and modify the config
24
- with open(config_path, 'r') as f:
25
- config = json.load(f)
26
-
27
- # Fix the rope_scaling format
28
- if 'rope_scaling' in config and not (isinstance(config['rope_scaling'], dict) and 'type' in config['rope_scaling'] and 'factor' in config['rope_scaling']):
29
- # Convert to the expected format
30
- old_scaling = config['rope_scaling']
31
- config['rope_scaling'] = {
32
- 'type': 'dynamic',
33
- 'factor': old_scaling.get('factor', 8.0)
34
- }
35
- # Save the modified config
36
- with open(config_path, 'w') as f:
37
- json.dump(config, f)
38
-
39
- # Now load the model with the fixed config
40
- model = AutoModelForCausalLM.from_pretrained(
41
- model_id,
42
- torch_dtype=torch.bfloat16,
43
- device_map="auto"
44
- )
45
-
46
- # Define a simple addition function schema
47
- function_schema = {
48
- "name": "add_numbers",
49
- "description": "Add two numbers together",
50
- "parameters": {
51
- "type": "object",
52
- "properties": {
53
- "number1": {
54
- "type": "number",
55
- "description": "The first number"
56
- },
57
- "number2": {
58
- "type": "number",
59
- "description": "The second number"
60
- }
61
- },
62
- "required": ["number1", "number2"]
63
- }
64
- }
65
-
66
- # Create prompt with function definition
67
- def create_prompt(user_input, function):
68
- prompt = f"<|system|>\nYou are a helpful assistant that can use functions. Please call the add_numbers function for any addition requests.\n\nAvailable function:\n{json.dumps(function)}\n<|user|>\n{user_input}\n<|assistant|>\n"
69
- return prompt
70
-
71
- # Extract function call from response
72
- def extract_function_call(response_text):
73
- try:
74
- if "<functioncall>" in response_text and "</functioncall>" in response_text:
75
- func_text = response_text.split("<functioncall>")[1].split("</functioncall>")[0].strip()
76
- return json.loads(func_text)
77
- return None
78
- except Exception as e:
79
- print(f"Error extracting function call: {e}")
80
- return None
81
-
82
- # Actually perform the addition
83
- def execute_add_numbers(params):
84
- try:
85
- num1 = float(params.get("number1", 0))
86
- num2 = float(params.get("number2", 0))
87
- return {"result": num1 + num2}
88
- except Exception as e:
89
- return {"error": str(e)}
90
-
91
- def process_query(query, debug=False):
92
- # Create the initial prompt
93
- prompt = create_prompt(query, function_schema)
94
-
95
- # Generate the initial response
96
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
97
- outputs = model.generate(
98
- **inputs,
99
- max_new_tokens=256,
100
- temperature=0.1
101
- )
102
- response = tokenizer.decode(outputs[0], skip_special_tokens=False)
103
-
104
- # Process the response
105
- try:
106
- assistant_response = response.split("<|assistant|>")[1].strip()
107
- except:
108
- return "Error parsing model response."
109
-
110
- debug_info = f"Initial response:\n{assistant_response}\n\n" if debug else ""
111
-
112
- # Check for function call
113
- function_call = extract_function_call(assistant_response)
114
- if not function_call:
115
- return debug_info + "No function call detected in the response."
116
-
117
- debug_info += f"Function call detected:\n{json.dumps(function_call, indent=2)}\n\n" if debug else ""
118
-
119
- # Execute the function
120
- result = execute_add_numbers(function_call)
121
-
122
- debug_info += f"Function result:\n{json.dumps(result, indent=2)}\n\n" if debug else ""
123
-
124
- # Create follow-up prompt with function result
125
- follow_up_prompt = f"{prompt}\n<functioncall>\n{json.dumps(function_call)}\n</functioncall>\n\n<functionresponse>\n{json.dumps(result)}\n</functionresponse>\n"
126
-
127
- # Generate final response
128
- follow_up_inputs = tokenizer(follow_up_prompt, return_tensors="pt").to(model.device)
129
- follow_up_outputs = model.generate(
130
- **follow_up_inputs,
131
- max_new_tokens=256,
132
- temperature=0.1
133
- )
134
- follow_up_response = tokenizer.decode(follow_up_outputs[0], skip_special_tokens=False)
135
-
136
- try:
137
- if "<functionresponse>" in follow_up_response and "</functionresponse>" in follow_up_response:
138
- final_response = follow_up_response.split("</functionresponse>")[1].strip()
139
- else:
140
- final_response = follow_up_response.split("<|assistant|>")[1].strip()
141
- except:
142
- return debug_info + "Error extracting final response."
143
-
144
- if debug:
145
- return debug_info + f"Final response:\n{final_response}"
146
- else:
147
- return final_response
148
-
149
- # Create Gradio interface
150
- with gr.Blocks() as demo:
151
- gr.Markdown("# Llama 3.1 Function Calling: Addition Calculator")
152
- gr.Markdown("Ask the model to add numbers, and it will use the `add_numbers` function")
153
-
154
- with gr.Row():
155
- query_input = gr.Textbox(
156
- label="Your Question",
157
- placeholder="Example: What is 24 plus 18?",
158
- lines=2
159
- )
160
- debug_checkbox = gr.Checkbox(label="Show Debug Info", value=False)
161
-
162
- submit_btn = gr.Button("Submit")
163
-
164
- output = gr.Textbox(label="Response", lines=10)
165
-
166
- submit_btn.click(
167
- fn=process_query,
168
- inputs=[query_input, debug_checkbox],
169
- outputs=output
170
- )
171
-
172
- gr.Examples(
173
- [
174
- ["What is 25 plus 17?"],
175
- ["Can you add 123 and 456?"],
176
- ["Calculate 3.14 + 2.71"]
177
- ],
178
- inputs=query_input
179
- )
180
-
181
- demo.launch()