Llama_funCall / app.py
ceymox's picture
Update app.py
b1678d4 verified
raw
history blame
5.03 kB
# app.py
import json
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model and tokenizer
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Define a simple addition function schema
function_schema = {
"name": "add_numbers",
"description": "Add two numbers together",
"parameters": {
"type": "object",
"properties": {
"number1": {
"type": "number",
"description": "The first number"
},
"number2": {
"type": "number",
"description": "The second number"
}
},
"required": ["number1", "number2"]
}
}
# Create prompt with function definition
def create_prompt(user_input, function):
prompt = f"<|system|>\nYou are a helpful assistant that can use functions. Please call the add_numbers function for any addition requests.\n\nAvailable function:\n{json.dumps(function)}\n<|user|>\n{user_input}\n<|assistant|>\n"
return prompt
# Extract function call from response
def extract_function_call(response_text):
try:
if "<functioncall>" in response_text and "</functioncall>" in response_text:
func_text = response_text.split("<functioncall>")[1].split("</functioncall>")[0].strip()
return json.loads(func_text)
return None
except Exception as e:
print(f"Error extracting function call: {e}")
return None
# Actually perform the addition
def execute_add_numbers(params):
try:
num1 = float(params.get("number1", 0))
num2 = float(params.get("number2", 0))
return {"result": num1 + num2}
except Exception as e:
return {"error": str(e)}
def process_query(query, debug=False):
# Create the initial prompt
prompt = create_prompt(query, function_schema)
# Generate the initial response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1
)
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
# Process the response
try:
assistant_response = response.split("<|assistant|>")[1].strip()
except:
return "Error parsing model response."
debug_info = f"Initial response:\n{assistant_response}\n\n" if debug else ""
# Check for function call
function_call = extract_function_call(assistant_response)
if not function_call:
return debug_info + "No function call detected in the response."
debug_info += f"Function call detected:\n{json.dumps(function_call, indent=2)}\n\n" if debug else ""
# Execute the function
result = execute_add_numbers(function_call)
debug_info += f"Function result:\n{json.dumps(result, indent=2)}\n\n" if debug else ""
# Create follow-up prompt with function result
follow_up_prompt = f"{prompt}\n<functioncall>\n{json.dumps(function_call)}\n</functioncall>\n\n<functionresponse>\n{json.dumps(result)}\n</functionresponse>\n"
# Generate final response
follow_up_inputs = tokenizer(follow_up_prompt, return_tensors="pt").to(model.device)
follow_up_outputs = model.generate(
**follow_up_inputs,
max_new_tokens=256,
temperature=0.1
)
follow_up_response = tokenizer.decode(follow_up_outputs[0], skip_special_tokens=False)
try:
if "<functionresponse>" in follow_up_response and "</functionresponse>" in follow_up_response:
final_response = follow_up_response.split("</functionresponse>")[1].strip()
else:
final_response = follow_up_response.split("<|assistant|>")[1].strip()
except:
return debug_info + "Error extracting final response."
if debug:
return debug_info + f"Final response:\n{final_response}"
else:
return final_response
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Llama 3.1 Function Calling: Addition Calculator")
gr.Markdown("Ask the model to add numbers, and it will use the `add_numbers` function")
with gr.Row():
query_input = gr.Textbox(
label="Your Question",
placeholder="Example: What is 24 plus 18?",
lines=2
)
debug_checkbox = gr.Checkbox(label="Show Debug Info", value=False)
submit_btn = gr.Button("Submit")
output = gr.Textbox(label="Response", lines=10)
submit_btn.click(
fn=process_query,
inputs=[query_input, debug_checkbox],
outputs=output
)
gr.Examples(
[
["What is 25 plus 17?"],
["Can you add 123 and 456?"],
["Calculate 3.14 + 2.71"]
],
inputs=query_input
)
demo.launch()