Duibonduil commited on
Commit
ee2e701
·
verified ·
1 Parent(s): 10c7d0d

Upload inout_parse.py

Browse files
aworld/trace/instrumentation/openai/inout_parse.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import threading
4
+ import copy
5
+ import json
6
+ import openai
7
+ from importlib.metadata import version
8
+ from aworld.logs.util import logger
9
+ from aworld.trace.base import Span
10
+ from aworld.utils import import_package
11
+ import aworld.trace.instrumentation.semconv as semconv
12
+
13
+ _PYDANTIC_VERSION = version("pydantic")
14
+
15
+
16
+ def should_trace_prompts():
17
+ '''Determine whether it is necessary to record the message
18
+ '''
19
+ return (os.getenv("SHOULD_TRACE_PROMPTS") or "true").lower() == "true"
20
+
21
+
22
+ def need_flatten_messages():
23
+ '''Determine whether it is necessary to flatten the messages
24
+ '''
25
+ return (os.getenv("TRACE_FLATTEN_MESSAGES") or "false").lower() == "true"
26
+
27
+
28
+ def run_async(method):
29
+ try:
30
+ loop = asyncio.get_running_loop()
31
+ except RuntimeError:
32
+ loop = None
33
+
34
+ if loop and loop.is_running():
35
+ thread = threading.Thread(target=lambda: asyncio.run(method))
36
+ thread.start()
37
+ thread.join()
38
+ else:
39
+ asyncio.run(method)
40
+
41
+
42
+ async def handle_openai_request(span: Span, kwargs, instance):
43
+ if not span or not span.is_recording():
44
+ return
45
+ try:
46
+ attributes = parser_request_params(kwargs, instance)
47
+ if should_trace_prompts():
48
+ messages = kwargs.get("messages")
49
+ if need_flatten_messages():
50
+ attributes.update(parse_request_message(messages))
51
+ else:
52
+ attributes.update({
53
+ semconv.GEN_AI_PROMPT: str(messages),
54
+ })
55
+ span.set_attributes(attributes)
56
+ except ValueError as e:
57
+ logger.warning(f"trace handle openai request error: {e}")
58
+
59
+
60
+ def parser_request_params(kwargs, instance):
61
+ attributes = {
62
+ semconv.GEN_AI_SYSTEM: "OpenAI",
63
+ semconv.GEN_AI_REQUEST_MODEL: kwargs.get("model", ""),
64
+ semconv.GEN_AI_REQUEST_MAX_TOKENS: kwargs.get("max_tokens", ""),
65
+ semconv.GEN_AI_REQUEST_TEMPERATURE: kwargs.get("temperature", ""),
66
+ semconv.GEN_AI_REQUEST_TOP_P: kwargs.get("top_p", ""),
67
+ semconv.GEN_AI_REQUEST_FREQUENCY_PENALTY: kwargs.get("frequency_penalty", ""),
68
+ semconv.GEN_AI_REQUEST_PRESENCE_PENALTY: kwargs.get("presence_penalty", ""),
69
+ semconv.GEN_AI_REQUEST_USER: kwargs.get("user", ""),
70
+ semconv.GEN_AI_REQUEST_EXTRA_HEADERS: kwargs.get("extra_headers", ""),
71
+ semconv.GEN_AI_REQUEST_STREAMING: kwargs.get("stream", ""),
72
+ semconv.GEN_AI_OPERATION_NAME: "chat"
73
+ }
74
+
75
+ client = instance._client
76
+ if isinstance(client, (openai.AsyncOpenAI, openai.OpenAI)):
77
+ attributes.update({"llm.base_url": str(client.base_url)})
78
+
79
+ filterd_attri = {k: v for k, v in attributes.items()
80
+ if (v and v is not "")}
81
+ return filterd_attri
82
+
83
+
84
+ def is_streaming_response(response):
85
+ return isinstance(response, openai.Stream) or isinstance(response, openai.AsyncStream)
86
+
87
+
88
+ def parse_openai_response(response, request_kwargs, instance, is_streaming):
89
+ return {
90
+ semconv.GEN_AI_RESPONSE_MODEL: response.get("model") or request_kwargs.get("model") or None,
91
+ semconv.GEN_AI_SERVER_ADDRESS: _get_openai_base_url(instance)
92
+ }
93
+
94
+
95
+ def record_stream_token_usage(complete_response, request_kwargs) -> tuple[int, int]:
96
+ '''
97
+ return (prompt_usage, completion_usage)
98
+ '''
99
+ prompt_usage = 0
100
+ completion_usage = 0
101
+
102
+ # prompt_usage
103
+ if request_kwargs and request_kwargs.get("messages"):
104
+ prompt_content = ""
105
+ model_name = complete_response.get(
106
+ "model") or request_kwargs.get("model") or "gpt-4"
107
+ for msg in request_kwargs.get("messages"):
108
+ if msg.get("content"):
109
+ prompt_content += msg.get("content")
110
+ if model_name:
111
+ prompt_usage = get_token_count_from_string(
112
+ prompt_content, model_name)
113
+
114
+ # completion_usage
115
+ if complete_response.get("choices"):
116
+ completion_content = ""
117
+ model_name = complete_response.get("model") or "gpt-4"
118
+
119
+ for choice in complete_response.get("choices"):
120
+ if choice.get("message") and choice.get("message").get("content"):
121
+ completion_content += choice["message"]["content"]
122
+
123
+ if model_name:
124
+ completion_usage = get_token_count_from_string(
125
+ completion_content, model_name)
126
+
127
+ return (prompt_usage, completion_usage)
128
+
129
+
130
+ def _get_openai_base_url(instance):
131
+ if hasattr(instance, "_client"):
132
+ client = instance._client # pylint: disable=protected-access
133
+ if isinstance(client, (openai.AsyncOpenAI, openai.OpenAI)):
134
+ return str(client.base_url)
135
+
136
+ return ""
137
+
138
+
139
+ def get_token_count_from_string(string: str, model_name: str):
140
+ import_package("tiktoken")
141
+ import tiktoken
142
+
143
+ if tiktoken_encodings.get(model_name) is None:
144
+ try:
145
+ encoding = tiktoken.encoding_for_model(model_name)
146
+ except KeyError as ex:
147
+ logger.warning(
148
+ f"Failed to get tiktoken encoding for model_name {model_name}, error: {str(ex)}")
149
+ return None
150
+
151
+ tiktoken_encodings[model_name] = encoding
152
+ else:
153
+ encoding = tiktoken_encodings.get(model_name)
154
+
155
+ token_count = len(encoding.encode(string))
156
+ return token_count
157
+
158
+
159
+ def record_stream_response_chunk(chunk, complete_response):
160
+ chunk = model_as_dict(chunk)
161
+ complete_response["model"] = chunk.get("model")
162
+ complete_response["id"] = chunk.get("id")
163
+
164
+ # prompt filter results
165
+ if chunk.get("prompt_filter_results"):
166
+ complete_response["prompt_filter_results"] = chunk.get(
167
+ "prompt_filter_results")
168
+
169
+ for choice in chunk.get("choices"):
170
+ index = choice.get("index")
171
+ if len(complete_response.get("choices")) <= index:
172
+ complete_response["choices"].append(
173
+ {"index": index, "message": {"content": "", "role": ""}})
174
+ complete_choice = complete_response.get("choices")[index]
175
+ if choice.get("finish_reason"):
176
+ complete_choice["finish_reason"] = choice.get("finish_reason")
177
+ if choice.get("content_filter_results"):
178
+ complete_choice["content_filter_results"] = choice.get(
179
+ "content_filter_results")
180
+
181
+ delta = choice.get("delta")
182
+
183
+ if delta and delta.get("content"):
184
+ complete_choice["message"]["content"] += delta.get("content")
185
+
186
+ if delta and delta.get("role"):
187
+ complete_choice["message"]["role"] = delta.get("role")
188
+ if delta and delta.get("tool_calls"):
189
+ tool_calls = delta.get("tool_calls")
190
+ if not isinstance(tool_calls, list) or len(tool_calls) == 0:
191
+ continue
192
+
193
+ if not complete_choice["message"].get("tool_calls"):
194
+ complete_choice["message"]["tool_calls"] = []
195
+
196
+ for tool_call in tool_calls:
197
+ i = int(tool_call["index"])
198
+ if len(complete_choice["message"]["tool_calls"]) <= i:
199
+ complete_choice["message"]["tool_calls"].append(
200
+ {"id": "", "function": {"name": "", "arguments": ""}}
201
+ )
202
+
203
+ span_tool_call = complete_choice["message"]["tool_calls"][i]
204
+ span_function = span_tool_call["function"]
205
+ tool_call_function = tool_call.get("function")
206
+
207
+ if tool_call.get("id"):
208
+ span_tool_call["id"] = tool_call.get("id")
209
+ if tool_call_function and tool_call_function.get("name"):
210
+ span_function["name"] = tool_call_function.get("name")
211
+ if tool_call_function and tool_call_function.get("arguments"):
212
+ span_function["arguments"] += tool_call_function.get(
213
+ "arguments")
214
+
215
+
216
+ def parse_request_message(messages):
217
+ '''
218
+ flatten request message to attributes
219
+ '''
220
+ attributes = {}
221
+ for i, msg in enumerate(messages):
222
+ prefix = f"{semconv.GEN_AI_PROMPT}.{i}"
223
+ attributes.update({f"{prefix}.role": msg.get("role")})
224
+ if msg.get("content"):
225
+ content = copy.deepcopy(msg.get("content"))
226
+ content = json.dumps(content)
227
+ attributes.update({f"{prefix}.content": content})
228
+ if msg.get("tool_call_id"):
229
+ attributes.update({
230
+ f"{prefix}.tool_call_id": msg.get("tool_call_id")})
231
+ tool_calls = msg.get("tool_calls")
232
+ if tool_calls:
233
+ for i, tool_call in enumerate(tool_calls):
234
+ tool_call = model_as_dict(tool_call)
235
+ function = tool_call.get("function")
236
+ attributes.update({
237
+ f"{prefix}.tool_calls.{i}.id": tool_call.get("id")})
238
+ attributes.update({
239
+ f"{prefix}.tool_calls.{i}.name": function.get("name")})
240
+ attributes.update({
241
+ f"{prefix}.tool_calls.{i}.arguments": function.get("arguments")})
242
+ return attributes
243
+
244
+
245
+ def parse_response_message(choices) -> dict:
246
+ attributes = {}
247
+ if not should_trace_prompts():
248
+ return attributes
249
+ for choice in choices:
250
+ index = choice.get("index")
251
+ prefix = f"{semconv.GEN_AI_COMPLETION}.{index}"
252
+ attributes.update(
253
+ {f"{prefix}.finish_reason": choice.get("finish_reason")})
254
+
255
+ message = choice.get("message")
256
+ if not message:
257
+ continue
258
+
259
+ attributes.update({f"{prefix}.role": message.get("role")})
260
+
261
+ if message.get("refusal"):
262
+ attributes.update({f"{prefix}.refusal": message.get("refusal")})
263
+ else:
264
+ attributes.update({f"{prefix}.content": message.get("content")})
265
+
266
+ function_call = message.get("function_call")
267
+ if function_call:
268
+ attributes.update(
269
+ {f"{prefix}.tool_calls.0.name": function_call.get("name")})
270
+ attributes.update(
271
+ {f"{prefix}.tool_calls.0.arguments": function_call.get("arguments")})
272
+
273
+ tool_calls = message.get("tool_calls")
274
+ if tool_calls:
275
+ for i, tool_call in enumerate(tool_calls):
276
+ function = tool_call.get("function")
277
+ attributes.update(
278
+ {f"{prefix}.tool_calls.{i}.id": tool_call.get("id")})
279
+ attributes.update(
280
+ {f"{prefix}.tool_calls.{i}.name": function.get("name")})
281
+ attributes.update(
282
+ {f"{prefix}.tool_calls.{i}.arguments": function.get("arguments")})
283
+ return attributes
284
+
285
+
286
+ def model_as_dict(model):
287
+ if isinstance(model, dict):
288
+ return model
289
+ if _PYDANTIC_VERSION < "2.0.0":
290
+ return model.dict()
291
+ if hasattr(model, "model_dump"):
292
+ return model.model_dump()
293
+ elif hasattr(model, "parse"): # Raw API response
294
+ return model_as_dict(model.parse())
295
+ else:
296
+ return model