Duibonduil commited on
Commit
7a3c0ee
·
verified ·
1 Parent(s): 2a0b0bf

Upload 11 files

Browse files
aworld/trace/instrumentation/__init__.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Collection
5
+ from packaging.requirements import Requirement, InvalidRequirement
6
+ from importlib_metadata import version, PackageNotFoundError
7
+ from aworld.logs.util import logger
8
+
9
+
10
+ class Instrumentor(ABC):
11
+ _instance = None
12
+ _has_instrumented = False
13
+
14
+ def __new__(cls, *args, **kwargs):
15
+ if cls._instance is None:
16
+ cls._instance = object.__new__(cls)
17
+
18
+ return cls._instance
19
+
20
+ def instrument(self, **kwargs: Any):
21
+ """
22
+ Instrument the library.
23
+ """
24
+ if self._has_instrumented:
25
+ logger.warning(
26
+ f"Instrumentor[{self.__class__.__name__}] has already instrumented, skip")
27
+ return
28
+
29
+ if not self._check_dependency_conflicts():
30
+ return
31
+
32
+ result = self._instrument(**kwargs)
33
+ self._has_instrumented = True
34
+ return result
35
+
36
+ def uninstrument(self, **kwargs: Any):
37
+ """
38
+ Uninstrument the library.
39
+ """
40
+ if not self._has_instrumented:
41
+ logger.warning("Instrumentor has not instrumented, skip")
42
+ return
43
+ self._uninstrument(**kwargs)
44
+ self._has_instrumented = False
45
+
46
+ @abstractmethod
47
+ def _uninstrument(self, **kwargs: Any):
48
+ """
49
+ Uninstrument the library.
50
+ """
51
+
52
+ @abstractmethod
53
+ def _instrument(self, **kwargs: Any):
54
+ """
55
+ Instrument the library.
56
+ """
57
+
58
+ def _check_dependency_conflicts(self):
59
+ dependencies = self.instrumentation_dependencies()
60
+ for dependence in dependencies:
61
+ try:
62
+ requirement = Requirement(dependence)
63
+ except InvalidRequirement as exc:
64
+ logger.warning(
65
+ f'error parsing dependency, reporting as a conflict: "{dependence}" - {exc}')
66
+ return False
67
+ try:
68
+ dist_version = version(requirement.name)
69
+ except PackageNotFoundError as exc:
70
+ logger.warning(
71
+ f'dependency not found, reporting as a conflict: "{dependence}" - {exc}')
72
+ return False
73
+
74
+ if requirement.specifier and not requirement.specifier.contains(dist_version):
75
+ logger.warning(
76
+ f'dependency version conflict, reporting as a conflict: requested: "{self.required}" but found: "{self.found}"')
77
+ return False
78
+
79
+ return True
80
+
81
+ @abstractmethod
82
+ def instrumentation_dependencies(self) -> Collection[str]:
83
+ """
84
+ Return a list of dependencies that the instrumentation requires.
85
+ """
aworld/trace/instrumentation/asgi.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timeit import default_timer
2
+ from typing import Any, Awaitable, Callable
3
+ from functools import wraps
4
+ from aworld.metrics.context_manager import MetricContext
5
+ from aworld.trace.instrumentation.http_util import (
6
+ collect_request_attributes_asgi,
7
+ url_disabled,
8
+ parser_host_port_url_from_asgi
9
+ )
10
+ from aworld.trace.base import Span, TraceProvider, TraceContext, Tracer, SpanType
11
+ from aworld.trace.propagator import get_global_trace_propagator
12
+ from aworld.trace.propagator.carrier import DictCarrier, ListTupleCarrier
13
+ from aworld.metrics.metric import MetricType
14
+ from aworld.metrics.template import MetricTemplate
15
+ from aworld.logs.util import logger
16
+
17
+
18
+ def _wrapped_receive(
19
+ server_span: Span,
20
+ server_span_name: str,
21
+ scope: dict[str, Any],
22
+ receive: Callable[[], Awaitable[dict[str, Any]]],
23
+ attributes: dict[str],
24
+ client_request_hook: Callable = None
25
+ ):
26
+
27
+ @wraps(receive)
28
+ async def otel_receive():
29
+ message = await receive()
30
+ if client_request_hook and callable(client_request_hook):
31
+ client_request_hook(scope, message)
32
+
33
+ server_span.set_attribute("asgi.event.type", message.get("type", ""))
34
+ return message
35
+
36
+ return otel_receive
37
+
38
+
39
+ def _wrapped_send(
40
+ server_span: Span,
41
+ server_span_name: str,
42
+ scope: dict[str, Any],
43
+ send: Callable[[dict[str, Any]], Awaitable[None]],
44
+ attributes: dict[str],
45
+ client_response_hook: Callable = None
46
+ ):
47
+ expecting_trailers = False
48
+
49
+ @wraps(send)
50
+ async def otel_send(message: dict[str, Any]):
51
+ nonlocal expecting_trailers
52
+
53
+ status_code = None
54
+ if message["type"] == "http.response.start":
55
+ status_code = message["status"]
56
+ elif message["type"] == "websocket.send":
57
+ status_code = 200
58
+
59
+ # raw_headers = message.get("headers")
60
+ # if raw_headers:
61
+ if status_code:
62
+ server_span.set_attribute(
63
+ "http.response.status_code", status_code)
64
+
65
+ if callable(client_response_hook):
66
+ client_response_hook(scope, message)
67
+
68
+ if message["type"] == "http.response.start":
69
+ expecting_trailers = message.get("trailers", False)
70
+
71
+ propagator = get_global_trace_propagator()
72
+ if propagator:
73
+ trace_context = TraceContext(
74
+ trace_id=server_span.get_trace_id(),
75
+ span_id=server_span.get_span_id()
76
+ )
77
+ propagator.inject(
78
+ trace_context, DictCarrier(message))
79
+
80
+ await send(message)
81
+
82
+ if (
83
+ not expecting_trailers
84
+ and message["type"] == "http.response.body"
85
+ and not message.get("more_body", False)
86
+ ) or (
87
+ expecting_trailers
88
+ and message["type"] == "http.response.trailers"
89
+ and not message.get("more_trailers", False)
90
+ ):
91
+ server_span.end()
92
+
93
+ return otel_send
94
+
95
+
96
+ class TraceMiddleware:
97
+ """
98
+ A ASGI Middleware for tracing requests and responses.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ app,
104
+ excluded_urls=None,
105
+ tracer_provider: TraceProvider = None,
106
+ tracer: Tracer = None,
107
+ server_request_hook: Callable = None,
108
+ client_request_hook: Callable = None,
109
+ client_response_hook: Callable = None,):
110
+ self.app = app
111
+ self.excluded_urls = excluded_urls
112
+ self.tracer_provider = tracer_provider
113
+ self.server_request_hook = server_request_hook
114
+ self.client_request_hook = client_request_hook
115
+ self.client_response_hook = client_response_hook
116
+
117
+ self.tracer: Tracer = (self.tracer_provider.get_tracer(
118
+ "aworld.trace.instrumentation.asgi"
119
+ ) if tracer is None else tracer)
120
+
121
+ self.duration_histogram = MetricTemplate(
122
+ type=MetricType.HISTOGRAM,
123
+ name="asgi_request_duration_histogram",
124
+ description="Duration of flask HTTP server requests."
125
+ )
126
+
127
+ self.active_requests_counter = MetricTemplate(
128
+ type=MetricType.UPDOWNCOUNTER,
129
+ name="asgi_active_request_counter",
130
+ unit="1",
131
+ description="Number of active HTTP server requests.",
132
+ )
133
+
134
+ async def __call__(
135
+ self,
136
+ scope: dict[str, Any],
137
+ receive: Callable[[], Awaitable[dict[str, Any]]],
138
+ send: Callable[[dict[str, Any]], Awaitable[None]],
139
+ ):
140
+ start = default_timer()
141
+ if scope["type"] not in ("http", "websocket"):
142
+ return await self.app(scope, receive, send)
143
+
144
+ _, _, url = parser_host_port_url_from_asgi(scope)
145
+ if self.excluded_urls and url_disabled(url, self.excluded_urls):
146
+ return await self.app(scope, receive, send)
147
+
148
+ span_name = scope.get("method", "HTTP").strip(
149
+ ).upper() + "_" + scope.get("path", "").strip()
150
+
151
+ attributes = collect_request_attributes_asgi(scope)
152
+
153
+ if scope["type"] == "http" and MetricContext.metric_initialized():
154
+ MetricContext.inc(self.active_requests_counter, 1, attributes)
155
+
156
+ trace_context = None
157
+ propagator = get_global_trace_propagator()
158
+ if propagator:
159
+ trace_context = propagator.extract(
160
+ ListTupleCarrier(scope.get("headers", [])))
161
+ logger.info(
162
+ f"asgi extract trace_context: {trace_context}, scope: {scope}")
163
+ try:
164
+ with self.tracer.start_as_current_span(
165
+ span_name, span_type=SpanType.SERVER, trace_context=trace_context, attributes=attributes
166
+ ) as span:
167
+
168
+ if callable(self.server_request_hook):
169
+ self.server_request_hook(scope)
170
+
171
+ wrappered_receive = _wrapped_receive(
172
+ span,
173
+ span_name,
174
+ scope,
175
+ receive,
176
+ attributes,
177
+ self.client_request_hook
178
+ )
179
+ wrappered_send = _wrapped_send(
180
+ span,
181
+ span_name,
182
+ scope,
183
+ send,
184
+ attributes,
185
+ self.client_response_hook
186
+ )
187
+
188
+ await self.app(scope, wrappered_receive, wrappered_send)
189
+ finally:
190
+ if scope["type"] == "http":
191
+ duration_s = default_timer() - start
192
+
193
+ if MetricContext.metric_initialized():
194
+ MetricContext.histogram_record(
195
+ self.duration_histogram,
196
+ duration_s,
197
+ attributes
198
+ )
199
+ MetricContext.inc(
200
+ self.active_requests_counter, -1, attributes)
201
+
202
+ if span.is_recording():
203
+ span.end()
aworld/trace/instrumentation/eventbus.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wrapt
2
+ from typing import Any, Collection
3
+ from aworld.trace.instrumentation import Instrumentor
4
+ from aworld.trace.base import Tracer, get_tracer_provider_silent, TraceContext
5
+ from aworld.trace.propagator import get_global_trace_propagator, get_global_trace_context
6
+ from aworld.trace.propagator.carrier import DictCarrier
7
+ from aworld.logs.util import logger
8
+
9
+
10
+ def _emit_message_class_wrapper(tracer: Tracer):
11
+ async def awrapper(wrapped, instance, args, kwargs):
12
+ from aworld.core.event.base import Message
13
+ try:
14
+ event = args[0] if len(args) > 0 else kwargs.get("event")
15
+ propagator = get_global_trace_propagator()
16
+ trace_provider = get_tracer_provider_silent()
17
+ if trace_provider and propagator and event and isinstance(event, Message):
18
+ if not event.headers:
19
+ event.headers = {}
20
+ current_span = trace_provider.get_current_span()
21
+ if current_span:
22
+ trace_context = TraceContext(
23
+ trace_id=current_span.get_trace_id(), span_id=current_span.get_span_id())
24
+ propagator.inject(trace_context=trace_context,
25
+ carrier=DictCarrier(event.headers))
26
+ logger.info(
27
+ f"EventManager emit_message trace propagate, event.headers={event.headers}")
28
+ except Exception as e:
29
+ logger.error(
30
+ f"EventManager emit_message trace propagate exception: {e}")
31
+ return await wrapped(*args, **kwargs)
32
+ return awrapper
33
+
34
+
35
+ def _emit_message_instance_wrapper(tracer: Tracer):
36
+
37
+ @wrapt.decorator
38
+ async def awrapper(wrapped, instance, args, kwargs):
39
+ wrapper = _emit_message_class_wrapper(tracer)
40
+ return await wrapper(wrapped, instance, args, kwargs)
41
+
42
+ return awrapper
43
+
44
+
45
+ def _consume_class_wrapper(tracer: Tracer):
46
+ async def awrapper(wrapped, instance, args, kwargs):
47
+ from aworld.core.event.base import Message
48
+ event = await wrapped(*args, **kwargs)
49
+ try:
50
+ propagator = get_global_trace_propagator()
51
+ if propagator and event and isinstance(event, Message) and event.headers:
52
+ trace_context = propagator.extract(DictCarrier(event.headers))
53
+ logger.info(
54
+ f"extract trace_context from event: {trace_context}")
55
+ if trace_context:
56
+ get_global_trace_context().set(trace_context)
57
+ except Exception as e:
58
+ logger.error(
59
+ f"EventManager consume trace propagate exception: {e}")
60
+ return event
61
+ return awrapper
62
+
63
+
64
+ def _consume_instance_wrapper(tracer: Tracer):
65
+
66
+ @wrapt.decorator
67
+ async def awrapper(wrapped, instance, args, kwargs):
68
+ wrapper = _consume_class_wrapper(tracer)
69
+ return await wrapper(wrapped, instance, args, kwargs)
70
+
71
+ return awrapper
72
+
73
+
74
+ class EventBusInstrumentor(Instrumentor):
75
+
76
+ def instrumentation_dependencies(self) -> Collection[str]:
77
+ return ()
78
+
79
+ def _uninstrument(self, **kwargs: Any):
80
+ pass
81
+
82
+ def _instrument(self, **kwargs: Any):
83
+ tracer_provider = get_tracer_provider_silent()
84
+ if not tracer_provider:
85
+ return
86
+ tracer = tracer_provider.get_tracer(
87
+ "aworld.trace.instrumentation.eventbus")
88
+
89
+ wrapt.wrap_function_wrapper(
90
+ "aworld.events.manager",
91
+ "EventManager.emit_message",
92
+ _emit_message_class_wrapper(tracer=tracer)
93
+ )
94
+
95
+ wrapt.wrap_function_wrapper(
96
+ "aworld.events.manager",
97
+ "EventManager.consume",
98
+ _consume_class_wrapper(tracer=tracer)
99
+ )
100
+
101
+
102
+ def wrap_event_manager(manager: 'aworld.events.manager.EventManager'):
103
+ tracer_provider = get_tracer_provider_silent()
104
+ if not tracer_provider:
105
+ return manager
106
+ tracer = tracer_provider.get_tracer(
107
+ "aworld.trace.instrumentation.eventbus")
108
+
109
+ emit_wrapper = _emit_message_instance_wrapper(tracer)
110
+ consume_wrapper = _consume_instance_wrapper(tracer)
111
+
112
+ manager.emit_message = emit_wrapper(manager.emit_message)
113
+ manager.consume = consume_wrapper(manager.consume)
114
+
115
+ return manager
aworld/trace/instrumentation/fastapi.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable
2
+ from .asgi import TraceMiddleware
3
+ from aworld.trace.instrumentation import Instrumentor
4
+ from aworld.trace.base import TraceProvider, get_tracer_provider
5
+ from aworld.trace.instrumentation.http_util import (
6
+ get_excluded_urls,
7
+ parse_excluded_urls,
8
+ )
9
+ from aworld.utils.import_package import import_packages
10
+ from aworld.logs.util import logger
11
+
12
+ import_packages(['fastapi']) # noqa
13
+ import fastapi # noqa
14
+
15
+
16
+ class _InstrumentedFastAPI(fastapi.FastAPI):
17
+ """Instrumented FastAPI class."""
18
+ _tracer_provider: TraceProvider = None
19
+ _excluded_urls: list[str] = None
20
+ _server_request_hook: Callable = None
21
+ _client_request_hook: Callable = None
22
+ _client_response_hook: Callable = None
23
+ _instrumented_fastapi_apps = set()
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+
28
+ tracer = self._tracer_provider.get_tracer(
29
+ "aworld.trace.instrumentation.fastapi")
30
+
31
+ self.add_middleware(
32
+ TraceMiddleware,
33
+ tracer=tracer,
34
+ excluded_urls=self._excluded_urls,
35
+ server_request_hook=self._server_request_hook,
36
+ client_request_hook=self._client_request_hook,
37
+ client_response_hook=self._client_response_hook
38
+ )
39
+
40
+ self._is_instrumented_by_trace = True
41
+ self._instrumented_fastapi_apps.add(self)
42
+
43
+ def __del__(self):
44
+ if self in self._instrumented_fastapi_apps:
45
+ self._instrumented_fastapi_apps.remove(self)
46
+
47
+
48
+ class FastAPIInstrumentor(Instrumentor):
49
+ """FastAPI Instrumentor."""
50
+ _original_fastapi = None
51
+
52
+ @staticmethod
53
+ def uninstrument_app(app: fastapi.FastAPI):
54
+ app.user_middleware = [
55
+ x
56
+ for x in app.user_middleware
57
+ if x.cls is not TraceMiddleware
58
+ ]
59
+ app.middleware_stack = app.build_middleware_stack()
60
+ app._is_instrumented_by_trace = False
61
+
62
+ def instrumentation_dependencies(self) -> dict[str, Any]:
63
+ return {"fastapi": fastapi}
64
+
65
+ def _instrument(self, **kwargs):
66
+ self._original_fastapi = fastapi.FastAPI
67
+ _InstrumentedFastAPI._tracer_provider = kwargs.get("tracer_provider")
68
+ _InstrumentedFastAPI._server_request_hook = kwargs.get(
69
+ "server_request_hook"
70
+ )
71
+ _InstrumentedFastAPI._client_request_hook = kwargs.get(
72
+ "client_request_hook"
73
+ )
74
+ _InstrumentedFastAPI._client_response_hook = kwargs.get(
75
+ "client_response_hook"
76
+ )
77
+ excluded_urls = kwargs.get("excluded_urls")
78
+ _InstrumentedFastAPI._excluded_urls = (
79
+ get_excluded_urls("FASTAPI")
80
+ if excluded_urls is None
81
+ else parse_excluded_urls(excluded_urls)
82
+ )
83
+ fastapi.FastAPI = _InstrumentedFastAPI
84
+
85
+ def _uninstrument(self, **kwargs):
86
+ for app in _InstrumentedFastAPI._instrumented_fastapi_apps:
87
+ self.uninstrument_app(app)
88
+ _InstrumentedFastAPI._instrumented_fastapi_apps.clear()
89
+ fastapi.FastAPI = self._original_fastapi
90
+
91
+
92
+ def instrument_fastapi(excluded_urls: str = None,
93
+ server_request_hook: Callable = None,
94
+ client_request_hook: Callable = None,
95
+ client_response_hook: Callable = None,
96
+ tracer_provider: TraceProvider = None,
97
+ **kwargs: Any,
98
+ ):
99
+ kwargs.update({
100
+ "excluded_urls": excluded_urls,
101
+ "server_request_hook": server_request_hook,
102
+ "client_request_hook": client_request_hook,
103
+ "client_response_hook": client_response_hook,
104
+ "tracer_provider": tracer_provider or get_tracer_provider(),
105
+ })
106
+ FastAPIInstrumentor().instrument(**kwargs)
107
+ logger.info("FastAPI instrumented.")
aworld/trace/instrumentation/flask.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flask
2
+ import weakref
3
+ from typing import Any, Callable, Collection
4
+ from time import time_ns
5
+ from timeit import default_timer
6
+ from importlib_metadata import version
7
+ from packaging import version as package_version
8
+ from aworld.trace.instrumentation import Instrumentor
9
+ from aworld.trace.base import Span, TraceProvider, TraceContext, Tracer, SpanType, get_tracer_provider
10
+ from aworld.metrics.metric import MetricType
11
+ from aworld.metrics.template import MetricTemplate
12
+ from aworld.logs.util import logger
13
+ from aworld.trace.instrumentation.http_util import (
14
+ collect_request_attributes,
15
+ url_disabled,
16
+ get_excluded_urls,
17
+ parse_excluded_urls,
18
+ HTTP_ROUTE
19
+ )
20
+ from aworld.trace.propagator import get_global_trace_propagator
21
+ from aworld.metrics.context_manager import MetricContext
22
+ from aworld.trace.propagator.carrier import ListTupleCarrier, DictCarrier
23
+
24
+ _ENVIRON_STARTTIME_KEY = "aworld-flask.starttime_key"
25
+ _ENVIRON_SPAN_KEY = "aworld-flask.span_key"
26
+ _ENVIRON_REQCTX_REF_KEY = "aworld-flask.reqctx_ref_key"
27
+
28
+ flask_version = version("flask")
29
+ if package_version.parse(flask_version) >= package_version.parse("2.2.0"):
30
+
31
+ def _request_ctx_ref() -> weakref.ReferenceType:
32
+ return weakref.ref(flask.globals.request_ctx._get_current_object())
33
+
34
+ else:
35
+
36
+ def _request_ctx_ref() -> weakref.ReferenceType:
37
+ return weakref.ref(flask._request_ctx_stack.top)
38
+
39
+
40
+ def _rewrapped_app(
41
+ wsgi_app,
42
+ active_requests_counter,
43
+ duration_histogram,
44
+ response_hook=None,
45
+ excluded_urls=None,
46
+ ):
47
+ def _wrapped_app(wrapped_app_environ, start_response):
48
+ # We want to measure the time for route matching, etc.
49
+ # In theory, we could start the span here and use
50
+ # update_name later but that API is "highly discouraged" so
51
+ # we better avoid it.
52
+ wrapped_app_environ[_ENVIRON_STARTTIME_KEY] = time_ns()
53
+ start = default_timer()
54
+ attributes = collect_request_attributes(wrapped_app_environ)
55
+
56
+ if MetricContext.metric_initialized():
57
+ MetricContext.inc(active_requests_counter, 1, attributes)
58
+
59
+ request_route = None
60
+
61
+ def _start_response(status, response_headers, *args, **kwargs):
62
+ if flask.request and (
63
+ excluded_urls is None
64
+ or not url_disabled(flask.request.url, excluded_urls)
65
+ ):
66
+ nonlocal request_route
67
+ request_route = flask.request.url_rule
68
+
69
+ span: Span = flask.request.environ.get(_ENVIRON_SPAN_KEY)
70
+
71
+ propagator = get_global_trace_propagator()
72
+ if propagator and span:
73
+ trace_context = TraceContext(
74
+ trace_id=span.get_trace_id(),
75
+ span_id=span.get_span_id()
76
+ )
77
+ propagator.inject(
78
+ trace_context, ListTupleCarrier(response_headers))
79
+
80
+ if span and span.is_recording():
81
+ status_code_str, _ = status.split(" ", 1)
82
+ try:
83
+ status_code = int(status_code_str)
84
+ except ValueError:
85
+ status_code = -1
86
+
87
+ span.set_attribute(
88
+ "http.response.status_code", status_code)
89
+ span.set_attributes(attributes)
90
+
91
+ if response_hook is not None:
92
+ response_hook(span, status, response_headers)
93
+ return start_response(status, response_headers, *args, **kwargs)
94
+
95
+ result = wsgi_app(wrapped_app_environ, _start_response)
96
+ duration_s = default_timer() - start
97
+
98
+ if MetricContext.metric_initialized():
99
+ MetricContext.histogram_record(
100
+ duration_histogram,
101
+ duration_s,
102
+ attributes
103
+ )
104
+ MetricContext.dec(active_requests_counter, 1, attributes)
105
+ return result
106
+
107
+ return _wrapped_app
108
+
109
+
110
+ def _wrapped_before_request(
111
+ request_hook=None,
112
+ tracer: Tracer = None,
113
+ excluded_urls=None
114
+ ):
115
+ def _before_request():
116
+ if excluded_urls and url_disabled(flask.request.url, excluded_urls):
117
+ return
118
+ flask_request_environ = flask.request.environ
119
+ logger.info(
120
+ f"_wrapped_before_request flask_request_environ={flask_request_environ}")
121
+
122
+ attributes = collect_request_attributes(flask_request_environ)
123
+
124
+ if flask.request.url_rule:
125
+ # For 404 that result from no route found, etc, we
126
+ # don't have a url_rule.
127
+ attributes[HTTP_ROUTE] = flask.request.url_rule.rule
128
+ span_name = f"HTTP {flask.request.url_rule.rule}"
129
+ else:
130
+ span_name = f"HTTP {flask.request.url}"
131
+
132
+ propagator = get_global_trace_propagator()
133
+ trace_context = None
134
+ if propagator:
135
+ trace_context = propagator.extract(
136
+ DictCarrier(flask_request_environ))
137
+
138
+ logger.info(f"_wrapped_before_request trace_context={trace_context}")
139
+
140
+ span = tracer.start_span(
141
+ span_name,
142
+ SpanType.SERVER,
143
+ attributes=attributes,
144
+ start_time=flask_request_environ.get(_ENVIRON_STARTTIME_KEY),
145
+ trace_context=trace_context
146
+ )
147
+
148
+ if request_hook:
149
+ request_hook(span, flask_request_environ)
150
+
151
+ flask_request_environ[_ENVIRON_SPAN_KEY] = span
152
+ flask_request_environ[_ENVIRON_REQCTX_REF_KEY] = _request_ctx_ref()
153
+
154
+ return _before_request
155
+
156
+
157
+ def _wrapped_teardown_request(
158
+ excluded_urls=None,
159
+ ):
160
+ def _teardown_request(exc):
161
+ if excluded_urls and url_disabled(flask.request.url, excluded_urls):
162
+ return
163
+
164
+ span: Span = flask.request.environ.get(_ENVIRON_SPAN_KEY)
165
+
166
+ original_reqctx_ref = flask.request.environ.get(
167
+ _ENVIRON_REQCTX_REF_KEY
168
+ )
169
+ current_reqctx_ref = _request_ctx_ref()
170
+ if not span or original_reqctx_ref != current_reqctx_ref:
171
+ return
172
+ if exc is None:
173
+ span.end()
174
+ else:
175
+ span.record_exception(exc)
176
+ span.end()
177
+
178
+ return _teardown_request
179
+
180
+
181
+ class _InstrumentedFlask(flask.Flask):
182
+ _excluded_urls = None
183
+ _tracer_provider: TraceProvider = None
184
+ _request_hook = None
185
+ _response_hook = None
186
+
187
+ def __init__(self, *args, **kwargs):
188
+ super().__init__(*args, **kwargs)
189
+
190
+ tracer = self._tracer_provider.get_tracer(
191
+ "aworld.trace.instrumentation.flask")
192
+
193
+ duration_histogram = MetricTemplate(
194
+ type=MetricType.HISTOGRAM,
195
+ name="flask_request_duration_histogram",
196
+ description="Duration of flask HTTP server requests."
197
+ )
198
+
199
+ active_requests_counter = MetricTemplate(
200
+ type=MetricType.UPDOWNCOUNTER,
201
+ name="flask_active_request_counter",
202
+ unit="1",
203
+ description="Number of active HTTP server requests.",
204
+ )
205
+
206
+ self.wsgi_app = _rewrapped_app(
207
+ self.wsgi_app,
208
+ active_requests_counter,
209
+ duration_histogram,
210
+ _InstrumentedFlask._response_hook,
211
+ excluded_urls=_InstrumentedFlask._excluded_urls
212
+ )
213
+
214
+ _before_request = _wrapped_before_request(
215
+ _InstrumentedFlask._request_hook,
216
+ tracer,
217
+ excluded_urls=_InstrumentedFlask._excluded_urls
218
+ )
219
+ self._before_request = _before_request
220
+ self.before_request(_before_request)
221
+
222
+ _teardown_request = _wrapped_teardown_request(
223
+ excluded_urls=_InstrumentedFlask._excluded_urls,
224
+ )
225
+ self.teardown_request(_teardown_request)
226
+
227
+
228
+ class FlaskInstrumentor(Instrumentor):
229
+
230
+ def instrumentation_dependencies(self) -> Collection[str]:
231
+ return ("flask >= 1.0",)
232
+
233
+ def _instrument(self, **kwargs: Any):
234
+ logger.info("Flask _instrument entered.")
235
+ self._original_flask = flask.Flask
236
+ request_hook = kwargs.get("request_hook")
237
+ response_hook = kwargs.get("response_hook")
238
+ if callable(request_hook):
239
+ _InstrumentedFlask._request_hook = request_hook
240
+ if callable(response_hook):
241
+ _InstrumentedFlask._response_hook = response_hook
242
+ tracer_provider = kwargs.get("tracer_provider")
243
+ _InstrumentedFlask._tracer_provider = tracer_provider
244
+ excluded_urls = kwargs.get("excluded_urls")
245
+ _InstrumentedFlask._excluded_urls = (
246
+ get_excluded_urls("FLASK")
247
+ if excluded_urls is None
248
+ else parse_excluded_urls(excluded_urls)
249
+ )
250
+ flask.Flask = _InstrumentedFlask
251
+ logger.info("Flask _instrument exited.")
252
+
253
+ def _uninstrument(self, **kwargs):
254
+ flask.Flask = self._original_flask
255
+
256
+
257
+ def instrument_flask(excluded_urls: str = None,
258
+ request_hook: Callable = None,
259
+ response_hook: Callable = None,
260
+ tracer_provider: TraceProvider = None,
261
+ **kwargs: Any,
262
+ ):
263
+ """
264
+ Instrument the Flask application.
265
+ Args:
266
+ excluded_urls (str): A comma separated list of URLs to be excluded from instrumentation.
267
+ request_hook (Callable): A function to be called before a request is processed.
268
+ response_hook (Callable): A function to be called after a request is processed.
269
+ tracer_provider (TraceProvider): The trace provider to use.
270
+ """
271
+ all_kwargs = {
272
+ "excluded_urls": excluded_urls,
273
+ "request_hook": request_hook,
274
+ "response_hook": response_hook,
275
+ "tracer_provider": tracer_provider or get_tracer_provider(),
276
+ **kwargs
277
+ }
278
+ FlaskInstrumentor().instrument(**all_kwargs)
279
+ logger.info("Flask instrumented.")
aworld/trace/instrumentation/http_util.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from re import compile as re_compile
3
+ from re import search
4
+ from typing import Final, Iterable, Any
5
+ from urllib.parse import urlparse, urlunparse, unquote
6
+ from wsgiref.types import WSGIEnvironment
7
+ from requests.models import PreparedRequest
8
+
9
+ HTTP_REQUEST_METHOD: Final = "http.request.method"
10
+ HTTP_FLAVOR: Final = "http.flavor"
11
+ HTTP_HOST: Final = "http.host"
12
+ HTTP_SCHEME: Final = "http.scheme"
13
+ HTTP_USER_AGENT: Final = "http.user_agent"
14
+ HTTP_SERVER_NAME: Final = "http.server_name"
15
+ SERVER_ADDRESS: Final = "server.address"
16
+ SERVER_PORT: Final = "server.port"
17
+ URL_PATH: Final = "url.path"
18
+ URL_QUERY: Final = "url.query"
19
+ CLIENT_ADDRESS: Final = "client.address"
20
+ CLIENT_PORT: Final = "client.port"
21
+ URL_FULL: Final = "url.full"
22
+
23
+ HTTP_REQUEST_BODY_SIZE: Final = "http.request.body.size"
24
+ HTTP_REQUEST_HEADER: Final = "http.request.header"
25
+ HTTP_REQUEST_SIZE: Final = "http.request.size"
26
+ HTTP_RESPONSE_BODY_SIZE: Final = "http.response.body.size"
27
+ HTTP_RESPONSE_HEADER: Final = "http.response.header"
28
+ HTTP_RESPONSE_SIZE: Final = "http.response.size"
29
+ HTTP_RESPONSE_STATUS_CODE: Final = "http.response.status_code"
30
+ HTTP_ROUTE = "http.route"
31
+
32
+
33
+ def collect_request_attributes(environ: WSGIEnvironment):
34
+
35
+ attributes: dict[str] = {}
36
+
37
+ request_method = environ.get("REQUEST_METHOD", "")
38
+ request_method = request_method.upper()
39
+ attributes[HTTP_REQUEST_METHOD] = request_method
40
+ attributes[HTTP_FLAVOR] = environ.get("SERVER_PROTOCOL", "")
41
+ attributes[HTTP_SCHEME] = environ.get("wsgi.url_scheme", "")
42
+ attributes[HTTP_SERVER_NAME] = environ.get("SERVER_NAME", "")
43
+ attributes[HTTP_HOST] = environ.get("HTTP_HOST", "")
44
+ host_port = environ.get("SERVER_PORT")
45
+ if host_port:
46
+ attributes[SERVER_PORT] = host_port
47
+ target = environ.get("RAW_URI")
48
+ if target is None:
49
+ target = environ.get("REQUEST_URI")
50
+ if target:
51
+ path, query = _parse_url_query(target)
52
+ attributes[URL_PATH] = path
53
+ attributes[URL_QUERY] = query
54
+ remote_addr = environ.get("REMOTE_ADDR", "")
55
+ attributes[CLIENT_ADDRESS] = remote_addr
56
+ attributes[CLIENT_PORT] = environ.get("REMOTE_PORT", "")
57
+ remote_host = environ.get("REMOTE_HOST")
58
+ if remote_host and remote_host != remote_addr:
59
+ attributes[CLIENT_ADDRESS] = remote_host
60
+ attributes[HTTP_USER_AGENT] = environ.get("HTTP_USER_AGENT", "")
61
+ return attributes
62
+
63
+
64
+ def collect_attributes_from_request(request: PreparedRequest) -> dict[str]:
65
+ attributes: dict[str] = {}
66
+
67
+ url = remove_url_credentials(request.url)
68
+ attributes[HTTP_REQUEST_METHOD] = request.method
69
+ attributes[URL_FULL] = url
70
+ parsed_url = urlparse(url)
71
+ if parsed_url.scheme:
72
+ attributes[HTTP_SCHEME] = parsed_url.scheme
73
+ if parsed_url.hostname:
74
+ attributes[HTTP_HOST] = parsed_url.hostname
75
+ if parsed_url.port:
76
+ attributes[SERVER_PORT] = parsed_url.port
77
+ return attributes
78
+
79
+
80
+ def url_disabled(url: str, excluded_urls: Iterable[str]) -> bool:
81
+ """
82
+ Check if the url is disabled.
83
+ Args:
84
+ url: The url to check.
85
+ excluded_urls: The excluded urls.
86
+ Returns:
87
+ True if the url is disabled, False otherwise.
88
+ """
89
+ if excluded_urls is None:
90
+ return False
91
+ regex = re_compile("|".join(excluded_urls))
92
+ return search(regex, url)
93
+
94
+
95
+ def get_excluded_urls(instrumentation: str) -> list[str]:
96
+ """
97
+ Get the excluded urls.
98
+ Args:
99
+ instrumentation: The instrumentation to get the excluded urls for.
100
+ Returns:
101
+ The excluded urls.
102
+ """
103
+
104
+ excluded_urls = os.environ.get(f"{instrumentation}_EXCLUDED_URLS")
105
+
106
+ return parse_excluded_urls(excluded_urls)
107
+
108
+
109
+ def parse_excluded_urls(excluded_urls: str) -> list[str]:
110
+ """
111
+ Parse the excluded urls.
112
+ Args:
113
+ excluded_urls: The excluded urls.
114
+ Returns:
115
+ The excluded urls.
116
+ """
117
+ if excluded_urls:
118
+ excluded_url_list = [
119
+ excluded_url.strip() for excluded_url in excluded_urls.split(",")
120
+ ]
121
+ else:
122
+ excluded_url_list = []
123
+
124
+ return excluded_url_list
125
+
126
+
127
+ def remove_url_credentials(url: str) -> str:
128
+ """Given a string url, remove the username and password only if it is a valid url"""
129
+
130
+ try:
131
+ parsed = urlparse(url)
132
+ if all([parsed.scheme, parsed.netloc]): # checks for valid url
133
+ parsed_url = urlparse(url)
134
+ _, _, netloc = parsed.netloc.rpartition("@")
135
+ return urlunparse(
136
+ (
137
+ parsed_url.scheme,
138
+ netloc,
139
+ parsed_url.path,
140
+ parsed_url.params,
141
+ parsed_url.query,
142
+ parsed_url.fragment,
143
+ )
144
+ )
145
+ except ValueError: # an unparsable url was passed
146
+ pass
147
+ return url
148
+
149
+
150
+ def parser_host_port_url_from_asgi(scope: dict[str, Any]):
151
+ """Returns (host, port, full_url) tuple."""
152
+ server = scope.get("server") or ["0.0.0.0", 80]
153
+ port = server[1]
154
+ server_host = server[0] + (":" + str(port) if str(port) != "80" else "")
155
+ full_path = scope.get("path", "")
156
+ http_url = scope.get("scheme", "http") + "://" + server_host + full_path
157
+ return server_host, port, http_url
158
+
159
+
160
+ def collect_request_attributes_asgi(scope: dict[str, Any]):
161
+ attributes: dict[str] = {}
162
+ server_host, port, http_url = parser_host_port_url_from_asgi(scope)
163
+ query_string = scope.get("query_string")
164
+ if query_string and http_url:
165
+ if isinstance(query_string, bytes):
166
+ query_string = query_string.decode("utf8")
167
+ http_url += "?" + unquote(query_string)
168
+ attributes[HTTP_REQUEST_METHOD] = scope.get("method", "")
169
+ attributes[HTTP_FLAVOR] = scope.get("http_version", "")
170
+ attributes[HTTP_SCHEME] = scope.get("scheme", "")
171
+ attributes[HTTP_HOST] = server_host
172
+ attributes[SERVER_PORT] = port
173
+ attributes[URL_FULL] = remove_url_credentials(http_url)
174
+ attributes[URL_PATH] = scope.get("path", "")
175
+ header = scope.get("headers")
176
+ if header:
177
+ for key, value in header:
178
+ if key == b"user-agent":
179
+ attributes[HTTP_USER_AGENT] = value.decode("utf8")
180
+
181
+ client = scope.get("client")
182
+ if client:
183
+ attributes[CLIENT_ADDRESS] = client[0]
184
+ attributes[CLIENT_PORT] = client[1]
185
+
186
+ return attributes
187
+
188
+
189
+ def _parse_url_query(url: str):
190
+ parsed_url = urlparse(url)
191
+ path = parsed_url.path
192
+ query_params = parsed_url.query
193
+ return path, query_params
aworld/trace/instrumentation/llm_metrics.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aworld.metrics.context_manager import MetricContext
2
+ from aworld.metrics.template import MetricTemplate
3
+ from aworld.metrics.metric import MetricType
4
+
5
+ tokens_usage_histogram = MetricTemplate(
6
+ type=MetricType.HISTOGRAM,
7
+ name="llm_token_usage",
8
+ unit="token",
9
+ description="Measures number of input and output tokens used"
10
+ )
11
+
12
+ chat_choice_counter = MetricTemplate(
13
+ type=MetricType.COUNTER,
14
+ name="llm_generation_choice_counter",
15
+ unit="choice",
16
+ description="Number of choices returned by chat completions call"
17
+ )
18
+
19
+ duration_histogram = MetricTemplate(
20
+ type=MetricType.HISTOGRAM,
21
+ name="llm_chat_duration",
22
+ unit="s",
23
+ description="AI chat duration",
24
+ )
25
+
26
+ chat_exception_counter = MetricTemplate(
27
+ type=MetricType.COUNTER,
28
+ name="llm_chat_exception_counter",
29
+ unit="time",
30
+ description="Number of exceptions occurred during chat completions",
31
+ )
32
+
33
+ streaming_time_to_first_token_histogram = MetricTemplate(
34
+ type=MetricType.HISTOGRAM,
35
+ name="llm_streaming_time_to_first_token",
36
+ unit="s",
37
+ description="Time to first token in streaming chat completions",
38
+ )
39
+ streaming_time_to_generate_histogram = MetricTemplate(
40
+ type=MetricType.HISTOGRAM,
41
+ name="streaming_time_to_generate",
42
+ unit="s",
43
+ description="Time between first token and completion in streaming chat completions",
44
+ )
45
+
46
+
47
+ def record_exception_metric(exception, duration):
48
+ '''
49
+ record chat exception to metrics
50
+ '''
51
+ if MetricContext.metric_initialized():
52
+ labels = {
53
+ "error.type": exception.__class__.__name__,
54
+ }
55
+ if duration_histogram:
56
+ MetricContext.histogram_record(
57
+ duration_histogram, duration, labels=labels)
58
+ if chat_exception_counter:
59
+ MetricContext.count(
60
+ chat_exception_counter, 1, labels=labels)
61
+
62
+
63
+ def record_streaming_time_to_first_token(duration, labels):
64
+ '''
65
+ Record duration of start time to first token in stream.
66
+ '''
67
+ if MetricContext.metric_initialized():
68
+ MetricContext.histogram_record(
69
+ streaming_time_to_first_token_histogram, duration, labels=labels)
70
+
71
+
72
+ def record_streaming_time_to_generate(first_token_to_generate_duration, labels):
73
+ '''
74
+ Record duration the first token to response to generation
75
+ '''
76
+ if MetricContext.metric_initialized():
77
+ MetricContext.histogram_record(
78
+ streaming_time_to_generate_histogram, first_token_to_generate_duration, labels=labels)
79
+
80
+
81
+ def record_chat_response_metric(attributes,
82
+ prompt_tokens,
83
+ completion_tokens,
84
+ duration,
85
+ choices=None
86
+ ):
87
+ '''
88
+ Record chat response to metrics
89
+ '''
90
+ if MetricContext.metric_initialized():
91
+ if prompt_tokens and tokens_usage_histogram:
92
+ labels = {
93
+ **attributes,
94
+ "llm.prompt_usage_type": "prompt_tokens"
95
+ }
96
+ MetricContext.histogram_record(
97
+ tokens_usage_histogram, prompt_tokens, labels=labels)
98
+ if completion_tokens and tokens_usage_histogram:
99
+ labels = {
100
+ **attributes,
101
+ "llm.prompt_usage_type": "completion_tokens"
102
+ }
103
+ MetricContext.histogram_record(
104
+ tokens_usage_histogram, completion_tokens, labels=labels)
105
+ if duration and duration_histogram:
106
+ MetricContext.histogram_record(
107
+ duration_histogram, duration, labels=attributes)
108
+ if choices and chat_choice_counter:
109
+ MetricContext.count(chat_choice_counter,
110
+ len(choices), labels=attributes)
111
+ for choice in choices:
112
+ if choice.get("finish_reason"):
113
+ finish_reason_attr = {
114
+ **attributes,
115
+ "llm.finish_reason": choice.get("finish_reason")
116
+ }
117
+ MetricContext.count(
118
+ chat_choice_counter, 1, labels=finish_reason_attr)
aworld/trace/instrumentation/requests.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aworld.logs.util import logger
2
+ from aworld.trace.instrumentation.http_util import (
3
+ collect_attributes_from_request,
4
+ url_disabled,
5
+ get_excluded_urls,
6
+ parse_excluded_urls,
7
+ HTTP_RESPONSE_STATUS_CODE,
8
+ HTTP_FLAVOR
9
+ )
10
+ from aworld.metrics.context_manager import MetricContext
11
+ from aworld.metrics.template import MetricTemplate
12
+ from aworld.metrics.metric import MetricType
13
+ from aworld.trace.instrumentation import Instrumentor
14
+ from aworld.trace.propagator.carrier import DictCarrier
15
+ import functools
16
+ from timeit import default_timer
17
+ from requests import sessions
18
+ from requests.models import PreparedRequest, Response
19
+ from requests.structures import CaseInsensitiveDict
20
+ from typing import Collection, Any, Callable
21
+ from aworld.trace.base import TraceProvider, TraceContext, Tracer, SpanType, get_tracer_provider
22
+ from aworld.trace.propagator import get_global_trace_propagator
23
+
24
+
25
+ def _wrapped_send(
26
+ tracer: Tracer = None,
27
+ excluded_urls=None,
28
+ request_hook: Callable = None,
29
+ response_hook: Callable = None,
30
+ duration_histogram: MetricTemplate = None
31
+ ):
32
+
33
+ oringinal_send = sessions.Session.send
34
+
35
+ @functools.wraps(oringinal_send)
36
+ def instrumented_send(
37
+ self: sessions.Session, request: PreparedRequest, **kwargs: Any
38
+ ):
39
+ if excluded_urls and url_disabled(request.url, excluded_urls):
40
+ return oringinal_send(self, request, **kwargs)
41
+
42
+ def get_or_create_headers():
43
+ request.headers = (
44
+ request.headers
45
+ if request.headers is not None
46
+ else CaseInsensitiveDict()
47
+ )
48
+ return request.headers
49
+
50
+ method = request.method
51
+ if method is None:
52
+ method = "HTTP"
53
+ span_name = method.upper()
54
+
55
+ span_attributes = collect_attributes_from_request(request)
56
+ with tracer.start_as_current_span(
57
+ span_name, span_type=SpanType.CLIENT, attributes=span_attributes
58
+ ) as span:
59
+ exception = None
60
+ if callable(request_hook):
61
+ request_hook(span, request)
62
+
63
+ headers = get_or_create_headers()
64
+
65
+ trace_context = TraceContext(
66
+ trace_id=span.get_trace_id(),
67
+ span_id=span.get_span_id(),
68
+ )
69
+ propagator = get_global_trace_propagator()
70
+ if propagator:
71
+ propagator.inject(trace_context, DictCarrier(headers))
72
+
73
+ start_time = default_timer()
74
+ try:
75
+ logger.info("Sending headers: %s", request.headers)
76
+ result = oringinal_send(
77
+ self, request, **kwargs
78
+ ) # *** PROCEED
79
+ except Exception as exc: # pylint: disable=W0703
80
+ exception = exc
81
+ result = getattr(exc, "response", None)
82
+ finally:
83
+ elapsed_time = max(default_timer() - start_time, 0)
84
+
85
+ if isinstance(result, Response):
86
+ span_attributes = {}
87
+ span_attributes[HTTP_RESPONSE_STATUS_CODE] = result.status_code
88
+
89
+ if result.raw is not None:
90
+ version = getattr(result.raw, "version", None)
91
+ if version:
92
+ # Only HTTP/1 is supported by requests
93
+ version_text = "1.1" if version == 11 else "1.0"
94
+ span_attributes[HTTP_FLAVOR] = version_text
95
+ span.set_attributes(span_attributes)
96
+
97
+ if callable(response_hook):
98
+ response_hook(span, request, result)
99
+
100
+ if exception is not None:
101
+ span.record_exception(exception)
102
+
103
+ if duration_histogram is not None and MetricContext.metric_initialized():
104
+ MetricContext.histogram_record(
105
+ duration_histogram,
106
+ elapsed_time,
107
+ span_attributes
108
+ )
109
+
110
+ if exception is not None:
111
+ raise exception.with_traceback(exception.__traceback__)
112
+
113
+ return result
114
+
115
+ return instrumented_send
116
+
117
+
118
+ class _InstrumentedSession(sessions.Session):
119
+ """
120
+ An instrumented requests.Session class.
121
+ """
122
+ _excluded_urls = None
123
+ _tracer_provider: TraceProvider = None
124
+ _request_hook = None
125
+ _response_hook = None
126
+
127
+ def __init__(self, *args, **kwargs):
128
+ super().__init__(*args, **kwargs)
129
+
130
+ tracer = self._tracer_provider.get_tracer(
131
+ "aworld.trace.instrumentation.requests")
132
+ excluded_urls = kwargs.get("excluded_urls")
133
+
134
+ duration_histogram = MetricTemplate(
135
+ type=MetricType.HISTOGRAM,
136
+ name="client_request_duration_histogram",
137
+ unit="s",
138
+ description="Duration of HTTP client requests."
139
+ )
140
+ self.send = functools.partial(_wrapped_send(
141
+ tracer=tracer,
142
+ excluded_urls=excluded_urls,
143
+ request_hook=self._request_hook,
144
+ response_hook=self._response_hook,
145
+ duration_histogram=duration_histogram
146
+ ), self)
147
+
148
+
149
+ class RequestsInstrumentor(Instrumentor):
150
+ """
151
+ An instrumentor for the requests module.
152
+ """
153
+
154
+ def instrumentation_dependencies(self) -> Collection[str]:
155
+ return ["requests"]
156
+
157
+ def _instrument(self, **kwargs):
158
+ """
159
+ Instruments the requests module.
160
+ """
161
+ logger.info("requests _instrument entered.")
162
+ self._original_session = sessions.Session
163
+ request_hook = kwargs.get("request_hook")
164
+ response_hook = kwargs.get("response_hook")
165
+ if callable(request_hook):
166
+ _InstrumentedSession._request_hook = request_hook
167
+ if callable(response_hook):
168
+ _InstrumentedSession._response_hook = response_hook
169
+ tracer_provider = kwargs.get("tracer_provider")
170
+ _InstrumentedSession._tracer_provider = tracer_provider
171
+ excluded_urls = kwargs.get("excluded_urls")
172
+ _InstrumentedSession._excluded_urls = (
173
+ get_excluded_urls("FLASK")
174
+ if excluded_urls is None
175
+ else parse_excluded_urls(excluded_urls)
176
+ )
177
+ sessions.Session = _InstrumentedSession
178
+ logger.info("requests _instrument exited.")
179
+
180
+ def _uninstrument(self, **kwargs):
181
+ """
182
+ Uninstruments the requests module.
183
+ """
184
+ sessions.Session = self._original_session
185
+
186
+
187
+ def instrument_requests(excluded_urls: str = None,
188
+ request_hook: Callable = None,
189
+ response_hook: Callable = None,
190
+ tracer_provider: TraceProvider = None,
191
+ **kwargs: Any,
192
+ ):
193
+ """
194
+ Instruments the requests module.
195
+ Args:
196
+ excluded_urls: A comma separated list of URLs to exclude from tracing.
197
+ request_hook: A function that will be called before a request is sent.
198
+ The function will be called with the span and the request.
199
+ response_hook: A function that will be called after a response is received.
200
+ The function will be called with the span and the response.
201
+ tracer_provider: The tracer provider to use. If not provided, the global
202
+ tracer provider will be used.
203
+ kwargs: Additional keyword arguments.
204
+ """
205
+ all_kwargs = {
206
+ "excluded_urls": excluded_urls,
207
+ "request_hook": request_hook,
208
+ "response_hook": response_hook,
209
+ "tracer_provider": tracer_provider or get_tracer_provider(),
210
+ **kwargs
211
+ }
212
+ RequestsInstrumentor().instrument(**all_kwargs)
213
+ logger.info("Requests instrumented.")
aworld/trace/instrumentation/semconv.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GenAI semconv attribute names
2
+ GEN_AI_SYSTEM = "gen_ai.system"
3
+ GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
4
+ GEN_AI_REQUEST_FREQUENCY_PENALTY = "gen_ai.request.frequency_penalty"
5
+ GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
6
+ GEN_AI_REQUEST_PRESENCE_PENALTY = "gen_ai.request.presence_penalty"
7
+ GEN_AI_REQUEST_STOP_SEQUENCES = "gen_ai.request.stop_sequences"
8
+ GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
9
+ GEN_AI_REQUEST_TOP_K = "gen_ai.request.top_k"
10
+ GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
11
+ GEN_AI_REQUEST_STREAMING = "gen_ai.request.streaming"
12
+ GEN_AI_REQUEST_USER = "gen_ai.request.user"
13
+ GEN_AI_REQUEST_EXTRA_HEADERS = "gen_ai.request.extra_headers"
14
+ GEN_AI_PROMPT = "gen_ai.prompt"
15
+ GEN_AI_PROMPT_TOOLS = "gen_ai.prompt.tools"
16
+ GEN_AI_COMPLETION = "gen_ai.completion"
17
+ GEN_AI_COMPLETION_TOOL_CALLS = "gen_ai.completion.tool_calls"
18
+ GEN_AI_COMPLETION_CONTENT = "gen_ai.completion.content"
19
+ GEN_AI_DURATION = "gen_ai.duration"
20
+ GEN_AI_FIRST_TOKEN_DURATION = "gen_ai.first_token_duration"
21
+ GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons"
22
+ GEN_AI_RESPONSE_ID = "gen_ai.response.id"
23
+ GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
24
+ GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
25
+ GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
26
+ GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
27
+ GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
28
+ GEN_AI_METHOD_NAME = "gen_ai.method.name"
29
+ GEN_AI_SERVER_ADDRESS = "gen_ai.server.address"
aworld/trace/instrumentation/threading.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import Protocol, TypeVar, Any, Callable
3
+ from wrapt import wrap_function_wrapper
4
+ from concurrent import futures
5
+ import aworld.trace as trace
6
+ from aworld.trace.base import TraceContext, Span
7
+ from aworld.trace.propagator import get_global_trace_context
8
+ from aworld.trace.instrumentation import Instrumentor
9
+ from aworld.trace.instrumentation.utils import unwrap
10
+ from aworld.logs.util import logger
11
+
12
+
13
+ R = TypeVar("R")
14
+
15
+
16
+ class HasTraceContext(Protocol):
17
+ _trace_context: TraceContext
18
+
19
+
20
+ class ThreadingInstrumentor(Instrumentor):
21
+ '''
22
+ Trace instrumentor for threading
23
+ '''
24
+
25
+ def instrumentation_dependencies(self) -> str:
26
+ return ()
27
+
28
+ def _instrument(self, **kwargs: Any):
29
+ self._instrument_thread()
30
+ self._instrument_timer()
31
+ self._instrument_thread_pool()
32
+
33
+ def _uninstrument(self, **kwargs: Any):
34
+ self._uninstrument_thread()
35
+ self._uninstrument_timer()
36
+ self._uninstrument_thread_pool()
37
+
38
+ @staticmethod
39
+ def _instrument_thread():
40
+ wrap_function_wrapper(
41
+ threading.Thread,
42
+ "start",
43
+ ThreadingInstrumentor.__wrap_threading_start,
44
+ )
45
+ wrap_function_wrapper(
46
+ threading.Thread,
47
+ "run",
48
+ ThreadingInstrumentor.__wrap_threading_run,
49
+ )
50
+
51
+ @staticmethod
52
+ def _instrument_timer():
53
+ wrap_function_wrapper(
54
+ threading.Timer,
55
+ "start",
56
+ ThreadingInstrumentor.__wrap_threading_start,
57
+ )
58
+ wrap_function_wrapper(
59
+ threading.Timer,
60
+ "run",
61
+ ThreadingInstrumentor.__wrap_threading_run,
62
+ )
63
+
64
+ @staticmethod
65
+ def _instrument_thread_pool():
66
+ wrap_function_wrapper(
67
+ futures.ThreadPoolExecutor,
68
+ "submit",
69
+ ThreadingInstrumentor.__wrap_thread_pool_submit,
70
+ )
71
+
72
+ @staticmethod
73
+ def _uninstrument_thread():
74
+ unwrap(threading.Thread, "start")
75
+ unwrap(threading.Thread, "run")
76
+
77
+ @staticmethod
78
+ def _uninstrument_timer():
79
+ unwrap(threading.Timer, "start")
80
+ unwrap(threading.Timer, "run")
81
+
82
+ @staticmethod
83
+ def _uninstrument_thread_pool():
84
+ unwrap(futures.ThreadPoolExecutor, "submit")
85
+
86
+ @staticmethod
87
+ def __wrap_threading_start(
88
+ call_wrapped: Callable[[], None],
89
+ instance: HasTraceContext,
90
+ args: tuple[()],
91
+ kwargs: dict[str, Any],
92
+ ) -> None:
93
+ span: Span = trace.get_current_span()
94
+ if span:
95
+ instance._trace_context = TraceContext(
96
+ trace_id=span.get_trace_id(), span_id=span.get_span_id())
97
+ return call_wrapped(*args, **kwargs)
98
+
99
+ @staticmethod
100
+ def __wrap_threading_run(
101
+ call_wrapped: Callable[..., R],
102
+ instance: HasTraceContext,
103
+ args: tuple[Any, ...],
104
+ kwargs: dict[str, Any],
105
+ ) -> R:
106
+
107
+ token = None
108
+ try:
109
+ if hasattr(instance, "_trace_context"):
110
+ if instance._trace_context:
111
+ token = get_global_trace_context().set(instance._trace_context)
112
+ return call_wrapped(*args, **kwargs)
113
+ finally:
114
+ if token:
115
+ get_global_trace_context().reset(token)
116
+
117
+ @staticmethod
118
+ def __wrap_thread_pool_submit(
119
+ call_wrapped: Callable[..., R],
120
+ instance: futures.ThreadPoolExecutor,
121
+ args: tuple[Callable[..., Any], ...],
122
+ kwargs: dict[str, Any],
123
+ ) -> R:
124
+ # obtain the original function and wrapped kwargs
125
+ original_func = args[0]
126
+ trace_context = None
127
+ span: Span = trace.get_current_span()
128
+ if span and span.get_trace_id() != "":
129
+ trace_context = TraceContext(
130
+ trace_id=span.get_trace_id(), span_id=span.get_span_id())
131
+
132
+ def wrapped_func(*func_args: Any, **func_kwargs: Any) -> R:
133
+ token = None
134
+ try:
135
+ if trace_context:
136
+ token = get_global_trace_context().set(trace_context)
137
+ return original_func(*func_args, **func_kwargs)
138
+ finally:
139
+ if token:
140
+ get_global_trace_context().reset(token)
141
+
142
+ # replace the original function with the wrapped function
143
+ new_args: tuple[Callable[..., Any], ...] = (wrapped_func,) + args[1:]
144
+ return call_wrapped(*new_args, **kwargs)
145
+
146
+
147
+ def instrument_theading(**kwargs: Any) -> None:
148
+ ThreadingInstrumentor().instrument(**kwargs)
149
+ logger.info("Threading instrumented")
aworld/trace/instrumentation/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from wrapt import ObjectProxy
3
+
4
+
5
+ def unwrap(obj: object, attr: str):
6
+ """Given a function that was wrapped by wrapt.wrap_function_wrapper, unwrap it
7
+
8
+ The object containing the function to unwrap may be passed as dotted module path string.
9
+
10
+ Args:
11
+ obj: Object that holds a reference to the wrapped function or dotted import path as string
12
+ attr (str): Name of the wrapped function
13
+ """
14
+ if isinstance(obj, str):
15
+ try:
16
+ module_path, class_name = obj.rsplit(".", 1)
17
+ except ValueError as exc:
18
+ raise ImportError(
19
+ f"Cannot parse '{obj}' as dotted import path"
20
+ ) from exc
21
+ module = import_module(module_path)
22
+ try:
23
+ obj = getattr(module, class_name)
24
+ except AttributeError as exc:
25
+ raise ImportError(
26
+ f"Cannot import '{class_name}' from '{module}'"
27
+ ) from exc
28
+
29
+ func = getattr(obj, attr, None)
30
+ if func and isinstance(func, ObjectProxy) and hasattr(func, "__wrapped__"):
31
+ setattr(obj, attr, func.__wrapped__)