import json import os import warnings from typing import Dict, Any, List, Union, Type, get_origin, get_args from .variables.default import DEFAULT_CONFIG from .variables.base import BaseConfig from ..retrievers.utils import get_all_retriever_names class Config: """Config class for GPT Researcher.""" CONFIG_DIR = os.path.join(os.path.dirname(__file__), "variables") def __init__(self, config_path: str | None = None): """Initialize the config class.""" self.config_path = config_path self.llm_kwargs: Dict[str, Any] = {} self.embedding_kwargs: Dict[str, Any] = {} config_to_use = self.load_config(config_path) self._set_attributes(config_to_use) self._set_embedding_attributes() self._set_llm_attributes() self._handle_deprecated_attributes() self._set_doc_path(config_to_use) def _set_attributes(self, config: Dict[str, Any]) -> None: for key, value in config.items(): env_value = os.getenv(key) if env_value is not None: value = self.convert_env_value(key, env_value, BaseConfig.__annotations__[key]) setattr(self, key.lower(), value) # Handle RETRIEVER with default value retriever_env = os.environ.get("RETRIEVER", config.get("RETRIEVER", "tavily")) try: self.retrievers = self.parse_retrievers(retriever_env) except ValueError as e: print(f"Warning: {str(e)}. Defaulting to 'tavily' retriever.") self.retrievers = ["tavily"] def _set_embedding_attributes(self) -> None: self.embedding_provider, self.embedding_model = self.parse_embedding( self.embedding ) def _set_llm_attributes(self) -> None: self.fast_llm_provider, self.fast_llm_model = self.parse_llm(self.fast_llm) self.smart_llm_provider, self.smart_llm_model = self.parse_llm(self.smart_llm) self.strategic_llm_provider, self.strategic_llm_model = self.parse_llm(self.strategic_llm) def _handle_deprecated_attributes(self) -> None: if os.getenv("EMBEDDING_PROVIDER") is not None: warnings.warn( "EMBEDDING_PROVIDER is deprecated and will be removed soon. Use EMBEDDING instead.", FutureWarning, stacklevel=2, ) self.embedding_provider = ( os.environ["EMBEDDING_PROVIDER"] or self.embedding_provider ) match os.environ["EMBEDDING_PROVIDER"]: case "ollama": self.embedding_model = os.environ["OLLAMA_EMBEDDING_MODEL"] case "custom": self.embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "custom") case "openai": self.embedding_model = "text-embedding-3-large" case "azure_openai": self.embedding_model = "text-embedding-3-large" case "huggingface": self.embedding_model = "sentence-transformers/all-MiniLM-L6-v2" case _: raise Exception("Embedding provider not found.") _deprecation_warning = ( "LLM_PROVIDER, FAST_LLM_MODEL and SMART_LLM_MODEL are deprecated and " "will be removed soon. Use FAST_LLM and SMART_LLM instead." ) if os.getenv("LLM_PROVIDER") is not None: warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) self.fast_llm_provider = ( os.environ["LLM_PROVIDER"] or self.fast_llm_provider ) self.smart_llm_provider = ( os.environ["LLM_PROVIDER"] or self.smart_llm_provider ) if os.getenv("FAST_LLM_MODEL") is not None: warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) self.fast_llm_model = os.environ["FAST_LLM_MODEL"] or self.fast_llm_model if os.getenv("SMART_LLM_MODEL") is not None: warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) self.smart_llm_model = os.environ["SMART_LLM_MODEL"] or self.smart_llm_model def _set_doc_path(self, config: Dict[str, Any]) -> None: self.doc_path = config['DOC_PATH'] if self.doc_path: try: self.validate_doc_path() except Exception as e: print(f"Warning: Error validating doc_path: {str(e)}. Using default doc_path.") self.doc_path = DEFAULT_CONFIG['DOC_PATH'] @classmethod def load_config(cls, config_path: str | None) -> Dict[str, Any]: """Load a configuration by name.""" if config_path is None: return DEFAULT_CONFIG # config_path = os.path.join(cls.CONFIG_DIR, config_path) if not os.path.exists(config_path): if config_path and config_path != "default": print(f"Warning: Configuration not found at '{config_path}'. Using default configuration.") if not config_path.endswith(".json"): print(f"Do you mean '{config_path}.json'?") return DEFAULT_CONFIG with open(config_path, "r") as f: custom_config = json.load(f) # Merge with default config to ensure all keys are present merged_config = DEFAULT_CONFIG.copy() merged_config.update(custom_config) return merged_config @classmethod def list_available_configs(cls) -> List[str]: """List all available configuration names.""" configs = ["default"] for file in os.listdir(cls.CONFIG_DIR): if file.endswith(".json"): configs.append(file[:-5]) # Remove .json extension return configs def parse_retrievers(self, retriever_str: str) -> List[str]: """Parse the retriever string into a list of retrievers and validate them.""" retrievers = [retriever.strip() for retriever in retriever_str.split(",")] valid_retrievers = get_all_retriever_names() or [] invalid_retrievers = [r for r in retrievers if r not in valid_retrievers] if invalid_retrievers: raise ValueError( f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. " f"Valid options are: {', '.join(valid_retrievers)}." ) return retrievers @staticmethod def parse_llm(llm_str: str | None) -> tuple[str | None, str | None]: """Parse llm string into (llm_provider, llm_model).""" from gpt_researcher.llm_provider.generic.base import _SUPPORTED_PROVIDERS if llm_str is None: return None, None try: llm_provider, llm_model = llm_str.split(":", 1) assert llm_provider in _SUPPORTED_PROVIDERS, ( f"Unsupported {llm_provider}.\nSupported llm providers are: " + ", ".join(_SUPPORTED_PROVIDERS) ) return llm_provider, llm_model except ValueError: raise ValueError( "Set SMART_LLM or FAST_LLM = ':' " "Eg 'openai:gpt-4o-mini'" ) @staticmethod def parse_embedding(embedding_str: str | None) -> tuple[str | None, str | None]: """Parse embedding string into (embedding_provider, embedding_model).""" from gpt_researcher.memory.embeddings import _SUPPORTED_PROVIDERS if embedding_str is None: return None, None try: embedding_provider, embedding_model = embedding_str.split(":", 1) assert embedding_provider in _SUPPORTED_PROVIDERS, ( f"Unsupported {embedding_provider}.\nSupported embedding providers are: " + ", ".join(_SUPPORTED_PROVIDERS) ) return embedding_provider, embedding_model except ValueError: raise ValueError( "Set EMBEDDING = ':' " "Eg 'openai:text-embedding-3-large'" ) def validate_doc_path(self): """Ensure that the folder exists at the doc path""" os.makedirs(self.doc_path, exist_ok=True) @staticmethod def convert_env_value(key: str, env_value: str, type_hint: Type) -> Any: """Convert environment variable to the appropriate type based on the type hint.""" origin = get_origin(type_hint) args = get_args(type_hint) if origin is Union: # Handle Union types (e.g., Union[str, None]) for arg in args: if arg is type(None): if env_value.lower() in ("none", "null", ""): return None else: try: return Config.convert_env_value(key, env_value, arg) except ValueError: continue raise ValueError(f"Cannot convert {env_value} to any of {args}") if type_hint is bool: return env_value.lower() in ("true", "1", "yes", "on") elif type_hint is int: return int(env_value) elif type_hint is float: return float(env_value) elif type_hint in (str, Any): return env_value elif origin is list or origin is List: return json.loads(env_value) else: raise ValueError(f"Unsupported type {type_hint} for key {key}")