Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| # Copyright (c) 2025 inclusionAI. | |
| import os | |
| import traceback | |
| import uuid | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Any | |
| import yaml | |
| from pydantic import BaseModel,Field | |
| from enum import Enum | |
| from aworld.logs.util import logger | |
| def load_config(file_name: str, dir_name: str = None) -> Dict[str, Any]: | |
| """Dynamically load config file form current path. | |
| Args: | |
| file_name: Config file name. | |
| dir_name: Config file directory. | |
| Returns: | |
| Config dict. | |
| """ | |
| if dir_name: | |
| file_path = os.path.join(dir_name, file_name) | |
| else: | |
| # load conf form current path | |
| current_dir = Path(__file__).parent.absolute() | |
| file_path = os.path.join(current_dir, file_name) | |
| if not os.path.exists(file_path): | |
| logger.debug(f"{file_path} not exists, please check it.") | |
| configs = dict() | |
| try: | |
| with open(file_path, "r") as file: | |
| yaml_data = yaml.safe_load(file) | |
| configs.update(yaml_data) | |
| except FileNotFoundError: | |
| logger.debug(f"Can not find the file: {file_path}") | |
| except Exception as e: | |
| logger.warning(f"{file_name} read fail.\n", traceback.format_exc()) | |
| return configs | |
| def wipe_secret_info(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: | |
| """Return a deep copy of this config as a plain Dict as well ass wipe up secret info, used to log.""" | |
| def _wipe_secret(conf): | |
| def _wipe_secret_plain_value(v): | |
| if isinstance(v, List): | |
| return [_wipe_secret_plain_value(e) for e in v] | |
| elif isinstance(v, Dict): | |
| return _wipe_secret(v) | |
| else: | |
| return v | |
| key_list = [] | |
| for key in conf.keys(): | |
| key_list.append(key) | |
| for key in key_list: | |
| if key.strip('"') in keys: | |
| conf[key] = '-^_^-' | |
| else: | |
| _wipe_secret_plain_value(conf[key]) | |
| return conf | |
| if not config: | |
| return config | |
| return _wipe_secret(config) | |
| class ClientType(Enum): | |
| SDK = "sdk" | |
| HTTP = "http" | |
| class ConfigDict(dict): | |
| """Object mode operates dict, can read non-existent attributes through `get` method.""" | |
| __setattr__ = dict.__setitem__ | |
| __getattr__ = dict.__getitem__ | |
| def __init__(self, seq: dict = None, **kwargs): | |
| if seq is None: | |
| seq = OrderedDict() | |
| super(ConfigDict, self).__init__(seq, **kwargs) | |
| self.nested(self) | |
| def nested(self, seq: dict): | |
| """Nested recursive processing dict. | |
| Args: | |
| seq: Python original format dict | |
| """ | |
| for k, v in seq.items(): | |
| if isinstance(v, dict): | |
| seq[k] = ConfigDict(v) | |
| self.nested(v) | |
| class BaseConfig(BaseModel): | |
| def config_dict(self) -> ConfigDict: | |
| return ConfigDict(self.model_dump()) | |
| class ModelConfig(BaseConfig): | |
| llm_provider: str = None | |
| llm_model_name: str = None | |
| llm_temperature: float = 1. | |
| llm_base_url: str = None | |
| llm_api_key: str = None | |
| llm_client_type: ClientType = ClientType.SDK | |
| llm_sync_enabled: bool = True | |
| llm_async_enabled: bool = True | |
| max_retries: int = 3 | |
| max_model_len: Optional[int] = None # Maximum model context length | |
| model_type: Optional[str] = 'qwen' # Model type determines tokenizer and maximum length | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| for key, value in kwargs.items(): | |
| if hasattr(self, key): | |
| setattr(self, key, value) | |
| # init max_model_len | |
| if not hasattr(self, 'max_model_len') or self.max_model_len is None: | |
| # qwen or other default model_type | |
| self.max_model_len = 128000 | |
| if hasattr(self, 'model_type') and self.model_type == 'claude': | |
| self.max_model_len = 200000 | |
| class LlmCompressionConfig(BaseConfig): | |
| enabled: bool = False | |
| compress_type: str = 'llm' # llm, llmlingua | |
| trigger_compress_token_length: int = 10000 # Trigger compression when exceeding this length | |
| compress_model: ModelConfig = None | |
| class OptimizationConfig(BaseConfig): | |
| enabled: bool = False | |
| max_token_budget_ratio: float = 0.5 # Maximum context length ratio | |
| class ContextRuleConfig(BaseConfig): | |
| """Context interference rule configuration""" | |
| # ===== Performance optimization configuration ===== | |
| optimization_config: OptimizationConfig = OptimizationConfig() | |
| # ===== LLM conversation compression configuration ===== | |
| llm_compression_config: LlmCompressionConfig = LlmCompressionConfig() | |
| class AgentConfig(BaseConfig): | |
| name: str = None | |
| desc: str = None | |
| llm_config: ModelConfig = ModelConfig() | |
| # for compatibility | |
| llm_provider: str = None | |
| llm_model_name: str = None | |
| llm_temperature: float = 1. | |
| llm_base_url: str = None | |
| llm_api_key: str = None | |
| llm_client_type: ClientType = ClientType.SDK | |
| llm_sync_enabled: bool = True | |
| llm_async_enabled: bool = True | |
| max_retries: int = 3 | |
| max_model_len: Optional[int] = None # Maximum model context length | |
| model_type: Optional[str] = 'qwen' # Model type determines tokenizer and maximum length | |
| # default reset init in first | |
| need_reset: bool = True | |
| # use vision model | |
| use_vision: bool = True | |
| max_steps: int = 10 | |
| max_input_tokens: int = 128000 | |
| max_actions_per_step: int = 10 | |
| system_prompt: Optional[str] = None | |
| agent_prompt: Optional[str] = None | |
| working_dir: Optional[str] = None | |
| enable_recording: bool = False | |
| use_tools_in_prompt: bool = False | |
| exit_on_failure: bool = False | |
| ext: dict = {} | |
| human_tools: List[str] = [] | |
| # context rule | |
| context_rule: ContextRuleConfig = ContextRuleConfig() | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| # Apply all provided kwargs to the instance | |
| for key, value in kwargs.items(): | |
| if hasattr(self, key): | |
| setattr(self, key, value) | |
| # Synchronize model configuration between AgentConfig and llm_config | |
| self._sync_model_config() | |
| # Initialize max_model_len if not set | |
| if not hasattr(self, 'max_model_len') or self.max_model_len is None: | |
| # Default to qwen or other model_type | |
| self.max_model_len = 128000 | |
| if hasattr(self, 'model_type') and self.model_type == 'claude': | |
| self.max_model_len = 200000 | |
| def _sync_model_config(self): | |
| """Synchronize model configuration between AgentConfig and llm_config""" | |
| # Ensure llm_config is initialized | |
| if self.llm_config is None: | |
| self.llm_config = ModelConfig() | |
| # Dynamically get all field names from ModelConfig | |
| model_fields = list(ModelConfig.model_fields.keys()) | |
| # Filter to only include fields that exist in current AgentConfig | |
| agent_fields = set(self.model_fields.keys()) | |
| filtered_model_fields = [field for field in model_fields if field in agent_fields] | |
| # Check which configuration has llm_model_name set | |
| agent_has_model_name = getattr(self, 'llm_model_name', None) is not None | |
| llm_config_has_model_name = getattr(self.llm_config, 'llm_model_name', None) is not None | |
| if agent_has_model_name: | |
| # If AgentConfig has llm_model_name, sync all fields from AgentConfig to llm_config | |
| for field in filtered_model_fields: | |
| agent_value = getattr(self, field, None) | |
| if agent_value is not None: | |
| setattr(self.llm_config, field, agent_value) | |
| elif llm_config_has_model_name: | |
| # If llm_config has llm_model_name, sync all fields from llm_config to AgentConfig | |
| for field in filtered_model_fields: | |
| llm_config_value = getattr(self.llm_config, field, None) | |
| if llm_config_value is not None: | |
| setattr(self, field, llm_config_value) | |
| class TaskConfig(BaseConfig): | |
| task_id: str = str(uuid.uuid4()) | |
| task_name: str | None = None | |
| max_steps: int = 100 | |
| max_actions_per_step: int = 10 | |
| stream: bool = False | |
| exit_on_failure: bool = False | |
| ext: dict = {} | |
| class ToolConfig(BaseConfig): | |
| name: str = None | |
| custom_executor: bool = False | |
| enable_recording: bool = False | |
| working_dir: str = "" | |
| max_retry: int = 3 | |
| llm_config: ModelConfig = None | |
| reuse: bool = False | |
| use_async: bool = False | |
| exit_on_failure: bool = False | |
| ext: dict = {} | |
| class RunConfig(BaseConfig): | |
| name: str = 'local' | |
| worker_num: int = 1 | |
| reuse_process: bool = True | |
| cls: Optional[str] = None | |
| event_bus: Optional[Dict[str, Any]] = None | |
| tracer: Optional[Dict[str, Any]] = None | |
| replay_buffer: Optional[Dict[str, Any]] = None | |
| class EvaluationConfig(BaseConfig): | |
| work_dir: Optional[str] = None | |
| run_times: int = 1 | |