update parsing
Browse files- src/llm/chat.py +20 -16
src/llm/chat.py
CHANGED
@@ -15,6 +15,7 @@ You SHOULD NOT include any other text in the response.
|
|
15 |
|
16 |
Here is a list of functions in JSON format that you can invoke.\n\n{functions}\n""".format(functions=FUNCTION_SCHEMA)
|
17 |
|
|
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
|
20 |
class FunctionCallingChat:
|
@@ -42,21 +43,24 @@ class FunctionCallingChat:
|
|
42 |
|
43 |
output = self.model.generate(tokenized, generation_config=generation_cfg)
|
44 |
raw = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
45 |
-
|
46 |
|
47 |
try:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
15 |
|
16 |
Here is a list of functions in JSON format that you can invoke.\n\n{functions}\n""".format(functions=FUNCTION_SCHEMA)
|
17 |
|
18 |
+
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
21 |
class FunctionCallingChat:
|
|
|
43 |
|
44 |
output = self.model.generate(tokenized, generation_config=generation_cfg)
|
45 |
raw = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
46 |
+
tool_calls_lst_str = raw.split("assistant")[-1]
|
47 |
|
48 |
try:
|
49 |
+
tree = ast.parse(tool_calls_lst_str, mode="eval")
|
50 |
+
call_nodes = tree.body.elts
|
51 |
+
except SyntaxError:
|
52 |
+
|
53 |
+
return {"raw_tool_call": tool_calls_lst_str,
|
54 |
+
"results": "Cannot parse the function call."}
|
55 |
+
|
56 |
+
tool_calls_result = []
|
57 |
+
for call in call_nodes:
|
58 |
+
function_name = call.func.id
|
59 |
+
parameters = {kw.arg: ast.literal_eval(kw.value)
|
60 |
+
for kw in call.keywords}
|
61 |
+
|
62 |
+
result = TOOLS[function_name](**parameters)
|
63 |
+
tool_calls_result.append(result)
|
64 |
+
|
65 |
+
return {"raw_tool_call": tool_calls_lst_str,
|
66 |
+
"results": tool_calls_result}
|