File size: 5,033 Bytes
b1678d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# app.py
import json
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load model and tokenizer
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# Define a simple addition function schema
function_schema = {
    "name": "add_numbers",
    "description": "Add two numbers together",
    "parameters": {
        "type": "object",
        "properties": {
            "number1": {
                "type": "number",
                "description": "The first number"
            },
            "number2": {
                "type": "number",
                "description": "The second number"
            }
        },
        "required": ["number1", "number2"]
    }
}

# Create prompt with function definition
def create_prompt(user_input, function):
    prompt = f"<|system|>\nYou are a helpful assistant that can use functions. Please call the add_numbers function for any addition requests.\n\nAvailable function:\n{json.dumps(function)}\n<|user|>\n{user_input}\n<|assistant|>\n"
    return prompt

# Extract function call from response
def extract_function_call(response_text):
    try:
        if "<functioncall>" in response_text and "</functioncall>" in response_text:
            func_text = response_text.split("<functioncall>")[1].split("</functioncall>")[0].strip()
            return json.loads(func_text)
        return None
    except Exception as e:
        print(f"Error extracting function call: {e}")
        return None

# Actually perform the addition
def execute_add_numbers(params):
    try:
        num1 = float(params.get("number1", 0))
        num2 = float(params.get("number2", 0))
        return {"result": num1 + num2}
    except Exception as e:
        return {"error": str(e)}

def process_query(query, debug=False):
    # Create the initial prompt
    prompt = create_prompt(query, function_schema)
    
    # Generate the initial response
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.1
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    # Process the response
    try:
        assistant_response = response.split("<|assistant|>")[1].strip()
    except:
        return "Error parsing model response."
    
    debug_info = f"Initial response:\n{assistant_response}\n\n" if debug else ""
    
    # Check for function call
    function_call = extract_function_call(assistant_response)
    if not function_call:
        return debug_info + "No function call detected in the response."
    
    debug_info += f"Function call detected:\n{json.dumps(function_call, indent=2)}\n\n" if debug else ""
    
    # Execute the function
    result = execute_add_numbers(function_call)
    
    debug_info += f"Function result:\n{json.dumps(result, indent=2)}\n\n" if debug else ""
    
    # Create follow-up prompt with function result
    follow_up_prompt = f"{prompt}\n<functioncall>\n{json.dumps(function_call)}\n</functioncall>\n\n<functionresponse>\n{json.dumps(result)}\n</functionresponse>\n"
    
    # Generate final response
    follow_up_inputs = tokenizer(follow_up_prompt, return_tensors="pt").to(model.device)
    follow_up_outputs = model.generate(
        **follow_up_inputs,
        max_new_tokens=256,
        temperature=0.1
    )
    follow_up_response = tokenizer.decode(follow_up_outputs[0], skip_special_tokens=False)
    
    try:
        if "<functionresponse>" in follow_up_response and "</functionresponse>" in follow_up_response:
            final_response = follow_up_response.split("</functionresponse>")[1].strip()
        else:
            final_response = follow_up_response.split("<|assistant|>")[1].strip()
    except:
        return debug_info + "Error extracting final response."
    
    if debug:
        return debug_info + f"Final response:\n{final_response}"
    else:
        return final_response

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Llama 3.1 Function Calling: Addition Calculator")
    gr.Markdown("Ask the model to add numbers, and it will use the `add_numbers` function")
    
    with gr.Row():
        query_input = gr.Textbox(
            label="Your Question",
            placeholder="Example: What is 24 plus 18?",
            lines=2
        )
        debug_checkbox = gr.Checkbox(label="Show Debug Info", value=False)
    
    submit_btn = gr.Button("Submit")
    
    output = gr.Textbox(label="Response", lines=10)
    
    submit_btn.click(
        fn=process_query,
        inputs=[query_input, debug_checkbox],
        outputs=output
    )
    
    gr.Examples(
        [
            ["What is 25 plus 17?"],
            ["Can you add 123 and 456?"],
            ["Calculate 3.14 + 2.71"]
        ],
        inputs=query_input
    )

demo.launch()