wjnwjn59 commited on
Commit
99b9f93
·
1 Parent(s): 7cc25b5

update parsing

Browse files
Files changed (1) hide show
  1. src/llm/chat.py +20 -16
src/llm/chat.py CHANGED
@@ -15,6 +15,7 @@ You SHOULD NOT include any other text in the response.
15
 
16
  Here is a list of functions in JSON format that you can invoke.\n\n{functions}\n""".format(functions=FUNCTION_SCHEMA)
17
 
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  class FunctionCallingChat:
@@ -42,21 +43,24 @@ class FunctionCallingChat:
42
 
43
  output = self.model.generate(tokenized, generation_config=generation_cfg)
44
  raw = self.tokenizer.decode(output[0], skip_special_tokens=True)
45
- tool_calls_str = raw.split("assistant")[-1]
46
 
47
  try:
48
- calls = ast.literal_eval(tool_calls_str)
49
- except Exception as e:
50
- calls = []
51
- print(f"Error parsing tool calls: {e}", file=sys.stderr)
52
-
53
- if calls == []:
54
- return {"raw_tool_call": tool_calls_str, "results": "No function call detected."}
55
- else:
56
- results = []
57
- for call in calls:
58
- fn_name = call.func.id
59
- kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in call.keywords}
60
- results.append(TOOLS[fn_name](**kwargs))
61
-
62
- return {"raw_tool_call": tool_calls_str, "results": results}
 
 
 
 
15
 
16
  Here is a list of functions in JSON format that you can invoke.\n\n{functions}\n""".format(functions=FUNCTION_SCHEMA)
17
 
18
+
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  class FunctionCallingChat:
 
43
 
44
  output = self.model.generate(tokenized, generation_config=generation_cfg)
45
  raw = self.tokenizer.decode(output[0], skip_special_tokens=True)
46
+ tool_calls_lst_str = raw.split("assistant")[-1]
47
 
48
  try:
49
+ tree = ast.parse(tool_calls_lst_str, mode="eval")
50
+ call_nodes = tree.body.elts
51
+ except SyntaxError:
52
+
53
+ return {"raw_tool_call": tool_calls_lst_str,
54
+ "results": "Cannot parse the function call."}
55
+
56
+ tool_calls_result = []
57
+ for call in call_nodes:
58
+ function_name = call.func.id
59
+ parameters = {kw.arg: ast.literal_eval(kw.value)
60
+ for kw in call.keywords}
61
+
62
+ result = TOOLS[function_name](**parameters)
63
+ tool_calls_result.append(result)
64
+
65
+ return {"raw_tool_call": tool_calls_lst_str,
66
+ "results": tool_calls_result}