# app/core/model_manager.py import logging import os import asyncio from pathlib import Path from typing import Callable, Optional, Dict, List # Imports for downloading specific model types import nltk from huggingface_hub import snapshot_download import spacy.cli # Internal application imports from app.core.config import ( MODELS_DIR, NLTK_DATA_DIR, SPACY_MODEL_ID, SENTENCE_TRANSFORMER_MODEL_ID, TONE_MODEL_ID, TRANSLATION_MODEL_ID, WORDNET_NLTK_ID, APP_NAME ) from app.core.exceptions import ModelNotDownloadedError, ModelDownloadFailedError, ServiceError logger = logging.getLogger(f"{APP_NAME}.core.model_manager") # Type alias for progress callback ProgressCallback = Callable[[str, str, float, Optional[str]], None] # (model_id, status, progress, message) def _get_hf_model_local_path(model_id: str) -> Path: """Helper to get the expected local path for a Hugging Face model.""" # snapshot_download creates a specific folder structure inside MODELS_DIR/hf_cache # For example, for "bert-base-uncased", it might be MODELS_DIR/hf_cache/models--bert-base-uncased # The actual model files are inside that. # The `transformers` library usually handles this resolution. # We just need to check if the directory created by snapshot_download exists. # A robust check involves looking inside that directory. return MODELS_DIR / "hf_cache" / model_id.replace("/", "--") # Standard HF cache path logic def check_model_exists(model_id: str, model_type: str) -> bool: """ Checks if a specific model or NLTK data is already downloaded locally. """ if model_type == "huggingface": local_path = _get_hf_model_local_path(model_id) # Check if the directory exists and contains some files return local_path.is_dir() and any(local_path.iterdir()) elif model_type == "spacy": # spaCy models are symlinked or copied into a specific site-packages location # The easiest check is to try loading it, or check spacy.util.is_package # For our purposes, we'll check if the directory created by `spacy download` exists # within our MODELS_DIR, assuming we direct spaCy there. # However, `spacy.load` is the most reliable. For pre-check, we'll rely on the # existence check in load_spacy_model. This is a simplified check. # The actual loading process in app.services.base handles the `is_package` check. # For `spacy.cli.download` to work with MODELS_DIR, it often requires setting SPACY_DATA. spacy_target_path = MODELS_DIR / model_id return spacy_target_path.is_dir() and any(spacy_target_path.iterdir()) elif model_type == "nltk": # NLTK data check try: return nltk.data.find(f"corpora/{model_id}") is not None except LookupError: return False else: logger.warning(f"Unknown model type for check_model_exists: {model_type}") return False # --- Download Functions --- async def download_hf_model_async( model_id: str, feature_name: str, progress_callback: Optional[ProgressCallback] = None ) -> None: """ Asynchronously downloads a Hugging Face model from the Hub. """ logger.info(f"Initiating download for Hugging Face model '{model_id}' for '{feature_name}'...") if check_model_exists(model_id, "huggingface"): logger.info(f"Hugging Face model '{model_id}' already exists locally. Skipping download.") if progress_callback: progress_callback(model_id, "completed", 1.0, "Already downloaded.") return # Use a thread pool for blocking download operation try: def _blocking_download(): # This downloads to MODELS_DIR/hf_cache by default if HF_HOME is set to MODELS_DIR # Otherwise, specify cache_dir. # For simplicity, we rely on `settings.MODELS_DIR` handling HF_HOME in config.py snapshot_download( repo_id=model_id, cache_dir=str(MODELS_DIR / "hf_cache"), # Explicitly set cache directory local_dir_use_symlinks=False, # Use False for better self-contained app # The `_` prefix means it's an internal parameter not typically exposed. # `progress_callback` in `snapshot_download` is not directly exposed for live updates. # We log at beginning and end. ) logger.info(f"Hugging Face model '{model_id}' download complete.") if progress_callback: progress_callback(model_id, "downloading", 0.05, "Starting download...") await asyncio.to_thread(_blocking_download) # Run blocking download in a separate thread if progress_callback: progress_callback(model_id, "completed", 1.0, "Download successful.") except Exception as e: logger.error(f"Failed to download Hugging Face model '{model_id}': {e}", exc_info=True) if progress_callback: progress_callback(model_id, "failed", 0.0, f"Error: {e}") raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e)) async def download_spacy_model_async( model_id: str, feature_name: str, progress_callback: Optional[ProgressCallback] = None ) -> None: """ Asynchronously downloads a spaCy model. """ logger.info(f"Initiating download for spaCy model '{model_id}' for '{feature_name}'...") # Check if the model package is already installed/available in the spacy data path # NOTE: This check might not be sufficient if SPACY_DATA isn't correctly pointing. # The `spacy.util.is_package` would be more robust but requires `import spacy` first. # For now, we trust `spacy.cli.download` to handle the check or fail gracefully. # We must ensure SPACY_DATA environment variable is set to MODELS_DIR # for spacy.cli.download to put it in our custom path. original_spacy_data = os.environ.get("SPACY_DATA") try: os.environ["SPACY_DATA"] = str(MODELS_DIR) if check_model_exists(model_id, "spacy"): # Using our own simplified check logger.info(f"SpaCy model '{model_id}' already exists locally. Skipping download.") if progress_callback: progress_callback(model_id, "completed", 1.0, "Already downloaded.") return def _blocking_download(): # spacy.cli.download attempts to download and link/copy # It will raise an error if already downloaded if it can't link, etc. # We're relying on our check_model_exists before this. spacy.cli.download(model_id) logger.info(f"SpaCy model '{model_id}' download complete.") if progress_callback: progress_callback(model_id, "downloading", 0.05, "Starting download...") await asyncio.to_thread(_blocking_download) if progress_callback: progress_callback(model_id, "completed", 1.0, "Download successful.") except Exception as e: logger.error(f"Failed to download spaCy model '{model_id}': {e}", exc_info=True) if progress_callback: progress_callback(model_id, "failed", 0.0, f"Error: {e}") raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e)) finally: # Restore original SPACY_DATA if it was set if original_spacy_data is not None: os.environ["SPACY_DATA"] = original_spacy_data else: if "SPACY_DATA" in os.environ: del os.environ["SPACY_DATA"] async def download_nltk_data_async( data_id: str, feature_name: str, progress_callback: Optional[ProgressCallback] = None ) -> None: """ Asynchronously downloads NLTK data. """ logger.info(f"Initiating download for NLTK data '{data_id}' for '{feature_name}'...") # NLTK data path should be set by NLTK_DATA environment variable in config.py # `nltk.download` will use this path. if check_model_exists(data_id, "nltk"): logger.info(f"NLTK data '{data_id}' already exists locally. Skipping download.") if progress_callback: progress_callback(data_id, "completed", 1.0, "Already downloaded.") return def _blocking_download(): # NLTK downloader can show a GUI, so ensure it's not trying to do that # `download_dir` should be set by NLTK_DATA env variable. # `quiet=True` is important for programmatic download. nltk.download(data_id, download_dir=str(NLTK_DATA_DIR), quiet=True) logger.info(f"NLTK data '{data_id}' download complete.") try: if progress_callback: progress_callback(data_id, "downloading", 0.05, "Starting download...") await asyncio.to_thread(_blocking_download) if progress_callback: progress_callback(data_id, "completed", 1.0, "Download successful.") except Exception as e: logger.error(f"Failed to download NLTK data '{data_id}': {e}", exc_info=True) if progress_callback: progress_callback(data_id, "failed", 0.0, f"Error: {e}") raise ModelDownloadFailedError(data_id, feature_name, original_error=str(e)) # --- Comprehensive Model Management --- def get_all_required_models() -> List[Dict]: """ Returns a list of all models required by the application, with their type and feature. """ return [ {"id": SPACY_MODEL_ID, "type": "spacy", "feature": "Text Processing (General)"}, {"id": SENTENCE_TRANSFORMER_MODEL_ID, "type": "huggingface", "feature": "Sentence Embeddings"}, {"id": TONE_MODEL_ID, "type": "huggingface", "feature": "Tone Classification"}, {"id": TRANSLATION_MODEL_ID, "type": "huggingface", "feature": "Translation"}, {"id": WORDNET_NLTK_ID, "type": "nltk", "feature": "Synonym Suggestion"}, # Add any other models here as your application grows ] async def download_all_required_models(progress_callback: Optional[ProgressCallback] = None) -> Dict[str, str]: """ Attempts to download all required models. Returns a dictionary of download statuses. """ required_models = get_all_required_models() download_statuses = {} for model_info in required_models: model_id = model_info["id"] model_type = model_info["type"] feature_name = model_info["feature"] if check_model_exists(model_id, model_type): status_message = f"'{model_id}' ({feature_name}) already downloaded." logger.info(status_message) download_statuses[model_id] = "already_downloaded" if progress_callback: progress_callback(model_id, "completed", 1.0, status_message) continue logger.info(f"Attempting to download '{model_id}' ({feature_name})...") try: if model_type == "huggingface": await download_hf_model_async(model_id, feature_name, progress_callback) elif model_type == "spacy": await download_spacy_model_async(model_id, feature_name, progress_callback) elif model_type == "nltk": await download_nltk_data_async(model_id, feature_name, progress_callback) else: raise ValueError(f"Unsupported model type: {model_type}") status_message = f"'{model_id}' ({feature_name}) downloaded successfully." logger.info(status_message) download_statuses[model_id] = "success" except ModelDownloadFailedError as e: status_message = f"Failed to download '{model_id}' ({feature_name}): {e.original_error}" logger.error(status_message) download_statuses[model_id] = "failed" # The progress_callback is already called within the specific download functions on failure except Exception as e: status_message = f"An unexpected error occurred while downloading '{model_id}' ({feature_name}): {e}" logger.error(status_message, exc_info=True) download_statuses[model_id] = "failed" if progress_callback: progress_callback(model_id, "failed", 0.0, status_message) logger.info("Finished attempting to download all required models.") return download_statuses