Spaces:
Sleeping
Sleeping
# coding: utf-8 | |
# Copyright (c) 2025 inclusionAI. | |
import os | |
import traceback | |
import uuid | |
from collections import OrderedDict | |
from pathlib import Path | |
from typing import Optional, List, Dict, Any | |
import yaml | |
from pydantic import BaseModel,Field | |
from enum import Enum | |
from aworld.logs.util import logger | |
def load_config(file_name: str, dir_name: str = None) -> Dict[str, Any]: | |
"""Dynamically load config file form current path. | |
Args: | |
file_name: Config file name. | |
dir_name: Config file directory. | |
Returns: | |
Config dict. | |
""" | |
if dir_name: | |
file_path = os.path.join(dir_name, file_name) | |
else: | |
# load conf form current path | |
current_dir = Path(__file__).parent.absolute() | |
file_path = os.path.join(current_dir, file_name) | |
if not os.path.exists(file_path): | |
logger.debug(f"{file_path} not exists, please check it.") | |
configs = dict() | |
try: | |
with open(file_path, "r") as file: | |
yaml_data = yaml.safe_load(file) | |
configs.update(yaml_data) | |
except FileNotFoundError: | |
logger.debug(f"Can not find the file: {file_path}") | |
except Exception as e: | |
logger.warning(f"{file_name} read fail.\n", traceback.format_exc()) | |
return configs | |
def wipe_secret_info(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: | |
"""Return a deep copy of this config as a plain Dict as well ass wipe up secret info, used to log.""" | |
def _wipe_secret(conf): | |
def _wipe_secret_plain_value(v): | |
if isinstance(v, List): | |
return [_wipe_secret_plain_value(e) for e in v] | |
elif isinstance(v, Dict): | |
return _wipe_secret(v) | |
else: | |
return v | |
key_list = [] | |
for key in conf.keys(): | |
key_list.append(key) | |
for key in key_list: | |
if key.strip('"') in keys: | |
conf[key] = '-^_^-' | |
else: | |
_wipe_secret_plain_value(conf[key]) | |
return conf | |
if not config: | |
return config | |
return _wipe_secret(config) | |
class ClientType(Enum): | |
SDK = "sdk" | |
HTTP = "http" | |
class ConfigDict(dict): | |
"""Object mode operates dict, can read non-existent attributes through `get` method.""" | |
__setattr__ = dict.__setitem__ | |
__getattr__ = dict.__getitem__ | |
def __init__(self, seq: dict = None, **kwargs): | |
if seq is None: | |
seq = OrderedDict() | |
super(ConfigDict, self).__init__(seq, **kwargs) | |
self.nested(self) | |
def nested(self, seq: dict): | |
"""Nested recursive processing dict. | |
Args: | |
seq: Python original format dict | |
""" | |
for k, v in seq.items(): | |
if isinstance(v, dict): | |
seq[k] = ConfigDict(v) | |
self.nested(v) | |
class BaseConfig(BaseModel): | |
def config_dict(self) -> ConfigDict: | |
return ConfigDict(self.model_dump()) | |
class ModelConfig(BaseConfig): | |
llm_provider: str = None | |
llm_model_name: str = None | |
llm_temperature: float = 1. | |
llm_base_url: str = None | |
llm_api_key: str = None | |
llm_client_type: ClientType = ClientType.SDK | |
llm_sync_enabled: bool = True | |
llm_async_enabled: bool = True | |
max_retries: int = 3 | |
max_model_len: Optional[int] = None # Maximum model context length | |
model_type: Optional[str] = 'qwen' # Model type determines tokenizer and maximum length | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
for key, value in kwargs.items(): | |
if hasattr(self, key): | |
setattr(self, key, value) | |
# init max_model_len | |
if not hasattr(self, 'max_model_len') or self.max_model_len is None: | |
# qwen or other default model_type | |
self.max_model_len = 128000 | |
if hasattr(self, 'model_type') and self.model_type == 'claude': | |
self.max_model_len = 200000 | |
class LlmCompressionConfig(BaseConfig): | |
enabled: bool = False | |
compress_type: str = 'llm' # llm, llmlingua | |
trigger_compress_token_length: int = 10000 # Trigger compression when exceeding this length | |
compress_model: ModelConfig = None | |
class OptimizationConfig(BaseConfig): | |
enabled: bool = False | |
max_token_budget_ratio: float = 0.5 # Maximum context length ratio | |
class ContextRuleConfig(BaseConfig): | |
"""Context interference rule configuration""" | |
# ===== Performance optimization configuration ===== | |
optimization_config: OptimizationConfig = OptimizationConfig() | |
# ===== LLM conversation compression configuration ===== | |
llm_compression_config: LlmCompressionConfig = LlmCompressionConfig() | |
class AgentConfig(BaseConfig): | |
name: str = None | |
desc: str = None | |
llm_config: ModelConfig = ModelConfig() | |
# for compatibility | |
llm_provider: str = None | |
llm_model_name: str = None | |
llm_temperature: float = 1. | |
llm_base_url: str = None | |
llm_api_key: str = None | |
llm_client_type: ClientType = ClientType.SDK | |
llm_sync_enabled: bool = True | |
llm_async_enabled: bool = True | |
max_retries: int = 3 | |
max_model_len: Optional[int] = None # Maximum model context length | |
model_type: Optional[str] = 'qwen' # Model type determines tokenizer and maximum length | |
# default reset init in first | |
need_reset: bool = True | |
# use vision model | |
use_vision: bool = True | |
max_steps: int = 10 | |
max_input_tokens: int = 128000 | |
max_actions_per_step: int = 10 | |
system_prompt: Optional[str] = None | |
agent_prompt: Optional[str] = None | |
working_dir: Optional[str] = None | |
enable_recording: bool = False | |
use_tools_in_prompt: bool = False | |
exit_on_failure: bool = False | |
ext: dict = {} | |
human_tools: List[str] = [] | |
# context rule | |
context_rule: ContextRuleConfig = ContextRuleConfig() | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
# Apply all provided kwargs to the instance | |
for key, value in kwargs.items(): | |
if hasattr(self, key): | |
setattr(self, key, value) | |
# Synchronize model configuration between AgentConfig and llm_config | |
self._sync_model_config() | |
# Initialize max_model_len if not set | |
if not hasattr(self, 'max_model_len') or self.max_model_len is None: | |
# Default to qwen or other model_type | |
self.max_model_len = 128000 | |
if hasattr(self, 'model_type') and self.model_type == 'claude': | |
self.max_model_len = 200000 | |
def _sync_model_config(self): | |
"""Synchronize model configuration between AgentConfig and llm_config""" | |
# Ensure llm_config is initialized | |
if self.llm_config is None: | |
self.llm_config = ModelConfig() | |
# Dynamically get all field names from ModelConfig | |
model_fields = list(ModelConfig.model_fields.keys()) | |
# Filter to only include fields that exist in current AgentConfig | |
agent_fields = set(self.model_fields.keys()) | |
filtered_model_fields = [field for field in model_fields if field in agent_fields] | |
# Check which configuration has llm_model_name set | |
agent_has_model_name = getattr(self, 'llm_model_name', None) is not None | |
llm_config_has_model_name = getattr(self.llm_config, 'llm_model_name', None) is not None | |
if agent_has_model_name: | |
# If AgentConfig has llm_model_name, sync all fields from AgentConfig to llm_config | |
for field in filtered_model_fields: | |
agent_value = getattr(self, field, None) | |
if agent_value is not None: | |
setattr(self.llm_config, field, agent_value) | |
elif llm_config_has_model_name: | |
# If llm_config has llm_model_name, sync all fields from llm_config to AgentConfig | |
for field in filtered_model_fields: | |
llm_config_value = getattr(self.llm_config, field, None) | |
if llm_config_value is not None: | |
setattr(self, field, llm_config_value) | |
class TaskConfig(BaseConfig): | |
task_id: str = str(uuid.uuid4()) | |
task_name: str | None = None | |
max_steps: int = 100 | |
max_actions_per_step: int = 10 | |
stream: bool = False | |
exit_on_failure: bool = False | |
ext: dict = {} | |
class ToolConfig(BaseConfig): | |
name: str = None | |
custom_executor: bool = False | |
enable_recording: bool = False | |
working_dir: str = "" | |
max_retry: int = 3 | |
llm_config: ModelConfig = None | |
reuse: bool = False | |
use_async: bool = False | |
exit_on_failure: bool = False | |
ext: dict = {} | |
class RunConfig(BaseConfig): | |
name: str = 'local' | |
worker_num: int = 1 | |
reuse_process: bool = True | |
cls: Optional[str] = None | |
event_bus: Optional[Dict[str, Any]] = None | |
tracer: Optional[Dict[str, Any]] = None | |
replay_buffer: Optional[Dict[str, Any]] = None | |
class EvaluationConfig(BaseConfig): | |
work_dir: Optional[str] = None | |
run_times: int = 1 | |