Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| # Copyright (c) 2025 inclusionAI. | |
| import asyncio | |
| import inspect | |
| import os | |
| import pkgutil | |
| import re | |
| import socket | |
| import sys | |
| import threading | |
| import time | |
| from functools import wraps | |
| from types import FunctionType | |
| from typing import Callable, Any, Tuple, List, Iterator, Dict, Union | |
| from aworld.logs.util import logger | |
| def convert_to_snake(name: str) -> str: | |
| """Class name convert to snake.""" | |
| if '_' not in name: | |
| name = re.sub(r'([a-z])([A-Z])', r'\1_\2', name) | |
| return name.lower() | |
| def snake_to_camel(snake): | |
| words = snake.split('_') | |
| return ''.join([w.capitalize() for w in words]) | |
| def is_abstract_method(cls, method_name): | |
| method = getattr(cls, method_name) | |
| return (hasattr(method, '__isabstractmethod__') and method.__isabstractmethod__) or ( | |
| isinstance(method, FunctionType) and hasattr( | |
| method, '__abstractmethods__') and method in method.__abstractmethods__) | |
| def override_in_subclass(name: str, sub_cls: object, base_cls: object) -> bool: | |
| """Judge whether a subclass overrides a specified method. | |
| Args: | |
| name: The method name of sub class and base class | |
| sub_cls: Specify subclasses of the base class. | |
| base_cls: The parent class of the subclass. | |
| Returns: | |
| Overwrite as true in subclasses, vice versa. | |
| """ | |
| if not issubclass(sub_cls, base_cls): | |
| logger.warning(f"{sub_cls} is not sub class of {base_cls}") | |
| return False | |
| if sub_cls == base_cls and hasattr(sub_cls, name) and not is_abstract_method(sub_cls, name): | |
| return True | |
| this_method = getattr(sub_cls, name) | |
| base_method = getattr(base_cls, name) | |
| return this_method is not base_method | |
| def convert_to_subclass(obj, subclass): | |
| obj.__class__ = subclass | |
| return obj | |
| def _walk_to_root(path: str) -> Iterator[str]: | |
| """Yield directories starting from the given directory up to the root.""" | |
| if not os.path.exists(path): | |
| yield '' | |
| if os.path.isfile(path): | |
| path = os.path.dirname(path) | |
| last_dir = None | |
| current_dir = os.path.abspath(path) | |
| while last_dir != current_dir: | |
| yield current_dir | |
| parent_dir = os.path.abspath(os.path.join(current_dir, os.path.pardir)) | |
| last_dir, current_dir = current_dir, parent_dir | |
| def find_file(filename: str) -> str: | |
| """Find file from the folders for the given file. | |
| NOTE: Current running path priority, followed by the execution file path, and finally the aworld package path. | |
| Args: | |
| filename: The file name that you want to search. | |
| """ | |
| def run_dir(): | |
| try: | |
| main = __import__('__main__', None, None, fromlist=['__file__']) | |
| return os.path.dirname(main.__file__) | |
| except ModuleNotFoundError: | |
| return os.getcwd() | |
| path = os.getcwd() | |
| if os.path.exists(os.path.join(path, filename)): | |
| path = os.getcwd() | |
| elif os.path.exists(os.path.join(run_dir(), filename)): | |
| path = run_dir() | |
| else: | |
| frame = inspect.currentframe() | |
| current_file = __file__ | |
| while frame.f_code.co_filename == current_file or not os.path.exists( | |
| frame.f_code.co_filename | |
| ): | |
| assert frame.f_back is not None | |
| frame = frame.f_back | |
| frame_filename = frame.f_code.co_filename | |
| path = os.path.dirname(os.path.abspath(frame_filename)) | |
| for dirname in _walk_to_root(path): | |
| if not dirname: | |
| continue | |
| check_path = os.path.join(dirname, filename) | |
| if os.path.isfile(check_path): | |
| return check_path | |
| return '' | |
| def search_in_module(module: object, base_classes: List[type]) -> List[Tuple[str, type]]: | |
| """Find all classes that inherit from a specific base class in the module.""" | |
| results = [] | |
| for name, obj in inspect.getmembers(module, inspect.isclass): | |
| for base_class in base_classes: | |
| if issubclass(obj, base_class) and obj is not base_class: | |
| results.append((name, obj)) | |
| return results | |
| def _scan_package(package_name: str, base_classes: List[type], results: List[Tuple[str, type]] = []): | |
| try: | |
| package = sys.modules[package_name] | |
| except: | |
| return | |
| try: | |
| for sub_package, name, is_pkg in pkgutil.walk_packages(package.__path__): | |
| try: | |
| __import__(f"{package_name}.{name}") | |
| except: | |
| continue | |
| if is_pkg: | |
| _scan_package(package_name + "." + name, base_classes, results) | |
| try: | |
| module = __import__(f"{package_name}.{name}", fromlist=[name]) | |
| results.extend(search_in_module(module, base_classes)) | |
| except: | |
| continue | |
| except: | |
| pass | |
| def scan_packages(package: str, base_classes: List[type]) -> List[Tuple[str, type]]: | |
| results = [] | |
| _scan_package(package, base_classes, results) | |
| return results | |
| class ReturnThread(threading.Thread): | |
| def __init__(self, func, *args, **kwargs): | |
| threading.Thread.__init__(self) | |
| self.func = func | |
| self.args = args | |
| self.kwargs = kwargs | |
| self.result = None | |
| self.daemon = True | |
| def run(self): | |
| self.result = asyncio.run(self.func(*self.args, **self.kwargs)) | |
| def asyncio_loop(): | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| loop = None | |
| return loop | |
| def sync_exec(async_func: Callable[..., Any], *args, **kwargs): | |
| """Async function to sync execution.""" | |
| if not asyncio.iscoroutinefunction(async_func): | |
| return async_func(*args, **kwargs) | |
| loop = asyncio_loop() | |
| if loop and loop.is_running(): | |
| thread = ReturnThread(async_func, *args, **kwargs) | |
| thread.start() | |
| thread.join() | |
| result = thread.result | |
| else: | |
| loop = asyncio.get_event_loop() | |
| result = loop.run_until_complete(async_func(*args, **kwargs)) | |
| return result | |
| def nest_dict_counter(usage: Dict[str, Union[int, Dict[str, int]]], | |
| other: Dict[str, Union[int, Dict[str, int]]], | |
| ignore_zero: bool = True): | |
| """Add counts from two dicts or nest dicts.""" | |
| result = {} | |
| for elem, count in usage.items(): | |
| # nest dict | |
| if isinstance(count, Dict): | |
| res = nest_dict_counter(usage[elem], other.get(elem, {})) | |
| result[elem] = res | |
| continue | |
| newcount = count + other.get(elem, 0) | |
| if not ignore_zero or newcount > 0: | |
| result[elem] = newcount | |
| for elem, count in other.items(): | |
| if elem not in usage and not ignore_zero: | |
| result[elem] = count | |
| return result | |
| def get_class(module_class: str): | |
| import importlib | |
| assert module_class | |
| module_class = module_class.strip() | |
| idx = module_class.rfind('.') | |
| if idx != -1: | |
| module = importlib.import_module(module_class[0:idx]) | |
| class_names = module_class[idx + 1:].split(":") | |
| cls_obj = getattr(module, class_names[0]) | |
| for inner_class_name in class_names[1:]: | |
| cls_obj = getattr(cls_obj, inner_class_name) | |
| return cls_obj | |
| else: | |
| raise Exception("{} can not find!".format(module_class)) | |
| def new_instance(module_class: str, *args, **kwargs): | |
| """Create module class instance based on module name.""" | |
| return get_class(module_class)(*args, **kwargs) | |
| def retryable(tries: int = 3, delay: int = 1): | |
| def inner_retry(f): | |
| def f_retry(*args, **kwargs): | |
| mtries, mdelay = tries, delay | |
| while mtries > 0: | |
| try: | |
| return f(*args, **kwargs) | |
| except Exception as e: | |
| msg = f"{str(e)}, Retrying in {mdelay} seconds..." | |
| logger.warning(msg) | |
| time.sleep(mdelay) | |
| mtries -= 1 | |
| return f(*args, **kwargs) | |
| return f_retry | |
| return inner_retry | |
| def get_local_ip(): | |
| try: | |
| # build UDP socket | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| # connect to an external address (no need to connect) | |
| s.connect(("8.8.8.8", 80)) | |
| # get local IP | |
| local_ip = s.getsockname()[0] | |
| s.close() | |
| return local_ip | |
| except Exception: | |
| return "127.0.0.1" | |
| def replace_env_variables(config) -> Any: | |
| """Replace environment variables in configuration. | |
| Environment variables should be in the format ${ENV_VAR_NAME}. | |
| Args: | |
| config: Configuration to process (dict, list, or other value) | |
| Returns: | |
| Processed configuration with environment variables replaced | |
| """ | |
| if isinstance(config, dict): | |
| for key, value in config.items(): | |
| if isinstance(value, str): | |
| if "${" in value and "}" in value: | |
| pattern = r'\${([^}]+)}' | |
| matches = re.findall(pattern, value) | |
| result = value | |
| for env_var_name in matches: | |
| env_var_value = os.getenv(env_var_name, f"${{{env_var_name}}}") | |
| result = result.replace(f"${{{env_var_name}}}", env_var_value) | |
| config[key] = result | |
| logger.info(f"Replaced {value} with {config[key]}") | |
| if isinstance(value, dict) or isinstance(value, list): | |
| replace_env_variables(value) | |
| elif isinstance(config, list): | |
| for index, item in enumerate(config): | |
| if isinstance(item, str): | |
| if "${" in item and "}" in item: | |
| pattern = r'\${([^}]+)}' | |
| matches = re.findall(pattern, item) | |
| result = item | |
| for env_var_name in matches: | |
| env_var_value = os.getenv(env_var_name, f"${{{env_var_name}}}") | |
| result = result.replace(f"${{{env_var_name}}}", env_var_value) | |
| config[index] = result | |
| logger.info(f"Replaced {item} with {config[index]}") | |
| if isinstance(item, dict) or isinstance(item, list): | |
| replace_env_variables(item) | |
| return config | |
| def get_local_hostname(): | |
| """ | |
| Get the local hostname. | |
| First try `socket.gethostname()`, if it fails or returns an invalid value, | |
| then try reverse DNS lookup using local IP. | |
| """ | |
| try: | |
| hostname = socket.gethostname() | |
| # Simple validation - if hostname contains '.', consider it a valid FQDN (Fully Qualified Domain Name) | |
| if hostname and '.' in hostname: | |
| return hostname | |
| # If hostname is not qualified, try reverse lookup via IP | |
| local_ip = get_local_ip() | |
| if local_ip: | |
| try: | |
| # Get hostname from IP | |
| hostname, _, _ = socket.gethostbyaddr(local_ip) | |
| return hostname | |
| except (socket.herror, socket.gaierror): | |
| # Reverse lookup failed, return original hostname or IP | |
| pass | |
| # If all methods fail, return original gethostname() result or IP | |
| return hostname if hostname else local_ip | |
| except Exception: | |
| # Final fallback strategy | |
| return "localhost" |