update parsing
Browse files- src/llm/chat.py +20 -16
src/llm/chat.py
CHANGED
|
@@ -15,6 +15,7 @@ You SHOULD NOT include any other text in the response.
|
|
| 15 |
|
| 16 |
Here is a list of functions in JSON format that you can invoke.\n\n{functions}\n""".format(functions=FUNCTION_SCHEMA)
|
| 17 |
|
|
|
|
| 18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
|
| 20 |
class FunctionCallingChat:
|
|
@@ -42,21 +43,24 @@ class FunctionCallingChat:
|
|
| 42 |
|
| 43 |
output = self.model.generate(tokenized, generation_config=generation_cfg)
|
| 44 |
raw = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
| 45 |
-
|
| 46 |
|
| 47 |
try:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
Here is a list of functions in JSON format that you can invoke.\n\n{functions}\n""".format(functions=FUNCTION_SCHEMA)
|
| 17 |
|
| 18 |
+
|
| 19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
|
| 21 |
class FunctionCallingChat:
|
|
|
|
| 43 |
|
| 44 |
output = self.model.generate(tokenized, generation_config=generation_cfg)
|
| 45 |
raw = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
| 46 |
+
tool_calls_lst_str = raw.split("assistant")[-1]
|
| 47 |
|
| 48 |
try:
|
| 49 |
+
tree = ast.parse(tool_calls_lst_str, mode="eval")
|
| 50 |
+
call_nodes = tree.body.elts
|
| 51 |
+
except SyntaxError:
|
| 52 |
+
|
| 53 |
+
return {"raw_tool_call": tool_calls_lst_str,
|
| 54 |
+
"results": "Cannot parse the function call."}
|
| 55 |
+
|
| 56 |
+
tool_calls_result = []
|
| 57 |
+
for call in call_nodes:
|
| 58 |
+
function_name = call.func.id
|
| 59 |
+
parameters = {kw.arg: ast.literal_eval(kw.value)
|
| 60 |
+
for kw in call.keywords}
|
| 61 |
+
|
| 62 |
+
result = TOOLS[function_name](**parameters)
|
| 63 |
+
tool_calls_result.append(result)
|
| 64 |
+
|
| 65 |
+
return {"raw_tool_call": tool_calls_lst_str,
|
| 66 |
+
"results": tool_calls_result}
|