Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import ast | |
| import uuid | |
| import time | |
| from pathlib import Path | |
| from collections import deque | |
| from functools import partial | |
| from typing import TYPE_CHECKING, Any, Callable, ContextManager, cast | |
| from aworld.trace.base import AttributeValueType | |
| from aworld.trace.constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY | |
| if TYPE_CHECKING: | |
| from .context_manager import TraceManager | |
| from .auto_trace import not_auto_trace | |
| def compile_source( | |
| tree: ast.AST, filename: str, module_name: str, trace_manager: TraceManager, min_duration_ns: int | |
| ) -> Callable[[dict[str, Any]], None]: | |
| """Compile a modified AST of the module's source code in the module's namespace. | |
| Returns a function which accepts module globals and executes the compiled code. | |
| The modified AST wraps the body of every function definition in `with context_factories[index]():`. | |
| `context_factories` is added to the module's namespace as `aworld_<uuid>`. | |
| `index` is a different constant number for each function definition. | |
| """ | |
| context_factories_var_name = f'aworld_{uuid.uuid4().hex}' | |
| # The variable name for storing context_factors in the module's namespace. | |
| context_factories: list[Callable[[], ContextManager[Any]]] = [] | |
| tree = rewrite_ast(tree, filename, context_factories_var_name, module_name, trace_manager, context_factories, | |
| min_duration_ns) | |
| assert isinstance(tree, ast.Module) # for type checking | |
| # dont_inherit=True is necessary to prevent the module from inheriting the __future__ import from this module. | |
| code = compile(tree, filename, 'exec', dont_inherit=True) | |
| def execute(globs: dict[str, Any]): | |
| globs[context_factories_var_name] = context_factories | |
| exec(code, globs, globs) | |
| return execute | |
| def rewrite_ast( | |
| tree: ast.AST, | |
| filename: str, | |
| context_factories_var_name: str, | |
| module_name: str, | |
| trace_manager: TraceManager, | |
| context_factories: list[Callable[[], ContextManager[Any]]], | |
| min_duration_ns: int, | |
| ) -> ast.AST: | |
| transformer = AutoTraceTransformer( | |
| context_factories_var_name, filename, module_name, trace_manager, context_factories, min_duration_ns | |
| ) | |
| return transformer.visit(tree) | |
| class AutoTraceTransformer(ast.NodeTransformer): | |
| """Trace all encountered functions except those explicitly marked with `@no_auto_trace`.""" | |
| def __init__( | |
| self, | |
| context_factories_var_name: str, | |
| filename: str, | |
| module_name: str, | |
| trace_manager: TraceManager, | |
| context_factories: list[Callable[[], ContextManager[Any]]], | |
| min_duration_ns: int, | |
| ): | |
| self._context_factories_var_name = context_factories_var_name | |
| self._filename = filename | |
| self._module_name = module_name | |
| self._trace_manager = trace_manager | |
| self._context_factories = context_factories | |
| self._min_duration_ns = min_duration_ns | |
| self._qualname_stack: list[str] = [] | |
| def visit_ClassDef(self, node: ast.ClassDef): | |
| """Visit a class definition and rewrite its methods.""" | |
| if self.check_not_auto_trace(node): | |
| return node | |
| self._qualname_stack.append(node.name) | |
| node = cast(ast.ClassDef, self.generic_visit(node)) | |
| self._qualname_stack.pop() | |
| return node | |
| def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: | |
| """Visit a function definition and rewrite it.""" | |
| if self.check_not_auto_trace(node): | |
| return node | |
| self._qualname_stack.append(node.name) | |
| qualname = '.'.join(self._qualname_stack) | |
| self._qualname_stack.append('<locals>') | |
| self.generic_visit(node) | |
| self._qualname_stack.pop() # <locals> | |
| self._qualname_stack.pop() # node.name | |
| return self.rewrite_function(node, qualname) | |
| def check_not_auto_trace(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef) -> bool: | |
| """Return true if the node has a `@not_auto_trace` decorator.""" | |
| return any( | |
| ( | |
| isinstance(node, ast.Name) | |
| and node.id == not_auto_trace.__name__ | |
| # or ( | |
| # isinstance(node, ast.Attribute) | |
| # and node.attr == not_auto_trace.__name__ | |
| # and isinstance(node.value, ast.Name) | |
| # and node.value.id == xxx.__name__ | |
| # ) | |
| ) | |
| for node in node.decorator_list | |
| ) | |
| def rewrite_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.AST: | |
| """Rewrite a function definition to trace its execution.""" | |
| if has_yield(node): | |
| return node | |
| body = node.body.copy() | |
| new_body: list[ast.stmt] = [] | |
| if ( | |
| body | |
| and isinstance(body[0], ast.Expr) | |
| and isinstance(body[0].value, ast.Constant) | |
| and isinstance(body[0].value.value, str) | |
| ): | |
| new_body.append(body.pop(0)) | |
| if not body or ( | |
| len(body) == 1 | |
| and ( | |
| isinstance(body[0], ast.Pass) | |
| or (isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant)) | |
| ) | |
| ): | |
| return node | |
| span = ast.With( | |
| items=[ | |
| ast.withitem( | |
| context_expr=self.trace_context_method_call_node(node, qualname), | |
| ) | |
| ], | |
| body=body, | |
| type_comment=node.type_comment, | |
| ) | |
| new_body.append(span) | |
| return ast.fix_missing_locations( | |
| ast.copy_location( | |
| type(node)( # type: ignore | |
| name=node.name, | |
| args=node.args, | |
| body=new_body, | |
| decorator_list=node.decorator_list, | |
| returns=node.returns, | |
| type_comment=node.type_comment, | |
| ), | |
| node, | |
| ) | |
| ) | |
| def trace_context_method_call_node(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.Call: | |
| """Return a method call to `context_factories[index]()`.""" | |
| index = len(self._context_factories) | |
| span_factory = partial( | |
| self._trace_manager._create_auto_span, # type: ignore | |
| *self.build_create_auto_span_args(qualname, node.lineno), | |
| ) | |
| if self._min_duration_ns > 0: | |
| timer = time.time_ns | |
| min_duration = self._min_duration_ns | |
| # This needs to be as fast as possible since it's the cost of auto-tracing a function | |
| # that never actually gets instrumented because its calls are all faster than `min_duration`. | |
| class MeasureTime: | |
| __slots__ = 'start' | |
| def __enter__(_self): | |
| _self.start = timer() | |
| def __exit__(_self, *_): | |
| # the first call exceeding min_ruration will not be tracked, and subsequent calls will only be tracked | |
| if timer() - _self.start >= min_duration: | |
| self._context_factories[index] = span_factory | |
| self._context_factories.append(MeasureTime) | |
| else: | |
| self._context_factories.append(span_factory) | |
| # This node means: | |
| # context_factories[index]() | |
| # where `context_factories` is a global variable with the name `self._context_factories_var_name` | |
| # pointing to the `self.context_factories` list. | |
| return ast.Call( | |
| func=ast.Subscript( | |
| value=ast.Name(id=self._context_factories_var_name, ctx=ast.Load()), | |
| slice=ast.Index(value=ast.Constant(value=index)), # type: ignore | |
| ctx=ast.Load(), | |
| ), | |
| args=[], | |
| keywords=[], | |
| ) | |
| def build_create_auto_span_args(self, qualname: str, lineno: int) -> tuple[str, dict[str, AttributeValueType]]: | |
| """Build the arguments for `create_auto_span`.""" | |
| stack_info = { | |
| 'code.filepath': get_filepath(self._filename), | |
| 'code.lineno': lineno, | |
| 'code.function': qualname, | |
| } | |
| attributes: dict[str, AttributeValueType] = {**stack_info} # type: ignore | |
| msg_template = f'Calling {self._module_name}.{qualname}' | |
| attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] = msg_template | |
| span_name = msg_template | |
| return span_name, attributes | |
| def has_yield(node: ast.AST): | |
| """Return true if the node has a yield statement.""" | |
| queue = deque([node]) | |
| while queue: | |
| node = queue.popleft() | |
| for child in ast.iter_child_nodes(node): | |
| if isinstance(child, (ast.Yield, ast.YieldFrom)): | |
| return True | |
| if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)): | |
| queue.append(child) | |
| def get_filepath(file: str): | |
| """Return a dict with the filepath attribute.""" | |
| path = Path(file) | |
| if path.is_absolute(): | |
| try: | |
| path = path.relative_to(Path('.').resolve()) | |
| except ValueError: # pragma: no cover | |
| # happens if filename path is not within CWD | |
| pass | |
| return str(path) | |