Duibonduil commited on
Commit
2a0b0bf
·
verified ·
1 Parent(s): 98a672f

Upload model_response_parse.py

Browse files
aworld/trace/instrumentation/uni_llmmodel/model_response_parse.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import aworld.trace.instrumentation.semconv as semconv
4
+ from aworld.models.model_response import ModelResponse, ToolCall
5
+ from aworld.trace.base import Span
6
+ from aworld.trace.instrumentation.openai.inout_parse import should_trace_prompts, need_flatten_messages
7
+ from aworld.logs.util import logger
8
+
9
+
10
+ def parser_request_params(kwargs, instance: 'aworld.models.llm.LLMModel'):
11
+ attributes = {
12
+ semconv.GEN_AI_SYSTEM: instance.provider_name,
13
+ semconv.GEN_AI_REQUEST_MODEL: instance.provider.model_name,
14
+ semconv.GEN_AI_REQUEST_MAX_TOKENS: kwargs.get("max_tokens", ""),
15
+ semconv.GEN_AI_REQUEST_TEMPERATURE: kwargs.get("temperature", ""),
16
+ semconv.GEN_AI_REQUEST_STOP_SEQUENCES: str(kwargs.get("stop", [])),
17
+ semconv.GEN_AI_REQUEST_FREQUENCY_PENALTY: kwargs.get("frequency_penalty", ""),
18
+ semconv.GEN_AI_REQUEST_PRESENCE_PENALTY: kwargs.get("presence_penalty", ""),
19
+ semconv.GEN_AI_REQUEST_USER: kwargs.get("user", ""),
20
+ semconv.GEN_AI_REQUEST_EXTRA_HEADERS: kwargs.get("extra_headers", ""),
21
+ semconv.GEN_AI_REQUEST_STREAMING: kwargs.get("stream", ""),
22
+ semconv.GEN_AI_REQUEST_TOP_P: kwargs.get("top_p", ""),
23
+ semconv.GEN_AI_OPERATION_NAME: "chat"
24
+ }
25
+ return attributes
26
+
27
+
28
+ async def handle_request(span: Span, kwargs, instance):
29
+ if not span or not span.is_recording():
30
+ return
31
+ try:
32
+ attributes = parser_request_params(kwargs, instance)
33
+ if should_trace_prompts():
34
+ messages = kwargs.get("messages")
35
+ if need_flatten_messages():
36
+ attributes.update(parse_request_message(messages))
37
+ else:
38
+ attributes.update({
39
+ semconv.GEN_AI_PROMPT: covert_to_jsonstr(messages)
40
+ })
41
+ tools = kwargs.get("tools")
42
+ if tools:
43
+ if need_flatten_messages():
44
+ attributes.update(parse_prompt_tools(tools))
45
+ else:
46
+ attributes.update({
47
+ semconv.GEN_AI_PROMPT_TOOLS: covert_to_jsonstr(tools)
48
+ })
49
+
50
+ filterd_attri = {k: v for k, v in attributes.items()
51
+ if (v and v is not "")}
52
+
53
+ span.set_attributes(filterd_attri)
54
+ except Exception as e:
55
+ logger.warning(f"trace handle openai request error: {e}")
56
+
57
+
58
+ def get_common_attributes_from_response(instance: 'LLMModel', is_async, is_streaming):
59
+ operation = "acompletion" if is_async else "completion"
60
+ if is_streaming:
61
+ operation = "astream_completion" if is_async else "stream_completion"
62
+ return {
63
+ semconv.GEN_AI_SYSTEM: instance.provider_name,
64
+ semconv.GEN_AI_RESPONSE_MODEL: instance.provider.model_name,
65
+ semconv.GEN_AI_METHOD_NAME: operation,
66
+ semconv.GEN_AI_SERVER_ADDRESS: instance.provider.base_url
67
+ }
68
+
69
+
70
+ def accumulate_stream_response(chunk: ModelResponse, complete_response: dict):
71
+ logger.info(f"accumulate_stream_response chunk= {chunk}")
72
+ pass
73
+
74
+
75
+ def record_stream_token_usage(complete_response, request_kwargs) -> tuple[int, int]:
76
+ '''
77
+ return (prompt_usage, completion_usage)
78
+ '''
79
+ logger.info(
80
+ f"record_stream_token_usage complete_response= {complete_response}")
81
+ return (0, 0)
82
+
83
+
84
+ def parse_request_message(messages):
85
+ '''
86
+ flatten request message to attributes
87
+ '''
88
+ attributes = {}
89
+ for i, msg in enumerate(messages):
90
+ prefix = f"{semconv.GEN_AI_PROMPT}.{i}"
91
+ attributes.update({f"{prefix}.role": msg.get("role")})
92
+ if msg.get("content"):
93
+ content = copy.deepcopy(msg.get("content"))
94
+ content = json.dumps(content, ensure_ascii=False)
95
+ attributes.update({f"{prefix}.content": content})
96
+ if msg.get("tool_call_id"):
97
+ attributes.update({
98
+ f"{prefix}.tool_call_id": msg.get("tool_call_id")})
99
+ tool_calls = msg.get("tool_calls")
100
+ # logger.info(f"input tool_calls={tool_calls}")
101
+ if tool_calls:
102
+ for i, tool_call in enumerate(tool_calls):
103
+ if isinstance(tool_call, dict):
104
+ function = tool_call.get('function')
105
+ attributes.update({
106
+ f"{prefix}.tool_calls.{i}.id": tool_call.get("id")})
107
+ attributes.update({
108
+ f"{prefix}.tool_calls.{i}.name": function.get("name")})
109
+ attributes.update({
110
+ f"{prefix}.tool_calls.{i}.arguments": function.get("arguments")})
111
+ elif isinstance(tool_call, ToolCall):
112
+ function = tool_call.function
113
+ attributes.update({
114
+ f"{prefix}.tool_calls.{i}.id": tool_call.id})
115
+ attributes.update({
116
+ f"{prefix}.tool_calls.{i}.name": function.name})
117
+ attributes.update({
118
+ f"{prefix}.tool_calls.{i}.arguments": function.arguments})
119
+ return attributes
120
+
121
+
122
+ def parse_prompt_tools(tools):
123
+ attributes = {}
124
+ for i, tool in enumerate(tools):
125
+ prefix = f"{semconv.GEN_AI_PROMPT_TOOLS}.{i}"
126
+ if isinstance(tool, dict):
127
+ tool_type = tool.get("type")
128
+ attributes.update({
129
+ f"{prefix}.type": tool_type})
130
+ if tool.get(tool_type):
131
+ attributes.update({
132
+ f"{prefix}.name": tool.get(tool_type).get("name")})
133
+ return attributes
134
+
135
+
136
+ def parse_response_message(tool_calls) -> dict:
137
+ attributes = {}
138
+ prefix = semconv.GEN_AI_COMPLETION_TOOL_CALLS
139
+ if tool_calls:
140
+ if need_flatten_messages():
141
+ for i, tool_call in enumerate(tool_calls):
142
+ function = tool_call.get("function")
143
+ attributes.update(
144
+ {f"{prefix}.{i}.id": tool_call.get("id")})
145
+ attributes.update(
146
+ {f"{prefix}.{i}.name": function.get("name")})
147
+ attributes.update(
148
+ {f"{prefix}.{i}.arguments": function.get("arguments")})
149
+ else:
150
+ attributes.update({
151
+ prefix: covert_to_jsonstr(tool_calls)
152
+ })
153
+ return attributes
154
+
155
+
156
+ def response_to_dic(response: ModelResponse) -> dict:
157
+ logger.info(f"completion response= {response}")
158
+ return response.to_dict()
159
+
160
+
161
+ def covert_to_jsonstr(obj):
162
+ return json.dumps(_to_serializable(obj), ensure_ascii=False)
163
+
164
+
165
+ def _to_serializable(obj):
166
+ if isinstance(obj, dict):
167
+ return {k: _to_serializable(v) for k, v in obj.items()}
168
+ elif isinstance(obj, list):
169
+ return [_to_serializable(i) for i in obj]
170
+ elif hasattr(obj, "to_dict"):
171
+ return obj.to_dict()
172
+ elif hasattr(obj, "model_dump"):
173
+ return obj.model_dump()
174
+ elif hasattr(obj, "dict"):
175
+ return obj.dict()
176
+ else:
177
+ return obj