|
import ast |
|
import re |
|
import sys |
|
import warnings |
|
from importlib.abc import Loader, MetaPathFinder |
|
from importlib.machinery import ModuleSpec |
|
from importlib.util import spec_from_loader |
|
from types import ModuleType |
|
from typing import TYPE_CHECKING, Sequence, Union, Callable, Iterator, TypeVar, Any, cast |
|
|
|
from aworld.trace.base import log_trace_error |
|
from .rewrite_ast import compile_source |
|
|
|
if TYPE_CHECKING: |
|
from .context_manager import TraceManager |
|
|
|
|
|
class AutoTraceModule: |
|
"""A class that represents a module being imported that should maybe be traced automatically.""" |
|
|
|
def __init__(self, module_name: str) -> None: |
|
self._module_name = module_name |
|
"""Fully qualified absolute name of the module being imported.""" |
|
|
|
def need_auto_trace(self, prefix: Union[str, Sequence[str]]) -> bool: |
|
""" |
|
Check if the module name starts with the given prefix. |
|
""" |
|
if isinstance(prefix, str): |
|
prefix = (prefix,) |
|
pattern = '|'.join([get_module_pattern(p) for p in prefix]) |
|
return bool(re.match(pattern, self._module_name)) |
|
|
|
|
|
class TraceImportFinder(MetaPathFinder): |
|
"""A class that implements the `find_spec` method of the `MetaPathFinder` protocol.""" |
|
|
|
def __init__(self, trace_manager: "TraceManager", module_funcs: Callable[[AutoTraceModule], bool], |
|
min_duration_ns: int) -> None: |
|
self._trace_manager = trace_manager |
|
self._modules_filter = module_funcs |
|
self._min_duration_ns = min_duration_ns |
|
|
|
def _find_plain_specs( |
|
self, fullname: str, path: Sequence[str] = None, target: ModuleType = None |
|
) -> Iterator[ModuleSpec]: |
|
"""Yield module specs returned by other finders on `sys.meta_path`.""" |
|
for finder in sys.meta_path: |
|
|
|
if isinstance(finder, TraceImportFinder): |
|
continue |
|
|
|
try: |
|
plain_spec = finder.find_spec(fullname, path, target) |
|
except Exception: |
|
continue |
|
|
|
if plain_spec: |
|
yield plain_spec |
|
|
|
def find_spec(self, fullname: str, path: Sequence[str], target=None) -> None: |
|
"""Find the spec for the given module name.""" |
|
|
|
for plain_spec in self._find_plain_specs(fullname, path, target): |
|
|
|
get_source = getattr(plain_spec.loader, 'get_source', None) |
|
if not callable(get_source): |
|
continue |
|
try: |
|
source = cast(str, get_source(fullname)) |
|
except Exception: |
|
continue |
|
|
|
if not source: |
|
continue |
|
|
|
filename = plain_spec.origin |
|
if not filename: |
|
try: |
|
filename = cast('str | None', plain_spec.loader.get_filename(fullname)) |
|
except Exception: |
|
pass |
|
filename = filename or f'<{fullname}>' |
|
|
|
if not self._modules_filter(AutoTraceModule(fullname)): |
|
return None |
|
|
|
try: |
|
tree = ast.parse(source) |
|
except Exception: |
|
|
|
continue |
|
|
|
try: |
|
execute = compile_source(tree, filename, fullname, self._trace_manager, self._min_duration_ns) |
|
except Exception: |
|
log_trace_error() |
|
return None |
|
|
|
loader = AutoTraceLoader(plain_spec, execute) |
|
return spec_from_loader(fullname, loader) |
|
|
|
|
|
class AutoTraceLoader(Loader): |
|
""" |
|
A class that implements the `exec_module` method of the `Loader` protocol. |
|
""" |
|
|
|
def __init__(self, plain_spec: ModuleSpec, execute: Callable[[dict[str, Any]], None]) -> None: |
|
self._plain_spec = plain_spec |
|
self._execute = execute |
|
|
|
def exec_module(self, module: ModuleType): |
|
"""Execute a modified AST of the module's source code in the module's namespace. |
|
""" |
|
self._execute(module.__dict__) |
|
|
|
def create_module(self, spec: ModuleSpec): |
|
return None |
|
|
|
def get_code(self, _name: str): |
|
"""`python -m` uses the `runpy` module which calls this method instead of going through the normal protocol. |
|
So return some code which can be executed with the module namespace. |
|
Here `__loader__` will be this object, i.e. `self`. |
|
source = '__loader__.execute(globals())' |
|
return compile(source, '<string>', 'exec', dont_inherit=True) |
|
""" |
|
|
|
def __getattr__(self, item: str): |
|
"""Forward some methods to the plain spec's loader (likely a `SourceFileLoader`) if they exist.""" |
|
if item in {'get_filename', 'is_package'}: |
|
return getattr(self.plain_spec.loader, item) |
|
raise AttributeError(item) |
|
|
|
|
|
def convert_to_modules_func(modules: Sequence[str]) -> Callable[[AutoTraceModule], bool]: |
|
"""Convert a sequence of module names to a function that checks if a module name starts with any of the given module names. |
|
""" |
|
return lambda module: module.need_auto_trace(modules) |
|
|
|
|
|
def get_module_pattern(module: str): |
|
""" |
|
Get the regex pattern for the given module name. |
|
""" |
|
|
|
if not re.match(r'[\w.]+$', module, re.UNICODE): |
|
return module |
|
module = re.escape(module) |
|
return rf'{module}($|\.)' |
|
|
|
|
|
def install_auto_tracing(trace_manager: "TraceManager", |
|
modules: Union[Sequence[str], |
|
Callable[[AutoTraceModule], bool]], |
|
min_duration_seconds: float |
|
) -> None: |
|
""" |
|
Automatically trace the execution of a function. |
|
""" |
|
if isinstance(modules, Sequence): |
|
module_funcs = convert_to_modules_func(modules) |
|
else: |
|
module_funcs = modules |
|
|
|
if not callable(module_funcs): |
|
raise TypeError('modules must be a list of strings or a callable') |
|
|
|
for module in list(sys.modules.values()): |
|
try: |
|
auto_trace_module = AutoTraceModule(module.__name__) |
|
except Exception: |
|
continue |
|
|
|
if module_funcs(auto_trace_module): |
|
warnings.warn(f'The module {module.__name__!r} matches modules to trace, but it has already been imported. ' |
|
f'Call `auto_tracing` earlier', |
|
stacklevel=2, |
|
) |
|
|
|
min_duration_ns = int(min_duration_seconds * 1_000_000_000) |
|
trace_manager = trace_manager.new_manager('auto_tracing') |
|
finder = TraceImportFinder(trace_manager, module_funcs, min_duration_ns) |
|
sys.meta_path.insert(0, finder) |
|
|
|
|
|
T = TypeVar('T') |
|
|
|
|
|
def not_auto_trace(x: T) -> T: |
|
"""Decorator to prevent a function/class from being traced by `auto_tracing`""" |
|
return x |
|
|