Duibonduil commited on
Commit
3e56848
·
verified ·
1 Parent(s): e020370

Upload 11 files

Browse files
aworld/trace/__init__.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import traceback
4
+ from typing import Sequence, Union
5
+ from aworld.trace.context_manager import TraceManager
6
+ from aworld.trace.constants import RunType
7
+ from aworld.logs.util import logger
8
+ from aworld.trace.config import configure, ObservabilityConfig
9
+
10
+
11
+ def get_tool_name(tool_name: str,
12
+ action: Union['ActionModel', Sequence['ActionModel']]) -> tuple[str, RunType]:
13
+ if tool_name == "mcp" and action:
14
+ try:
15
+ if isinstance(action, (list, tuple)):
16
+ action = action[0]
17
+ mcp_name = action.action_name.split("__")[0]
18
+ return (mcp_name, RunType.MCP)
19
+ except ValueError:
20
+ logger.warning(traceback.format_exc())
21
+ return (tool_name, RunType.MCP)
22
+ return (tool_name, RunType.TOOL)
23
+
24
+
25
+ def get_span_name_from_message(message: 'aworld.core.event.base.Message') -> tuple[str, RunType]:
26
+ from aworld.core.event.base import Constants
27
+ span_name = (message.receiver or message.id)
28
+ if message.category == Constants.AGENT:
29
+ return (span_name, RunType.AGNET)
30
+ if message.category == Constants.TOOL:
31
+ action = message.payload
32
+ if isinstance(action, (list, tuple)):
33
+ action = action[0]
34
+ if action:
35
+ tool_name, run_type = get_tool_name(action.tool_name, action)
36
+ return (tool_name, run_type)
37
+ return (span_name, RunType.TOOL)
38
+ return (span_name, RunType.OTHER)
39
+
40
+
41
+ def message_span(message: 'aworld.core.event.base.Message' = None, attributes: dict = None):
42
+ if message:
43
+ span_name, run_type = get_span_name_from_message(message)
44
+ message_span_attribute = {
45
+ "event.payload": str(message.payload),
46
+ "event.topic": message.topic or "",
47
+ "event.receiver": message.receiver or "",
48
+ "event.sender": message.sender or "",
49
+ "event.category": message.category,
50
+ "event.id": message.id,
51
+ "event.session_id": message.session_id
52
+ }
53
+ message_span_attribute.update(attributes or {})
54
+ return GLOBAL_TRACE_MANAGER.span(
55
+ span_name=f"{run_type.value.lower()}_event_{span_name}",
56
+ attributes=message_span_attribute,
57
+ run_type=run_type
58
+ )
59
+ else:
60
+ raise ValueError("message_span message is None")
61
+
62
+
63
+ GLOBAL_TRACE_MANAGER: TraceManager = TraceManager()
64
+ span = GLOBAL_TRACE_MANAGER.span
65
+ func_span = GLOBAL_TRACE_MANAGER.func_span
66
+ auto_tracing = GLOBAL_TRACE_MANAGER.auto_tracing
67
+ get_current_span = GLOBAL_TRACE_MANAGER.get_current_span
68
+ new_manager = GLOBAL_TRACE_MANAGER.get_current_span
69
+
70
+ __all__ = [
71
+ "span",
72
+ "func_span",
73
+ "message_span",
74
+ "auto_tracing",
75
+ "get_current_span",
76
+ "new_manager",
77
+ "RunType",
78
+ "configure",
79
+ "ObservabilityConfig"
80
+ ]
aworld/trace/auto_trace.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import re
3
+ import sys
4
+ import warnings
5
+ from importlib.abc import Loader, MetaPathFinder
6
+ from importlib.machinery import ModuleSpec
7
+ from importlib.util import spec_from_loader
8
+ from types import ModuleType
9
+ from typing import TYPE_CHECKING, Sequence, Union, Callable, Iterator, TypeVar, Any, cast
10
+
11
+ from aworld.trace.base import log_trace_error
12
+ from .rewrite_ast import compile_source
13
+
14
+ if TYPE_CHECKING:
15
+ from .context_manager import TraceManager
16
+
17
+
18
+ class AutoTraceModule:
19
+ """A class that represents a module being imported that should maybe be traced automatically."""
20
+
21
+ def __init__(self, module_name: str) -> None:
22
+ self._module_name = module_name
23
+ """Fully qualified absolute name of the module being imported."""
24
+
25
+ def need_auto_trace(self, prefix: Union[str, Sequence[str]]) -> bool:
26
+ """
27
+ Check if the module name starts with the given prefix.
28
+ """
29
+ if isinstance(prefix, str):
30
+ prefix = (prefix,)
31
+ pattern = '|'.join([get_module_pattern(p) for p in prefix])
32
+ return bool(re.match(pattern, self._module_name))
33
+
34
+
35
+ class TraceImportFinder(MetaPathFinder):
36
+ """A class that implements the `find_spec` method of the `MetaPathFinder` protocol."""
37
+
38
+ def __init__(self, trace_manager: "TraceManager", module_funcs: Callable[[AutoTraceModule], bool],
39
+ min_duration_ns: int) -> None:
40
+ self._trace_manager = trace_manager
41
+ self._modules_filter = module_funcs
42
+ self._min_duration_ns = min_duration_ns
43
+
44
+ def _find_plain_specs(
45
+ self, fullname: str, path: Sequence[str] = None, target: ModuleType = None
46
+ ) -> Iterator[ModuleSpec]:
47
+ """Yield module specs returned by other finders on `sys.meta_path`."""
48
+ for finder in sys.meta_path:
49
+ # Skip this finder or any like it to avoid infinite recursion.
50
+ if isinstance(finder, TraceImportFinder):
51
+ continue
52
+
53
+ try:
54
+ plain_spec = finder.find_spec(fullname, path, target)
55
+ except Exception: # pragma: no cover
56
+ continue
57
+
58
+ if plain_spec:
59
+ yield plain_spec
60
+
61
+ def find_spec(self, fullname: str, path: Sequence[str], target=None) -> None:
62
+ """Find the spec for the given module name."""
63
+
64
+ for plain_spec in self._find_plain_specs(fullname, path, target):
65
+ # Get module specs returned by other finders on `sys.meta_path`
66
+ get_source = getattr(plain_spec.loader, 'get_source', None)
67
+ if not callable(get_source):
68
+ continue
69
+ try:
70
+ source = cast(str, get_source(fullname))
71
+ except Exception:
72
+ continue
73
+
74
+ if not source:
75
+ continue
76
+
77
+ filename = plain_spec.origin
78
+ if not filename:
79
+ try:
80
+ filename = cast('str | None', plain_spec.loader.get_filename(fullname))
81
+ except Exception:
82
+ pass
83
+ filename = filename or f'<{fullname}>'
84
+
85
+ if not self._modules_filter(AutoTraceModule(fullname)):
86
+ return None
87
+
88
+ try:
89
+ tree = ast.parse(source)
90
+ except Exception:
91
+ # Invalid source code. Try another one.
92
+ continue
93
+
94
+ try:
95
+ execute = compile_source(tree, filename, fullname, self._trace_manager, self._min_duration_ns)
96
+ except Exception: # pragma: no cover
97
+ log_trace_error()
98
+ return None
99
+
100
+ loader = AutoTraceLoader(plain_spec, execute)
101
+ return spec_from_loader(fullname, loader)
102
+
103
+
104
+ class AutoTraceLoader(Loader):
105
+ """
106
+ A class that implements the `exec_module` method of the `Loader` protocol.
107
+ """
108
+
109
+ def __init__(self, plain_spec: ModuleSpec, execute: Callable[[dict[str, Any]], None]) -> None:
110
+ self._plain_spec = plain_spec
111
+ self._execute = execute
112
+
113
+ def exec_module(self, module: ModuleType):
114
+ """Execute a modified AST of the module's source code in the module's namespace.
115
+ """
116
+ self._execute(module.__dict__)
117
+
118
+ def create_module(self, spec: ModuleSpec):
119
+ return None
120
+
121
+ def get_code(self, _name: str):
122
+ """`python -m` uses the `runpy` module which calls this method instead of going through the normal protocol.
123
+ So return some code which can be executed with the module namespace.
124
+ Here `__loader__` will be this object, i.e. `self`.
125
+ source = '__loader__.execute(globals())'
126
+ return compile(source, '<string>', 'exec', dont_inherit=True)
127
+ """
128
+
129
+ def __getattr__(self, item: str):
130
+ """Forward some methods to the plain spec's loader (likely a `SourceFileLoader`) if they exist."""
131
+ if item in {'get_filename', 'is_package'}:
132
+ return getattr(self.plain_spec.loader, item)
133
+ raise AttributeError(item)
134
+
135
+
136
+ def convert_to_modules_func(modules: Sequence[str]) -> Callable[[AutoTraceModule], bool]:
137
+ """Convert a sequence of module names to a function that checks if a module name starts with any of the given module names.
138
+ """
139
+ return lambda module: module.need_auto_trace(modules)
140
+
141
+
142
+ def get_module_pattern(module: str):
143
+ """
144
+ Get the regex pattern for the given module name.
145
+ """
146
+
147
+ if not re.match(r'[\w.]+$', module, re.UNICODE):
148
+ return module
149
+ module = re.escape(module)
150
+ return rf'{module}($|\.)'
151
+
152
+
153
+ def install_auto_tracing(trace_manager: "TraceManager",
154
+ modules: Union[Sequence[str],
155
+ Callable[[AutoTraceModule], bool]],
156
+ min_duration_seconds: float
157
+ ) -> None:
158
+ """
159
+ Automatically trace the execution of a function.
160
+ """
161
+ if isinstance(modules, Sequence):
162
+ module_funcs = convert_to_modules_func(modules)
163
+ else:
164
+ module_funcs = modules
165
+
166
+ if not callable(module_funcs):
167
+ raise TypeError('modules must be a list of strings or a callable')
168
+
169
+ for module in list(sys.modules.values()):
170
+ try:
171
+ auto_trace_module = AutoTraceModule(module.__name__)
172
+ except Exception:
173
+ continue
174
+
175
+ if module_funcs(auto_trace_module):
176
+ warnings.warn(f'The module {module.__name__!r} matches modules to trace, but it has already been imported. '
177
+ f'Call `auto_tracing` earlier',
178
+ stacklevel=2,
179
+ )
180
+
181
+ min_duration_ns = int(min_duration_seconds * 1_000_000_000)
182
+ trace_manager = trace_manager.new_manager('auto_tracing')
183
+ finder = TraceImportFinder(trace_manager, module_funcs, min_duration_ns)
184
+ sys.meta_path.insert(0, finder)
185
+
186
+
187
+ T = TypeVar('T')
188
+
189
+
190
+ def not_auto_trace(x: T) -> T:
191
+ """Decorator to prevent a function/class from being traced by `auto_tracing`"""
192
+ return x
aworld/trace/base.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional, Any, Iterator, Union, Sequence, Protocol, Iterable
3
+ from enum import Enum
4
+ from weakref import WeakSet
5
+ from dataclasses import dataclass, field
6
+ from aworld.logs.util import trace_logger
7
+
8
+
9
+ class TraceProvider(ABC):
10
+
11
+ @abstractmethod
12
+ def get_tracer(
13
+ self,
14
+ name: str,
15
+ version: Optional[str] = None
16
+ ) -> "Tracer":
17
+ """Returns a `Tracer` for use by the given name.
18
+
19
+ This function may return different `Tracer` types (e.g. a no-op tracer
20
+ vs. a functional tracer).
21
+
22
+ Args:
23
+ name: The uniquely identifiable name for instrumentation
24
+ scope, such as instrumentation library, package, module or class name.
25
+ ``__name__`` may not be used as this can result in
26
+ different tracer names if the tracers are in different files.
27
+ It is better to use a fixed string that can be imported where
28
+ needed and used consistently as the name of the tracer.
29
+
30
+ This should *not* be the name of the module that is
31
+ instrumented but the name of the module doing the instrumentation.
32
+ E.g., instead of ``"requests"``, use
33
+ ``"opentelemetry.instrumentation.requests"``.
34
+
35
+ version: Optional. The version string of the
36
+ instrumenting library. Usually this should be the same as
37
+ ``importlib.metadata.version(instrumenting_library_name)``
38
+ """
39
+
40
+ @abstractmethod
41
+ def shutdown(self) -> None:
42
+ """Shuts down the provider and all its resources.
43
+ This method should be called when the application is shutting down.
44
+ """
45
+
46
+ @abstractmethod
47
+ def force_flush(self, timeout: Optional[float] = None) -> bool:
48
+ """Forces all the data to be sent to the backend.
49
+ This method should be called when the application is shutting down.
50
+ Args:
51
+ timeout: The maximum time to wait for the data to be sent.
52
+ Returns:
53
+ True if the data was sent successfully, False otherwise.
54
+ """
55
+
56
+ @abstractmethod
57
+ def get_current_span(self) -> Optional["Span"]:
58
+ """Returns the current span from the current context.
59
+ Returns:
60
+ The current span from the current context.
61
+ """
62
+
63
+
64
+ class SpanType(Enum):
65
+ """Specifies additional details on how this span relates to its parent span.
66
+ """
67
+
68
+ #: Default value. Indicates that the span is used internally in the
69
+ # application.
70
+ INTERNAL = 0
71
+
72
+ #: Indicates that the span describes an operation that handles a remote
73
+ # request.
74
+ SERVER = 1
75
+
76
+ #: Indicates that the span describes a request to some remote service.
77
+ CLIENT = 2
78
+
79
+ #: Indicates that the span describes a producer sending a message to a
80
+ #: broker. Unlike client and server, there is usually no direct critical
81
+ #: path latency relationship between producer and consumer spans.
82
+ PRODUCER = 3
83
+
84
+ #: Indicates that the span describes a consumer receiving a message from a
85
+ #: broker. Unlike client and server, there is usually no direct critical
86
+ #: path latency relationship between producer and consumer spans.
87
+ CONSUMER = 4
88
+
89
+
90
+ AttributeValueType = Union[
91
+ str,
92
+ bool,
93
+ int,
94
+ float,
95
+ Sequence[str],
96
+ Sequence[bool],
97
+ Sequence[int],
98
+ Sequence[float],
99
+ ]
100
+
101
+
102
+ class Tracer(ABC):
103
+ """Handles span creation and in-process context propagation.
104
+ """
105
+
106
+ @abstractmethod
107
+ def start_span(
108
+ self,
109
+ name: str,
110
+ span_type: SpanType = SpanType.INTERNAL,
111
+ attributes: dict[str, AttributeValueType] = None,
112
+ start_time: Optional[int] = None,
113
+ record_exception: bool = True,
114
+ set_status_on_exception: bool = True,
115
+ trace_context: Optional["TraceContext"] = None,
116
+ ) -> "Span":
117
+ """Starts and returns a new Span.
118
+ Args:
119
+ name: The name of the span.
120
+ kind: The span's kind (relationship to parent). Note that is
121
+ meaningful even if there is no parent.
122
+ attributes: The span's attributes.
123
+ start_time: Sets the start time of a span
124
+ record_exception: Whether to record any exceptions raised within the
125
+ context as error event on the span.
126
+ set_status_on_exception: Only relevant if the returned span is used
127
+ in a with/context manager. Defines whether the span status will
128
+ be automatically set to ERROR when an uncaught exception is
129
+ raised in the span with block. The span status won't be set by
130
+ this mechanism if it was previously set manually.
131
+ trace_context: The trace context to use for the span. If not
132
+ provided, the current trace context will be used.
133
+ """
134
+
135
+ @abstractmethod
136
+ def start_as_current_span(
137
+ self,
138
+ name: str,
139
+ span_type: SpanType = SpanType.INTERNAL,
140
+ attributes: dict[str, AttributeValueType] = None,
141
+ start_time: Optional[int] = None,
142
+ record_exception: bool = True,
143
+ set_status_on_exception: bool = True,
144
+ end_on_exit: bool = True,
145
+ trace_context: Optional['TraceContext'] = None
146
+ ) -> Iterator["Span"]:
147
+ """Context manager for creating a new span and set it
148
+ as the current span in this tracer's context.
149
+
150
+ Example::
151
+
152
+ with tracer.start_as_current_span("one") as parent:
153
+ parent.add_event("parent's event")
154
+ with tracer.start_as_current_span("two") as child:
155
+ child.add_event("child's event")
156
+ trace.get_current_span() # returns child
157
+ trace.get_current_span() # returns parent
158
+ trace.get_current_span() # returns previously active span
159
+
160
+ This can also be used as a decorator::
161
+ @tracer.start_as_current_span("name")
162
+ def function():
163
+
164
+ Args:
165
+ name: The name of the span to be created.
166
+ kind: The span's kind (relationship to parent). Note that is
167
+ meaningful even if there is no parent.
168
+ attributes: The span's attributes.
169
+ start_time: Sets the start time of a span
170
+ record_exception: Whether to record any exceptions raised within the
171
+ context as error event on the span.
172
+ set_status_on_exception: Only relevant if the returned span is used
173
+ in a with/context manager. Defines whether the span status will
174
+ be automatically set to ERROR when an uncaught exception is
175
+ raised in the span with block. The span status won't be set by
176
+ this mechanism if it was previously set manually.
177
+ end_on_exit: Whether to end the span automatically when leaving the
178
+ context manager.
179
+ trace_context: The trace context to use for the span. If not
180
+ provided, the current trace context will be used.
181
+ """
182
+
183
+
184
+ class Span(ABC):
185
+ """A Span represents a single operation within a trace.
186
+ """
187
+
188
+ @abstractmethod
189
+ def end(self, end_time: Optional[int] = None) -> None:
190
+ """Sets the current time as the span's end time.
191
+
192
+ The span's end time is the wall time at which the operation finished.
193
+
194
+ Only the first call to `end` should modify the span, and
195
+ implementations are free to ignore or raise on further calls.
196
+ """
197
+
198
+ @abstractmethod
199
+ def set_attribute(self, key: str, value: Any) -> None:
200
+ """Sets an attribute on the Span.
201
+ Args:
202
+ key: The attribute key.
203
+ value: The attribute value.
204
+ """
205
+
206
+ @abstractmethod
207
+ def set_attributes(self, attributes: dict[str, Any]) -> None:
208
+ """Sets multiple attributes on the Span.
209
+ Args:
210
+ attributes: A dictionary of attributes to set.
211
+ """
212
+
213
+ @abstractmethod
214
+ def is_recording(self) -> bool:
215
+ """Returns whether this span will be recorded.
216
+ Returns true if this Span is active and recording information like attributes using set_attribute.
217
+ """
218
+
219
+ @abstractmethod
220
+ def record_exception(
221
+ self,
222
+ exception: BaseException,
223
+ attributes: dict[str, Any] = None,
224
+ timestamp: Optional[int] = None,
225
+ escaped: bool = False,
226
+ ) -> None:
227
+ """Records an exception in the span.
228
+ Args:
229
+ exception: The exception to record.
230
+ attributes: A dictionary of attributes to set on the exception event.
231
+ timestamp: The timestamp of the exception.
232
+ escaped: Whether the exception was escaped.
233
+ """
234
+
235
+ @abstractmethod
236
+ def get_trace_id(self) -> str:
237
+ """Returns the trace ID of the span.
238
+ Returns:
239
+ The trace ID of the span.
240
+ """
241
+
242
+ @abstractmethod
243
+ def get_span_id(self) -> str:
244
+ """Returns the ID of the span.
245
+ Returns:
246
+ The ID of the span.
247
+ """
248
+
249
+ def _add_to_open_spans(self) -> None:
250
+ """Add the current span to OPEN_SPANS."""
251
+ _OPEN_SPANS.add(self)
252
+
253
+ def _remove_from_open_spans(self) -> None:
254
+ """Remove the current span from OPEN_SPANS."""
255
+ _OPEN_SPANS.discard(self)
256
+
257
+
258
+ class NoOpSpan(Span):
259
+ """No-op implementation of `Span`."""
260
+
261
+ def end(self, end_time: Optional[int] = None) -> None:
262
+ pass
263
+
264
+ def set_attribute(self, key: str, value: Any) -> None:
265
+ pass
266
+
267
+ def set_attributes(self, attributes: dict[str, Any]) -> None:
268
+ pass
269
+
270
+ def is_recording(self) -> bool:
271
+ return False
272
+
273
+ def record_exception(
274
+ self,
275
+ exception: BaseException,
276
+ attributes: dict[str, Any] = None,
277
+ timestamp: Optional[int] = None,
278
+ escaped: bool = False,
279
+ ) -> None:
280
+ pass
281
+
282
+ def get_trace_id(self) -> str:
283
+ return ""
284
+
285
+ def get_span_id(self) -> str:
286
+ return ""
287
+
288
+
289
+ class NoOpTracer(Tracer):
290
+ """No-op implementation of `Tracer`."""
291
+
292
+ def start_span(
293
+ self,
294
+ name: str,
295
+ span_type: SpanType = SpanType.INTERNAL,
296
+ attributes: dict[str, AttributeValueType] = None,
297
+ start_time: Optional[int] = None,
298
+ record_exception: bool = True,
299
+ set_status_on_exception: bool = True,
300
+ trace_context: Optional["TraceContext"] = None,
301
+ ) -> Span:
302
+ return NoOpSpan()
303
+
304
+ def start_as_current_span(
305
+ self,
306
+ name: str,
307
+ span_type: SpanType = SpanType.INTERNAL,
308
+ attributes: dict[str, AttributeValueType] = None,
309
+ start_time: Optional[int] = None,
310
+ record_exception: bool = True,
311
+ set_status_on_exception: bool = True,
312
+ end_on_exit: bool = True,
313
+ trace_context: Optional['TraceContext'] = None
314
+ ) -> Iterator[Span]:
315
+ yield NoOpSpan()
316
+
317
+
318
+ class Carrier(Protocol):
319
+ """Carrier is a protocol that represents a carrier for trace context.
320
+ """
321
+
322
+ def get(self, key: str) -> Optional[str]:
323
+ """Returns the value of the given key from the carrier.
324
+ Args:
325
+ key: The key to get the value for.
326
+ Returns:
327
+ The value of the given key from the carrier.
328
+ """
329
+
330
+ def set(self, key: str, value: str) -> None:
331
+ """Sets the value of the given key in the carrier.
332
+ Args:
333
+ key: The key to set the value for.
334
+ value: The value to set.
335
+ """
336
+
337
+ def keys(self) -> Iterable[str]:
338
+ """Returns an iterable of keys in the carrier.
339
+ Returns:
340
+ An iterable of keys in the carrier.
341
+ """
342
+
343
+
344
+ @dataclass(frozen=True)
345
+ class TraceContext:
346
+ """TraceContext is a class that represents a trace context.
347
+ """
348
+ trace_id: str
349
+ span_id: str
350
+ version: str = "00"
351
+ trace_flags: str = "01"
352
+ attributes: dict[str, Any] = field(default_factory=dict)
353
+
354
+
355
+ class Propagator(ABC):
356
+ """Propagator is a protocol that represents a propagator for trace context.
357
+ """
358
+
359
+ def _get_value(self, carrier: Carrier, name: str) -> str:
360
+ """
361
+ Get value from carrier.
362
+ Args:
363
+ carrier: The carrier to get value from.
364
+ name: The name of the value.
365
+ Returns:
366
+ The value of the name.
367
+ """
368
+ return carrier.get(name) or carrier.get('HTTP_' + name.upper().replace('-', '_'))
369
+
370
+ @abstractmethod
371
+ def extract(self, carrier: Carrier) -> Optional[TraceContext]:
372
+ """Extracts a trace context from the given carrier.
373
+ Args:
374
+ carrier: The carrier to extract the trace context from.
375
+ Returns:
376
+ The trace context extracted from the carrier.
377
+ """
378
+ @abstractmethod
379
+ def inject(self, trace_context: TraceContext, carrier: Carrier) -> None:
380
+ """Injects a trace context into the given carrier.
381
+ Args:
382
+ trace_context: The trace context to inject.
383
+ carrier: The carrier to inject the trace context into.
384
+ """
385
+
386
+
387
+ _GLOBAL_TRACER_PROVIDER: Optional[TraceProvider] = None
388
+ _OPEN_SPANS: WeakSet[Span] = WeakSet()
389
+
390
+
391
+ def set_tracer_provider(provider: TraceProvider):
392
+ """
393
+ Set the global tracer provider.
394
+ """
395
+ global _GLOBAL_TRACER_PROVIDER
396
+ _GLOBAL_TRACER_PROVIDER = provider
397
+
398
+
399
+ def get_tracer_provider() -> TraceProvider:
400
+ """
401
+ Get the global tracer provider.
402
+ """
403
+ global _GLOBAL_TRACER_PROVIDER
404
+ if _GLOBAL_TRACER_PROVIDER is None:
405
+ raise Exception("No tracer provider has been set.")
406
+ return _GLOBAL_TRACER_PROVIDER
407
+
408
+
409
+ def get_tracer_provider_silent():
410
+ try:
411
+ return get_tracer_provider()
412
+ except Exception:
413
+ return None
414
+
415
+
416
+ def log_trace_error():
417
+ """
418
+ Log an error with traceback information.
419
+ """
420
+ trace_logger.exception(
421
+ 'This is logging the trace internal error.',
422
+ )
aworld/trace/config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pydantic import BaseModel
3
+ from typing import Sequence, Optional
4
+ from aworld.trace.span_cosumer import SpanConsumer
5
+ from logging import Logger
6
+ from aworld.logs.util import trace_logger
7
+ from aworld.trace.context_manager import trace_configure
8
+ from aworld.metrics.context_manager import MetricContext
9
+ from aworld.logs.log import set_log_provider, instrument_logging
10
+ from aworld.trace.instrumentation.uni_llmmodel import LLMModelInstrumentor
11
+ from aworld.trace.instrumentation.eventbus import EventBusInstrumentor
12
+
13
+
14
+ class ObservabilityConfig(BaseModel):
15
+ '''
16
+ Observability configuration
17
+ '''
18
+ class Config:
19
+ arbitrary_types_allowed = True
20
+ trace_provider: Optional[str] = "otlp"
21
+ trace_backends: Optional[Sequence[str]] = ["memory"]
22
+ trace_base_url: Optional[str] = None
23
+ trace_write_token: Optional[str] = None
24
+ trace_span_consumers: Optional[Sequence[SpanConsumer]] = None
25
+ # whether to start the trace service
26
+ trace_server_enabled: Optional[bool] = False
27
+ trace_server_port: Optional[int] = 7079
28
+ metrics_provider: Optional[str] = None
29
+ metrics_backend: Optional[str] = None
30
+ metrics_base_url: Optional[str] = None
31
+ metrics_write_token: Optional[str] = None
32
+ # whether to instrument system metrics
33
+ metrics_system_enabled: Optional[bool] = False
34
+ logs_provider: Optional[str] = None
35
+ logs_backend: Optional[str] = None
36
+ logs_base_url: Optional[str] = None
37
+ logs_write_token: Optional[str] = None
38
+ # The loggers that need to record the log as a span
39
+ logs_trace_instrumented_loggers: Sequence[Logger] = [trace_logger]
40
+
41
+
42
+ def configure(config: ObservabilityConfig = None):
43
+ if config is None:
44
+ config = ObservabilityConfig()
45
+ _trace_configure(config)
46
+ _metrics_configure(config)
47
+ _log_configure(config)
48
+ LLMModelInstrumentor().instrument()
49
+ EventBusInstrumentor().instrument()
50
+
51
+
52
+ def _trace_configure(config: ObservabilityConfig):
53
+ if not config.trace_base_url and config.trace_provider == "otlp":
54
+ if "logfire" in config.trace_backends:
55
+ config.trace_base_url = os.getenv("LOGFIRE_WRITE_TOKEN")
56
+ elif os.getenv("OTLP_TRACES_ENDPOINT"):
57
+ config.trace_base_url = os.getenv("OTLP_TRACES_ENDPOINT")
58
+ config.trace_backends.append("other_otlp")
59
+
60
+ trace_configure(
61
+ provider=config.trace_provider,
62
+ backends=config.trace_backends,
63
+ base_url=config.trace_base_url,
64
+ write_token=config.trace_write_token,
65
+ span_consumers=config.trace_span_consumers,
66
+ server_enabled=config.trace_server_enabled,
67
+ server_port=config.trace_server_port
68
+ )
69
+
70
+
71
+ def _metrics_configure(config: ObservabilityConfig):
72
+ if config.metrics_provider and config.metrics_backend:
73
+ MetricContext.configure(
74
+ provider=config.metrics_provider,
75
+ backend=config.metrics_backend,
76
+ base_url=config.metrics_base_url,
77
+ write_token=config.metrics_write_token,
78
+ metrics_system_enabled=config.metrics_system_enabled
79
+ )
80
+
81
+
82
+ def _log_configure(config: ObservabilityConfig):
83
+ if config.logs_provider and config.logs_backend:
84
+ if config.logs_backend == "logfire" and not config.logs_write_token:
85
+ config.logs_write_token = os.getenv("LOGFIRE_WRITE_TOKEN")
86
+ set_log_provider(provider=config.logs_provider,
87
+ backend=config.logs_backend,
88
+ base_url=config.logs_base_url,
89
+ write_token=config.logs_write_token)
90
+
91
+ if config.logs_trace_instrumented_loggers:
92
+ for logger in config.logs_trace_instrumented_loggers:
93
+ instrument_logging(logger)
aworld/trace/constants.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ ATTRIBUTES_NAMESPACE = 'aworld'
4
+ """Namespace within OTEL attributes used by aworld."""
5
+
6
+ ATTRIBUTES_MESSAGE_KEY = f'{ATTRIBUTES_NAMESPACE}.msg'
7
+ """The formatted message for a log."""
8
+
9
+ ATTRIBUTES_MESSAGE_TEMPLATE_KEY = f'{ATTRIBUTES_NAMESPACE}.msg_template'
10
+ """The template for a log message."""
11
+
12
+ ATTRIBUTES_MESSAGE_RUN_TYPE_KEY = f'{ATTRIBUTES_NAMESPACE}.run_type'
13
+ """The template for a log message."""
14
+
15
+ MESSAGE_FORMATTED_VALUE_LENGTH_LIMIT = 128
16
+ """Maximum number of characters for formatted values in a trace message."""
17
+
18
+
19
+ class RunType(Enum):
20
+ '''Span run type supported in the framework
21
+ '''
22
+ AGNET = "AGENT"
23
+ TOOL = "TOOL"
24
+ MCP = "MCP"
25
+ LLM = "LLM"
26
+ OTHER = "OTHER"
aworld/trace/context_manager.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import inspect
3
+ from typing import Union, Optional, Any, Type, Sequence, Callable, Iterable
4
+ from aworld.trace.base import (
5
+ AttributeValueType,
6
+ NoOpSpan,
7
+ Span, Tracer,
8
+ NoOpTracer,
9
+ get_tracer_provider,
10
+ get_tracer_provider_silent,
11
+ log_trace_error
12
+ )
13
+ from aworld.trace.span_cosumer import SpanConsumer
14
+ from aworld.version_gen import __version__
15
+ from aworld.trace.auto_trace import AutoTraceModule, install_auto_tracing
16
+ from aworld.trace.stack_info import get_user_stack_info
17
+ from aworld.trace.constants import (
18
+ ATTRIBUTES_MESSAGE_KEY,
19
+ ATTRIBUTES_MESSAGE_RUN_TYPE_KEY,
20
+ ATTRIBUTES_MESSAGE_TEMPLATE_KEY,
21
+ RunType
22
+ )
23
+ from aworld.trace.msg_format import (
24
+ chunks_formatter,
25
+ warn_formatting,
26
+ FStringAwaitError,
27
+ KnownFormattingError,
28
+ warn_fstring_await
29
+ )
30
+ from aworld.trace.function_trace import trace_func
31
+ from .opentelemetry.opentelemetry_adapter import configure_otlp_provider
32
+ from aworld.logs.util import logger
33
+
34
+
35
+ def trace_configure(provider: str = "otlp",
36
+ backends: Sequence[str] = None,
37
+ base_url: str = None,
38
+ write_token: str = None,
39
+ span_consumers: Optional[Sequence[SpanConsumer]] = None,
40
+ **kwargs
41
+ ) -> None:
42
+ """
43
+ Configure the trace provider.
44
+ Args:
45
+ provider: The trace provider to use.
46
+ backends: The trace backends to use.
47
+ base_url: The base URL of the trace backend.
48
+ write_token: The write token of the trace backend.
49
+ span_consumers: The span consumers to use.
50
+ **kwargs: Additional arguments to pass to the trace provider.
51
+ Returns:
52
+ None
53
+ """
54
+ exist_provider = get_tracer_provider_silent()
55
+ if exist_provider:
56
+ logger.info("Trace provider already configured, shutting down...")
57
+ exist_provider.shutdown()
58
+ if provider == "otlp":
59
+ configure_otlp_provider(
60
+ backends=backends, base_url=base_url, write_token=write_token, span_consumers=span_consumers, **kwargs)
61
+ else:
62
+ raise ValueError(f"Unknown trace provider: {provider}")
63
+
64
+
65
+ class TraceManager:
66
+ """
67
+ TraceManager is a class that provides a way to trace the execution of a function.
68
+ """
69
+
70
+ def __init__(self, tracer_name: str = None) -> None:
71
+ self._tracer_name = tracer_name or "aworld"
72
+ self._version = __version__
73
+
74
+ def _create_auto_span(self,
75
+ name: str,
76
+ attributes: dict[str, AttributeValueType] = None
77
+ ) -> Span:
78
+ """
79
+ Create a auto trace span with the given name and attributes.
80
+ """
81
+ return self._create_context_span(name, attributes)
82
+
83
+ def _create_context_span(self,
84
+ name: str,
85
+ attributes: dict[str, AttributeValueType] = None) -> Span:
86
+ try:
87
+ tracer = get_tracer_provider().get_tracer(
88
+ name=self._tracer_name, version=self._version)
89
+ return ContextSpan(span_name=name, tracer=tracer, attributes=attributes)
90
+ except Exception:
91
+ return ContextSpan(span_name=name, tracer=NoOpTracer(), attributes=attributes)
92
+
93
+ def get_current_span(self) -> Span:
94
+ """
95
+ Get the current span.
96
+ """
97
+ try:
98
+ return get_tracer_provider().get_current_span()
99
+ except Exception:
100
+ return NoOpSpan()
101
+
102
+ def new_manager(self, tracer_name_suffix: str = None) -> "TraceManager":
103
+ """
104
+ Create a new TraceManager with the given tracer name suffix.
105
+ """
106
+ tracer_name = self._tracer_name if not tracer_name_suffix else f"{self._tracer_name}.{tracer_name_suffix}"
107
+ return TraceManager(tracer_name=tracer_name)
108
+
109
+ def auto_tracing(self,
110
+ modules: Union[Sequence[str], Callable[[AutoTraceModule], bool]],
111
+ min_duration: float) -> None:
112
+ """
113
+ Automatically trace the execution of a function.
114
+ Args:
115
+ modules: A list of module names or a callable that takes a `AutoTraceModule` and returns a boolean.
116
+ min_duration: The minimum duration of a function to be traced.
117
+ Returns:
118
+ None
119
+ """
120
+ install_auto_tracing(self, modules, min_duration)
121
+
122
+ def span(self,
123
+ msg_template: str = "",
124
+ attributes: dict[str, AttributeValueType] = None,
125
+ *,
126
+ span_name: str = None,
127
+ run_type: RunType = RunType.OTHER) -> "ContextSpan":
128
+
129
+ try:
130
+ attributes = attributes or {}
131
+ stack_info = get_user_stack_info()
132
+ merged_attributes = {**stack_info, **attributes}
133
+ # Retrieve stack information of user code and add it to the attributes
134
+
135
+ if any(c in msg_template for c in ('{', '}')):
136
+ fstring_frame = inspect.currentframe().f_back
137
+ else:
138
+ fstring_frame = None
139
+ log_message, extra_attrs, msg_template = format_span_msg(
140
+ msg_template,
141
+ merged_attributes,
142
+ fstring_frame=fstring_frame,
143
+ )
144
+ merged_attributes[ATTRIBUTES_MESSAGE_KEY] = log_message
145
+ merged_attributes.update(extra_attrs)
146
+ merged_attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] = msg_template
147
+ merged_attributes[ATTRIBUTES_MESSAGE_RUN_TYPE_KEY] = run_type.value
148
+ span_name = span_name or msg_template
149
+
150
+ return self._create_context_span(span_name, merged_attributes)
151
+
152
+ except Exception:
153
+ log_trace_error()
154
+ return ContextSpan(span_name=span_name, tracer=NoOpTracer(), attributes=attributes)
155
+
156
+ def func_span(self,
157
+ msg_template: Union[str, Callable] = None,
158
+ *,
159
+ attributes: dict[str, AttributeValueType] = None,
160
+ span_name: str = None,
161
+ extract_args: Union[bool, Iterable[str]] = False,
162
+ **kwargs) -> Callable:
163
+ """
164
+ A decorator that traces the execution of a function.
165
+ Args:
166
+ msg_template: The message template to use.
167
+ attributes: The attributes to add to the span.
168
+ span_name: The name of the span.
169
+ extract_args: Whether to extract arguments from the function call.
170
+ **kwargs: Additional attributes to add to the span.
171
+ Returns:
172
+ A decorator that traces the execution of a function.
173
+ """
174
+ if callable(msg_template):
175
+ # @trace_func
176
+ # def foo():
177
+ return self.func_span()(msg_template)
178
+
179
+ attributes = attributes or {}
180
+ attributes.update(kwargs)
181
+ return trace_func(self, msg_template, attributes, span_name, extract_args)
182
+
183
+
184
+ class ContextSpan(Span):
185
+ """A context manager that wraps an existing `Span` object.
186
+ This class provides a way to use a `Span` object as a context manager.
187
+ When the context manager is entered, it returns the `Span` itself.
188
+ When the context manager is exited, it calls `end` on the `Span`.
189
+ Args:
190
+ span: The `Span` object to wrap.
191
+ """
192
+
193
+ def __init__(self,
194
+ span_name: str,
195
+ tracer: Tracer,
196
+ attributes: dict[str, AttributeValueType] = None) -> None:
197
+ self._span_name = span_name
198
+ self._tracer = tracer
199
+ self._attributes = attributes
200
+ self._span: Span = None
201
+ self._coro_context = None
202
+
203
+ def _start(self):
204
+ if self._span is not None:
205
+ return
206
+
207
+ self._span = self._tracer.start_span(
208
+ name=self._span_name,
209
+ attributes=self._attributes,
210
+ )
211
+
212
+ def __enter__(self) -> "Span":
213
+ self._start()
214
+ return self
215
+
216
+ def __exit__(
217
+ self,
218
+ exc_type: Optional[Type[BaseException]],
219
+ exc_val: Optional[BaseException],
220
+ traceback: Optional[Any],
221
+ ) -> None:
222
+ """Ends context manager and calls `end` on the `Span`."""
223
+ self._handle_exit(exc_type, exc_val, traceback)
224
+
225
+ async def __aenter__(self) -> "Span":
226
+ self._start()
227
+
228
+ return self
229
+
230
+ async def __aexit__(
231
+ self,
232
+ exc_type: Optional[Type[BaseException]],
233
+ exc_val: Optional[BaseException],
234
+ traceback: Optional[Any],
235
+ ) -> None:
236
+ self._handle_exit(exc_type, exc_val, traceback)
237
+
238
+ def _handle_exit(
239
+ self,
240
+ exc_type: Optional[Type[BaseException]],
241
+ exc_val: Optional[BaseException],
242
+ traceback: Optional[Any],
243
+ ) -> None:
244
+ try:
245
+ if self._span and self._span.is_recording() and isinstance(exc_val, BaseException):
246
+ self._span.record_exception(exc_val, escaped=True)
247
+ except ValueError as e:
248
+ logger.warning(f"Failed to record_exception: {e}")
249
+ finally:
250
+ if self._span:
251
+ self._span.end()
252
+
253
+ def end(self, end_time: Optional[int] = None) -> None:
254
+ if self._span:
255
+ self._span.end(end_time)
256
+
257
+ def set_attribute(self, key: str, value: AttributeValueType) -> None:
258
+ if self._span:
259
+ self._span.set_attribute(key, value)
260
+
261
+ def set_attributes(self, attributes: dict[str, AttributeValueType]) -> None:
262
+ if self._span:
263
+ self._span.set_attributes(attributes)
264
+
265
+ def is_recording(self) -> bool:
266
+ if self._span:
267
+ return self._span.is_recording()
268
+ return False
269
+
270
+ def record_exception(
271
+ self,
272
+ exception: BaseException,
273
+ attributes: dict[str, Any] = None,
274
+ timestamp: Optional[int] = None,
275
+ escaped: bool = False,
276
+ ) -> None:
277
+ if self._span:
278
+ self._span.record_exception(
279
+ exception, attributes, timestamp, escaped)
280
+
281
+ def get_trace_id(self) -> str:
282
+ if self._span:
283
+ return self._span.get_trace_id()
284
+
285
+ def get_span_id(self) -> str:
286
+ if self._span:
287
+ return self._span.get_span_id()
288
+
289
+
290
+ def format_span_msg(
291
+ format_string: str,
292
+ kwargs: dict[str, Any],
293
+ fstring_frame: types.FrameType = None,
294
+ ) -> tuple[str, dict[str, Any], str]:
295
+ """ Returns
296
+ 1. The formatted message.
297
+ 2. A dictionary of extra attributes to add to the span/log.
298
+ These can come from evaluating values in f-strings.
299
+ 3. The final message template, which may differ from `format_string` if it was an f-string.
300
+ """
301
+ try:
302
+ chunks, extra_attrs, new_template = chunks_formatter.chunks(
303
+ format_string,
304
+ kwargs,
305
+ fstring_frame=fstring_frame
306
+ )
307
+ return ''.join(chunk['v'] for chunk in chunks), extra_attrs, new_template
308
+ except KnownFormattingError as e:
309
+ warn_formatting(str(e) or str(e.__cause__))
310
+ except FStringAwaitError as e:
311
+ warn_fstring_await(str(e))
312
+ except Exception:
313
+ log_trace_error()
314
+
315
+ # Formatting failed, so just use the original format string as the message.
316
+ return format_string, {}, format_string
aworld/trace/function_trace.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import contextlib
3
+ import functools
4
+ from typing import TYPE_CHECKING, Callable, Any, Union, Iterable
5
+ from aworld.trace.base import (
6
+ AttributeValueType
7
+ )
8
+
9
+ from aworld.trace.stack_info import get_filepath_attribute
10
+ from aworld.trace.constants import (
11
+ ATTRIBUTES_MESSAGE_TEMPLATE_KEY
12
+ )
13
+
14
+ if TYPE_CHECKING:
15
+ from aworld.trace.context_manager import TraceManager, ContextSpan
16
+
17
+
18
+ def trace_func(trace_manager: "TraceManager",
19
+ msg_template: str = None,
20
+ attributes: dict[str, AttributeValueType] = None,
21
+ span_name: str = None,
22
+ extract_args: Union[bool, Iterable[str]] = False):
23
+ """A decorator that traces the execution of a function.
24
+
25
+ Args:
26
+ trace_manager: The trace manager to use.
27
+ msg_template: The message template to use.
28
+ attributes: The attributes to use.
29
+ span_name: The span name to use.
30
+ extract_args: Whether to extract arguments from the function call.
31
+
32
+ Returns:
33
+ The decorated function.
34
+ """
35
+
36
+ def decorator(func: Callable) -> Callable:
37
+ func_meta = get_function_meta(func, msg_template)
38
+ func_meta.update(attributes or {})
39
+ final_span_name = span_name or func_meta.get(ATTRIBUTES_MESSAGE_TEMPLATE_KEY) or func.__name__
40
+
41
+ if inspect.isgeneratorfunction(func):
42
+ def wrapper(*args, **kwargs):
43
+ with open_func_span(trace_manager, func_meta, final_span_name,
44
+ get_func_args(func, extract_args, *args, **kwargs)):
45
+ for item in func(*args, **kwargs):
46
+ yield item
47
+ elif inspect.isasyncgenfunction(func):
48
+ async def wrapper(*args, **kwargs):
49
+ with open_func_span(trace_manager, func_meta, final_span_name,
50
+ get_func_args(func, extract_args, *args, **kwargs)):
51
+ async for item in func(*args, **kwargs):
52
+ yield item
53
+ elif inspect.iscoroutinefunction(func):
54
+ async def wrapper(*args, **kwargs):
55
+ with open_func_span(trace_manager, func_meta, final_span_name,
56
+ get_func_args(func, extract_args, *args, **kwargs)):
57
+ return await func(*args, **kwargs)
58
+ else:
59
+ def wrapper(*args, **kwargs):
60
+ with open_func_span(trace_manager, func_meta, final_span_name,
61
+ get_func_args(func, extract_args, *args, **kwargs)):
62
+ return func(*args, **kwargs)
63
+
64
+ wrapper = functools.wraps(func)(wrapper) # type: ignore
65
+ return wrapper
66
+
67
+ return decorator
68
+
69
+
70
+ def open_func_span(trace_manager: "TraceManager",
71
+ func_meta: dict[str, AttributeValueType],
72
+ span_name: str,
73
+ func_args: dict[str, AttributeValueType]):
74
+ """Open a function span.
75
+
76
+ Args:
77
+ func_meta: The function meta information.
78
+ span_name: The span name.
79
+
80
+ Returns:
81
+ The function span.
82
+ """
83
+ func_meta.update(func_args)
84
+ return trace_manager._create_auto_span(name=span_name, attributes=func_meta)
85
+
86
+
87
+ def get_func_args(func: Callable,
88
+ extract_args: Union[bool, Iterable[str]] = False,
89
+ *args,
90
+ **kwargs):
91
+ """Get the arguments of a function.
92
+
93
+ Args:
94
+ func: The function to get the arguments of.
95
+ extract_args: Whether to extract arguments from the function call.
96
+ *args: The positional arguments.
97
+ **kwargs: The keyword arguments.
98
+
99
+ Returns:
100
+ The arguments of the function.
101
+ """
102
+ func_sig = inspect.signature(func)
103
+ if func_sig.parameters:
104
+ func_args = func_sig.bind(*args, **kwargs).arguments
105
+ if extract_args is not False:
106
+ if isinstance(extract_args, bool):
107
+ extract_args = func_sig.parameters.keys()
108
+ func_args = {k: v for k, v in func_args.items() if k in extract_args}
109
+ return func_args
110
+ return {}
111
+
112
+
113
+ def get_function_meta(func: Any,
114
+ msg_template: str = None) -> dict[str, AttributeValueType]:
115
+ """Get the meta information of a function.\
116
+
117
+ Args:
118
+ func: The function to get the meta information of.
119
+ msg_template: The message template to use.
120
+
121
+ Returns:
122
+ The meta information of the function.
123
+ """
124
+ func = inspect.unwrap(func)
125
+ if not inspect.isfunction(func) and hasattr(func, '__call__'):
126
+ func = func.__call__
127
+ func = inspect.unwrap(func)
128
+
129
+ func_name = getattr(func, '__qualname__', getattr(func, '__name__', build_func_name(func)))
130
+ if not msg_template:
131
+ try:
132
+ msg_template = f'Calling {inspect.getmodule(func).__name__}.{func_name}' # type: ignore
133
+ except Exception: # pragma: no cover
134
+ msg_template = f'Calling {func_name}'
135
+ meta: dict[str, AttributeValueType] = {
136
+ 'code.function': func_name,
137
+ ATTRIBUTES_MESSAGE_TEMPLATE_KEY: msg_template,
138
+ }
139
+ with contextlib.suppress(Exception):
140
+ meta['code.lineno'] = func.__code__.co_firstlineno
141
+ with contextlib.suppress(Exception):
142
+ # get code.filepath
143
+ meta.update(get_filepath_attribute(inspect.getsourcefile(func)))
144
+
145
+ func_sig = inspect.signature(func)
146
+ if func_sig.parameters:
147
+ meta['func.args'] = [str(param) for param in func_sig.parameters.values()
148
+ if param.name != 'self']
149
+ return meta
150
+
151
+
152
+ def build_func_name(func: Any) -> str:
153
+ """Build the function name.
154
+
155
+ Args:
156
+ func: The function to build the name of.
157
+
158
+ Returns:
159
+ The function name.
160
+ """
161
+ try:
162
+ result = repr(func)
163
+ except Exception:
164
+ result = f'<{type(func).__name__} object>'
165
+
166
+ return result
aworld/trace/msg_format.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import inspect
3
+ import sys
4
+ import types
5
+ import warnings
6
+ import executing
7
+ from functools import lru_cache
8
+ from string import Formatter
9
+ from types import CodeType
10
+ from typing import Any, Literal, TypeVar
11
+ from typing_extensions import NotRequired, TypedDict
12
+ from .constants import MESSAGE_FORMATTED_VALUE_LENGTH_LIMIT
13
+ from .stack_info import get_user_frame_and_stacklevel
14
+
15
+ Truncatable = TypeVar('Truncatable', str, bytes, 'list[Any]', 'tuple[Any, ...]')
16
+
17
+ class LiteralChunk(TypedDict):
18
+ t: Literal['lit']
19
+ v: str
20
+
21
+
22
+ class ArgChunk(TypedDict):
23
+ t: Literal['arg']
24
+ v: str
25
+ spec: NotRequired[str]
26
+
27
+
28
+ class KnownFormattingError(Exception):
29
+ """An error raised when there's something wrong with a format string or the field values.
30
+
31
+ In other words this should correspond to errors that would be raised when using `str.format`,
32
+ and generally indicate a user error, most likely that they weren't trying to pass a template string at all.
33
+ """
34
+
35
+
36
+ class FStringAwaitError(Exception):
37
+ """An error raised when an await expression is found in an f-string.
38
+
39
+ This is a specific case that can't be handled by f-string introspection and requires
40
+ pre-evaluating the await expression before logging.
41
+ """
42
+
43
+
44
+ class FormattingFailedWarning(UserWarning):
45
+ pass
46
+
47
+ class InspectArgumentsFailedWarning(Warning):
48
+ pass
49
+
50
+ class ChunksFormatter(Formatter):
51
+ def chunks(
52
+ self,
53
+ format_string: str,
54
+ kwargs: dict[str, Any],
55
+ *,
56
+ fstring_frame: types.FrameType = None,
57
+ ) -> tuple[list[LiteralChunk | ArgChunk], dict[str, Any], str]:
58
+ # Returns
59
+ # 1. A list of chunks
60
+ # 2. A dictionary of extra attributes to add to the span/log.
61
+ # These can come from evaluating values in f-strings,
62
+ # or from noting scrubbed values.
63
+ # 3. The final message template, which may differ from `format_string` if it was an f-string.
64
+ if fstring_frame:
65
+ result = self._fstring_chunks(kwargs, fstring_frame)
66
+ if result: # returns None if faile
67
+ return result
68
+
69
+ chunks = self._vformat_chunks(
70
+ format_string,
71
+ kwargs=kwargs
72
+ )
73
+ # When there's no f-string magic, there's no changes in the template string.
74
+ return chunks, {}, format_string
75
+
76
+ def _fstring_chunks(
77
+ self,
78
+ kwargs: dict[str, Any],
79
+ frame: types.FrameType,
80
+ ) -> tuple[list[LiteralChunk | ArgChunk], dict[str, Any], str]:
81
+ # `frame` is the frame of the method that's being called by the user
82
+ # called_code = frame.f_code
83
+ frame = frame.f_back or frame # type: ignore
84
+ assert frame is not None
85
+ # This is where the magic happens. It has caching.
86
+ ex = executing.Source.executing(frame)
87
+
88
+ call_node = ex.node
89
+ if call_node is None: # type: ignore[reportUnnecessaryComparison]
90
+ # `executing` failed to find a node.
91
+ # This shouldn't happen in most cases, but it's best not to rely on it always working.
92
+ if not ex.source.text:
93
+ # This is a very likely cause.
94
+ # There's nothing we could possibly do to make magic work here,
95
+ # and it's a clear case where the user should turn the magic off.
96
+ warn_inspect_arguments(
97
+ 'No source code available. '
98
+ 'This happens when running in an interactive shell, '
99
+ 'using exec(), or running .pyc files without the source .py files.',
100
+ get_stacklevel(frame),
101
+ )
102
+ return None
103
+
104
+ msg = '`executing` failed to find a node.'
105
+ if sys.version_info[:2] < (3, 11): # pragma: no cover
106
+ # inspect_arguments is only on by default for 3.11+ for this reason.
107
+ # The AST modifications made by auto-tracing
108
+ # mean that the bytecode doesn't match the source code seen by `executing`.
109
+ # In 3.11+, a different algorithm is used by `executing` which can deal with this.
110
+ msg += ' This may be caused by a combination of using Python < 3.11 and auto-tracing.'
111
+
112
+ # Try a simple fallback heuristic to find the node which should work in most cases.
113
+ main_nodes: list[ast.AST] = []
114
+ for statement in ex.statements:
115
+ if isinstance(statement, ast.With):
116
+ # Only look at the 'header' of a with statement, not its body.
117
+ main_nodes += statement.items
118
+ else:
119
+ main_nodes.append(statement)
120
+ call_nodes = [
121
+ node
122
+ for main_node in main_nodes
123
+ for node in ast.walk(main_node)
124
+ if isinstance(node, ast.Call)
125
+ if node.args or node.keywords
126
+ ]
127
+ if len(call_nodes) != 1:
128
+ warn_inspect_arguments(msg, get_stacklevel(frame))
129
+ return None
130
+
131
+ [call_node] = call_nodes
132
+
133
+ if not isinstance(call_node, ast.Call): # pragma: no cover
134
+ # Very unlikely.
135
+ warn_inspect_arguments(
136
+ '`executing` unexpectedly identified a non-Call node.',
137
+ get_stacklevel(frame),
138
+ )
139
+ return None
140
+
141
+ if call_node.args:
142
+ arg_node = call_node.args[0]
143
+ else:
144
+ # Very unlikely.
145
+ warn_inspect_arguments(
146
+ "Couldn't identify the `msg_template` argument in the call.",
147
+ get_stacklevel(frame),
148
+ )
149
+ return None
150
+
151
+ if not isinstance(arg_node, ast.JoinedStr):
152
+ # Not an f-string, not a problem.
153
+ # Just use normal formatting.
154
+ return None
155
+
156
+ # We have an f-string AST node.
157
+ # Now prepare the namespaces that we will use to evaluate the components.
158
+ global_vars = frame.f_globals
159
+ local_vars = {**frame.f_locals, **kwargs}
160
+
161
+ # Now for the actual formatting!
162
+ result: list[LiteralChunk | ArgChunk] = []
163
+
164
+ # We construct the message template (i.e. the span name) from the AST.
165
+ # We don't use the source code of the f-string because that gets messy
166
+ # if there's escaped quotes or implicit joining of adjacent strings.
167
+ new_template = ''
168
+
169
+ extra_attrs: dict[str, Any] = {}
170
+ for node_value in arg_node.values:
171
+ if isinstance(node_value, ast.Constant):
172
+ # These are the parts of the f-string not enclosed by `{}`, e.g. 'foo ' in f'foo {bar}'
173
+ value: str = node_value.value
174
+ result.append({'v': value, 't': 'lit'})
175
+ new_template += value
176
+ else:
177
+ # These are the parts of the f-string enclosed by `{}`, e.g. 'bar' in f'foo {bar}'
178
+ assert isinstance(node_value, ast.FormattedValue)
179
+
180
+ # This is cached.
181
+ source, value_code, formatted_code = compile_formatted_value(node_value, ex.source)
182
+
183
+ # Note that this doesn't include:
184
+ # - The format spec, e.g. `:0.2f`
185
+ # - The conversion, e.g. `!r`
186
+ # - The '=' sign within the braces, e.g. `{bar=}`.
187
+ # The AST represents f'{bar = }' as f'bar = {bar}' which is how the template will look.
188
+ new_template += '{' + source + '}'
189
+
190
+ # The actual value of the expression.
191
+ value = eval(value_code, global_vars, local_vars)
192
+ extra_attrs[source] = value
193
+
194
+ # Format the value according to the format spec, converting to a string.
195
+ formatted = eval(formatted_code, global_vars, {**local_vars, '@fvalue': value})
196
+ formatted = self._clean_value(formatted)
197
+ result.append({'v': formatted, 't': 'arg'})
198
+
199
+ return result, extra_attrs, new_template
200
+
201
+ def _vformat_chunks(
202
+ self,
203
+ format_string: str,
204
+ kwargs: dict[str, Any],
205
+ *,
206
+ recursion_depth: int = 2,
207
+ ) -> list[LiteralChunk | ArgChunk]:
208
+ """Copied from `string.Formatter._vformat` https://github.com/python/cpython/blob/v3.11.4/Lib/string.py#L198-L247 then altered."""
209
+ if recursion_depth < 0:
210
+ raise KnownFormattingError('Max format spec recursion exceeded')
211
+ result: list[LiteralChunk | ArgChunk] = []
212
+ # We currently don't use positional arguments
213
+ args = ()
214
+
215
+ for literal_text, field_name, format_spec, conversion in self.parse(format_string):
216
+ # output the literal text
217
+ if literal_text:
218
+ result.append({'v': literal_text, 't': 'lit'})
219
+
220
+ # if there's a field, output it
221
+ if field_name is not None:
222
+ # this is some markup, find the object and do
223
+ # the formatting
224
+ if field_name == '':
225
+ raise KnownFormattingError('Empty curly brackets `{}` are not allowed. A field name is required.')
226
+
227
+ # ADDED BY US:
228
+ if field_name.endswith('='):
229
+ if result and result[-1]['t'] == 'lit':
230
+ result[-1]['v'] += field_name
231
+ else:
232
+ result.append({'v': field_name, 't': 'lit'})
233
+ field_name = field_name[:-1]
234
+
235
+ # given the field_name, find the object it references
236
+ # and the argument it came from
237
+ try:
238
+ obj, _arg_used = self.get_field(field_name, args, kwargs)
239
+ except IndexError:
240
+ raise KnownFormattingError('Numeric field names are not allowed.')
241
+ except KeyError as exc1:
242
+ if str(exc1) == repr(field_name):
243
+ raise KnownFormattingError(f'The field {{{field_name}}} is not defined.') from exc1
244
+
245
+ try:
246
+ # field_name is something like 'a.b' or 'a[b]'
247
+ # Evaluating that expression failed, so now just try getting the whole thing from kwargs.
248
+ # In particular, OTEL attributes with dots in their names are normal and handled here.
249
+ obj = kwargs[field_name]
250
+ except KeyError as exc2:
251
+ # e.g. neither 'a' nor 'a.b' is defined
252
+ raise KnownFormattingError(f'The fields {exc1} and {exc2} are not defined.') from exc2
253
+ except Exception as exc:
254
+ raise KnownFormattingError(f'Error getting field {{{field_name}}}: {exc}') from exc
255
+
256
+ # do any conversion on the resulting object
257
+ if conversion is not None:
258
+ try:
259
+ obj = self.convert_field(obj, conversion)
260
+ except Exception as exc:
261
+ raise KnownFormattingError(f'Error converting field {{{field_name}}}: {exc}') from exc
262
+
263
+ # expand the format spec, if needed
264
+ format_spec_chunks = self._vformat_chunks(
265
+ format_spec or '', kwargs, recursion_depth=recursion_depth - 1
266
+ )
267
+ format_spec = ''.join(chunk['v'] for chunk in format_spec_chunks)
268
+
269
+ try:
270
+ value = self.format_field(obj, format_spec)
271
+ except Exception as exc:
272
+ raise KnownFormattingError(f'Error formatting field {{{field_name}}}: {exc}') from exc
273
+ value = self._clean_value(value)
274
+ d: ArgChunk = {'v': value, 't': 'arg'}
275
+ if format_spec:
276
+ d['spec'] = format_spec
277
+ result.append(d)
278
+
279
+ return result
280
+
281
+ def _clean_value(self, value: str) -> str:
282
+ return truncate_sequence(seq=value, max_length=MESSAGE_FORMATTED_VALUE_LENGTH_LIMIT, middle='...')
283
+
284
+ def warn_inspect_arguments(msg: str, stacklevel: int):
285
+ """Warn about an error in inspecting arguments.
286
+ This is a separate function so that it can be called from multiple places.
287
+ """
288
+ msg = (
289
+ 'Failed to introspect calling code. '
290
+ 'Falling back to normal message formatting '
291
+ 'which may result in loss of information if using an f-string. '
292
+ 'The problem was:\n'
293
+ ) + msg
294
+ warnings.warn(msg, InspectArgumentsFailedWarning, stacklevel=stacklevel)
295
+
296
+
297
+ def get_stacklevel(frame: types.FrameType):
298
+ """Get a stacklevel which can be passed to warn_inspect_arguments
299
+ which points at the given frame, where the f-string was found.
300
+ """
301
+ current_frame = inspect.currentframe()
302
+ stacklevel = 0
303
+ while current_frame: # pragma: no branch
304
+ if current_frame == frame:
305
+ break
306
+ stacklevel += 1
307
+ current_frame = current_frame.f_back
308
+ return stacklevel
309
+
310
+ @lru_cache
311
+ def compile_formatted_value(node: ast.FormattedValue, ex_source: executing.Source) -> tuple[str, CodeType, CodeType]:
312
+ """Returns three things that can be expensive to compute.
313
+
314
+ 1. Source code corresponding to the node value (excluding the format spec).
315
+ 2. A compiled code object which can be evaluated to calculate the value.
316
+ 3. Another code object which formats the value.
317
+ """
318
+ source = get_node_source_text(node.value, ex_source)
319
+
320
+ # Check if the expression contains await before attempting to compile
321
+ for sub_node in ast.walk(node.value):
322
+ if isinstance(sub_node, ast.Await):
323
+ raise FStringAwaitError(source)
324
+
325
+ value_code = compile(source, '<fvalue1>', 'eval')
326
+ expr = ast.Expression(
327
+ ast.JoinedStr(
328
+ values=[
329
+ # Similar to the original FormattedValue node,
330
+ # but replace the actual expression with a simple variable lookup
331
+ # so that it the expression doesn't need to be evaluated again.
332
+ # Use @ in the variable name so that it can't possibly conflict
333
+ # with a normal variable.
334
+ # The value of this variable will be provided in the eval() call
335
+ # and will come from evaluating value_code above.
336
+ ast.FormattedValue(
337
+ value=ast.Name(id='@fvalue', ctx=ast.Load()),
338
+ conversion=node.conversion,
339
+ format_spec=node.format_spec,
340
+ )
341
+ ]
342
+ )
343
+ )
344
+ ast.fix_missing_locations(expr)
345
+ formatted_code = compile(expr, '<fvalue2>', 'eval')
346
+ return source, value_code, formatted_code
347
+
348
+ def get_node_source_text(node: ast.AST, ex_source: executing.Source):
349
+ """Returns some Python source code representing `node`.
350
+
351
+ Preferably the actual original code given by `ast.get_source_segment`,
352
+ but falling back to `ast.unparse(node)` if the former is incorrect.
353
+ This happens sometimes due to Python bugs (especially for older Python versions)
354
+ in the source positions of AST nodes inside f-strings.
355
+ """
356
+ # ast.unparse is not available in Python 3.8, which is why inspect_arguments is forbidden in 3.8.
357
+ source_unparsed = ast.unparse(node)
358
+ source_segment = ast.get_source_segment(ex_source.text, node) or ''
359
+ try:
360
+ # Verify that the source segment is correct by checking that the AST is equivalent to what we have.
361
+ source_segment_unparsed = ast.unparse(ast.parse(source_segment, mode='eval'))
362
+ except Exception: # probably SyntaxError, but ast.parse can raise other exceptions too
363
+ source_segment_unparsed = ''
364
+ return source_segment if source_unparsed == source_segment_unparsed else source_unparsed
365
+
366
+
367
+ def truncate_sequence(seq: Truncatable, *, max_length: int, middle: Truncatable) -> Truncatable:
368
+ """Return a sequence at with `len()` at most `max_length`, with `middle` in the middle if truncated."""
369
+ if len(seq) <= max_length:
370
+ return seq
371
+ remaining_length = max_length - len(middle)
372
+ half = remaining_length // 2
373
+ return seq[:half] + middle + seq[-half:]
374
+
375
+ def warn_at_user_stacklevel(msg: str, category: type[Warning]):
376
+ """Warn at the user's stack level.
377
+ """
378
+ _frame, stacklevel = get_user_frame_and_stacklevel()
379
+ warnings.warn(msg, stacklevel=stacklevel, category=category)
380
+
381
+ def warn_formatting(msg: str):
382
+ """Warn about a formatting error.
383
+ """
384
+ warn_at_user_stacklevel(
385
+ f'\n'
386
+ f' Ensure you are either:\n'
387
+ ' (1) passing an f-string directly, or\n'
388
+ ' (2) passing a literal `str.format`-style template, not a preformatted string.\n'
389
+ f' The problem was: {msg}',
390
+ category=FormattingFailedWarning,
391
+ )
392
+
393
+ def warn_fstring_await(msg: str):
394
+ """Warn about an await expression in an f-string.
395
+ """
396
+ warn_at_user_stacklevel(
397
+ f'\n'
398
+ f' Cannot evaluate await expression in f-string. Pre-evaluate the expression before logging.\n'
399
+ f' The problematic f-string value was: {msg}',
400
+ category=FormattingFailedWarning,
401
+ )
402
+
403
+ chunks_formatter = ChunksFormatter()
aworld/trace/rewrite_ast.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import uuid
5
+ import time
6
+ from pathlib import Path
7
+ from collections import deque
8
+ from functools import partial
9
+ from typing import TYPE_CHECKING, Any, Callable, ContextManager, cast
10
+
11
+ from aworld.trace.base import AttributeValueType
12
+ from aworld.trace.constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY
13
+
14
+ if TYPE_CHECKING:
15
+ from .context_manager import TraceManager
16
+ from .auto_trace import not_auto_trace
17
+
18
+
19
+ def compile_source(
20
+ tree: ast.AST, filename: str, module_name: str, trace_manager: TraceManager, min_duration_ns: int
21
+ ) -> Callable[[dict[str, Any]], None]:
22
+ """Compile a modified AST of the module's source code in the module's namespace.
23
+
24
+ Returns a function which accepts module globals and executes the compiled code.
25
+
26
+ The modified AST wraps the body of every function definition in `with context_factories[index]():`.
27
+ `context_factories` is added to the module's namespace as `aworld_<uuid>`.
28
+ `index` is a different constant number for each function definition.
29
+ """
30
+
31
+ context_factories_var_name = f'aworld_{uuid.uuid4().hex}'
32
+ # The variable name for storing context_factors in the module's namespace.
33
+
34
+ context_factories: list[Callable[[], ContextManager[Any]]] = []
35
+ tree = rewrite_ast(tree, filename, context_factories_var_name, module_name, trace_manager, context_factories,
36
+ min_duration_ns)
37
+ assert isinstance(tree, ast.Module) # for type checking
38
+ # dont_inherit=True is necessary to prevent the module from inheriting the __future__ import from this module.
39
+ code = compile(tree, filename, 'exec', dont_inherit=True)
40
+
41
+ def execute(globs: dict[str, Any]):
42
+ globs[context_factories_var_name] = context_factories
43
+ exec(code, globs, globs)
44
+
45
+ return execute
46
+
47
+
48
+ def rewrite_ast(
49
+ tree: ast.AST,
50
+ filename: str,
51
+ context_factories_var_name: str,
52
+ module_name: str,
53
+ trace_manager: TraceManager,
54
+ context_factories: list[Callable[[], ContextManager[Any]]],
55
+ min_duration_ns: int,
56
+ ) -> ast.AST:
57
+ transformer = AutoTraceTransformer(
58
+ context_factories_var_name, filename, module_name, trace_manager, context_factories, min_duration_ns
59
+ )
60
+ return transformer.visit(tree)
61
+
62
+
63
+ class AutoTraceTransformer(ast.NodeTransformer):
64
+ """Trace all encountered functions except those explicitly marked with `@no_auto_trace`."""
65
+
66
+ def __init__(
67
+ self,
68
+ context_factories_var_name: str,
69
+ filename: str,
70
+ module_name: str,
71
+ trace_manager: TraceManager,
72
+ context_factories: list[Callable[[], ContextManager[Any]]],
73
+ min_duration_ns: int,
74
+ ):
75
+ self._context_factories_var_name = context_factories_var_name
76
+ self._filename = filename
77
+ self._module_name = module_name
78
+ self._trace_manager = trace_manager
79
+ self._context_factories = context_factories
80
+ self._min_duration_ns = min_duration_ns
81
+ self._qualname_stack: list[str] = []
82
+
83
+ def visit_ClassDef(self, node: ast.ClassDef):
84
+ """Visit a class definition and rewrite its methods."""
85
+
86
+ if self.check_not_auto_trace(node):
87
+ return node
88
+
89
+ self._qualname_stack.append(node.name)
90
+ node = cast(ast.ClassDef, self.generic_visit(node))
91
+ self._qualname_stack.pop()
92
+ return node
93
+
94
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
95
+ """Visit a function definition and rewrite it."""
96
+
97
+ if self.check_not_auto_trace(node):
98
+ return node
99
+
100
+ self._qualname_stack.append(node.name)
101
+ qualname = '.'.join(self._qualname_stack)
102
+ self._qualname_stack.append('<locals>')
103
+ self.generic_visit(node)
104
+ self._qualname_stack.pop() # <locals>
105
+ self._qualname_stack.pop() # node.name
106
+ return self.rewrite_function(node, qualname)
107
+
108
+ def check_not_auto_trace(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef) -> bool:
109
+ """Return true if the node has a `@not_auto_trace` decorator."""
110
+ return any(
111
+ (
112
+ isinstance(node, ast.Name)
113
+ and node.id == not_auto_trace.__name__
114
+ # or (
115
+ # isinstance(node, ast.Attribute)
116
+ # and node.attr == not_auto_trace.__name__
117
+ # and isinstance(node.value, ast.Name)
118
+ # and node.value.id == xxx.__name__
119
+ # )
120
+ )
121
+ for node in node.decorator_list
122
+ )
123
+
124
+ def rewrite_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.AST:
125
+ """Rewrite a function definition to trace its execution."""
126
+
127
+ if has_yield(node):
128
+ return node
129
+
130
+ body = node.body.copy()
131
+ new_body: list[ast.stmt] = []
132
+ if (
133
+ body
134
+ and isinstance(body[0], ast.Expr)
135
+ and isinstance(body[0].value, ast.Constant)
136
+ and isinstance(body[0].value.value, str)
137
+ ):
138
+ new_body.append(body.pop(0))
139
+
140
+ if not body or (
141
+ len(body) == 1
142
+ and (
143
+ isinstance(body[0], ast.Pass)
144
+ or (isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant))
145
+ )
146
+ ):
147
+ return node
148
+
149
+ span = ast.With(
150
+ items=[
151
+ ast.withitem(
152
+ context_expr=self.trace_context_method_call_node(node, qualname),
153
+ )
154
+ ],
155
+ body=body,
156
+ type_comment=node.type_comment,
157
+ )
158
+ new_body.append(span)
159
+
160
+ return ast.fix_missing_locations(
161
+ ast.copy_location(
162
+ type(node)( # type: ignore
163
+ name=node.name,
164
+ args=node.args,
165
+ body=new_body,
166
+ decorator_list=node.decorator_list,
167
+ returns=node.returns,
168
+ type_comment=node.type_comment,
169
+ ),
170
+ node,
171
+ )
172
+ )
173
+
174
+ def trace_context_method_call_node(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.Call:
175
+ """Return a method call to `context_factories[index]()`."""
176
+
177
+ index = len(self._context_factories)
178
+ span_factory = partial(
179
+ self._trace_manager._create_auto_span, # type: ignore
180
+ *self.build_create_auto_span_args(qualname, node.lineno),
181
+ )
182
+ if self._min_duration_ns > 0:
183
+
184
+ timer = time.time_ns
185
+ min_duration = self._min_duration_ns
186
+
187
+ # This needs to be as fast as possible since it's the cost of auto-tracing a function
188
+ # that never actually gets instrumented because its calls are all faster than `min_duration`.
189
+ class MeasureTime:
190
+ __slots__ = 'start'
191
+
192
+ def __enter__(_self):
193
+ _self.start = timer()
194
+
195
+ def __exit__(_self, *_):
196
+ # the first call exceeding min_ruration will not be tracked, and subsequent calls will only be tracked
197
+ if timer() - _self.start >= min_duration:
198
+ self._context_factories[index] = span_factory
199
+
200
+ self._context_factories.append(MeasureTime)
201
+ else:
202
+ self._context_factories.append(span_factory)
203
+
204
+ # This node means:
205
+ # context_factories[index]()
206
+ # where `context_factories` is a global variable with the name `self._context_factories_var_name`
207
+ # pointing to the `self.context_factories` list.
208
+ return ast.Call(
209
+ func=ast.Subscript(
210
+ value=ast.Name(id=self._context_factories_var_name, ctx=ast.Load()),
211
+ slice=ast.Index(value=ast.Constant(value=index)), # type: ignore
212
+ ctx=ast.Load(),
213
+ ),
214
+ args=[],
215
+ keywords=[],
216
+ )
217
+
218
+ def build_create_auto_span_args(self, qualname: str, lineno: int) -> tuple[str, dict[str, AttributeValueType]]:
219
+ """Build the arguments for `create_auto_span`."""
220
+
221
+ stack_info = {
222
+ 'code.filepath': get_filepath(self._filename),
223
+ 'code.lineno': lineno,
224
+ 'code.function': qualname,
225
+ }
226
+ attributes: dict[str, AttributeValueType] = {**stack_info} # type: ignore
227
+
228
+ msg_template = f'Calling {self._module_name}.{qualname}'
229
+ attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] = msg_template
230
+
231
+ span_name = msg_template
232
+
233
+ return span_name, attributes
234
+
235
+
236
+ def has_yield(node: ast.AST):
237
+ """Return true if the node has a yield statement."""
238
+
239
+ queue = deque([node])
240
+ while queue:
241
+ node = queue.popleft()
242
+ for child in ast.iter_child_nodes(node):
243
+ if isinstance(child, (ast.Yield, ast.YieldFrom)):
244
+ return True
245
+ if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)):
246
+ queue.append(child)
247
+
248
+
249
+ def get_filepath(file: str):
250
+ """Return a dict with the filepath attribute."""
251
+
252
+ path = Path(file)
253
+ if path.is_absolute():
254
+ try:
255
+ path = path.relative_to(Path('.').resolve())
256
+ except ValueError: # pragma: no cover
257
+ # happens if filename path is not within CWD
258
+ pass
259
+ return str(path)
aworld/trace/span_cosumer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Sequence
3
+ from aworld.trace.base import Span
4
+
5
+
6
+ class SpanConsumer(ABC):
7
+ """SpanConsumer is a protocol that represents a consumer for spans.
8
+ """
9
+ @abstractmethod
10
+ def consume(self, spans: Sequence[Span]) -> None:
11
+ """Consumes a span.
12
+ Args:
13
+ spans: The span to consume.
14
+ """
15
+
16
+
17
+ _SPAN_CONSUMER_REGISTRY = {}
18
+
19
+
20
+ def register_span_consumer(default_kwargs=None) -> None:
21
+ """Registers a span consumer.
22
+ Args:
23
+ default_kwargs: A dictionary of default keyword arguments to pass to the span consumer.
24
+ """
25
+
26
+ default_kwargs = default_kwargs or {}
27
+
28
+ def decorator(cls):
29
+ _SPAN_CONSUMER_REGISTRY[cls.__name__] = (cls, default_kwargs)
30
+ return cls
31
+
32
+ return decorator
33
+
34
+
35
+ def get_span_consumers() -> Sequence[SpanConsumer]:
36
+ """Returns a list of span consumers.
37
+ Returns:
38
+ A list of span consumers.
39
+ """
40
+ return [
41
+ cls(**kwargs)
42
+ for cls, kwargs in _SPAN_CONSUMER_REGISTRY.values()
43
+ ]
aworld/trace/stack_info.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import sys
3
+ import aworld.trace as atrace
4
+ from types import CodeType, FrameType
5
+ from typing import Optional, TypedDict, Union
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+
9
+ StackInfo = TypedDict('StackInfo', {'code.filepath': str, 'code.lineno': int, 'code.function': str}, total=False)
10
+
11
+ NON_USER_CODE_PREFIXES: tuple[str, ...] = ()
12
+
13
+ def add_non_user_code_prefix(path: Union[str, Path]) -> None:
14
+ global NON_USER_CODE_PREFIXES
15
+ path = str(Path(path).absolute())
16
+ NON_USER_CODE_PREFIXES += (path,)
17
+
18
+ add_non_user_code_prefix(Path(inspect.__file__).parent)
19
+ add_non_user_code_prefix(Path(atrace.__file__).parent)
20
+
21
+ def get_user_stack_info() -> StackInfo:
22
+ """Get the stack info for the first calling frame in user code.
23
+
24
+ See is_user_code for details.
25
+ Returns an empty dict if no such frame is found.
26
+ """
27
+ frame, _stacklevel = get_user_frame_and_stacklevel()
28
+ if frame:
29
+ return get_stack_info_from_frame(frame)
30
+ return {}
31
+
32
+
33
+ def get_user_frame_and_stacklevel() -> tuple[Optional[FrameType], int]:
34
+ """Get the first calling frame in user code and a corresponding stacklevel that can be passed to `warnings.warn`.
35
+
36
+ See is_user_code for details.
37
+ Returns `(None, 0)` if no such frame is found.
38
+ """
39
+ frame = inspect.currentframe()
40
+ stacklevel = 0
41
+ while frame:
42
+ if is_user_code(frame.f_code):
43
+ return frame, stacklevel
44
+ frame = frame.f_back
45
+ stacklevel += 1
46
+ return None, 0
47
+
48
+ def get_stack_info_from_frame(frame: FrameType) -> StackInfo:
49
+ return {
50
+ **get_code_object_info(frame.f_code),
51
+ 'code.lineno': frame.f_lineno,
52
+ }
53
+
54
+ @lru_cache(maxsize=2048)
55
+ def get_code_object_info(code: CodeType) -> StackInfo:
56
+ result = get_filepath_attribute(code.co_filename)
57
+ if code.co_name != '<module>': # pragma: no branch
58
+ result['code.function'] = code.co_qualname if sys.version_info >= (3, 11) else code.co_name
59
+ result['code.lineno'] = code.co_firstlineno
60
+ return result
61
+
62
+ def get_filepath_attribute(file: str) -> StackInfo:
63
+ path = Path(file)
64
+ if path.is_absolute():
65
+ try:
66
+ path = path.relative_to(Path('.').resolve())
67
+ except ValueError: # pragma: no cover
68
+ # happens if filename path is not within CWD
69
+ pass
70
+ return {'code.filepath': str(path)}
71
+
72
+ @lru_cache(maxsize=8192)
73
+ def is_user_code(code: CodeType) -> bool:
74
+ """Check if the code object is from user code.
75
+
76
+ A code object is not user code if:
77
+ - It is from a file in
78
+ - the standard library
79
+ - site-packages (specifically wherever opentelemetry is installed)
80
+ - an unknown location (e.g. a dynamically generated code object) indicated by a filename starting with '<'
81
+
82
+ - It is a list/dict/set comprehension.
83
+ These are artificial frames only created before Python 3.12,
84
+ and they are always called directly from the enclosing function so it makes sense to skip them.
85
+ On the other hand, generator expressions and lambdas might be called far away from where they are defined.
86
+ """
87
+ return not (
88
+ str(Path(code.co_filename).absolute()).startswith(NON_USER_CODE_PREFIXES)
89
+ or code.co_filename.startswith('<')
90
+ or code.co_name in ('<listcomp>', '<dictcomp>', '<setcomp>')
91
+ )