Duibonduil commited on
Commit
c8568d3
·
verified ·
1 Parent(s): ee2e701

Upload __init__.py

Browse files
aworld/trace/instrumentation/__init__.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wrapt
2
+ import time
3
+ import inspect
4
+ import traceback
5
+ import aworld.trace.instrumentation.semconv as semconv
6
+ from typing import Collection, Any
7
+ from aworld.trace.instrumentation import Instrumentor
8
+ from aworld.trace.base import (
9
+ Tracer,
10
+ SpanType,
11
+ get_tracer_provider_silent
12
+ )
13
+ from aworld.trace.constants import ATTRIBUTES_MESSAGE_RUN_TYPE_KEY, RunType
14
+ from aworld.trace.instrumentation.llm_metrics import (
15
+ record_exception_metric,
16
+ record_chat_response_metric,
17
+ record_streaming_time_to_first_token,
18
+ record_streaming_time_to_generate
19
+ )
20
+ from aworld.trace.instrumentation.uni_llmmodel.model_response_parse import (
21
+ accumulate_stream_response,
22
+ get_common_attributes_from_response,
23
+ record_stream_token_usage,
24
+ parse_response_message,
25
+ response_to_dic,
26
+ handle_request
27
+ )
28
+ from aworld.trace.instrumentation.openai.inout_parse import run_async
29
+
30
+ from aworld.models.model_response import ModelResponse
31
+ from aworld.logs.util import logger
32
+
33
+
34
+ def _completion_wrapper(tracer: Tracer):
35
+
36
+ @wrapt.decorator
37
+ def wrapper(wrapped, instance, args, kwargs):
38
+ model_name = instance.provider.model_name
39
+ if not model_name:
40
+ model_name = "LLMModel"
41
+ span_attributes = {}
42
+ span_attributes[ATTRIBUTES_MESSAGE_RUN_TYPE_KEY] = RunType.LLM.value
43
+
44
+ span = tracer.start_span(
45
+ name=model_name, span_type=SpanType.CLIENT, attributes=span_attributes)
46
+
47
+ run_async(handle_request(span, kwargs, instance))
48
+ start_time = time.time()
49
+ try:
50
+ response = wrapped(*args, **kwargs)
51
+ except Exception as e:
52
+ record_exception(span=span,
53
+ start_time=start_time,
54
+ exception=e
55
+ )
56
+ span.end()
57
+ raise e
58
+
59
+ if (is_streaming_response(response)):
60
+ return WrappedGeneratorResponse(span=span,
61
+ response=response,
62
+ instance=instance,
63
+ start_time=start_time,
64
+ request_kwargs=kwargs
65
+ )
66
+ record_completion(span=span,
67
+ start_time=start_time,
68
+ response=response,
69
+ request_kwargs=kwargs,
70
+ instance=instance,
71
+ is_async=False
72
+ )
73
+ span.end()
74
+ return response
75
+
76
+ return wrapper
77
+
78
+
79
+ def _acompletion_class_wrapper(tracer: Tracer):
80
+
81
+ async def awrapper(wrapped, instance, args, kwargs):
82
+ model_name = instance.provider.model_name
83
+ if not model_name:
84
+ model_name = "LLMModel"
85
+ span_attributes = {}
86
+ span_attributes[ATTRIBUTES_MESSAGE_RUN_TYPE_KEY] = RunType.LLM.value
87
+
88
+ span = tracer.start_span(
89
+ name=model_name, span_type=SpanType.CLIENT, attributes=span_attributes)
90
+
91
+ await handle_request(span, kwargs, instance)
92
+ start_time = time.time()
93
+ try:
94
+ response = await wrapped(*args, **kwargs)
95
+ except Exception as e:
96
+ record_exception(span=span,
97
+ start_time=start_time,
98
+ exception=e
99
+ )
100
+ span.end()
101
+ raise e
102
+
103
+ if (is_streaming_response(response)):
104
+ return WrappedGeneratorResponse(span=span,
105
+ response=response,
106
+ instance=instance,
107
+ start_time=start_time,
108
+ request_kwargs=kwargs
109
+ )
110
+ record_completion(span=span,
111
+ start_time=start_time,
112
+ response=response,
113
+ request_kwargs=kwargs,
114
+ instance=instance,
115
+ is_async=True
116
+ )
117
+ span.end()
118
+ return response
119
+
120
+ return awrapper
121
+
122
+
123
+ async def _acompletion_instance_wrapper(tracer: Tracer):
124
+
125
+ @wrapt.decorator
126
+ async def _awrapper(wrapped, instance, args, kwargs):
127
+ wrapper_func = _acompletion_class_wrapper(tracer)
128
+ return await wrapper_func(wrapped, instance, args, kwargs)
129
+
130
+ return _awrapper
131
+
132
+
133
+ def is_streaming_response(response):
134
+ return inspect.isgenerator(response)
135
+
136
+
137
+ def record_exception(span, start_time, exception):
138
+ '''
139
+ record openai chat exception to trace and metrics
140
+ '''
141
+ try:
142
+ duration = time.time() - start_time if "start_time" in locals() else 0
143
+ if span.is_recording:
144
+ span.record_exception(exception=exception)
145
+ record_exception_metric(exception=exception, duration=duration)
146
+ except Exception as e:
147
+ logger.warning(f"openai instrument record exception error.{e}")
148
+
149
+
150
+ def record_completion(span,
151
+ start_time,
152
+ response,
153
+ request_kwargs,
154
+ instance,
155
+ is_async):
156
+ '''
157
+ Record chat completion to trace and metrics
158
+ '''
159
+ duration = time.time() - start_time if "start_time" in locals() else 0
160
+ response_dict = response_to_dic(response)
161
+ attributes = get_common_attributes_from_response(instance, is_async, False)
162
+ usage = response_dict.get("usage")
163
+ content = response_dict.get("content", "")
164
+ tool_calls = response_dict.get("tool_calls")
165
+ prompt_tokens = -1
166
+ completion_tokens = -1
167
+ total_tokens = -1
168
+ if usage:
169
+ prompt_tokens = usage.get("prompt_tokens")
170
+ completion_tokens = usage.get("completion_tokens")
171
+ total_tokens = usage.get("total_tokens")
172
+
173
+ span_attributes = {
174
+ **attributes,
175
+ semconv.GEN_AI_USAGE_INPUT_TOKENS: prompt_tokens,
176
+ semconv.GEN_AI_USAGE_OUTPUT_TOKENS: completion_tokens,
177
+ semconv.GEN_AI_USAGE_TOTAL_TOKENS: total_tokens,
178
+ semconv.GEN_AI_DURATION: duration,
179
+ semconv.GEN_AI_COMPLETION_CONTENT: content
180
+ }
181
+ span_attributes.update(parse_response_message(tool_calls))
182
+ span.set_attributes(span_attributes)
183
+ record_chat_response_metric(attributes=attributes,
184
+ prompt_tokens=prompt_tokens,
185
+ completion_tokens=completion_tokens,
186
+ duration=duration
187
+ )
188
+
189
+
190
+ class WrappedGeneratorResponse(wrapt.ObjectProxy):
191
+
192
+ def __init__(
193
+ self,
194
+ span,
195
+ response,
196
+ instance=None,
197
+ start_time=None,
198
+ request_kwargs=None
199
+ ):
200
+ super().__init__(response)
201
+ self._span = span
202
+ self._instance = instance
203
+ self._start_time = start_time
204
+ self._complete_response = {"choices": [], "model": ""}
205
+ self._first_token_recorded = False
206
+ self._time_of_first_token = None
207
+ self._request_kwargs = request_kwargs
208
+
209
+ def __iter__(self):
210
+ return self
211
+
212
+ def __aiter__(self):
213
+ return self
214
+
215
+ def __next__(self):
216
+ try:
217
+ chunk = self.__wrapped__.__next__()
218
+ except Exception as e:
219
+ if isinstance(e, StopIteration):
220
+ self._close_span(False)
221
+ raise e
222
+ else:
223
+ self._process_stream_chunk(chunk, False)
224
+ return chunk
225
+
226
+ async def __anext__(self):
227
+ try:
228
+ chunk = await self.__wrapped__.__anext__()
229
+ except Exception as e:
230
+ if isinstance(e, StopAsyncIteration):
231
+ self._close_span(True)
232
+ raise e
233
+ else:
234
+ self._process_stream_chunk(chunk, True)
235
+ return chunk
236
+
237
+ def _process_stream_chunk(self, chunk: ModelResponse, is_async):
238
+ accumulate_stream_response(chunk, self._complete_response)
239
+
240
+ if not self._first_token_recorded:
241
+ self._time_of_first_token = time.time()
242
+ duration = self._time_of_first_token - self._start_time
243
+ attribute = get_common_attributes_from_response(
244
+ self._instance, is_async, True)
245
+ record_streaming_time_to_first_token(duration, attribute)
246
+ self._first_token_recorded = True
247
+
248
+ def _close_span(self, is_async):
249
+ duration = None
250
+ first_token_duration = None
251
+ first_token_to_generate_duration = None
252
+ if self._start_time and isinstance(self._start_time, (float, int)):
253
+ duration = time.time() - self._start_time
254
+ if self._time_of_first_token and self._start_time and isinstance(self._start_time, (float, int)):
255
+ first_token_duration = self._time_of_first_token - self._start_time
256
+ first_token_to_generate_duration = time.time() - self._time_of_first_token
257
+
258
+ prompt_usage, completion_usage = record_stream_token_usage(
259
+ self._complete_response, self._request_kwargs)
260
+
261
+ attributes = get_common_attributes_from_response(
262
+ self._instance, is_async, True)
263
+
264
+ choices = self._complete_response.get("choices")
265
+ span_attributes = {
266
+ **attributes,
267
+ semconv.GEN_AI_USAGE_INPUT_TOKENS: prompt_usage,
268
+ semconv.GEN_AI_USAGE_OUTPUT_TOKENS: completion_usage,
269
+ semconv.GEN_AI_DURATION: duration,
270
+ semconv.GEN_AI_FIRST_TOKEN_DURATION: first_token_duration
271
+ }
272
+ span_attributes.update(parse_response_message(choices))
273
+
274
+ self._span.set_attributes(span_attributes)
275
+ record_chat_response_metric(attributes=attributes,
276
+ prompt_tokens=prompt_usage,
277
+ completion_tokens=completion_usage,
278
+ duration=duration,
279
+ choices=choices
280
+ )
281
+ record_streaming_time_to_generate(
282
+ first_token_to_generate_duration, attributes)
283
+
284
+ self._span.end()
285
+
286
+
287
+ class LLMModelInstrumentor(Instrumentor):
288
+
289
+ def instrumentation_dependencies(self) -> Collection[str]:
290
+ return ()
291
+
292
+ def _instrument(self, **kwargs):
293
+ tracer_provider = get_tracer_provider_silent()
294
+ if not tracer_provider:
295
+ return
296
+ tracer = tracer_provider.get_tracer(
297
+ "aworld.trace.instrumentation.llmmodel")
298
+
299
+ wrapt.wrap_function_wrapper(
300
+ "aworld.models.llm",
301
+ "LLMModel.completion",
302
+ _completion_wrapper(tracer=tracer)
303
+ )
304
+
305
+ wrapt.wrap_function_wrapper(
306
+ "aworld.models.llm",
307
+ "LLMModel.stream_completion",
308
+ _completion_wrapper(tracer=tracer)
309
+ )
310
+ wrapt.wrap_function_wrapper(
311
+ "aworld.models.llm",
312
+ "LLMModel.acompletion",
313
+ _acompletion_class_wrapper(tracer)
314
+ )
315
+
316
+ wrapt.wrap_function_wrapper(
317
+ "aworld.models.llm",
318
+ "LLMModel.astream_completion",
319
+ _acompletion_class_wrapper(tracer)
320
+ )
321
+
322
+ def _uninstrument(self, **kwargs: Any):
323
+ pass
324
+
325
+
326
+ def wrap_llmmodel(client: 'aworld.models.llm.LLMModel'):
327
+ try:
328
+ tracer_provider = get_tracer_provider_silent()
329
+ if not tracer_provider:
330
+ return client
331
+ tracer = tracer_provider.get_tracer(
332
+ "aworld.trace.instrumentation.llmmodel")
333
+
334
+ wrapper = _completion_wrapper(tracer)
335
+ awrapper = _acompletion_instance_wrapper(tracer)
336
+ client.completion = wrapper(client.completion)
337
+ client.stream_completion = wrapper(client.stream_completion)
338
+ client.acompletion = awrapper(client.acompletion)
339
+ client.astream_completion = awrapper(client.astream_completion)
340
+ except Exception:
341
+ logger.warning(traceback.format_exc())
342
+
343
+ return client