Upload 11 files
Browse files- aworld/trace/__init__.py +80 -0
- aworld/trace/auto_trace.py +192 -0
- aworld/trace/base.py +422 -0
- aworld/trace/config.py +93 -0
- aworld/trace/constants.py +26 -0
- aworld/trace/context_manager.py +316 -0
- aworld/trace/function_trace.py +166 -0
- aworld/trace/msg_format.py +403 -0
- aworld/trace/rewrite_ast.py +259 -0
- aworld/trace/span_cosumer.py +43 -0
- aworld/trace/stack_info.py +91 -0
aworld/trace/__init__.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import traceback
|
4 |
+
from typing import Sequence, Union
|
5 |
+
from aworld.trace.context_manager import TraceManager
|
6 |
+
from aworld.trace.constants import RunType
|
7 |
+
from aworld.logs.util import logger
|
8 |
+
from aworld.trace.config import configure, ObservabilityConfig
|
9 |
+
|
10 |
+
|
11 |
+
def get_tool_name(tool_name: str,
|
12 |
+
action: Union['ActionModel', Sequence['ActionModel']]) -> tuple[str, RunType]:
|
13 |
+
if tool_name == "mcp" and action:
|
14 |
+
try:
|
15 |
+
if isinstance(action, (list, tuple)):
|
16 |
+
action = action[0]
|
17 |
+
mcp_name = action.action_name.split("__")[0]
|
18 |
+
return (mcp_name, RunType.MCP)
|
19 |
+
except ValueError:
|
20 |
+
logger.warning(traceback.format_exc())
|
21 |
+
return (tool_name, RunType.MCP)
|
22 |
+
return (tool_name, RunType.TOOL)
|
23 |
+
|
24 |
+
|
25 |
+
def get_span_name_from_message(message: 'aworld.core.event.base.Message') -> tuple[str, RunType]:
|
26 |
+
from aworld.core.event.base import Constants
|
27 |
+
span_name = (message.receiver or message.id)
|
28 |
+
if message.category == Constants.AGENT:
|
29 |
+
return (span_name, RunType.AGNET)
|
30 |
+
if message.category == Constants.TOOL:
|
31 |
+
action = message.payload
|
32 |
+
if isinstance(action, (list, tuple)):
|
33 |
+
action = action[0]
|
34 |
+
if action:
|
35 |
+
tool_name, run_type = get_tool_name(action.tool_name, action)
|
36 |
+
return (tool_name, run_type)
|
37 |
+
return (span_name, RunType.TOOL)
|
38 |
+
return (span_name, RunType.OTHER)
|
39 |
+
|
40 |
+
|
41 |
+
def message_span(message: 'aworld.core.event.base.Message' = None, attributes: dict = None):
|
42 |
+
if message:
|
43 |
+
span_name, run_type = get_span_name_from_message(message)
|
44 |
+
message_span_attribute = {
|
45 |
+
"event.payload": str(message.payload),
|
46 |
+
"event.topic": message.topic or "",
|
47 |
+
"event.receiver": message.receiver or "",
|
48 |
+
"event.sender": message.sender or "",
|
49 |
+
"event.category": message.category,
|
50 |
+
"event.id": message.id,
|
51 |
+
"event.session_id": message.session_id
|
52 |
+
}
|
53 |
+
message_span_attribute.update(attributes or {})
|
54 |
+
return GLOBAL_TRACE_MANAGER.span(
|
55 |
+
span_name=f"{run_type.value.lower()}_event_{span_name}",
|
56 |
+
attributes=message_span_attribute,
|
57 |
+
run_type=run_type
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
raise ValueError("message_span message is None")
|
61 |
+
|
62 |
+
|
63 |
+
GLOBAL_TRACE_MANAGER: TraceManager = TraceManager()
|
64 |
+
span = GLOBAL_TRACE_MANAGER.span
|
65 |
+
func_span = GLOBAL_TRACE_MANAGER.func_span
|
66 |
+
auto_tracing = GLOBAL_TRACE_MANAGER.auto_tracing
|
67 |
+
get_current_span = GLOBAL_TRACE_MANAGER.get_current_span
|
68 |
+
new_manager = GLOBAL_TRACE_MANAGER.get_current_span
|
69 |
+
|
70 |
+
__all__ = [
|
71 |
+
"span",
|
72 |
+
"func_span",
|
73 |
+
"message_span",
|
74 |
+
"auto_tracing",
|
75 |
+
"get_current_span",
|
76 |
+
"new_manager",
|
77 |
+
"RunType",
|
78 |
+
"configure",
|
79 |
+
"ObservabilityConfig"
|
80 |
+
]
|
aworld/trace/auto_trace.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
import warnings
|
5 |
+
from importlib.abc import Loader, MetaPathFinder
|
6 |
+
from importlib.machinery import ModuleSpec
|
7 |
+
from importlib.util import spec_from_loader
|
8 |
+
from types import ModuleType
|
9 |
+
from typing import TYPE_CHECKING, Sequence, Union, Callable, Iterator, TypeVar, Any, cast
|
10 |
+
|
11 |
+
from aworld.trace.base import log_trace_error
|
12 |
+
from .rewrite_ast import compile_source
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from .context_manager import TraceManager
|
16 |
+
|
17 |
+
|
18 |
+
class AutoTraceModule:
|
19 |
+
"""A class that represents a module being imported that should maybe be traced automatically."""
|
20 |
+
|
21 |
+
def __init__(self, module_name: str) -> None:
|
22 |
+
self._module_name = module_name
|
23 |
+
"""Fully qualified absolute name of the module being imported."""
|
24 |
+
|
25 |
+
def need_auto_trace(self, prefix: Union[str, Sequence[str]]) -> bool:
|
26 |
+
"""
|
27 |
+
Check if the module name starts with the given prefix.
|
28 |
+
"""
|
29 |
+
if isinstance(prefix, str):
|
30 |
+
prefix = (prefix,)
|
31 |
+
pattern = '|'.join([get_module_pattern(p) for p in prefix])
|
32 |
+
return bool(re.match(pattern, self._module_name))
|
33 |
+
|
34 |
+
|
35 |
+
class TraceImportFinder(MetaPathFinder):
|
36 |
+
"""A class that implements the `find_spec` method of the `MetaPathFinder` protocol."""
|
37 |
+
|
38 |
+
def __init__(self, trace_manager: "TraceManager", module_funcs: Callable[[AutoTraceModule], bool],
|
39 |
+
min_duration_ns: int) -> None:
|
40 |
+
self._trace_manager = trace_manager
|
41 |
+
self._modules_filter = module_funcs
|
42 |
+
self._min_duration_ns = min_duration_ns
|
43 |
+
|
44 |
+
def _find_plain_specs(
|
45 |
+
self, fullname: str, path: Sequence[str] = None, target: ModuleType = None
|
46 |
+
) -> Iterator[ModuleSpec]:
|
47 |
+
"""Yield module specs returned by other finders on `sys.meta_path`."""
|
48 |
+
for finder in sys.meta_path:
|
49 |
+
# Skip this finder or any like it to avoid infinite recursion.
|
50 |
+
if isinstance(finder, TraceImportFinder):
|
51 |
+
continue
|
52 |
+
|
53 |
+
try:
|
54 |
+
plain_spec = finder.find_spec(fullname, path, target)
|
55 |
+
except Exception: # pragma: no cover
|
56 |
+
continue
|
57 |
+
|
58 |
+
if plain_spec:
|
59 |
+
yield plain_spec
|
60 |
+
|
61 |
+
def find_spec(self, fullname: str, path: Sequence[str], target=None) -> None:
|
62 |
+
"""Find the spec for the given module name."""
|
63 |
+
|
64 |
+
for plain_spec in self._find_plain_specs(fullname, path, target):
|
65 |
+
# Get module specs returned by other finders on `sys.meta_path`
|
66 |
+
get_source = getattr(plain_spec.loader, 'get_source', None)
|
67 |
+
if not callable(get_source):
|
68 |
+
continue
|
69 |
+
try:
|
70 |
+
source = cast(str, get_source(fullname))
|
71 |
+
except Exception:
|
72 |
+
continue
|
73 |
+
|
74 |
+
if not source:
|
75 |
+
continue
|
76 |
+
|
77 |
+
filename = plain_spec.origin
|
78 |
+
if not filename:
|
79 |
+
try:
|
80 |
+
filename = cast('str | None', plain_spec.loader.get_filename(fullname))
|
81 |
+
except Exception:
|
82 |
+
pass
|
83 |
+
filename = filename or f'<{fullname}>'
|
84 |
+
|
85 |
+
if not self._modules_filter(AutoTraceModule(fullname)):
|
86 |
+
return None
|
87 |
+
|
88 |
+
try:
|
89 |
+
tree = ast.parse(source)
|
90 |
+
except Exception:
|
91 |
+
# Invalid source code. Try another one.
|
92 |
+
continue
|
93 |
+
|
94 |
+
try:
|
95 |
+
execute = compile_source(tree, filename, fullname, self._trace_manager, self._min_duration_ns)
|
96 |
+
except Exception: # pragma: no cover
|
97 |
+
log_trace_error()
|
98 |
+
return None
|
99 |
+
|
100 |
+
loader = AutoTraceLoader(plain_spec, execute)
|
101 |
+
return spec_from_loader(fullname, loader)
|
102 |
+
|
103 |
+
|
104 |
+
class AutoTraceLoader(Loader):
|
105 |
+
"""
|
106 |
+
A class that implements the `exec_module` method of the `Loader` protocol.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, plain_spec: ModuleSpec, execute: Callable[[dict[str, Any]], None]) -> None:
|
110 |
+
self._plain_spec = plain_spec
|
111 |
+
self._execute = execute
|
112 |
+
|
113 |
+
def exec_module(self, module: ModuleType):
|
114 |
+
"""Execute a modified AST of the module's source code in the module's namespace.
|
115 |
+
"""
|
116 |
+
self._execute(module.__dict__)
|
117 |
+
|
118 |
+
def create_module(self, spec: ModuleSpec):
|
119 |
+
return None
|
120 |
+
|
121 |
+
def get_code(self, _name: str):
|
122 |
+
"""`python -m` uses the `runpy` module which calls this method instead of going through the normal protocol.
|
123 |
+
So return some code which can be executed with the module namespace.
|
124 |
+
Here `__loader__` will be this object, i.e. `self`.
|
125 |
+
source = '__loader__.execute(globals())'
|
126 |
+
return compile(source, '<string>', 'exec', dont_inherit=True)
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __getattr__(self, item: str):
|
130 |
+
"""Forward some methods to the plain spec's loader (likely a `SourceFileLoader`) if they exist."""
|
131 |
+
if item in {'get_filename', 'is_package'}:
|
132 |
+
return getattr(self.plain_spec.loader, item)
|
133 |
+
raise AttributeError(item)
|
134 |
+
|
135 |
+
|
136 |
+
def convert_to_modules_func(modules: Sequence[str]) -> Callable[[AutoTraceModule], bool]:
|
137 |
+
"""Convert a sequence of module names to a function that checks if a module name starts with any of the given module names.
|
138 |
+
"""
|
139 |
+
return lambda module: module.need_auto_trace(modules)
|
140 |
+
|
141 |
+
|
142 |
+
def get_module_pattern(module: str):
|
143 |
+
"""
|
144 |
+
Get the regex pattern for the given module name.
|
145 |
+
"""
|
146 |
+
|
147 |
+
if not re.match(r'[\w.]+$', module, re.UNICODE):
|
148 |
+
return module
|
149 |
+
module = re.escape(module)
|
150 |
+
return rf'{module}($|\.)'
|
151 |
+
|
152 |
+
|
153 |
+
def install_auto_tracing(trace_manager: "TraceManager",
|
154 |
+
modules: Union[Sequence[str],
|
155 |
+
Callable[[AutoTraceModule], bool]],
|
156 |
+
min_duration_seconds: float
|
157 |
+
) -> None:
|
158 |
+
"""
|
159 |
+
Automatically trace the execution of a function.
|
160 |
+
"""
|
161 |
+
if isinstance(modules, Sequence):
|
162 |
+
module_funcs = convert_to_modules_func(modules)
|
163 |
+
else:
|
164 |
+
module_funcs = modules
|
165 |
+
|
166 |
+
if not callable(module_funcs):
|
167 |
+
raise TypeError('modules must be a list of strings or a callable')
|
168 |
+
|
169 |
+
for module in list(sys.modules.values()):
|
170 |
+
try:
|
171 |
+
auto_trace_module = AutoTraceModule(module.__name__)
|
172 |
+
except Exception:
|
173 |
+
continue
|
174 |
+
|
175 |
+
if module_funcs(auto_trace_module):
|
176 |
+
warnings.warn(f'The module {module.__name__!r} matches modules to trace, but it has already been imported. '
|
177 |
+
f'Call `auto_tracing` earlier',
|
178 |
+
stacklevel=2,
|
179 |
+
)
|
180 |
+
|
181 |
+
min_duration_ns = int(min_duration_seconds * 1_000_000_000)
|
182 |
+
trace_manager = trace_manager.new_manager('auto_tracing')
|
183 |
+
finder = TraceImportFinder(trace_manager, module_funcs, min_duration_ns)
|
184 |
+
sys.meta_path.insert(0, finder)
|
185 |
+
|
186 |
+
|
187 |
+
T = TypeVar('T')
|
188 |
+
|
189 |
+
|
190 |
+
def not_auto_trace(x: T) -> T:
|
191 |
+
"""Decorator to prevent a function/class from being traced by `auto_tracing`"""
|
192 |
+
return x
|
aworld/trace/base.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Optional, Any, Iterator, Union, Sequence, Protocol, Iterable
|
3 |
+
from enum import Enum
|
4 |
+
from weakref import WeakSet
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from aworld.logs.util import trace_logger
|
7 |
+
|
8 |
+
|
9 |
+
class TraceProvider(ABC):
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def get_tracer(
|
13 |
+
self,
|
14 |
+
name: str,
|
15 |
+
version: Optional[str] = None
|
16 |
+
) -> "Tracer":
|
17 |
+
"""Returns a `Tracer` for use by the given name.
|
18 |
+
|
19 |
+
This function may return different `Tracer` types (e.g. a no-op tracer
|
20 |
+
vs. a functional tracer).
|
21 |
+
|
22 |
+
Args:
|
23 |
+
name: The uniquely identifiable name for instrumentation
|
24 |
+
scope, such as instrumentation library, package, module or class name.
|
25 |
+
``__name__`` may not be used as this can result in
|
26 |
+
different tracer names if the tracers are in different files.
|
27 |
+
It is better to use a fixed string that can be imported where
|
28 |
+
needed and used consistently as the name of the tracer.
|
29 |
+
|
30 |
+
This should *not* be the name of the module that is
|
31 |
+
instrumented but the name of the module doing the instrumentation.
|
32 |
+
E.g., instead of ``"requests"``, use
|
33 |
+
``"opentelemetry.instrumentation.requests"``.
|
34 |
+
|
35 |
+
version: Optional. The version string of the
|
36 |
+
instrumenting library. Usually this should be the same as
|
37 |
+
``importlib.metadata.version(instrumenting_library_name)``
|
38 |
+
"""
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def shutdown(self) -> None:
|
42 |
+
"""Shuts down the provider and all its resources.
|
43 |
+
This method should be called when the application is shutting down.
|
44 |
+
"""
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def force_flush(self, timeout: Optional[float] = None) -> bool:
|
48 |
+
"""Forces all the data to be sent to the backend.
|
49 |
+
This method should be called when the application is shutting down.
|
50 |
+
Args:
|
51 |
+
timeout: The maximum time to wait for the data to be sent.
|
52 |
+
Returns:
|
53 |
+
True if the data was sent successfully, False otherwise.
|
54 |
+
"""
|
55 |
+
|
56 |
+
@abstractmethod
|
57 |
+
def get_current_span(self) -> Optional["Span"]:
|
58 |
+
"""Returns the current span from the current context.
|
59 |
+
Returns:
|
60 |
+
The current span from the current context.
|
61 |
+
"""
|
62 |
+
|
63 |
+
|
64 |
+
class SpanType(Enum):
|
65 |
+
"""Specifies additional details on how this span relates to its parent span.
|
66 |
+
"""
|
67 |
+
|
68 |
+
#: Default value. Indicates that the span is used internally in the
|
69 |
+
# application.
|
70 |
+
INTERNAL = 0
|
71 |
+
|
72 |
+
#: Indicates that the span describes an operation that handles a remote
|
73 |
+
# request.
|
74 |
+
SERVER = 1
|
75 |
+
|
76 |
+
#: Indicates that the span describes a request to some remote service.
|
77 |
+
CLIENT = 2
|
78 |
+
|
79 |
+
#: Indicates that the span describes a producer sending a message to a
|
80 |
+
#: broker. Unlike client and server, there is usually no direct critical
|
81 |
+
#: path latency relationship between producer and consumer spans.
|
82 |
+
PRODUCER = 3
|
83 |
+
|
84 |
+
#: Indicates that the span describes a consumer receiving a message from a
|
85 |
+
#: broker. Unlike client and server, there is usually no direct critical
|
86 |
+
#: path latency relationship between producer and consumer spans.
|
87 |
+
CONSUMER = 4
|
88 |
+
|
89 |
+
|
90 |
+
AttributeValueType = Union[
|
91 |
+
str,
|
92 |
+
bool,
|
93 |
+
int,
|
94 |
+
float,
|
95 |
+
Sequence[str],
|
96 |
+
Sequence[bool],
|
97 |
+
Sequence[int],
|
98 |
+
Sequence[float],
|
99 |
+
]
|
100 |
+
|
101 |
+
|
102 |
+
class Tracer(ABC):
|
103 |
+
"""Handles span creation and in-process context propagation.
|
104 |
+
"""
|
105 |
+
|
106 |
+
@abstractmethod
|
107 |
+
def start_span(
|
108 |
+
self,
|
109 |
+
name: str,
|
110 |
+
span_type: SpanType = SpanType.INTERNAL,
|
111 |
+
attributes: dict[str, AttributeValueType] = None,
|
112 |
+
start_time: Optional[int] = None,
|
113 |
+
record_exception: bool = True,
|
114 |
+
set_status_on_exception: bool = True,
|
115 |
+
trace_context: Optional["TraceContext"] = None,
|
116 |
+
) -> "Span":
|
117 |
+
"""Starts and returns a new Span.
|
118 |
+
Args:
|
119 |
+
name: The name of the span.
|
120 |
+
kind: The span's kind (relationship to parent). Note that is
|
121 |
+
meaningful even if there is no parent.
|
122 |
+
attributes: The span's attributes.
|
123 |
+
start_time: Sets the start time of a span
|
124 |
+
record_exception: Whether to record any exceptions raised within the
|
125 |
+
context as error event on the span.
|
126 |
+
set_status_on_exception: Only relevant if the returned span is used
|
127 |
+
in a with/context manager. Defines whether the span status will
|
128 |
+
be automatically set to ERROR when an uncaught exception is
|
129 |
+
raised in the span with block. The span status won't be set by
|
130 |
+
this mechanism if it was previously set manually.
|
131 |
+
trace_context: The trace context to use for the span. If not
|
132 |
+
provided, the current trace context will be used.
|
133 |
+
"""
|
134 |
+
|
135 |
+
@abstractmethod
|
136 |
+
def start_as_current_span(
|
137 |
+
self,
|
138 |
+
name: str,
|
139 |
+
span_type: SpanType = SpanType.INTERNAL,
|
140 |
+
attributes: dict[str, AttributeValueType] = None,
|
141 |
+
start_time: Optional[int] = None,
|
142 |
+
record_exception: bool = True,
|
143 |
+
set_status_on_exception: bool = True,
|
144 |
+
end_on_exit: bool = True,
|
145 |
+
trace_context: Optional['TraceContext'] = None
|
146 |
+
) -> Iterator["Span"]:
|
147 |
+
"""Context manager for creating a new span and set it
|
148 |
+
as the current span in this tracer's context.
|
149 |
+
|
150 |
+
Example::
|
151 |
+
|
152 |
+
with tracer.start_as_current_span("one") as parent:
|
153 |
+
parent.add_event("parent's event")
|
154 |
+
with tracer.start_as_current_span("two") as child:
|
155 |
+
child.add_event("child's event")
|
156 |
+
trace.get_current_span() # returns child
|
157 |
+
trace.get_current_span() # returns parent
|
158 |
+
trace.get_current_span() # returns previously active span
|
159 |
+
|
160 |
+
This can also be used as a decorator::
|
161 |
+
@tracer.start_as_current_span("name")
|
162 |
+
def function():
|
163 |
+
|
164 |
+
Args:
|
165 |
+
name: The name of the span to be created.
|
166 |
+
kind: The span's kind (relationship to parent). Note that is
|
167 |
+
meaningful even if there is no parent.
|
168 |
+
attributes: The span's attributes.
|
169 |
+
start_time: Sets the start time of a span
|
170 |
+
record_exception: Whether to record any exceptions raised within the
|
171 |
+
context as error event on the span.
|
172 |
+
set_status_on_exception: Only relevant if the returned span is used
|
173 |
+
in a with/context manager. Defines whether the span status will
|
174 |
+
be automatically set to ERROR when an uncaught exception is
|
175 |
+
raised in the span with block. The span status won't be set by
|
176 |
+
this mechanism if it was previously set manually.
|
177 |
+
end_on_exit: Whether to end the span automatically when leaving the
|
178 |
+
context manager.
|
179 |
+
trace_context: The trace context to use for the span. If not
|
180 |
+
provided, the current trace context will be used.
|
181 |
+
"""
|
182 |
+
|
183 |
+
|
184 |
+
class Span(ABC):
|
185 |
+
"""A Span represents a single operation within a trace.
|
186 |
+
"""
|
187 |
+
|
188 |
+
@abstractmethod
|
189 |
+
def end(self, end_time: Optional[int] = None) -> None:
|
190 |
+
"""Sets the current time as the span's end time.
|
191 |
+
|
192 |
+
The span's end time is the wall time at which the operation finished.
|
193 |
+
|
194 |
+
Only the first call to `end` should modify the span, and
|
195 |
+
implementations are free to ignore or raise on further calls.
|
196 |
+
"""
|
197 |
+
|
198 |
+
@abstractmethod
|
199 |
+
def set_attribute(self, key: str, value: Any) -> None:
|
200 |
+
"""Sets an attribute on the Span.
|
201 |
+
Args:
|
202 |
+
key: The attribute key.
|
203 |
+
value: The attribute value.
|
204 |
+
"""
|
205 |
+
|
206 |
+
@abstractmethod
|
207 |
+
def set_attributes(self, attributes: dict[str, Any]) -> None:
|
208 |
+
"""Sets multiple attributes on the Span.
|
209 |
+
Args:
|
210 |
+
attributes: A dictionary of attributes to set.
|
211 |
+
"""
|
212 |
+
|
213 |
+
@abstractmethod
|
214 |
+
def is_recording(self) -> bool:
|
215 |
+
"""Returns whether this span will be recorded.
|
216 |
+
Returns true if this Span is active and recording information like attributes using set_attribute.
|
217 |
+
"""
|
218 |
+
|
219 |
+
@abstractmethod
|
220 |
+
def record_exception(
|
221 |
+
self,
|
222 |
+
exception: BaseException,
|
223 |
+
attributes: dict[str, Any] = None,
|
224 |
+
timestamp: Optional[int] = None,
|
225 |
+
escaped: bool = False,
|
226 |
+
) -> None:
|
227 |
+
"""Records an exception in the span.
|
228 |
+
Args:
|
229 |
+
exception: The exception to record.
|
230 |
+
attributes: A dictionary of attributes to set on the exception event.
|
231 |
+
timestamp: The timestamp of the exception.
|
232 |
+
escaped: Whether the exception was escaped.
|
233 |
+
"""
|
234 |
+
|
235 |
+
@abstractmethod
|
236 |
+
def get_trace_id(self) -> str:
|
237 |
+
"""Returns the trace ID of the span.
|
238 |
+
Returns:
|
239 |
+
The trace ID of the span.
|
240 |
+
"""
|
241 |
+
|
242 |
+
@abstractmethod
|
243 |
+
def get_span_id(self) -> str:
|
244 |
+
"""Returns the ID of the span.
|
245 |
+
Returns:
|
246 |
+
The ID of the span.
|
247 |
+
"""
|
248 |
+
|
249 |
+
def _add_to_open_spans(self) -> None:
|
250 |
+
"""Add the current span to OPEN_SPANS."""
|
251 |
+
_OPEN_SPANS.add(self)
|
252 |
+
|
253 |
+
def _remove_from_open_spans(self) -> None:
|
254 |
+
"""Remove the current span from OPEN_SPANS."""
|
255 |
+
_OPEN_SPANS.discard(self)
|
256 |
+
|
257 |
+
|
258 |
+
class NoOpSpan(Span):
|
259 |
+
"""No-op implementation of `Span`."""
|
260 |
+
|
261 |
+
def end(self, end_time: Optional[int] = None) -> None:
|
262 |
+
pass
|
263 |
+
|
264 |
+
def set_attribute(self, key: str, value: Any) -> None:
|
265 |
+
pass
|
266 |
+
|
267 |
+
def set_attributes(self, attributes: dict[str, Any]) -> None:
|
268 |
+
pass
|
269 |
+
|
270 |
+
def is_recording(self) -> bool:
|
271 |
+
return False
|
272 |
+
|
273 |
+
def record_exception(
|
274 |
+
self,
|
275 |
+
exception: BaseException,
|
276 |
+
attributes: dict[str, Any] = None,
|
277 |
+
timestamp: Optional[int] = None,
|
278 |
+
escaped: bool = False,
|
279 |
+
) -> None:
|
280 |
+
pass
|
281 |
+
|
282 |
+
def get_trace_id(self) -> str:
|
283 |
+
return ""
|
284 |
+
|
285 |
+
def get_span_id(self) -> str:
|
286 |
+
return ""
|
287 |
+
|
288 |
+
|
289 |
+
class NoOpTracer(Tracer):
|
290 |
+
"""No-op implementation of `Tracer`."""
|
291 |
+
|
292 |
+
def start_span(
|
293 |
+
self,
|
294 |
+
name: str,
|
295 |
+
span_type: SpanType = SpanType.INTERNAL,
|
296 |
+
attributes: dict[str, AttributeValueType] = None,
|
297 |
+
start_time: Optional[int] = None,
|
298 |
+
record_exception: bool = True,
|
299 |
+
set_status_on_exception: bool = True,
|
300 |
+
trace_context: Optional["TraceContext"] = None,
|
301 |
+
) -> Span:
|
302 |
+
return NoOpSpan()
|
303 |
+
|
304 |
+
def start_as_current_span(
|
305 |
+
self,
|
306 |
+
name: str,
|
307 |
+
span_type: SpanType = SpanType.INTERNAL,
|
308 |
+
attributes: dict[str, AttributeValueType] = None,
|
309 |
+
start_time: Optional[int] = None,
|
310 |
+
record_exception: bool = True,
|
311 |
+
set_status_on_exception: bool = True,
|
312 |
+
end_on_exit: bool = True,
|
313 |
+
trace_context: Optional['TraceContext'] = None
|
314 |
+
) -> Iterator[Span]:
|
315 |
+
yield NoOpSpan()
|
316 |
+
|
317 |
+
|
318 |
+
class Carrier(Protocol):
|
319 |
+
"""Carrier is a protocol that represents a carrier for trace context.
|
320 |
+
"""
|
321 |
+
|
322 |
+
def get(self, key: str) -> Optional[str]:
|
323 |
+
"""Returns the value of the given key from the carrier.
|
324 |
+
Args:
|
325 |
+
key: The key to get the value for.
|
326 |
+
Returns:
|
327 |
+
The value of the given key from the carrier.
|
328 |
+
"""
|
329 |
+
|
330 |
+
def set(self, key: str, value: str) -> None:
|
331 |
+
"""Sets the value of the given key in the carrier.
|
332 |
+
Args:
|
333 |
+
key: The key to set the value for.
|
334 |
+
value: The value to set.
|
335 |
+
"""
|
336 |
+
|
337 |
+
def keys(self) -> Iterable[str]:
|
338 |
+
"""Returns an iterable of keys in the carrier.
|
339 |
+
Returns:
|
340 |
+
An iterable of keys in the carrier.
|
341 |
+
"""
|
342 |
+
|
343 |
+
|
344 |
+
@dataclass(frozen=True)
|
345 |
+
class TraceContext:
|
346 |
+
"""TraceContext is a class that represents a trace context.
|
347 |
+
"""
|
348 |
+
trace_id: str
|
349 |
+
span_id: str
|
350 |
+
version: str = "00"
|
351 |
+
trace_flags: str = "01"
|
352 |
+
attributes: dict[str, Any] = field(default_factory=dict)
|
353 |
+
|
354 |
+
|
355 |
+
class Propagator(ABC):
|
356 |
+
"""Propagator is a protocol that represents a propagator for trace context.
|
357 |
+
"""
|
358 |
+
|
359 |
+
def _get_value(self, carrier: Carrier, name: str) -> str:
|
360 |
+
"""
|
361 |
+
Get value from carrier.
|
362 |
+
Args:
|
363 |
+
carrier: The carrier to get value from.
|
364 |
+
name: The name of the value.
|
365 |
+
Returns:
|
366 |
+
The value of the name.
|
367 |
+
"""
|
368 |
+
return carrier.get(name) or carrier.get('HTTP_' + name.upper().replace('-', '_'))
|
369 |
+
|
370 |
+
@abstractmethod
|
371 |
+
def extract(self, carrier: Carrier) -> Optional[TraceContext]:
|
372 |
+
"""Extracts a trace context from the given carrier.
|
373 |
+
Args:
|
374 |
+
carrier: The carrier to extract the trace context from.
|
375 |
+
Returns:
|
376 |
+
The trace context extracted from the carrier.
|
377 |
+
"""
|
378 |
+
@abstractmethod
|
379 |
+
def inject(self, trace_context: TraceContext, carrier: Carrier) -> None:
|
380 |
+
"""Injects a trace context into the given carrier.
|
381 |
+
Args:
|
382 |
+
trace_context: The trace context to inject.
|
383 |
+
carrier: The carrier to inject the trace context into.
|
384 |
+
"""
|
385 |
+
|
386 |
+
|
387 |
+
_GLOBAL_TRACER_PROVIDER: Optional[TraceProvider] = None
|
388 |
+
_OPEN_SPANS: WeakSet[Span] = WeakSet()
|
389 |
+
|
390 |
+
|
391 |
+
def set_tracer_provider(provider: TraceProvider):
|
392 |
+
"""
|
393 |
+
Set the global tracer provider.
|
394 |
+
"""
|
395 |
+
global _GLOBAL_TRACER_PROVIDER
|
396 |
+
_GLOBAL_TRACER_PROVIDER = provider
|
397 |
+
|
398 |
+
|
399 |
+
def get_tracer_provider() -> TraceProvider:
|
400 |
+
"""
|
401 |
+
Get the global tracer provider.
|
402 |
+
"""
|
403 |
+
global _GLOBAL_TRACER_PROVIDER
|
404 |
+
if _GLOBAL_TRACER_PROVIDER is None:
|
405 |
+
raise Exception("No tracer provider has been set.")
|
406 |
+
return _GLOBAL_TRACER_PROVIDER
|
407 |
+
|
408 |
+
|
409 |
+
def get_tracer_provider_silent():
|
410 |
+
try:
|
411 |
+
return get_tracer_provider()
|
412 |
+
except Exception:
|
413 |
+
return None
|
414 |
+
|
415 |
+
|
416 |
+
def log_trace_error():
|
417 |
+
"""
|
418 |
+
Log an error with traceback information.
|
419 |
+
"""
|
420 |
+
trace_logger.exception(
|
421 |
+
'This is logging the trace internal error.',
|
422 |
+
)
|
aworld/trace/config.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import Sequence, Optional
|
4 |
+
from aworld.trace.span_cosumer import SpanConsumer
|
5 |
+
from logging import Logger
|
6 |
+
from aworld.logs.util import trace_logger
|
7 |
+
from aworld.trace.context_manager import trace_configure
|
8 |
+
from aworld.metrics.context_manager import MetricContext
|
9 |
+
from aworld.logs.log import set_log_provider, instrument_logging
|
10 |
+
from aworld.trace.instrumentation.uni_llmmodel import LLMModelInstrumentor
|
11 |
+
from aworld.trace.instrumentation.eventbus import EventBusInstrumentor
|
12 |
+
|
13 |
+
|
14 |
+
class ObservabilityConfig(BaseModel):
|
15 |
+
'''
|
16 |
+
Observability configuration
|
17 |
+
'''
|
18 |
+
class Config:
|
19 |
+
arbitrary_types_allowed = True
|
20 |
+
trace_provider: Optional[str] = "otlp"
|
21 |
+
trace_backends: Optional[Sequence[str]] = ["memory"]
|
22 |
+
trace_base_url: Optional[str] = None
|
23 |
+
trace_write_token: Optional[str] = None
|
24 |
+
trace_span_consumers: Optional[Sequence[SpanConsumer]] = None
|
25 |
+
# whether to start the trace service
|
26 |
+
trace_server_enabled: Optional[bool] = False
|
27 |
+
trace_server_port: Optional[int] = 7079
|
28 |
+
metrics_provider: Optional[str] = None
|
29 |
+
metrics_backend: Optional[str] = None
|
30 |
+
metrics_base_url: Optional[str] = None
|
31 |
+
metrics_write_token: Optional[str] = None
|
32 |
+
# whether to instrument system metrics
|
33 |
+
metrics_system_enabled: Optional[bool] = False
|
34 |
+
logs_provider: Optional[str] = None
|
35 |
+
logs_backend: Optional[str] = None
|
36 |
+
logs_base_url: Optional[str] = None
|
37 |
+
logs_write_token: Optional[str] = None
|
38 |
+
# The loggers that need to record the log as a span
|
39 |
+
logs_trace_instrumented_loggers: Sequence[Logger] = [trace_logger]
|
40 |
+
|
41 |
+
|
42 |
+
def configure(config: ObservabilityConfig = None):
|
43 |
+
if config is None:
|
44 |
+
config = ObservabilityConfig()
|
45 |
+
_trace_configure(config)
|
46 |
+
_metrics_configure(config)
|
47 |
+
_log_configure(config)
|
48 |
+
LLMModelInstrumentor().instrument()
|
49 |
+
EventBusInstrumentor().instrument()
|
50 |
+
|
51 |
+
|
52 |
+
def _trace_configure(config: ObservabilityConfig):
|
53 |
+
if not config.trace_base_url and config.trace_provider == "otlp":
|
54 |
+
if "logfire" in config.trace_backends:
|
55 |
+
config.trace_base_url = os.getenv("LOGFIRE_WRITE_TOKEN")
|
56 |
+
elif os.getenv("OTLP_TRACES_ENDPOINT"):
|
57 |
+
config.trace_base_url = os.getenv("OTLP_TRACES_ENDPOINT")
|
58 |
+
config.trace_backends.append("other_otlp")
|
59 |
+
|
60 |
+
trace_configure(
|
61 |
+
provider=config.trace_provider,
|
62 |
+
backends=config.trace_backends,
|
63 |
+
base_url=config.trace_base_url,
|
64 |
+
write_token=config.trace_write_token,
|
65 |
+
span_consumers=config.trace_span_consumers,
|
66 |
+
server_enabled=config.trace_server_enabled,
|
67 |
+
server_port=config.trace_server_port
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def _metrics_configure(config: ObservabilityConfig):
|
72 |
+
if config.metrics_provider and config.metrics_backend:
|
73 |
+
MetricContext.configure(
|
74 |
+
provider=config.metrics_provider,
|
75 |
+
backend=config.metrics_backend,
|
76 |
+
base_url=config.metrics_base_url,
|
77 |
+
write_token=config.metrics_write_token,
|
78 |
+
metrics_system_enabled=config.metrics_system_enabled
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def _log_configure(config: ObservabilityConfig):
|
83 |
+
if config.logs_provider and config.logs_backend:
|
84 |
+
if config.logs_backend == "logfire" and not config.logs_write_token:
|
85 |
+
config.logs_write_token = os.getenv("LOGFIRE_WRITE_TOKEN")
|
86 |
+
set_log_provider(provider=config.logs_provider,
|
87 |
+
backend=config.logs_backend,
|
88 |
+
base_url=config.logs_base_url,
|
89 |
+
write_token=config.logs_write_token)
|
90 |
+
|
91 |
+
if config.logs_trace_instrumented_loggers:
|
92 |
+
for logger in config.logs_trace_instrumented_loggers:
|
93 |
+
instrument_logging(logger)
|
aworld/trace/constants.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
ATTRIBUTES_NAMESPACE = 'aworld'
|
4 |
+
"""Namespace within OTEL attributes used by aworld."""
|
5 |
+
|
6 |
+
ATTRIBUTES_MESSAGE_KEY = f'{ATTRIBUTES_NAMESPACE}.msg'
|
7 |
+
"""The formatted message for a log."""
|
8 |
+
|
9 |
+
ATTRIBUTES_MESSAGE_TEMPLATE_KEY = f'{ATTRIBUTES_NAMESPACE}.msg_template'
|
10 |
+
"""The template for a log message."""
|
11 |
+
|
12 |
+
ATTRIBUTES_MESSAGE_RUN_TYPE_KEY = f'{ATTRIBUTES_NAMESPACE}.run_type'
|
13 |
+
"""The template for a log message."""
|
14 |
+
|
15 |
+
MESSAGE_FORMATTED_VALUE_LENGTH_LIMIT = 128
|
16 |
+
"""Maximum number of characters for formatted values in a trace message."""
|
17 |
+
|
18 |
+
|
19 |
+
class RunType(Enum):
|
20 |
+
'''Span run type supported in the framework
|
21 |
+
'''
|
22 |
+
AGNET = "AGENT"
|
23 |
+
TOOL = "TOOL"
|
24 |
+
MCP = "MCP"
|
25 |
+
LLM = "LLM"
|
26 |
+
OTHER = "OTHER"
|
aworld/trace/context_manager.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
import inspect
|
3 |
+
from typing import Union, Optional, Any, Type, Sequence, Callable, Iterable
|
4 |
+
from aworld.trace.base import (
|
5 |
+
AttributeValueType,
|
6 |
+
NoOpSpan,
|
7 |
+
Span, Tracer,
|
8 |
+
NoOpTracer,
|
9 |
+
get_tracer_provider,
|
10 |
+
get_tracer_provider_silent,
|
11 |
+
log_trace_error
|
12 |
+
)
|
13 |
+
from aworld.trace.span_cosumer import SpanConsumer
|
14 |
+
from aworld.version_gen import __version__
|
15 |
+
from aworld.trace.auto_trace import AutoTraceModule, install_auto_tracing
|
16 |
+
from aworld.trace.stack_info import get_user_stack_info
|
17 |
+
from aworld.trace.constants import (
|
18 |
+
ATTRIBUTES_MESSAGE_KEY,
|
19 |
+
ATTRIBUTES_MESSAGE_RUN_TYPE_KEY,
|
20 |
+
ATTRIBUTES_MESSAGE_TEMPLATE_KEY,
|
21 |
+
RunType
|
22 |
+
)
|
23 |
+
from aworld.trace.msg_format import (
|
24 |
+
chunks_formatter,
|
25 |
+
warn_formatting,
|
26 |
+
FStringAwaitError,
|
27 |
+
KnownFormattingError,
|
28 |
+
warn_fstring_await
|
29 |
+
)
|
30 |
+
from aworld.trace.function_trace import trace_func
|
31 |
+
from .opentelemetry.opentelemetry_adapter import configure_otlp_provider
|
32 |
+
from aworld.logs.util import logger
|
33 |
+
|
34 |
+
|
35 |
+
def trace_configure(provider: str = "otlp",
|
36 |
+
backends: Sequence[str] = None,
|
37 |
+
base_url: str = None,
|
38 |
+
write_token: str = None,
|
39 |
+
span_consumers: Optional[Sequence[SpanConsumer]] = None,
|
40 |
+
**kwargs
|
41 |
+
) -> None:
|
42 |
+
"""
|
43 |
+
Configure the trace provider.
|
44 |
+
Args:
|
45 |
+
provider: The trace provider to use.
|
46 |
+
backends: The trace backends to use.
|
47 |
+
base_url: The base URL of the trace backend.
|
48 |
+
write_token: The write token of the trace backend.
|
49 |
+
span_consumers: The span consumers to use.
|
50 |
+
**kwargs: Additional arguments to pass to the trace provider.
|
51 |
+
Returns:
|
52 |
+
None
|
53 |
+
"""
|
54 |
+
exist_provider = get_tracer_provider_silent()
|
55 |
+
if exist_provider:
|
56 |
+
logger.info("Trace provider already configured, shutting down...")
|
57 |
+
exist_provider.shutdown()
|
58 |
+
if provider == "otlp":
|
59 |
+
configure_otlp_provider(
|
60 |
+
backends=backends, base_url=base_url, write_token=write_token, span_consumers=span_consumers, **kwargs)
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Unknown trace provider: {provider}")
|
63 |
+
|
64 |
+
|
65 |
+
class TraceManager:
|
66 |
+
"""
|
67 |
+
TraceManager is a class that provides a way to trace the execution of a function.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, tracer_name: str = None) -> None:
|
71 |
+
self._tracer_name = tracer_name or "aworld"
|
72 |
+
self._version = __version__
|
73 |
+
|
74 |
+
def _create_auto_span(self,
|
75 |
+
name: str,
|
76 |
+
attributes: dict[str, AttributeValueType] = None
|
77 |
+
) -> Span:
|
78 |
+
"""
|
79 |
+
Create a auto trace span with the given name and attributes.
|
80 |
+
"""
|
81 |
+
return self._create_context_span(name, attributes)
|
82 |
+
|
83 |
+
def _create_context_span(self,
|
84 |
+
name: str,
|
85 |
+
attributes: dict[str, AttributeValueType] = None) -> Span:
|
86 |
+
try:
|
87 |
+
tracer = get_tracer_provider().get_tracer(
|
88 |
+
name=self._tracer_name, version=self._version)
|
89 |
+
return ContextSpan(span_name=name, tracer=tracer, attributes=attributes)
|
90 |
+
except Exception:
|
91 |
+
return ContextSpan(span_name=name, tracer=NoOpTracer(), attributes=attributes)
|
92 |
+
|
93 |
+
def get_current_span(self) -> Span:
|
94 |
+
"""
|
95 |
+
Get the current span.
|
96 |
+
"""
|
97 |
+
try:
|
98 |
+
return get_tracer_provider().get_current_span()
|
99 |
+
except Exception:
|
100 |
+
return NoOpSpan()
|
101 |
+
|
102 |
+
def new_manager(self, tracer_name_suffix: str = None) -> "TraceManager":
|
103 |
+
"""
|
104 |
+
Create a new TraceManager with the given tracer name suffix.
|
105 |
+
"""
|
106 |
+
tracer_name = self._tracer_name if not tracer_name_suffix else f"{self._tracer_name}.{tracer_name_suffix}"
|
107 |
+
return TraceManager(tracer_name=tracer_name)
|
108 |
+
|
109 |
+
def auto_tracing(self,
|
110 |
+
modules: Union[Sequence[str], Callable[[AutoTraceModule], bool]],
|
111 |
+
min_duration: float) -> None:
|
112 |
+
"""
|
113 |
+
Automatically trace the execution of a function.
|
114 |
+
Args:
|
115 |
+
modules: A list of module names or a callable that takes a `AutoTraceModule` and returns a boolean.
|
116 |
+
min_duration: The minimum duration of a function to be traced.
|
117 |
+
Returns:
|
118 |
+
None
|
119 |
+
"""
|
120 |
+
install_auto_tracing(self, modules, min_duration)
|
121 |
+
|
122 |
+
def span(self,
|
123 |
+
msg_template: str = "",
|
124 |
+
attributes: dict[str, AttributeValueType] = None,
|
125 |
+
*,
|
126 |
+
span_name: str = None,
|
127 |
+
run_type: RunType = RunType.OTHER) -> "ContextSpan":
|
128 |
+
|
129 |
+
try:
|
130 |
+
attributes = attributes or {}
|
131 |
+
stack_info = get_user_stack_info()
|
132 |
+
merged_attributes = {**stack_info, **attributes}
|
133 |
+
# Retrieve stack information of user code and add it to the attributes
|
134 |
+
|
135 |
+
if any(c in msg_template for c in ('{', '}')):
|
136 |
+
fstring_frame = inspect.currentframe().f_back
|
137 |
+
else:
|
138 |
+
fstring_frame = None
|
139 |
+
log_message, extra_attrs, msg_template = format_span_msg(
|
140 |
+
msg_template,
|
141 |
+
merged_attributes,
|
142 |
+
fstring_frame=fstring_frame,
|
143 |
+
)
|
144 |
+
merged_attributes[ATTRIBUTES_MESSAGE_KEY] = log_message
|
145 |
+
merged_attributes.update(extra_attrs)
|
146 |
+
merged_attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] = msg_template
|
147 |
+
merged_attributes[ATTRIBUTES_MESSAGE_RUN_TYPE_KEY] = run_type.value
|
148 |
+
span_name = span_name or msg_template
|
149 |
+
|
150 |
+
return self._create_context_span(span_name, merged_attributes)
|
151 |
+
|
152 |
+
except Exception:
|
153 |
+
log_trace_error()
|
154 |
+
return ContextSpan(span_name=span_name, tracer=NoOpTracer(), attributes=attributes)
|
155 |
+
|
156 |
+
def func_span(self,
|
157 |
+
msg_template: Union[str, Callable] = None,
|
158 |
+
*,
|
159 |
+
attributes: dict[str, AttributeValueType] = None,
|
160 |
+
span_name: str = None,
|
161 |
+
extract_args: Union[bool, Iterable[str]] = False,
|
162 |
+
**kwargs) -> Callable:
|
163 |
+
"""
|
164 |
+
A decorator that traces the execution of a function.
|
165 |
+
Args:
|
166 |
+
msg_template: The message template to use.
|
167 |
+
attributes: The attributes to add to the span.
|
168 |
+
span_name: The name of the span.
|
169 |
+
extract_args: Whether to extract arguments from the function call.
|
170 |
+
**kwargs: Additional attributes to add to the span.
|
171 |
+
Returns:
|
172 |
+
A decorator that traces the execution of a function.
|
173 |
+
"""
|
174 |
+
if callable(msg_template):
|
175 |
+
# @trace_func
|
176 |
+
# def foo():
|
177 |
+
return self.func_span()(msg_template)
|
178 |
+
|
179 |
+
attributes = attributes or {}
|
180 |
+
attributes.update(kwargs)
|
181 |
+
return trace_func(self, msg_template, attributes, span_name, extract_args)
|
182 |
+
|
183 |
+
|
184 |
+
class ContextSpan(Span):
|
185 |
+
"""A context manager that wraps an existing `Span` object.
|
186 |
+
This class provides a way to use a `Span` object as a context manager.
|
187 |
+
When the context manager is entered, it returns the `Span` itself.
|
188 |
+
When the context manager is exited, it calls `end` on the `Span`.
|
189 |
+
Args:
|
190 |
+
span: The `Span` object to wrap.
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(self,
|
194 |
+
span_name: str,
|
195 |
+
tracer: Tracer,
|
196 |
+
attributes: dict[str, AttributeValueType] = None) -> None:
|
197 |
+
self._span_name = span_name
|
198 |
+
self._tracer = tracer
|
199 |
+
self._attributes = attributes
|
200 |
+
self._span: Span = None
|
201 |
+
self._coro_context = None
|
202 |
+
|
203 |
+
def _start(self):
|
204 |
+
if self._span is not None:
|
205 |
+
return
|
206 |
+
|
207 |
+
self._span = self._tracer.start_span(
|
208 |
+
name=self._span_name,
|
209 |
+
attributes=self._attributes,
|
210 |
+
)
|
211 |
+
|
212 |
+
def __enter__(self) -> "Span":
|
213 |
+
self._start()
|
214 |
+
return self
|
215 |
+
|
216 |
+
def __exit__(
|
217 |
+
self,
|
218 |
+
exc_type: Optional[Type[BaseException]],
|
219 |
+
exc_val: Optional[BaseException],
|
220 |
+
traceback: Optional[Any],
|
221 |
+
) -> None:
|
222 |
+
"""Ends context manager and calls `end` on the `Span`."""
|
223 |
+
self._handle_exit(exc_type, exc_val, traceback)
|
224 |
+
|
225 |
+
async def __aenter__(self) -> "Span":
|
226 |
+
self._start()
|
227 |
+
|
228 |
+
return self
|
229 |
+
|
230 |
+
async def __aexit__(
|
231 |
+
self,
|
232 |
+
exc_type: Optional[Type[BaseException]],
|
233 |
+
exc_val: Optional[BaseException],
|
234 |
+
traceback: Optional[Any],
|
235 |
+
) -> None:
|
236 |
+
self._handle_exit(exc_type, exc_val, traceback)
|
237 |
+
|
238 |
+
def _handle_exit(
|
239 |
+
self,
|
240 |
+
exc_type: Optional[Type[BaseException]],
|
241 |
+
exc_val: Optional[BaseException],
|
242 |
+
traceback: Optional[Any],
|
243 |
+
) -> None:
|
244 |
+
try:
|
245 |
+
if self._span and self._span.is_recording() and isinstance(exc_val, BaseException):
|
246 |
+
self._span.record_exception(exc_val, escaped=True)
|
247 |
+
except ValueError as e:
|
248 |
+
logger.warning(f"Failed to record_exception: {e}")
|
249 |
+
finally:
|
250 |
+
if self._span:
|
251 |
+
self._span.end()
|
252 |
+
|
253 |
+
def end(self, end_time: Optional[int] = None) -> None:
|
254 |
+
if self._span:
|
255 |
+
self._span.end(end_time)
|
256 |
+
|
257 |
+
def set_attribute(self, key: str, value: AttributeValueType) -> None:
|
258 |
+
if self._span:
|
259 |
+
self._span.set_attribute(key, value)
|
260 |
+
|
261 |
+
def set_attributes(self, attributes: dict[str, AttributeValueType]) -> None:
|
262 |
+
if self._span:
|
263 |
+
self._span.set_attributes(attributes)
|
264 |
+
|
265 |
+
def is_recording(self) -> bool:
|
266 |
+
if self._span:
|
267 |
+
return self._span.is_recording()
|
268 |
+
return False
|
269 |
+
|
270 |
+
def record_exception(
|
271 |
+
self,
|
272 |
+
exception: BaseException,
|
273 |
+
attributes: dict[str, Any] = None,
|
274 |
+
timestamp: Optional[int] = None,
|
275 |
+
escaped: bool = False,
|
276 |
+
) -> None:
|
277 |
+
if self._span:
|
278 |
+
self._span.record_exception(
|
279 |
+
exception, attributes, timestamp, escaped)
|
280 |
+
|
281 |
+
def get_trace_id(self) -> str:
|
282 |
+
if self._span:
|
283 |
+
return self._span.get_trace_id()
|
284 |
+
|
285 |
+
def get_span_id(self) -> str:
|
286 |
+
if self._span:
|
287 |
+
return self._span.get_span_id()
|
288 |
+
|
289 |
+
|
290 |
+
def format_span_msg(
|
291 |
+
format_string: str,
|
292 |
+
kwargs: dict[str, Any],
|
293 |
+
fstring_frame: types.FrameType = None,
|
294 |
+
) -> tuple[str, dict[str, Any], str]:
|
295 |
+
""" Returns
|
296 |
+
1. The formatted message.
|
297 |
+
2. A dictionary of extra attributes to add to the span/log.
|
298 |
+
These can come from evaluating values in f-strings.
|
299 |
+
3. The final message template, which may differ from `format_string` if it was an f-string.
|
300 |
+
"""
|
301 |
+
try:
|
302 |
+
chunks, extra_attrs, new_template = chunks_formatter.chunks(
|
303 |
+
format_string,
|
304 |
+
kwargs,
|
305 |
+
fstring_frame=fstring_frame
|
306 |
+
)
|
307 |
+
return ''.join(chunk['v'] for chunk in chunks), extra_attrs, new_template
|
308 |
+
except KnownFormattingError as e:
|
309 |
+
warn_formatting(str(e) or str(e.__cause__))
|
310 |
+
except FStringAwaitError as e:
|
311 |
+
warn_fstring_await(str(e))
|
312 |
+
except Exception:
|
313 |
+
log_trace_error()
|
314 |
+
|
315 |
+
# Formatting failed, so just use the original format string as the message.
|
316 |
+
return format_string, {}, format_string
|
aworld/trace/function_trace.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import contextlib
|
3 |
+
import functools
|
4 |
+
from typing import TYPE_CHECKING, Callable, Any, Union, Iterable
|
5 |
+
from aworld.trace.base import (
|
6 |
+
AttributeValueType
|
7 |
+
)
|
8 |
+
|
9 |
+
from aworld.trace.stack_info import get_filepath_attribute
|
10 |
+
from aworld.trace.constants import (
|
11 |
+
ATTRIBUTES_MESSAGE_TEMPLATE_KEY
|
12 |
+
)
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from aworld.trace.context_manager import TraceManager, ContextSpan
|
16 |
+
|
17 |
+
|
18 |
+
def trace_func(trace_manager: "TraceManager",
|
19 |
+
msg_template: str = None,
|
20 |
+
attributes: dict[str, AttributeValueType] = None,
|
21 |
+
span_name: str = None,
|
22 |
+
extract_args: Union[bool, Iterable[str]] = False):
|
23 |
+
"""A decorator that traces the execution of a function.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
trace_manager: The trace manager to use.
|
27 |
+
msg_template: The message template to use.
|
28 |
+
attributes: The attributes to use.
|
29 |
+
span_name: The span name to use.
|
30 |
+
extract_args: Whether to extract arguments from the function call.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
The decorated function.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def decorator(func: Callable) -> Callable:
|
37 |
+
func_meta = get_function_meta(func, msg_template)
|
38 |
+
func_meta.update(attributes or {})
|
39 |
+
final_span_name = span_name or func_meta.get(ATTRIBUTES_MESSAGE_TEMPLATE_KEY) or func.__name__
|
40 |
+
|
41 |
+
if inspect.isgeneratorfunction(func):
|
42 |
+
def wrapper(*args, **kwargs):
|
43 |
+
with open_func_span(trace_manager, func_meta, final_span_name,
|
44 |
+
get_func_args(func, extract_args, *args, **kwargs)):
|
45 |
+
for item in func(*args, **kwargs):
|
46 |
+
yield item
|
47 |
+
elif inspect.isasyncgenfunction(func):
|
48 |
+
async def wrapper(*args, **kwargs):
|
49 |
+
with open_func_span(trace_manager, func_meta, final_span_name,
|
50 |
+
get_func_args(func, extract_args, *args, **kwargs)):
|
51 |
+
async for item in func(*args, **kwargs):
|
52 |
+
yield item
|
53 |
+
elif inspect.iscoroutinefunction(func):
|
54 |
+
async def wrapper(*args, **kwargs):
|
55 |
+
with open_func_span(trace_manager, func_meta, final_span_name,
|
56 |
+
get_func_args(func, extract_args, *args, **kwargs)):
|
57 |
+
return await func(*args, **kwargs)
|
58 |
+
else:
|
59 |
+
def wrapper(*args, **kwargs):
|
60 |
+
with open_func_span(trace_manager, func_meta, final_span_name,
|
61 |
+
get_func_args(func, extract_args, *args, **kwargs)):
|
62 |
+
return func(*args, **kwargs)
|
63 |
+
|
64 |
+
wrapper = functools.wraps(func)(wrapper) # type: ignore
|
65 |
+
return wrapper
|
66 |
+
|
67 |
+
return decorator
|
68 |
+
|
69 |
+
|
70 |
+
def open_func_span(trace_manager: "TraceManager",
|
71 |
+
func_meta: dict[str, AttributeValueType],
|
72 |
+
span_name: str,
|
73 |
+
func_args: dict[str, AttributeValueType]):
|
74 |
+
"""Open a function span.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
func_meta: The function meta information.
|
78 |
+
span_name: The span name.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
The function span.
|
82 |
+
"""
|
83 |
+
func_meta.update(func_args)
|
84 |
+
return trace_manager._create_auto_span(name=span_name, attributes=func_meta)
|
85 |
+
|
86 |
+
|
87 |
+
def get_func_args(func: Callable,
|
88 |
+
extract_args: Union[bool, Iterable[str]] = False,
|
89 |
+
*args,
|
90 |
+
**kwargs):
|
91 |
+
"""Get the arguments of a function.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
func: The function to get the arguments of.
|
95 |
+
extract_args: Whether to extract arguments from the function call.
|
96 |
+
*args: The positional arguments.
|
97 |
+
**kwargs: The keyword arguments.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
The arguments of the function.
|
101 |
+
"""
|
102 |
+
func_sig = inspect.signature(func)
|
103 |
+
if func_sig.parameters:
|
104 |
+
func_args = func_sig.bind(*args, **kwargs).arguments
|
105 |
+
if extract_args is not False:
|
106 |
+
if isinstance(extract_args, bool):
|
107 |
+
extract_args = func_sig.parameters.keys()
|
108 |
+
func_args = {k: v for k, v in func_args.items() if k in extract_args}
|
109 |
+
return func_args
|
110 |
+
return {}
|
111 |
+
|
112 |
+
|
113 |
+
def get_function_meta(func: Any,
|
114 |
+
msg_template: str = None) -> dict[str, AttributeValueType]:
|
115 |
+
"""Get the meta information of a function.\
|
116 |
+
|
117 |
+
Args:
|
118 |
+
func: The function to get the meta information of.
|
119 |
+
msg_template: The message template to use.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
The meta information of the function.
|
123 |
+
"""
|
124 |
+
func = inspect.unwrap(func)
|
125 |
+
if not inspect.isfunction(func) and hasattr(func, '__call__'):
|
126 |
+
func = func.__call__
|
127 |
+
func = inspect.unwrap(func)
|
128 |
+
|
129 |
+
func_name = getattr(func, '__qualname__', getattr(func, '__name__', build_func_name(func)))
|
130 |
+
if not msg_template:
|
131 |
+
try:
|
132 |
+
msg_template = f'Calling {inspect.getmodule(func).__name__}.{func_name}' # type: ignore
|
133 |
+
except Exception: # pragma: no cover
|
134 |
+
msg_template = f'Calling {func_name}'
|
135 |
+
meta: dict[str, AttributeValueType] = {
|
136 |
+
'code.function': func_name,
|
137 |
+
ATTRIBUTES_MESSAGE_TEMPLATE_KEY: msg_template,
|
138 |
+
}
|
139 |
+
with contextlib.suppress(Exception):
|
140 |
+
meta['code.lineno'] = func.__code__.co_firstlineno
|
141 |
+
with contextlib.suppress(Exception):
|
142 |
+
# get code.filepath
|
143 |
+
meta.update(get_filepath_attribute(inspect.getsourcefile(func)))
|
144 |
+
|
145 |
+
func_sig = inspect.signature(func)
|
146 |
+
if func_sig.parameters:
|
147 |
+
meta['func.args'] = [str(param) for param in func_sig.parameters.values()
|
148 |
+
if param.name != 'self']
|
149 |
+
return meta
|
150 |
+
|
151 |
+
|
152 |
+
def build_func_name(func: Any) -> str:
|
153 |
+
"""Build the function name.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
func: The function to build the name of.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
The function name.
|
160 |
+
"""
|
161 |
+
try:
|
162 |
+
result = repr(func)
|
163 |
+
except Exception:
|
164 |
+
result = f'<{type(func).__name__} object>'
|
165 |
+
|
166 |
+
return result
|
aworld/trace/msg_format.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import inspect
|
3 |
+
import sys
|
4 |
+
import types
|
5 |
+
import warnings
|
6 |
+
import executing
|
7 |
+
from functools import lru_cache
|
8 |
+
from string import Formatter
|
9 |
+
from types import CodeType
|
10 |
+
from typing import Any, Literal, TypeVar
|
11 |
+
from typing_extensions import NotRequired, TypedDict
|
12 |
+
from .constants import MESSAGE_FORMATTED_VALUE_LENGTH_LIMIT
|
13 |
+
from .stack_info import get_user_frame_and_stacklevel
|
14 |
+
|
15 |
+
Truncatable = TypeVar('Truncatable', str, bytes, 'list[Any]', 'tuple[Any, ...]')
|
16 |
+
|
17 |
+
class LiteralChunk(TypedDict):
|
18 |
+
t: Literal['lit']
|
19 |
+
v: str
|
20 |
+
|
21 |
+
|
22 |
+
class ArgChunk(TypedDict):
|
23 |
+
t: Literal['arg']
|
24 |
+
v: str
|
25 |
+
spec: NotRequired[str]
|
26 |
+
|
27 |
+
|
28 |
+
class KnownFormattingError(Exception):
|
29 |
+
"""An error raised when there's something wrong with a format string or the field values.
|
30 |
+
|
31 |
+
In other words this should correspond to errors that would be raised when using `str.format`,
|
32 |
+
and generally indicate a user error, most likely that they weren't trying to pass a template string at all.
|
33 |
+
"""
|
34 |
+
|
35 |
+
|
36 |
+
class FStringAwaitError(Exception):
|
37 |
+
"""An error raised when an await expression is found in an f-string.
|
38 |
+
|
39 |
+
This is a specific case that can't be handled by f-string introspection and requires
|
40 |
+
pre-evaluating the await expression before logging.
|
41 |
+
"""
|
42 |
+
|
43 |
+
|
44 |
+
class FormattingFailedWarning(UserWarning):
|
45 |
+
pass
|
46 |
+
|
47 |
+
class InspectArgumentsFailedWarning(Warning):
|
48 |
+
pass
|
49 |
+
|
50 |
+
class ChunksFormatter(Formatter):
|
51 |
+
def chunks(
|
52 |
+
self,
|
53 |
+
format_string: str,
|
54 |
+
kwargs: dict[str, Any],
|
55 |
+
*,
|
56 |
+
fstring_frame: types.FrameType = None,
|
57 |
+
) -> tuple[list[LiteralChunk | ArgChunk], dict[str, Any], str]:
|
58 |
+
# Returns
|
59 |
+
# 1. A list of chunks
|
60 |
+
# 2. A dictionary of extra attributes to add to the span/log.
|
61 |
+
# These can come from evaluating values in f-strings,
|
62 |
+
# or from noting scrubbed values.
|
63 |
+
# 3. The final message template, which may differ from `format_string` if it was an f-string.
|
64 |
+
if fstring_frame:
|
65 |
+
result = self._fstring_chunks(kwargs, fstring_frame)
|
66 |
+
if result: # returns None if faile
|
67 |
+
return result
|
68 |
+
|
69 |
+
chunks = self._vformat_chunks(
|
70 |
+
format_string,
|
71 |
+
kwargs=kwargs
|
72 |
+
)
|
73 |
+
# When there's no f-string magic, there's no changes in the template string.
|
74 |
+
return chunks, {}, format_string
|
75 |
+
|
76 |
+
def _fstring_chunks(
|
77 |
+
self,
|
78 |
+
kwargs: dict[str, Any],
|
79 |
+
frame: types.FrameType,
|
80 |
+
) -> tuple[list[LiteralChunk | ArgChunk], dict[str, Any], str]:
|
81 |
+
# `frame` is the frame of the method that's being called by the user
|
82 |
+
# called_code = frame.f_code
|
83 |
+
frame = frame.f_back or frame # type: ignore
|
84 |
+
assert frame is not None
|
85 |
+
# This is where the magic happens. It has caching.
|
86 |
+
ex = executing.Source.executing(frame)
|
87 |
+
|
88 |
+
call_node = ex.node
|
89 |
+
if call_node is None: # type: ignore[reportUnnecessaryComparison]
|
90 |
+
# `executing` failed to find a node.
|
91 |
+
# This shouldn't happen in most cases, but it's best not to rely on it always working.
|
92 |
+
if not ex.source.text:
|
93 |
+
# This is a very likely cause.
|
94 |
+
# There's nothing we could possibly do to make magic work here,
|
95 |
+
# and it's a clear case where the user should turn the magic off.
|
96 |
+
warn_inspect_arguments(
|
97 |
+
'No source code available. '
|
98 |
+
'This happens when running in an interactive shell, '
|
99 |
+
'using exec(), or running .pyc files without the source .py files.',
|
100 |
+
get_stacklevel(frame),
|
101 |
+
)
|
102 |
+
return None
|
103 |
+
|
104 |
+
msg = '`executing` failed to find a node.'
|
105 |
+
if sys.version_info[:2] < (3, 11): # pragma: no cover
|
106 |
+
# inspect_arguments is only on by default for 3.11+ for this reason.
|
107 |
+
# The AST modifications made by auto-tracing
|
108 |
+
# mean that the bytecode doesn't match the source code seen by `executing`.
|
109 |
+
# In 3.11+, a different algorithm is used by `executing` which can deal with this.
|
110 |
+
msg += ' This may be caused by a combination of using Python < 3.11 and auto-tracing.'
|
111 |
+
|
112 |
+
# Try a simple fallback heuristic to find the node which should work in most cases.
|
113 |
+
main_nodes: list[ast.AST] = []
|
114 |
+
for statement in ex.statements:
|
115 |
+
if isinstance(statement, ast.With):
|
116 |
+
# Only look at the 'header' of a with statement, not its body.
|
117 |
+
main_nodes += statement.items
|
118 |
+
else:
|
119 |
+
main_nodes.append(statement)
|
120 |
+
call_nodes = [
|
121 |
+
node
|
122 |
+
for main_node in main_nodes
|
123 |
+
for node in ast.walk(main_node)
|
124 |
+
if isinstance(node, ast.Call)
|
125 |
+
if node.args or node.keywords
|
126 |
+
]
|
127 |
+
if len(call_nodes) != 1:
|
128 |
+
warn_inspect_arguments(msg, get_stacklevel(frame))
|
129 |
+
return None
|
130 |
+
|
131 |
+
[call_node] = call_nodes
|
132 |
+
|
133 |
+
if not isinstance(call_node, ast.Call): # pragma: no cover
|
134 |
+
# Very unlikely.
|
135 |
+
warn_inspect_arguments(
|
136 |
+
'`executing` unexpectedly identified a non-Call node.',
|
137 |
+
get_stacklevel(frame),
|
138 |
+
)
|
139 |
+
return None
|
140 |
+
|
141 |
+
if call_node.args:
|
142 |
+
arg_node = call_node.args[0]
|
143 |
+
else:
|
144 |
+
# Very unlikely.
|
145 |
+
warn_inspect_arguments(
|
146 |
+
"Couldn't identify the `msg_template` argument in the call.",
|
147 |
+
get_stacklevel(frame),
|
148 |
+
)
|
149 |
+
return None
|
150 |
+
|
151 |
+
if not isinstance(arg_node, ast.JoinedStr):
|
152 |
+
# Not an f-string, not a problem.
|
153 |
+
# Just use normal formatting.
|
154 |
+
return None
|
155 |
+
|
156 |
+
# We have an f-string AST node.
|
157 |
+
# Now prepare the namespaces that we will use to evaluate the components.
|
158 |
+
global_vars = frame.f_globals
|
159 |
+
local_vars = {**frame.f_locals, **kwargs}
|
160 |
+
|
161 |
+
# Now for the actual formatting!
|
162 |
+
result: list[LiteralChunk | ArgChunk] = []
|
163 |
+
|
164 |
+
# We construct the message template (i.e. the span name) from the AST.
|
165 |
+
# We don't use the source code of the f-string because that gets messy
|
166 |
+
# if there's escaped quotes or implicit joining of adjacent strings.
|
167 |
+
new_template = ''
|
168 |
+
|
169 |
+
extra_attrs: dict[str, Any] = {}
|
170 |
+
for node_value in arg_node.values:
|
171 |
+
if isinstance(node_value, ast.Constant):
|
172 |
+
# These are the parts of the f-string not enclosed by `{}`, e.g. 'foo ' in f'foo {bar}'
|
173 |
+
value: str = node_value.value
|
174 |
+
result.append({'v': value, 't': 'lit'})
|
175 |
+
new_template += value
|
176 |
+
else:
|
177 |
+
# These are the parts of the f-string enclosed by `{}`, e.g. 'bar' in f'foo {bar}'
|
178 |
+
assert isinstance(node_value, ast.FormattedValue)
|
179 |
+
|
180 |
+
# This is cached.
|
181 |
+
source, value_code, formatted_code = compile_formatted_value(node_value, ex.source)
|
182 |
+
|
183 |
+
# Note that this doesn't include:
|
184 |
+
# - The format spec, e.g. `:0.2f`
|
185 |
+
# - The conversion, e.g. `!r`
|
186 |
+
# - The '=' sign within the braces, e.g. `{bar=}`.
|
187 |
+
# The AST represents f'{bar = }' as f'bar = {bar}' which is how the template will look.
|
188 |
+
new_template += '{' + source + '}'
|
189 |
+
|
190 |
+
# The actual value of the expression.
|
191 |
+
value = eval(value_code, global_vars, local_vars)
|
192 |
+
extra_attrs[source] = value
|
193 |
+
|
194 |
+
# Format the value according to the format spec, converting to a string.
|
195 |
+
formatted = eval(formatted_code, global_vars, {**local_vars, '@fvalue': value})
|
196 |
+
formatted = self._clean_value(formatted)
|
197 |
+
result.append({'v': formatted, 't': 'arg'})
|
198 |
+
|
199 |
+
return result, extra_attrs, new_template
|
200 |
+
|
201 |
+
def _vformat_chunks(
|
202 |
+
self,
|
203 |
+
format_string: str,
|
204 |
+
kwargs: dict[str, Any],
|
205 |
+
*,
|
206 |
+
recursion_depth: int = 2,
|
207 |
+
) -> list[LiteralChunk | ArgChunk]:
|
208 |
+
"""Copied from `string.Formatter._vformat` https://github.com/python/cpython/blob/v3.11.4/Lib/string.py#L198-L247 then altered."""
|
209 |
+
if recursion_depth < 0:
|
210 |
+
raise KnownFormattingError('Max format spec recursion exceeded')
|
211 |
+
result: list[LiteralChunk | ArgChunk] = []
|
212 |
+
# We currently don't use positional arguments
|
213 |
+
args = ()
|
214 |
+
|
215 |
+
for literal_text, field_name, format_spec, conversion in self.parse(format_string):
|
216 |
+
# output the literal text
|
217 |
+
if literal_text:
|
218 |
+
result.append({'v': literal_text, 't': 'lit'})
|
219 |
+
|
220 |
+
# if there's a field, output it
|
221 |
+
if field_name is not None:
|
222 |
+
# this is some markup, find the object and do
|
223 |
+
# the formatting
|
224 |
+
if field_name == '':
|
225 |
+
raise KnownFormattingError('Empty curly brackets `{}` are not allowed. A field name is required.')
|
226 |
+
|
227 |
+
# ADDED BY US:
|
228 |
+
if field_name.endswith('='):
|
229 |
+
if result and result[-1]['t'] == 'lit':
|
230 |
+
result[-1]['v'] += field_name
|
231 |
+
else:
|
232 |
+
result.append({'v': field_name, 't': 'lit'})
|
233 |
+
field_name = field_name[:-1]
|
234 |
+
|
235 |
+
# given the field_name, find the object it references
|
236 |
+
# and the argument it came from
|
237 |
+
try:
|
238 |
+
obj, _arg_used = self.get_field(field_name, args, kwargs)
|
239 |
+
except IndexError:
|
240 |
+
raise KnownFormattingError('Numeric field names are not allowed.')
|
241 |
+
except KeyError as exc1:
|
242 |
+
if str(exc1) == repr(field_name):
|
243 |
+
raise KnownFormattingError(f'The field {{{field_name}}} is not defined.') from exc1
|
244 |
+
|
245 |
+
try:
|
246 |
+
# field_name is something like 'a.b' or 'a[b]'
|
247 |
+
# Evaluating that expression failed, so now just try getting the whole thing from kwargs.
|
248 |
+
# In particular, OTEL attributes with dots in their names are normal and handled here.
|
249 |
+
obj = kwargs[field_name]
|
250 |
+
except KeyError as exc2:
|
251 |
+
# e.g. neither 'a' nor 'a.b' is defined
|
252 |
+
raise KnownFormattingError(f'The fields {exc1} and {exc2} are not defined.') from exc2
|
253 |
+
except Exception as exc:
|
254 |
+
raise KnownFormattingError(f'Error getting field {{{field_name}}}: {exc}') from exc
|
255 |
+
|
256 |
+
# do any conversion on the resulting object
|
257 |
+
if conversion is not None:
|
258 |
+
try:
|
259 |
+
obj = self.convert_field(obj, conversion)
|
260 |
+
except Exception as exc:
|
261 |
+
raise KnownFormattingError(f'Error converting field {{{field_name}}}: {exc}') from exc
|
262 |
+
|
263 |
+
# expand the format spec, if needed
|
264 |
+
format_spec_chunks = self._vformat_chunks(
|
265 |
+
format_spec or '', kwargs, recursion_depth=recursion_depth - 1
|
266 |
+
)
|
267 |
+
format_spec = ''.join(chunk['v'] for chunk in format_spec_chunks)
|
268 |
+
|
269 |
+
try:
|
270 |
+
value = self.format_field(obj, format_spec)
|
271 |
+
except Exception as exc:
|
272 |
+
raise KnownFormattingError(f'Error formatting field {{{field_name}}}: {exc}') from exc
|
273 |
+
value = self._clean_value(value)
|
274 |
+
d: ArgChunk = {'v': value, 't': 'arg'}
|
275 |
+
if format_spec:
|
276 |
+
d['spec'] = format_spec
|
277 |
+
result.append(d)
|
278 |
+
|
279 |
+
return result
|
280 |
+
|
281 |
+
def _clean_value(self, value: str) -> str:
|
282 |
+
return truncate_sequence(seq=value, max_length=MESSAGE_FORMATTED_VALUE_LENGTH_LIMIT, middle='...')
|
283 |
+
|
284 |
+
def warn_inspect_arguments(msg: str, stacklevel: int):
|
285 |
+
"""Warn about an error in inspecting arguments.
|
286 |
+
This is a separate function so that it can be called from multiple places.
|
287 |
+
"""
|
288 |
+
msg = (
|
289 |
+
'Failed to introspect calling code. '
|
290 |
+
'Falling back to normal message formatting '
|
291 |
+
'which may result in loss of information if using an f-string. '
|
292 |
+
'The problem was:\n'
|
293 |
+
) + msg
|
294 |
+
warnings.warn(msg, InspectArgumentsFailedWarning, stacklevel=stacklevel)
|
295 |
+
|
296 |
+
|
297 |
+
def get_stacklevel(frame: types.FrameType):
|
298 |
+
"""Get a stacklevel which can be passed to warn_inspect_arguments
|
299 |
+
which points at the given frame, where the f-string was found.
|
300 |
+
"""
|
301 |
+
current_frame = inspect.currentframe()
|
302 |
+
stacklevel = 0
|
303 |
+
while current_frame: # pragma: no branch
|
304 |
+
if current_frame == frame:
|
305 |
+
break
|
306 |
+
stacklevel += 1
|
307 |
+
current_frame = current_frame.f_back
|
308 |
+
return stacklevel
|
309 |
+
|
310 |
+
@lru_cache
|
311 |
+
def compile_formatted_value(node: ast.FormattedValue, ex_source: executing.Source) -> tuple[str, CodeType, CodeType]:
|
312 |
+
"""Returns three things that can be expensive to compute.
|
313 |
+
|
314 |
+
1. Source code corresponding to the node value (excluding the format spec).
|
315 |
+
2. A compiled code object which can be evaluated to calculate the value.
|
316 |
+
3. Another code object which formats the value.
|
317 |
+
"""
|
318 |
+
source = get_node_source_text(node.value, ex_source)
|
319 |
+
|
320 |
+
# Check if the expression contains await before attempting to compile
|
321 |
+
for sub_node in ast.walk(node.value):
|
322 |
+
if isinstance(sub_node, ast.Await):
|
323 |
+
raise FStringAwaitError(source)
|
324 |
+
|
325 |
+
value_code = compile(source, '<fvalue1>', 'eval')
|
326 |
+
expr = ast.Expression(
|
327 |
+
ast.JoinedStr(
|
328 |
+
values=[
|
329 |
+
# Similar to the original FormattedValue node,
|
330 |
+
# but replace the actual expression with a simple variable lookup
|
331 |
+
# so that it the expression doesn't need to be evaluated again.
|
332 |
+
# Use @ in the variable name so that it can't possibly conflict
|
333 |
+
# with a normal variable.
|
334 |
+
# The value of this variable will be provided in the eval() call
|
335 |
+
# and will come from evaluating value_code above.
|
336 |
+
ast.FormattedValue(
|
337 |
+
value=ast.Name(id='@fvalue', ctx=ast.Load()),
|
338 |
+
conversion=node.conversion,
|
339 |
+
format_spec=node.format_spec,
|
340 |
+
)
|
341 |
+
]
|
342 |
+
)
|
343 |
+
)
|
344 |
+
ast.fix_missing_locations(expr)
|
345 |
+
formatted_code = compile(expr, '<fvalue2>', 'eval')
|
346 |
+
return source, value_code, formatted_code
|
347 |
+
|
348 |
+
def get_node_source_text(node: ast.AST, ex_source: executing.Source):
|
349 |
+
"""Returns some Python source code representing `node`.
|
350 |
+
|
351 |
+
Preferably the actual original code given by `ast.get_source_segment`,
|
352 |
+
but falling back to `ast.unparse(node)` if the former is incorrect.
|
353 |
+
This happens sometimes due to Python bugs (especially for older Python versions)
|
354 |
+
in the source positions of AST nodes inside f-strings.
|
355 |
+
"""
|
356 |
+
# ast.unparse is not available in Python 3.8, which is why inspect_arguments is forbidden in 3.8.
|
357 |
+
source_unparsed = ast.unparse(node)
|
358 |
+
source_segment = ast.get_source_segment(ex_source.text, node) or ''
|
359 |
+
try:
|
360 |
+
# Verify that the source segment is correct by checking that the AST is equivalent to what we have.
|
361 |
+
source_segment_unparsed = ast.unparse(ast.parse(source_segment, mode='eval'))
|
362 |
+
except Exception: # probably SyntaxError, but ast.parse can raise other exceptions too
|
363 |
+
source_segment_unparsed = ''
|
364 |
+
return source_segment if source_unparsed == source_segment_unparsed else source_unparsed
|
365 |
+
|
366 |
+
|
367 |
+
def truncate_sequence(seq: Truncatable, *, max_length: int, middle: Truncatable) -> Truncatable:
|
368 |
+
"""Return a sequence at with `len()` at most `max_length`, with `middle` in the middle if truncated."""
|
369 |
+
if len(seq) <= max_length:
|
370 |
+
return seq
|
371 |
+
remaining_length = max_length - len(middle)
|
372 |
+
half = remaining_length // 2
|
373 |
+
return seq[:half] + middle + seq[-half:]
|
374 |
+
|
375 |
+
def warn_at_user_stacklevel(msg: str, category: type[Warning]):
|
376 |
+
"""Warn at the user's stack level.
|
377 |
+
"""
|
378 |
+
_frame, stacklevel = get_user_frame_and_stacklevel()
|
379 |
+
warnings.warn(msg, stacklevel=stacklevel, category=category)
|
380 |
+
|
381 |
+
def warn_formatting(msg: str):
|
382 |
+
"""Warn about a formatting error.
|
383 |
+
"""
|
384 |
+
warn_at_user_stacklevel(
|
385 |
+
f'\n'
|
386 |
+
f' Ensure you are either:\n'
|
387 |
+
' (1) passing an f-string directly, or\n'
|
388 |
+
' (2) passing a literal `str.format`-style template, not a preformatted string.\n'
|
389 |
+
f' The problem was: {msg}',
|
390 |
+
category=FormattingFailedWarning,
|
391 |
+
)
|
392 |
+
|
393 |
+
def warn_fstring_await(msg: str):
|
394 |
+
"""Warn about an await expression in an f-string.
|
395 |
+
"""
|
396 |
+
warn_at_user_stacklevel(
|
397 |
+
f'\n'
|
398 |
+
f' Cannot evaluate await expression in f-string. Pre-evaluate the expression before logging.\n'
|
399 |
+
f' The problematic f-string value was: {msg}',
|
400 |
+
category=FormattingFailedWarning,
|
401 |
+
)
|
402 |
+
|
403 |
+
chunks_formatter = ChunksFormatter()
|
aworld/trace/rewrite_ast.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import ast
|
4 |
+
import uuid
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
from collections import deque
|
8 |
+
from functools import partial
|
9 |
+
from typing import TYPE_CHECKING, Any, Callable, ContextManager, cast
|
10 |
+
|
11 |
+
from aworld.trace.base import AttributeValueType
|
12 |
+
from aworld.trace.constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from .context_manager import TraceManager
|
16 |
+
from .auto_trace import not_auto_trace
|
17 |
+
|
18 |
+
|
19 |
+
def compile_source(
|
20 |
+
tree: ast.AST, filename: str, module_name: str, trace_manager: TraceManager, min_duration_ns: int
|
21 |
+
) -> Callable[[dict[str, Any]], None]:
|
22 |
+
"""Compile a modified AST of the module's source code in the module's namespace.
|
23 |
+
|
24 |
+
Returns a function which accepts module globals and executes the compiled code.
|
25 |
+
|
26 |
+
The modified AST wraps the body of every function definition in `with context_factories[index]():`.
|
27 |
+
`context_factories` is added to the module's namespace as `aworld_<uuid>`.
|
28 |
+
`index` is a different constant number for each function definition.
|
29 |
+
"""
|
30 |
+
|
31 |
+
context_factories_var_name = f'aworld_{uuid.uuid4().hex}'
|
32 |
+
# The variable name for storing context_factors in the module's namespace.
|
33 |
+
|
34 |
+
context_factories: list[Callable[[], ContextManager[Any]]] = []
|
35 |
+
tree = rewrite_ast(tree, filename, context_factories_var_name, module_name, trace_manager, context_factories,
|
36 |
+
min_duration_ns)
|
37 |
+
assert isinstance(tree, ast.Module) # for type checking
|
38 |
+
# dont_inherit=True is necessary to prevent the module from inheriting the __future__ import from this module.
|
39 |
+
code = compile(tree, filename, 'exec', dont_inherit=True)
|
40 |
+
|
41 |
+
def execute(globs: dict[str, Any]):
|
42 |
+
globs[context_factories_var_name] = context_factories
|
43 |
+
exec(code, globs, globs)
|
44 |
+
|
45 |
+
return execute
|
46 |
+
|
47 |
+
|
48 |
+
def rewrite_ast(
|
49 |
+
tree: ast.AST,
|
50 |
+
filename: str,
|
51 |
+
context_factories_var_name: str,
|
52 |
+
module_name: str,
|
53 |
+
trace_manager: TraceManager,
|
54 |
+
context_factories: list[Callable[[], ContextManager[Any]]],
|
55 |
+
min_duration_ns: int,
|
56 |
+
) -> ast.AST:
|
57 |
+
transformer = AutoTraceTransformer(
|
58 |
+
context_factories_var_name, filename, module_name, trace_manager, context_factories, min_duration_ns
|
59 |
+
)
|
60 |
+
return transformer.visit(tree)
|
61 |
+
|
62 |
+
|
63 |
+
class AutoTraceTransformer(ast.NodeTransformer):
|
64 |
+
"""Trace all encountered functions except those explicitly marked with `@no_auto_trace`."""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
context_factories_var_name: str,
|
69 |
+
filename: str,
|
70 |
+
module_name: str,
|
71 |
+
trace_manager: TraceManager,
|
72 |
+
context_factories: list[Callable[[], ContextManager[Any]]],
|
73 |
+
min_duration_ns: int,
|
74 |
+
):
|
75 |
+
self._context_factories_var_name = context_factories_var_name
|
76 |
+
self._filename = filename
|
77 |
+
self._module_name = module_name
|
78 |
+
self._trace_manager = trace_manager
|
79 |
+
self._context_factories = context_factories
|
80 |
+
self._min_duration_ns = min_duration_ns
|
81 |
+
self._qualname_stack: list[str] = []
|
82 |
+
|
83 |
+
def visit_ClassDef(self, node: ast.ClassDef):
|
84 |
+
"""Visit a class definition and rewrite its methods."""
|
85 |
+
|
86 |
+
if self.check_not_auto_trace(node):
|
87 |
+
return node
|
88 |
+
|
89 |
+
self._qualname_stack.append(node.name)
|
90 |
+
node = cast(ast.ClassDef, self.generic_visit(node))
|
91 |
+
self._qualname_stack.pop()
|
92 |
+
return node
|
93 |
+
|
94 |
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
|
95 |
+
"""Visit a function definition and rewrite it."""
|
96 |
+
|
97 |
+
if self.check_not_auto_trace(node):
|
98 |
+
return node
|
99 |
+
|
100 |
+
self._qualname_stack.append(node.name)
|
101 |
+
qualname = '.'.join(self._qualname_stack)
|
102 |
+
self._qualname_stack.append('<locals>')
|
103 |
+
self.generic_visit(node)
|
104 |
+
self._qualname_stack.pop() # <locals>
|
105 |
+
self._qualname_stack.pop() # node.name
|
106 |
+
return self.rewrite_function(node, qualname)
|
107 |
+
|
108 |
+
def check_not_auto_trace(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef) -> bool:
|
109 |
+
"""Return true if the node has a `@not_auto_trace` decorator."""
|
110 |
+
return any(
|
111 |
+
(
|
112 |
+
isinstance(node, ast.Name)
|
113 |
+
and node.id == not_auto_trace.__name__
|
114 |
+
# or (
|
115 |
+
# isinstance(node, ast.Attribute)
|
116 |
+
# and node.attr == not_auto_trace.__name__
|
117 |
+
# and isinstance(node.value, ast.Name)
|
118 |
+
# and node.value.id == xxx.__name__
|
119 |
+
# )
|
120 |
+
)
|
121 |
+
for node in node.decorator_list
|
122 |
+
)
|
123 |
+
|
124 |
+
def rewrite_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.AST:
|
125 |
+
"""Rewrite a function definition to trace its execution."""
|
126 |
+
|
127 |
+
if has_yield(node):
|
128 |
+
return node
|
129 |
+
|
130 |
+
body = node.body.copy()
|
131 |
+
new_body: list[ast.stmt] = []
|
132 |
+
if (
|
133 |
+
body
|
134 |
+
and isinstance(body[0], ast.Expr)
|
135 |
+
and isinstance(body[0].value, ast.Constant)
|
136 |
+
and isinstance(body[0].value.value, str)
|
137 |
+
):
|
138 |
+
new_body.append(body.pop(0))
|
139 |
+
|
140 |
+
if not body or (
|
141 |
+
len(body) == 1
|
142 |
+
and (
|
143 |
+
isinstance(body[0], ast.Pass)
|
144 |
+
or (isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant))
|
145 |
+
)
|
146 |
+
):
|
147 |
+
return node
|
148 |
+
|
149 |
+
span = ast.With(
|
150 |
+
items=[
|
151 |
+
ast.withitem(
|
152 |
+
context_expr=self.trace_context_method_call_node(node, qualname),
|
153 |
+
)
|
154 |
+
],
|
155 |
+
body=body,
|
156 |
+
type_comment=node.type_comment,
|
157 |
+
)
|
158 |
+
new_body.append(span)
|
159 |
+
|
160 |
+
return ast.fix_missing_locations(
|
161 |
+
ast.copy_location(
|
162 |
+
type(node)( # type: ignore
|
163 |
+
name=node.name,
|
164 |
+
args=node.args,
|
165 |
+
body=new_body,
|
166 |
+
decorator_list=node.decorator_list,
|
167 |
+
returns=node.returns,
|
168 |
+
type_comment=node.type_comment,
|
169 |
+
),
|
170 |
+
node,
|
171 |
+
)
|
172 |
+
)
|
173 |
+
|
174 |
+
def trace_context_method_call_node(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.Call:
|
175 |
+
"""Return a method call to `context_factories[index]()`."""
|
176 |
+
|
177 |
+
index = len(self._context_factories)
|
178 |
+
span_factory = partial(
|
179 |
+
self._trace_manager._create_auto_span, # type: ignore
|
180 |
+
*self.build_create_auto_span_args(qualname, node.lineno),
|
181 |
+
)
|
182 |
+
if self._min_duration_ns > 0:
|
183 |
+
|
184 |
+
timer = time.time_ns
|
185 |
+
min_duration = self._min_duration_ns
|
186 |
+
|
187 |
+
# This needs to be as fast as possible since it's the cost of auto-tracing a function
|
188 |
+
# that never actually gets instrumented because its calls are all faster than `min_duration`.
|
189 |
+
class MeasureTime:
|
190 |
+
__slots__ = 'start'
|
191 |
+
|
192 |
+
def __enter__(_self):
|
193 |
+
_self.start = timer()
|
194 |
+
|
195 |
+
def __exit__(_self, *_):
|
196 |
+
# the first call exceeding min_ruration will not be tracked, and subsequent calls will only be tracked
|
197 |
+
if timer() - _self.start >= min_duration:
|
198 |
+
self._context_factories[index] = span_factory
|
199 |
+
|
200 |
+
self._context_factories.append(MeasureTime)
|
201 |
+
else:
|
202 |
+
self._context_factories.append(span_factory)
|
203 |
+
|
204 |
+
# This node means:
|
205 |
+
# context_factories[index]()
|
206 |
+
# where `context_factories` is a global variable with the name `self._context_factories_var_name`
|
207 |
+
# pointing to the `self.context_factories` list.
|
208 |
+
return ast.Call(
|
209 |
+
func=ast.Subscript(
|
210 |
+
value=ast.Name(id=self._context_factories_var_name, ctx=ast.Load()),
|
211 |
+
slice=ast.Index(value=ast.Constant(value=index)), # type: ignore
|
212 |
+
ctx=ast.Load(),
|
213 |
+
),
|
214 |
+
args=[],
|
215 |
+
keywords=[],
|
216 |
+
)
|
217 |
+
|
218 |
+
def build_create_auto_span_args(self, qualname: str, lineno: int) -> tuple[str, dict[str, AttributeValueType]]:
|
219 |
+
"""Build the arguments for `create_auto_span`."""
|
220 |
+
|
221 |
+
stack_info = {
|
222 |
+
'code.filepath': get_filepath(self._filename),
|
223 |
+
'code.lineno': lineno,
|
224 |
+
'code.function': qualname,
|
225 |
+
}
|
226 |
+
attributes: dict[str, AttributeValueType] = {**stack_info} # type: ignore
|
227 |
+
|
228 |
+
msg_template = f'Calling {self._module_name}.{qualname}'
|
229 |
+
attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] = msg_template
|
230 |
+
|
231 |
+
span_name = msg_template
|
232 |
+
|
233 |
+
return span_name, attributes
|
234 |
+
|
235 |
+
|
236 |
+
def has_yield(node: ast.AST):
|
237 |
+
"""Return true if the node has a yield statement."""
|
238 |
+
|
239 |
+
queue = deque([node])
|
240 |
+
while queue:
|
241 |
+
node = queue.popleft()
|
242 |
+
for child in ast.iter_child_nodes(node):
|
243 |
+
if isinstance(child, (ast.Yield, ast.YieldFrom)):
|
244 |
+
return True
|
245 |
+
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)):
|
246 |
+
queue.append(child)
|
247 |
+
|
248 |
+
|
249 |
+
def get_filepath(file: str):
|
250 |
+
"""Return a dict with the filepath attribute."""
|
251 |
+
|
252 |
+
path = Path(file)
|
253 |
+
if path.is_absolute():
|
254 |
+
try:
|
255 |
+
path = path.relative_to(Path('.').resolve())
|
256 |
+
except ValueError: # pragma: no cover
|
257 |
+
# happens if filename path is not within CWD
|
258 |
+
pass
|
259 |
+
return str(path)
|
aworld/trace/span_cosumer.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Sequence
|
3 |
+
from aworld.trace.base import Span
|
4 |
+
|
5 |
+
|
6 |
+
class SpanConsumer(ABC):
|
7 |
+
"""SpanConsumer is a protocol that represents a consumer for spans.
|
8 |
+
"""
|
9 |
+
@abstractmethod
|
10 |
+
def consume(self, spans: Sequence[Span]) -> None:
|
11 |
+
"""Consumes a span.
|
12 |
+
Args:
|
13 |
+
spans: The span to consume.
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
_SPAN_CONSUMER_REGISTRY = {}
|
18 |
+
|
19 |
+
|
20 |
+
def register_span_consumer(default_kwargs=None) -> None:
|
21 |
+
"""Registers a span consumer.
|
22 |
+
Args:
|
23 |
+
default_kwargs: A dictionary of default keyword arguments to pass to the span consumer.
|
24 |
+
"""
|
25 |
+
|
26 |
+
default_kwargs = default_kwargs or {}
|
27 |
+
|
28 |
+
def decorator(cls):
|
29 |
+
_SPAN_CONSUMER_REGISTRY[cls.__name__] = (cls, default_kwargs)
|
30 |
+
return cls
|
31 |
+
|
32 |
+
return decorator
|
33 |
+
|
34 |
+
|
35 |
+
def get_span_consumers() -> Sequence[SpanConsumer]:
|
36 |
+
"""Returns a list of span consumers.
|
37 |
+
Returns:
|
38 |
+
A list of span consumers.
|
39 |
+
"""
|
40 |
+
return [
|
41 |
+
cls(**kwargs)
|
42 |
+
for cls, kwargs in _SPAN_CONSUMER_REGISTRY.values()
|
43 |
+
]
|
aworld/trace/stack_info.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import sys
|
3 |
+
import aworld.trace as atrace
|
4 |
+
from types import CodeType, FrameType
|
5 |
+
from typing import Optional, TypedDict, Union
|
6 |
+
from functools import lru_cache
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
StackInfo = TypedDict('StackInfo', {'code.filepath': str, 'code.lineno': int, 'code.function': str}, total=False)
|
10 |
+
|
11 |
+
NON_USER_CODE_PREFIXES: tuple[str, ...] = ()
|
12 |
+
|
13 |
+
def add_non_user_code_prefix(path: Union[str, Path]) -> None:
|
14 |
+
global NON_USER_CODE_PREFIXES
|
15 |
+
path = str(Path(path).absolute())
|
16 |
+
NON_USER_CODE_PREFIXES += (path,)
|
17 |
+
|
18 |
+
add_non_user_code_prefix(Path(inspect.__file__).parent)
|
19 |
+
add_non_user_code_prefix(Path(atrace.__file__).parent)
|
20 |
+
|
21 |
+
def get_user_stack_info() -> StackInfo:
|
22 |
+
"""Get the stack info for the first calling frame in user code.
|
23 |
+
|
24 |
+
See is_user_code for details.
|
25 |
+
Returns an empty dict if no such frame is found.
|
26 |
+
"""
|
27 |
+
frame, _stacklevel = get_user_frame_and_stacklevel()
|
28 |
+
if frame:
|
29 |
+
return get_stack_info_from_frame(frame)
|
30 |
+
return {}
|
31 |
+
|
32 |
+
|
33 |
+
def get_user_frame_and_stacklevel() -> tuple[Optional[FrameType], int]:
|
34 |
+
"""Get the first calling frame in user code and a corresponding stacklevel that can be passed to `warnings.warn`.
|
35 |
+
|
36 |
+
See is_user_code for details.
|
37 |
+
Returns `(None, 0)` if no such frame is found.
|
38 |
+
"""
|
39 |
+
frame = inspect.currentframe()
|
40 |
+
stacklevel = 0
|
41 |
+
while frame:
|
42 |
+
if is_user_code(frame.f_code):
|
43 |
+
return frame, stacklevel
|
44 |
+
frame = frame.f_back
|
45 |
+
stacklevel += 1
|
46 |
+
return None, 0
|
47 |
+
|
48 |
+
def get_stack_info_from_frame(frame: FrameType) -> StackInfo:
|
49 |
+
return {
|
50 |
+
**get_code_object_info(frame.f_code),
|
51 |
+
'code.lineno': frame.f_lineno,
|
52 |
+
}
|
53 |
+
|
54 |
+
@lru_cache(maxsize=2048)
|
55 |
+
def get_code_object_info(code: CodeType) -> StackInfo:
|
56 |
+
result = get_filepath_attribute(code.co_filename)
|
57 |
+
if code.co_name != '<module>': # pragma: no branch
|
58 |
+
result['code.function'] = code.co_qualname if sys.version_info >= (3, 11) else code.co_name
|
59 |
+
result['code.lineno'] = code.co_firstlineno
|
60 |
+
return result
|
61 |
+
|
62 |
+
def get_filepath_attribute(file: str) -> StackInfo:
|
63 |
+
path = Path(file)
|
64 |
+
if path.is_absolute():
|
65 |
+
try:
|
66 |
+
path = path.relative_to(Path('.').resolve())
|
67 |
+
except ValueError: # pragma: no cover
|
68 |
+
# happens if filename path is not within CWD
|
69 |
+
pass
|
70 |
+
return {'code.filepath': str(path)}
|
71 |
+
|
72 |
+
@lru_cache(maxsize=8192)
|
73 |
+
def is_user_code(code: CodeType) -> bool:
|
74 |
+
"""Check if the code object is from user code.
|
75 |
+
|
76 |
+
A code object is not user code if:
|
77 |
+
- It is from a file in
|
78 |
+
- the standard library
|
79 |
+
- site-packages (specifically wherever opentelemetry is installed)
|
80 |
+
- an unknown location (e.g. a dynamically generated code object) indicated by a filename starting with '<'
|
81 |
+
|
82 |
+
- It is a list/dict/set comprehension.
|
83 |
+
These are artificial frames only created before Python 3.12,
|
84 |
+
and they are always called directly from the enclosing function so it makes sense to skip them.
|
85 |
+
On the other hand, generator expressions and lambdas might be called far away from where they are defined.
|
86 |
+
"""
|
87 |
+
return not (
|
88 |
+
str(Path(code.co_filename).absolute()).startswith(NON_USER_CODE_PREFIXES)
|
89 |
+
or code.co_filename.startswith('<')
|
90 |
+
or code.co_name in ('<listcomp>', '<dictcomp>', '<setcomp>')
|
91 |
+
)
|