Spaces:
Build error
Build error
"""Translation provider factory for creating and managing translation providers.""" | |
import logging | |
from typing import Dict, List, Optional, Type | |
from enum import Enum | |
from ..base.translation_provider_base import TranslationProviderBase | |
from .nllb_provider import NLLBTranslationProvider | |
from ...domain.exceptions import TranslationFailedException | |
logger = logging.getLogger(__name__) | |
class TranslationProviderType(Enum): | |
"""Enumeration of available translation provider types.""" | |
NLLB = "nllb" | |
# Future providers can be added here | |
# GOOGLE = "google" | |
# AZURE = "azure" | |
# AWS = "aws" | |
class TranslationProviderFactory: | |
"""Factory for creating and managing translation provider instances.""" | |
# Registry of available provider classes | |
_PROVIDER_REGISTRY: Dict[TranslationProviderType, Type[TranslationProviderBase]] = { | |
TranslationProviderType.NLLB: NLLBTranslationProvider, | |
} | |
# Default provider configurations | |
_DEFAULT_CONFIGS = { | |
TranslationProviderType.NLLB: { | |
'model_name': 'facebook/nllb-200-3.3B', | |
'max_chunk_length': 1000 | |
} | |
} | |
def __init__(self): | |
"""Initialize the translation provider factory.""" | |
self._provider_cache: Dict[str, TranslationProviderBase] = {} | |
self._availability_cache: Dict[TranslationProviderType, bool] = {} | |
def create_provider( | |
self, | |
provider_type: TranslationProviderType, | |
config: Optional[Dict] = None, | |
use_cache: bool = True | |
) -> TranslationProviderBase: | |
""" | |
Create a translation provider instance. | |
Args: | |
provider_type: The type of provider to create | |
config: Optional configuration parameters for the provider | |
use_cache: Whether to use cached provider instances | |
Returns: | |
TranslationProviderBase: The created provider instance | |
Raises: | |
TranslationFailedException: If provider creation fails | |
""" | |
try: | |
# Generate cache key | |
cache_key = self._generate_cache_key(provider_type, config) | |
# Return cached instance if available and requested | |
if use_cache and cache_key in self._provider_cache: | |
logger.info(f"Returning cached {provider_type.value} provider") | |
return self._provider_cache[cache_key] | |
# Check if provider type is registered | |
if provider_type not in self._PROVIDER_REGISTRY: | |
raise TranslationFailedException( | |
f"Unknown translation provider type: {provider_type.value}. " | |
f"Available types: {[t.value for t in self._PROVIDER_REGISTRY.keys()]}" | |
) | |
# Get provider class | |
provider_class = self._PROVIDER_REGISTRY[provider_type] | |
# Merge default config with provided config | |
final_config = self._DEFAULT_CONFIGS.get(provider_type, {}).copy() | |
if config: | |
final_config.update(config) | |
logger.info(f"Creating {provider_type.value} translation provider") | |
logger.info(f"Provider config: {final_config}") | |
# Create provider instance | |
provider = provider_class(**final_config) | |
# Cache the provider if requested | |
if use_cache: | |
self._provider_cache[cache_key] = provider | |
logger.info(f"Successfully created {provider_type.value} translation provider") | |
return provider | |
except Exception as e: | |
logger.error(f"Failed to create {provider_type.value} provider: {str(e)}") | |
raise TranslationFailedException( | |
f"Failed to create {provider_type.value} provider: {str(e)}" | |
) from e | |
def get_available_providers(self, force_check: bool = False) -> List[TranslationProviderType]: | |
""" | |
Get list of available translation providers. | |
Args: | |
force_check: Whether to force availability check (ignore cache) | |
Returns: | |
List[TranslationProviderType]: List of available provider types | |
""" | |
available_providers = [] | |
for provider_type in self._PROVIDER_REGISTRY.keys(): | |
if self._is_provider_available(provider_type, force_check): | |
available_providers.append(provider_type) | |
logger.info(f"Available translation providers: {[p.value for p in available_providers]}") | |
return available_providers | |
def get_default_provider(self, config: Optional[Dict] = None) -> TranslationProviderBase: | |
""" | |
Get the default translation provider. | |
Args: | |
config: Optional configuration for the provider | |
Returns: | |
TranslationProviderBase: The default provider instance | |
Raises: | |
TranslationFailedException: If no providers are available | |
""" | |
available_providers = self.get_available_providers() | |
if not available_providers: | |
raise TranslationFailedException("No translation providers are available") | |
# Use NLLB as default if available, otherwise use the first available | |
default_type = TranslationProviderType.NLLB | |
if default_type not in available_providers: | |
default_type = available_providers[0] | |
logger.info(f"Using {default_type.value} as default translation provider") | |
return self.create_provider(default_type, config) | |
def get_provider_with_fallback( | |
self, | |
preferred_types: List[TranslationProviderType], | |
config: Optional[Dict] = None | |
) -> TranslationProviderBase: | |
""" | |
Get a provider with fallback options. | |
Args: | |
preferred_types: List of preferred provider types in order of preference | |
config: Optional configuration for the provider | |
Returns: | |
TranslationProviderBase: The first available provider from the list | |
Raises: | |
TranslationFailedException: If none of the preferred providers are available | |
""" | |
available_providers = self.get_available_providers() | |
for provider_type in preferred_types: | |
if provider_type in available_providers: | |
logger.info(f"Using {provider_type.value} translation provider") | |
return self.create_provider(provider_type, config) | |
# If no preferred providers are available, try any available provider | |
if available_providers: | |
fallback_type = available_providers[0] | |
logger.warning( | |
f"None of preferred providers {[p.value for p in preferred_types]} available. " | |
f"Falling back to {fallback_type.value}" | |
) | |
return self.create_provider(fallback_type, config) | |
raise TranslationFailedException( | |
f"None of the preferred translation providers are available: " | |
f"{[p.value for p in preferred_types]}" | |
) | |
def clear_cache(self) -> None: | |
"""Clear all cached provider instances.""" | |
self._provider_cache.clear() | |
self._availability_cache.clear() | |
logger.info("Translation provider cache cleared") | |
def get_provider_info(self, provider_type: TranslationProviderType) -> Dict: | |
""" | |
Get information about a specific provider type. | |
Args: | |
provider_type: The provider type to get info for | |
Returns: | |
dict: Provider information | |
""" | |
if provider_type not in self._PROVIDER_REGISTRY: | |
raise TranslationFailedException(f"Unknown provider type: {provider_type.value}") | |
provider_class = self._PROVIDER_REGISTRY[provider_type] | |
default_config = self._DEFAULT_CONFIGS.get(provider_type, {}) | |
is_available = self._is_provider_available(provider_type) | |
return { | |
'type': provider_type.value, | |
'class_name': provider_class.__name__, | |
'module': provider_class.__module__, | |
'available': is_available, | |
'default_config': default_config, | |
'description': provider_class.__doc__ or "No description available" | |
} | |
def get_all_providers_info(self) -> Dict[str, Dict]: | |
""" | |
Get information about all registered providers. | |
Returns: | |
dict: Information about all providers | |
""" | |
providers_info = {} | |
for provider_type in self._PROVIDER_REGISTRY.keys(): | |
providers_info[provider_type.value] = self.get_provider_info(provider_type) | |
return providers_info | |
def _is_provider_available(self, provider_type: TranslationProviderType, force_check: bool = False) -> bool: | |
""" | |
Check if a provider type is available. | |
Args: | |
provider_type: The provider type to check | |
force_check: Whether to force availability check (ignore cache) | |
Returns: | |
bool: True if provider is available, False otherwise | |
""" | |
# Return cached result if available and not forcing check | |
if not force_check and provider_type in self._availability_cache: | |
return self._availability_cache[provider_type] | |
try: | |
# Create a temporary instance to check availability | |
provider_class = self._PROVIDER_REGISTRY[provider_type] | |
default_config = self._DEFAULT_CONFIGS.get(provider_type, {}) | |
temp_provider = provider_class(**default_config) | |
is_available = temp_provider.is_available() | |
# Cache the result | |
self._availability_cache[provider_type] = is_available | |
logger.info(f"Provider {provider_type.value} availability: {is_available}") | |
return is_available | |
except Exception as e: | |
logger.warning(f"Error checking {provider_type.value} availability: {str(e)}") | |
self._availability_cache[provider_type] = False | |
return False | |
def _generate_cache_key(self, provider_type: TranslationProviderType, config: Optional[Dict]) -> str: | |
""" | |
Generate a cache key for provider instances. | |
Args: | |
provider_type: The provider type | |
config: The provider configuration | |
Returns: | |
str: Cache key | |
""" | |
config_str = "" | |
if config: | |
# Sort config items for consistent key generation | |
sorted_config = sorted(config.items()) | |
config_str = "_".join(f"{k}={v}" for k, v in sorted_config) | |
return f"{provider_type.value}_{config_str}" | |
def register_provider( | |
cls, | |
provider_type: TranslationProviderType, | |
provider_class: Type[TranslationProviderBase], | |
default_config: Optional[Dict] = None | |
) -> None: | |
""" | |
Register a new translation provider type. | |
Args: | |
provider_type: The provider type enum | |
provider_class: The provider class | |
default_config: Default configuration for the provider | |
""" | |
cls._PROVIDER_REGISTRY[provider_type] = provider_class | |
if default_config: | |
cls._DEFAULT_CONFIGS[provider_type] = default_config | |
logger.info(f"Registered translation provider: {provider_type.value}") | |
def get_supported_provider_types(cls) -> List[TranslationProviderType]: | |
""" | |
Get all supported provider types. | |
Returns: | |
List[TranslationProviderType]: List of supported provider types | |
""" | |
return list(cls._PROVIDER_REGISTRY.keys()) | |
# Global factory instance for convenience | |
translation_provider_factory = TranslationProviderFactory() | |
def create_translation_provider( | |
provider_type: TranslationProviderType = TranslationProviderType.NLLB, | |
config: Optional[Dict] = None | |
) -> TranslationProviderBase: | |
""" | |
Convenience function to create a translation provider. | |
Args: | |
provider_type: The type of provider to create | |
config: Optional configuration parameters | |
Returns: | |
TranslationProviderBase: The created provider instance | |
""" | |
return translation_provider_factory.create_provider(provider_type, config) | |
def get_default_translation_provider(config: Optional[Dict] = None) -> TranslationProviderBase: | |
""" | |
Convenience function to get the default translation provider. | |
Args: | |
config: Optional configuration parameters | |
Returns: | |
TranslationProviderBase: The default provider instance | |
""" | |
return translation_provider_factory.get_default_provider(config) |