|
from __future__ import annotations |
|
|
|
import ast |
|
import uuid |
|
import time |
|
from pathlib import Path |
|
from collections import deque |
|
from functools import partial |
|
from typing import TYPE_CHECKING, Any, Callable, ContextManager, cast |
|
|
|
from aworld.trace.base import AttributeValueType |
|
from aworld.trace.constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY |
|
|
|
if TYPE_CHECKING: |
|
from .context_manager import TraceManager |
|
from .auto_trace import not_auto_trace |
|
|
|
|
|
def compile_source( |
|
tree: ast.AST, filename: str, module_name: str, trace_manager: TraceManager, min_duration_ns: int |
|
) -> Callable[[dict[str, Any]], None]: |
|
"""Compile a modified AST of the module's source code in the module's namespace. |
|
|
|
Returns a function which accepts module globals and executes the compiled code. |
|
|
|
The modified AST wraps the body of every function definition in `with context_factories[index]():`. |
|
`context_factories` is added to the module's namespace as `aworld_<uuid>`. |
|
`index` is a different constant number for each function definition. |
|
""" |
|
|
|
context_factories_var_name = f'aworld_{uuid.uuid4().hex}' |
|
|
|
|
|
context_factories: list[Callable[[], ContextManager[Any]]] = [] |
|
tree = rewrite_ast(tree, filename, context_factories_var_name, module_name, trace_manager, context_factories, |
|
min_duration_ns) |
|
assert isinstance(tree, ast.Module) |
|
|
|
code = compile(tree, filename, 'exec', dont_inherit=True) |
|
|
|
def execute(globs: dict[str, Any]): |
|
globs[context_factories_var_name] = context_factories |
|
exec(code, globs, globs) |
|
|
|
return execute |
|
|
|
|
|
def rewrite_ast( |
|
tree: ast.AST, |
|
filename: str, |
|
context_factories_var_name: str, |
|
module_name: str, |
|
trace_manager: TraceManager, |
|
context_factories: list[Callable[[], ContextManager[Any]]], |
|
min_duration_ns: int, |
|
) -> ast.AST: |
|
transformer = AutoTraceTransformer( |
|
context_factories_var_name, filename, module_name, trace_manager, context_factories, min_duration_ns |
|
) |
|
return transformer.visit(tree) |
|
|
|
|
|
class AutoTraceTransformer(ast.NodeTransformer): |
|
"""Trace all encountered functions except those explicitly marked with `@no_auto_trace`.""" |
|
|
|
def __init__( |
|
self, |
|
context_factories_var_name: str, |
|
filename: str, |
|
module_name: str, |
|
trace_manager: TraceManager, |
|
context_factories: list[Callable[[], ContextManager[Any]]], |
|
min_duration_ns: int, |
|
): |
|
self._context_factories_var_name = context_factories_var_name |
|
self._filename = filename |
|
self._module_name = module_name |
|
self._trace_manager = trace_manager |
|
self._context_factories = context_factories |
|
self._min_duration_ns = min_duration_ns |
|
self._qualname_stack: list[str] = [] |
|
|
|
def visit_ClassDef(self, node: ast.ClassDef): |
|
"""Visit a class definition and rewrite its methods.""" |
|
|
|
if self.check_not_auto_trace(node): |
|
return node |
|
|
|
self._qualname_stack.append(node.name) |
|
node = cast(ast.ClassDef, self.generic_visit(node)) |
|
self._qualname_stack.pop() |
|
return node |
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: |
|
"""Visit a function definition and rewrite it.""" |
|
|
|
if self.check_not_auto_trace(node): |
|
return node |
|
|
|
self._qualname_stack.append(node.name) |
|
qualname = '.'.join(self._qualname_stack) |
|
self._qualname_stack.append('<locals>') |
|
self.generic_visit(node) |
|
self._qualname_stack.pop() |
|
self._qualname_stack.pop() |
|
return self.rewrite_function(node, qualname) |
|
|
|
def check_not_auto_trace(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef) -> bool: |
|
"""Return true if the node has a `@not_auto_trace` decorator.""" |
|
return any( |
|
( |
|
isinstance(node, ast.Name) |
|
and node.id == not_auto_trace.__name__ |
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
for node in node.decorator_list |
|
) |
|
|
|
def rewrite_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.AST: |
|
"""Rewrite a function definition to trace its execution.""" |
|
|
|
if has_yield(node): |
|
return node |
|
|
|
body = node.body.copy() |
|
new_body: list[ast.stmt] = [] |
|
if ( |
|
body |
|
and isinstance(body[0], ast.Expr) |
|
and isinstance(body[0].value, ast.Constant) |
|
and isinstance(body[0].value.value, str) |
|
): |
|
new_body.append(body.pop(0)) |
|
|
|
if not body or ( |
|
len(body) == 1 |
|
and ( |
|
isinstance(body[0], ast.Pass) |
|
or (isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant)) |
|
) |
|
): |
|
return node |
|
|
|
span = ast.With( |
|
items=[ |
|
ast.withitem( |
|
context_expr=self.trace_context_method_call_node(node, qualname), |
|
) |
|
], |
|
body=body, |
|
type_comment=node.type_comment, |
|
) |
|
new_body.append(span) |
|
|
|
return ast.fix_missing_locations( |
|
ast.copy_location( |
|
type(node)( |
|
name=node.name, |
|
args=node.args, |
|
body=new_body, |
|
decorator_list=node.decorator_list, |
|
returns=node.returns, |
|
type_comment=node.type_comment, |
|
), |
|
node, |
|
) |
|
) |
|
|
|
def trace_context_method_call_node(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.Call: |
|
"""Return a method call to `context_factories[index]()`.""" |
|
|
|
index = len(self._context_factories) |
|
span_factory = partial( |
|
self._trace_manager._create_auto_span, |
|
*self.build_create_auto_span_args(qualname, node.lineno), |
|
) |
|
if self._min_duration_ns > 0: |
|
|
|
timer = time.time_ns |
|
min_duration = self._min_duration_ns |
|
|
|
|
|
|
|
class MeasureTime: |
|
__slots__ = 'start' |
|
|
|
def __enter__(_self): |
|
_self.start = timer() |
|
|
|
def __exit__(_self, *_): |
|
|
|
if timer() - _self.start >= min_duration: |
|
self._context_factories[index] = span_factory |
|
|
|
self._context_factories.append(MeasureTime) |
|
else: |
|
self._context_factories.append(span_factory) |
|
|
|
|
|
|
|
|
|
|
|
return ast.Call( |
|
func=ast.Subscript( |
|
value=ast.Name(id=self._context_factories_var_name, ctx=ast.Load()), |
|
slice=ast.Index(value=ast.Constant(value=index)), |
|
ctx=ast.Load(), |
|
), |
|
args=[], |
|
keywords=[], |
|
) |
|
|
|
def build_create_auto_span_args(self, qualname: str, lineno: int) -> tuple[str, dict[str, AttributeValueType]]: |
|
"""Build the arguments for `create_auto_span`.""" |
|
|
|
stack_info = { |
|
'code.filepath': get_filepath(self._filename), |
|
'code.lineno': lineno, |
|
'code.function': qualname, |
|
} |
|
attributes: dict[str, AttributeValueType] = {**stack_info} |
|
|
|
msg_template = f'Calling {self._module_name}.{qualname}' |
|
attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] = msg_template |
|
|
|
span_name = msg_template |
|
|
|
return span_name, attributes |
|
|
|
|
|
def has_yield(node: ast.AST): |
|
"""Return true if the node has a yield statement.""" |
|
|
|
queue = deque([node]) |
|
while queue: |
|
node = queue.popleft() |
|
for child in ast.iter_child_nodes(node): |
|
if isinstance(child, (ast.Yield, ast.YieldFrom)): |
|
return True |
|
if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)): |
|
queue.append(child) |
|
|
|
|
|
def get_filepath(file: str): |
|
"""Return a dict with the filepath attribute.""" |
|
|
|
path = Path(file) |
|
if path.is_absolute(): |
|
try: |
|
path = path.relative_to(Path('.').resolve()) |
|
except ValueError: |
|
|
|
pass |
|
return str(path) |
|
|