Spaces:
Sleeping
Sleeping
File size: 5,033 Bytes
b1678d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# app.py
import json
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model and tokenizer
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Define a simple addition function schema
function_schema = {
"name": "add_numbers",
"description": "Add two numbers together",
"parameters": {
"type": "object",
"properties": {
"number1": {
"type": "number",
"description": "The first number"
},
"number2": {
"type": "number",
"description": "The second number"
}
},
"required": ["number1", "number2"]
}
}
# Create prompt with function definition
def create_prompt(user_input, function):
prompt = f"<|system|>\nYou are a helpful assistant that can use functions. Please call the add_numbers function for any addition requests.\n\nAvailable function:\n{json.dumps(function)}\n<|user|>\n{user_input}\n<|assistant|>\n"
return prompt
# Extract function call from response
def extract_function_call(response_text):
try:
if "<functioncall>" in response_text and "</functioncall>" in response_text:
func_text = response_text.split("<functioncall>")[1].split("</functioncall>")[0].strip()
return json.loads(func_text)
return None
except Exception as e:
print(f"Error extracting function call: {e}")
return None
# Actually perform the addition
def execute_add_numbers(params):
try:
num1 = float(params.get("number1", 0))
num2 = float(params.get("number2", 0))
return {"result": num1 + num2}
except Exception as e:
return {"error": str(e)}
def process_query(query, debug=False):
# Create the initial prompt
prompt = create_prompt(query, function_schema)
# Generate the initial response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1
)
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
# Process the response
try:
assistant_response = response.split("<|assistant|>")[1].strip()
except:
return "Error parsing model response."
debug_info = f"Initial response:\n{assistant_response}\n\n" if debug else ""
# Check for function call
function_call = extract_function_call(assistant_response)
if not function_call:
return debug_info + "No function call detected in the response."
debug_info += f"Function call detected:\n{json.dumps(function_call, indent=2)}\n\n" if debug else ""
# Execute the function
result = execute_add_numbers(function_call)
debug_info += f"Function result:\n{json.dumps(result, indent=2)}\n\n" if debug else ""
# Create follow-up prompt with function result
follow_up_prompt = f"{prompt}\n<functioncall>\n{json.dumps(function_call)}\n</functioncall>\n\n<functionresponse>\n{json.dumps(result)}\n</functionresponse>\n"
# Generate final response
follow_up_inputs = tokenizer(follow_up_prompt, return_tensors="pt").to(model.device)
follow_up_outputs = model.generate(
**follow_up_inputs,
max_new_tokens=256,
temperature=0.1
)
follow_up_response = tokenizer.decode(follow_up_outputs[0], skip_special_tokens=False)
try:
if "<functionresponse>" in follow_up_response and "</functionresponse>" in follow_up_response:
final_response = follow_up_response.split("</functionresponse>")[1].strip()
else:
final_response = follow_up_response.split("<|assistant|>")[1].strip()
except:
return debug_info + "Error extracting final response."
if debug:
return debug_info + f"Final response:\n{final_response}"
else:
return final_response
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Llama 3.1 Function Calling: Addition Calculator")
gr.Markdown("Ask the model to add numbers, and it will use the `add_numbers` function")
with gr.Row():
query_input = gr.Textbox(
label="Your Question",
placeholder="Example: What is 24 plus 18?",
lines=2
)
debug_checkbox = gr.Checkbox(label="Show Debug Info", value=False)
submit_btn = gr.Button("Submit")
output = gr.Textbox(label="Response", lines=10)
submit_btn.click(
fn=process_query,
inputs=[query_input, debug_checkbox],
outputs=output
)
gr.Examples(
[
["What is 25 plus 17?"],
["Can you add 123 and 456?"],
["Calculate 3.14 + 2.71"]
],
inputs=query_input
)
demo.launch() |