Spaces:
Sleeping
Sleeping
| """Translation provider factory for creating and managing translation providers.""" | |
| import logging | |
| from typing import Dict, List, Optional, Type | |
| from enum import Enum | |
| from ..base.translation_provider_base import TranslationProviderBase | |
| from .nllb_provider import NLLBTranslationProvider | |
| from ...domain.exceptions import TranslationFailedException | |
| logger = logging.getLogger(__name__) | |
| class TranslationProviderType(Enum): | |
| """Enumeration of available translation provider types.""" | |
| NLLB = "nllb" | |
| # Future providers can be added here | |
| # GOOGLE = "google" | |
| # AZURE = "azure" | |
| # AWS = "aws" | |
| class TranslationProviderFactory: | |
| """Factory for creating and managing translation provider instances.""" | |
| # Registry of available provider classes | |
| _PROVIDER_REGISTRY: Dict[TranslationProviderType, Type[TranslationProviderBase]] = { | |
| TranslationProviderType.NLLB: NLLBTranslationProvider, | |
| } | |
| # Default provider configurations | |
| _DEFAULT_CONFIGS = { | |
| TranslationProviderType.NLLB: { | |
| 'model_name': 'facebook/nllb-200-3.3B', | |
| 'max_chunk_length': 1000 | |
| } | |
| } | |
| def __init__(self): | |
| """Initialize the translation provider factory.""" | |
| self._provider_cache: Dict[str, TranslationProviderBase] = {} | |
| self._availability_cache: Dict[TranslationProviderType, bool] = {} | |
| def create_provider( | |
| self, | |
| provider_type: TranslationProviderType, | |
| config: Optional[Dict] = None, | |
| use_cache: bool = True | |
| ) -> TranslationProviderBase: | |
| """ | |
| Create a translation provider instance. | |
| Args: | |
| provider_type: The type of provider to create | |
| config: Optional configuration parameters for the provider | |
| use_cache: Whether to use cached provider instances | |
| Returns: | |
| TranslationProviderBase: The created provider instance | |
| Raises: | |
| TranslationFailedException: If provider creation fails | |
| """ | |
| try: | |
| # Generate cache key | |
| cache_key = self._generate_cache_key(provider_type, config) | |
| # Return cached instance if available and requested | |
| if use_cache and cache_key in self._provider_cache: | |
| logger.debug(f"Returning cached {provider_type.value} provider") | |
| return self._provider_cache[cache_key] | |
| # Check if provider type is registered | |
| if provider_type not in self._PROVIDER_REGISTRY: | |
| raise TranslationFailedException( | |
| f"Unknown translation provider type: {provider_type.value}. " | |
| f"Available types: {[t.value for t in self._PROVIDER_REGISTRY.keys()]}" | |
| ) | |
| # Get provider class | |
| provider_class = self._PROVIDER_REGISTRY[provider_type] | |
| # Merge default config with provided config | |
| final_config = self._DEFAULT_CONFIGS.get(provider_type, {}).copy() | |
| if config: | |
| final_config.update(config) | |
| logger.info(f"Creating {provider_type.value} translation provider") | |
| logger.debug(f"Provider config: {final_config}") | |
| # Create provider instance | |
| provider = provider_class(**final_config) | |
| # Cache the provider if requested | |
| if use_cache: | |
| self._provider_cache[cache_key] = provider | |
| logger.info(f"Successfully created {provider_type.value} translation provider") | |
| return provider | |
| except Exception as e: | |
| logger.error(f"Failed to create {provider_type.value} provider: {str(e)}") | |
| raise TranslationFailedException( | |
| f"Failed to create {provider_type.value} provider: {str(e)}" | |
| ) from e | |
| def get_available_providers(self, force_check: bool = False) -> List[TranslationProviderType]: | |
| """ | |
| Get list of available translation providers. | |
| Args: | |
| force_check: Whether to force availability check (ignore cache) | |
| Returns: | |
| List[TranslationProviderType]: List of available provider types | |
| """ | |
| available_providers = [] | |
| for provider_type in self._PROVIDER_REGISTRY.keys(): | |
| if self._is_provider_available(provider_type, force_check): | |
| available_providers.append(provider_type) | |
| logger.info(f"Available translation providers: {[p.value for p in available_providers]}") | |
| return available_providers | |
| def get_default_provider(self, config: Optional[Dict] = None) -> TranslationProviderBase: | |
| """ | |
| Get the default translation provider. | |
| Args: | |
| config: Optional configuration for the provider | |
| Returns: | |
| TranslationProviderBase: The default provider instance | |
| Raises: | |
| TranslationFailedException: If no providers are available | |
| """ | |
| available_providers = self.get_available_providers() | |
| if not available_providers: | |
| raise TranslationFailedException("No translation providers are available") | |
| # Use NLLB as default if available, otherwise use the first available | |
| default_type = TranslationProviderType.NLLB | |
| if default_type not in available_providers: | |
| default_type = available_providers[0] | |
| logger.info(f"Using {default_type.value} as default translation provider") | |
| return self.create_provider(default_type, config) | |
| def get_provider_with_fallback( | |
| self, | |
| preferred_types: List[TranslationProviderType], | |
| config: Optional[Dict] = None | |
| ) -> TranslationProviderBase: | |
| """ | |
| Get a provider with fallback options. | |
| Args: | |
| preferred_types: List of preferred provider types in order of preference | |
| config: Optional configuration for the provider | |
| Returns: | |
| TranslationProviderBase: The first available provider from the list | |
| Raises: | |
| TranslationFailedException: If none of the preferred providers are available | |
| """ | |
| available_providers = self.get_available_providers() | |
| for provider_type in preferred_types: | |
| if provider_type in available_providers: | |
| logger.info(f"Using {provider_type.value} translation provider") | |
| return self.create_provider(provider_type, config) | |
| # If no preferred providers are available, try any available provider | |
| if available_providers: | |
| fallback_type = available_providers[0] | |
| logger.warning( | |
| f"None of preferred providers {[p.value for p in preferred_types]} available. " | |
| f"Falling back to {fallback_type.value}" | |
| ) | |
| return self.create_provider(fallback_type, config) | |
| raise TranslationFailedException( | |
| f"None of the preferred translation providers are available: " | |
| f"{[p.value for p in preferred_types]}" | |
| ) | |
| def clear_cache(self) -> None: | |
| """Clear all cached provider instances.""" | |
| self._provider_cache.clear() | |
| self._availability_cache.clear() | |
| logger.info("Translation provider cache cleared") | |
| def get_provider_info(self, provider_type: TranslationProviderType) -> Dict: | |
| """ | |
| Get information about a specific provider type. | |
| Args: | |
| provider_type: The provider type to get info for | |
| Returns: | |
| dict: Provider information | |
| """ | |
| if provider_type not in self._PROVIDER_REGISTRY: | |
| raise TranslationFailedException(f"Unknown provider type: {provider_type.value}") | |
| provider_class = self._PROVIDER_REGISTRY[provider_type] | |
| default_config = self._DEFAULT_CONFIGS.get(provider_type, {}) | |
| is_available = self._is_provider_available(provider_type) | |
| return { | |
| 'type': provider_type.value, | |
| 'class_name': provider_class.__name__, | |
| 'module': provider_class.__module__, | |
| 'available': is_available, | |
| 'default_config': default_config, | |
| 'description': provider_class.__doc__ or "No description available" | |
| } | |
| def get_all_providers_info(self) -> Dict[str, Dict]: | |
| """ | |
| Get information about all registered providers. | |
| Returns: | |
| dict: Information about all providers | |
| """ | |
| providers_info = {} | |
| for provider_type in self._PROVIDER_REGISTRY.keys(): | |
| providers_info[provider_type.value] = self.get_provider_info(provider_type) | |
| return providers_info | |
| def _is_provider_available(self, provider_type: TranslationProviderType, force_check: bool = False) -> bool: | |
| """ | |
| Check if a provider type is available. | |
| Args: | |
| provider_type: The provider type to check | |
| force_check: Whether to force availability check (ignore cache) | |
| Returns: | |
| bool: True if provider is available, False otherwise | |
| """ | |
| # Return cached result if available and not forcing check | |
| if not force_check and provider_type in self._availability_cache: | |
| return self._availability_cache[provider_type] | |
| try: | |
| # Create a temporary instance to check availability | |
| provider_class = self._PROVIDER_REGISTRY[provider_type] | |
| default_config = self._DEFAULT_CONFIGS.get(provider_type, {}) | |
| temp_provider = provider_class(**default_config) | |
| is_available = temp_provider.is_available() | |
| # Cache the result | |
| self._availability_cache[provider_type] = is_available | |
| logger.debug(f"Provider {provider_type.value} availability: {is_available}") | |
| return is_available | |
| except Exception as e: | |
| logger.warning(f"Error checking {provider_type.value} availability: {str(e)}") | |
| self._availability_cache[provider_type] = False | |
| return False | |
| def _generate_cache_key(self, provider_type: TranslationProviderType, config: Optional[Dict]) -> str: | |
| """ | |
| Generate a cache key for provider instances. | |
| Args: | |
| provider_type: The provider type | |
| config: The provider configuration | |
| Returns: | |
| str: Cache key | |
| """ | |
| config_str = "" | |
| if config: | |
| # Sort config items for consistent key generation | |
| sorted_config = sorted(config.items()) | |
| config_str = "_".join(f"{k}={v}" for k, v in sorted_config) | |
| return f"{provider_type.value}_{config_str}" | |
| def register_provider( | |
| cls, | |
| provider_type: TranslationProviderType, | |
| provider_class: Type[TranslationProviderBase], | |
| default_config: Optional[Dict] = None | |
| ) -> None: | |
| """ | |
| Register a new translation provider type. | |
| Args: | |
| provider_type: The provider type enum | |
| provider_class: The provider class | |
| default_config: Default configuration for the provider | |
| """ | |
| cls._PROVIDER_REGISTRY[provider_type] = provider_class | |
| if default_config: | |
| cls._DEFAULT_CONFIGS[provider_type] = default_config | |
| logger.info(f"Registered translation provider: {provider_type.value}") | |
| def get_supported_provider_types(cls) -> List[TranslationProviderType]: | |
| """ | |
| Get all supported provider types. | |
| Returns: | |
| List[TranslationProviderType]: List of supported provider types | |
| """ | |
| return list(cls._PROVIDER_REGISTRY.keys()) | |
| # Global factory instance for convenience | |
| translation_provider_factory = TranslationProviderFactory() | |
| def create_translation_provider( | |
| provider_type: TranslationProviderType = TranslationProviderType.NLLB, | |
| config: Optional[Dict] = None | |
| ) -> TranslationProviderBase: | |
| """ | |
| Convenience function to create a translation provider. | |
| Args: | |
| provider_type: The type of provider to create | |
| config: Optional configuration parameters | |
| Returns: | |
| TranslationProviderBase: The created provider instance | |
| """ | |
| return translation_provider_factory.create_provider(provider_type, config) | |
| def get_default_translation_provider(config: Optional[Dict] = None) -> TranslationProviderBase: | |
| """ | |
| Convenience function to get the default translation provider. | |
| Args: | |
| config: Optional configuration parameters | |
| Returns: | |
| TranslationProviderBase: The default provider instance | |
| """ | |
| return translation_provider_factory.get_default_provider(config) |