Spaces:
Sleeping
Sleeping
# app.py | |
import json | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Load model and tokenizer | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
# Define a simple addition function schema | |
function_schema = { | |
"name": "add_numbers", | |
"description": "Add two numbers together", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"number1": { | |
"type": "number", | |
"description": "The first number" | |
}, | |
"number2": { | |
"type": "number", | |
"description": "The second number" | |
} | |
}, | |
"required": ["number1", "number2"] | |
} | |
} | |
# Create prompt with function definition | |
def create_prompt(user_input, function): | |
prompt = f"<|system|>\nYou are a helpful assistant that can use functions. Please call the add_numbers function for any addition requests.\n\nAvailable function:\n{json.dumps(function)}\n<|user|>\n{user_input}\n<|assistant|>\n" | |
return prompt | |
# Extract function call from response | |
def extract_function_call(response_text): | |
try: | |
if "<functioncall>" in response_text and "</functioncall>" in response_text: | |
func_text = response_text.split("<functioncall>")[1].split("</functioncall>")[0].strip() | |
return json.loads(func_text) | |
return None | |
except Exception as e: | |
print(f"Error extracting function call: {e}") | |
return None | |
# Actually perform the addition | |
def execute_add_numbers(params): | |
try: | |
num1 = float(params.get("number1", 0)) | |
num2 = float(params.get("number2", 0)) | |
return {"result": num1 + num2} | |
except Exception as e: | |
return {"error": str(e)} | |
def process_query(query, debug=False): | |
# Create the initial prompt | |
prompt = create_prompt(query, function_schema) | |
# Generate the initial response | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=256, | |
temperature=0.1 | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
# Process the response | |
try: | |
assistant_response = response.split("<|assistant|>")[1].strip() | |
except: | |
return "Error parsing model response." | |
debug_info = f"Initial response:\n{assistant_response}\n\n" if debug else "" | |
# Check for function call | |
function_call = extract_function_call(assistant_response) | |
if not function_call: | |
return debug_info + "No function call detected in the response." | |
debug_info += f"Function call detected:\n{json.dumps(function_call, indent=2)}\n\n" if debug else "" | |
# Execute the function | |
result = execute_add_numbers(function_call) | |
debug_info += f"Function result:\n{json.dumps(result, indent=2)}\n\n" if debug else "" | |
# Create follow-up prompt with function result | |
follow_up_prompt = f"{prompt}\n<functioncall>\n{json.dumps(function_call)}\n</functioncall>\n\n<functionresponse>\n{json.dumps(result)}\n</functionresponse>\n" | |
# Generate final response | |
follow_up_inputs = tokenizer(follow_up_prompt, return_tensors="pt").to(model.device) | |
follow_up_outputs = model.generate( | |
**follow_up_inputs, | |
max_new_tokens=256, | |
temperature=0.1 | |
) | |
follow_up_response = tokenizer.decode(follow_up_outputs[0], skip_special_tokens=False) | |
try: | |
if "<functionresponse>" in follow_up_response and "</functionresponse>" in follow_up_response: | |
final_response = follow_up_response.split("</functionresponse>")[1].strip() | |
else: | |
final_response = follow_up_response.split("<|assistant|>")[1].strip() | |
except: | |
return debug_info + "Error extracting final response." | |
if debug: | |
return debug_info + f"Final response:\n{final_response}" | |
else: | |
return final_response | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Llama 3.1 Function Calling: Addition Calculator") | |
gr.Markdown("Ask the model to add numbers, and it will use the `add_numbers` function") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Your Question", | |
placeholder="Example: What is 24 plus 18?", | |
lines=2 | |
) | |
debug_checkbox = gr.Checkbox(label="Show Debug Info", value=False) | |
submit_btn = gr.Button("Submit") | |
output = gr.Textbox(label="Response", lines=10) | |
submit_btn.click( | |
fn=process_query, | |
inputs=[query_input, debug_checkbox], | |
outputs=output | |
) | |
gr.Examples( | |
[ | |
["What is 25 plus 17?"], | |
["Can you add 123 and 456?"], | |
["Calculate 3.14 + 2.71"] | |
], | |
inputs=query_input | |
) | |
demo.launch() |