wellsaid / app /core /model_manager.py
iamspruce
fixed the api
73a6a7e
# app/core/model_manager.py
import logging
import os
import asyncio
from pathlib import Path
from typing import Callable, Optional, Dict, List
# Imports for downloading specific model types
import nltk
from huggingface_hub import snapshot_download
import spacy.cli
# Internal application imports
from app.core.config import (
MODELS_DIR,
NLTK_DATA_DIR,
SPACY_MODEL_ID,
SENTENCE_TRANSFORMER_MODEL_ID,
TONE_MODEL_ID,
TRANSLATION_MODEL_ID,
WORDNET_NLTK_ID,
APP_NAME
)
from app.core.exceptions import ModelNotDownloadedError, ModelDownloadFailedError, ServiceError
logger = logging.getLogger(f"{APP_NAME}.core.model_manager")
# Type alias for progress callback
ProgressCallback = Callable[[str, str, float, Optional[str]], None] # (model_id, status, progress, message)
def _get_hf_model_local_path(model_id: str) -> Path:
"""Helper to get the expected local path for a Hugging Face model."""
# snapshot_download creates a specific folder structure inside MODELS_DIR/hf_cache
# For example, for "bert-base-uncased", it might be MODELS_DIR/hf_cache/models--bert-base-uncased
# The actual model files are inside that.
# The `transformers` library usually handles this resolution.
# We just need to check if the directory created by snapshot_download exists.
# A robust check involves looking inside that directory.
return MODELS_DIR / "hf_cache" / model_id.replace("/", "--") # Standard HF cache path logic
def check_model_exists(model_id: str, model_type: str) -> bool:
"""
Checks if a specific model or NLTK data is already downloaded locally.
"""
if model_type == "huggingface":
local_path = _get_hf_model_local_path(model_id)
# Check if the directory exists and contains some files
return local_path.is_dir() and any(local_path.iterdir())
elif model_type == "spacy":
# spaCy models are symlinked or copied into a specific site-packages location
# The easiest check is to try loading it, or check spacy.util.is_package
# For our purposes, we'll check if the directory created by `spacy download` exists
# within our MODELS_DIR, assuming we direct spaCy there.
# However, `spacy.load` is the most reliable. For pre-check, we'll rely on the
# existence check in load_spacy_model. This is a simplified check.
# The actual loading process in app.services.base handles the `is_package` check.
# For `spacy.cli.download` to work with MODELS_DIR, it often requires setting SPACY_DATA.
spacy_target_path = MODELS_DIR / model_id
return spacy_target_path.is_dir() and any(spacy_target_path.iterdir())
elif model_type == "nltk":
# NLTK data check
try:
return nltk.data.find(f"corpora/{model_id}") is not None
except LookupError:
return False
else:
logger.warning(f"Unknown model type for check_model_exists: {model_type}")
return False
# --- Download Functions ---
async def download_hf_model_async(
model_id: str,
feature_name: str,
progress_callback: Optional[ProgressCallback] = None
) -> None:
"""
Asynchronously downloads a Hugging Face model from the Hub.
"""
logger.info(f"Initiating download for Hugging Face model '{model_id}' for '{feature_name}'...")
if check_model_exists(model_id, "huggingface"):
logger.info(f"Hugging Face model '{model_id}' already exists locally. Skipping download.")
if progress_callback:
progress_callback(model_id, "completed", 1.0, "Already downloaded.")
return
# Use a thread pool for blocking download operation
try:
def _blocking_download():
# This downloads to MODELS_DIR/hf_cache by default if HF_HOME is set to MODELS_DIR
# Otherwise, specify cache_dir.
# For simplicity, we rely on `settings.MODELS_DIR` handling HF_HOME in config.py
snapshot_download(
repo_id=model_id,
cache_dir=str(MODELS_DIR / "hf_cache"), # Explicitly set cache directory
local_dir_use_symlinks=False, # Use False for better self-contained app
# The `_` prefix means it's an internal parameter not typically exposed.
# `progress_callback` in `snapshot_download` is not directly exposed for live updates.
# We log at beginning and end.
)
logger.info(f"Hugging Face model '{model_id}' download complete.")
if progress_callback:
progress_callback(model_id, "downloading", 0.05, "Starting download...")
await asyncio.to_thread(_blocking_download) # Run blocking download in a separate thread
if progress_callback:
progress_callback(model_id, "completed", 1.0, "Download successful.")
except Exception as e:
logger.error(f"Failed to download Hugging Face model '{model_id}': {e}", exc_info=True)
if progress_callback:
progress_callback(model_id, "failed", 0.0, f"Error: {e}")
raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e))
async def download_spacy_model_async(
model_id: str,
feature_name: str,
progress_callback: Optional[ProgressCallback] = None
) -> None:
"""
Asynchronously downloads a spaCy model.
"""
logger.info(f"Initiating download for spaCy model '{model_id}' for '{feature_name}'...")
# Check if the model package is already installed/available in the spacy data path
# NOTE: This check might not be sufficient if SPACY_DATA isn't correctly pointing.
# The `spacy.util.is_package` would be more robust but requires `import spacy` first.
# For now, we trust `spacy.cli.download` to handle the check or fail gracefully.
# We must ensure SPACY_DATA environment variable is set to MODELS_DIR
# for spacy.cli.download to put it in our custom path.
original_spacy_data = os.environ.get("SPACY_DATA")
try:
os.environ["SPACY_DATA"] = str(MODELS_DIR)
if check_model_exists(model_id, "spacy"): # Using our own simplified check
logger.info(f"SpaCy model '{model_id}' already exists locally. Skipping download.")
if progress_callback:
progress_callback(model_id, "completed", 1.0, "Already downloaded.")
return
def _blocking_download():
# spacy.cli.download attempts to download and link/copy
# It will raise an error if already downloaded if it can't link, etc.
# We're relying on our check_model_exists before this.
spacy.cli.download(model_id)
logger.info(f"SpaCy model '{model_id}' download complete.")
if progress_callback:
progress_callback(model_id, "downloading", 0.05, "Starting download...")
await asyncio.to_thread(_blocking_download)
if progress_callback:
progress_callback(model_id, "completed", 1.0, "Download successful.")
except Exception as e:
logger.error(f"Failed to download spaCy model '{model_id}': {e}", exc_info=True)
if progress_callback:
progress_callback(model_id, "failed", 0.0, f"Error: {e}")
raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e))
finally:
# Restore original SPACY_DATA if it was set
if original_spacy_data is not None:
os.environ["SPACY_DATA"] = original_spacy_data
else:
if "SPACY_DATA" in os.environ:
del os.environ["SPACY_DATA"]
async def download_nltk_data_async(
data_id: str,
feature_name: str,
progress_callback: Optional[ProgressCallback] = None
) -> None:
"""
Asynchronously downloads NLTK data.
"""
logger.info(f"Initiating download for NLTK data '{data_id}' for '{feature_name}'...")
# NLTK data path should be set by NLTK_DATA environment variable in config.py
# `nltk.download` will use this path.
if check_model_exists(data_id, "nltk"):
logger.info(f"NLTK data '{data_id}' already exists locally. Skipping download.")
if progress_callback:
progress_callback(data_id, "completed", 1.0, "Already downloaded.")
return
def _blocking_download():
# NLTK downloader can show a GUI, so ensure it's not trying to do that
# `download_dir` should be set by NLTK_DATA env variable.
# `quiet=True` is important for programmatic download.
nltk.download(data_id, download_dir=str(NLTK_DATA_DIR), quiet=True)
logger.info(f"NLTK data '{data_id}' download complete.")
try:
if progress_callback:
progress_callback(data_id, "downloading", 0.05, "Starting download...")
await asyncio.to_thread(_blocking_download)
if progress_callback:
progress_callback(data_id, "completed", 1.0, "Download successful.")
except Exception as e:
logger.error(f"Failed to download NLTK data '{data_id}': {e}", exc_info=True)
if progress_callback:
progress_callback(data_id, "failed", 0.0, f"Error: {e}")
raise ModelDownloadFailedError(data_id, feature_name, original_error=str(e))
# --- Comprehensive Model Management ---
def get_all_required_models() -> List[Dict]:
"""
Returns a list of all models required by the application, with their type and feature.
"""
return [
{"id": SPACY_MODEL_ID, "type": "spacy", "feature": "Text Processing (General)"},
{"id": SENTENCE_TRANSFORMER_MODEL_ID, "type": "huggingface", "feature": "Sentence Embeddings"},
{"id": TONE_MODEL_ID, "type": "huggingface", "feature": "Tone Classification"},
{"id": TRANSLATION_MODEL_ID, "type": "huggingface", "feature": "Translation"},
{"id": WORDNET_NLTK_ID, "type": "nltk", "feature": "Synonym Suggestion"},
# Add any other models here as your application grows
]
async def download_all_required_models(progress_callback: Optional[ProgressCallback] = None) -> Dict[str, str]:
"""
Attempts to download all required models.
Returns a dictionary of download statuses.
"""
required_models = get_all_required_models()
download_statuses = {}
for model_info in required_models:
model_id = model_info["id"]
model_type = model_info["type"]
feature_name = model_info["feature"]
if check_model_exists(model_id, model_type):
status_message = f"'{model_id}' ({feature_name}) already downloaded."
logger.info(status_message)
download_statuses[model_id] = "already_downloaded"
if progress_callback:
progress_callback(model_id, "completed", 1.0, status_message)
continue
logger.info(f"Attempting to download '{model_id}' ({feature_name})...")
try:
if model_type == "huggingface":
await download_hf_model_async(model_id, feature_name, progress_callback)
elif model_type == "spacy":
await download_spacy_model_async(model_id, feature_name, progress_callback)
elif model_type == "nltk":
await download_nltk_data_async(model_id, feature_name, progress_callback)
else:
raise ValueError(f"Unsupported model type: {model_type}")
status_message = f"'{model_id}' ({feature_name}) downloaded successfully."
logger.info(status_message)
download_statuses[model_id] = "success"
except ModelDownloadFailedError as e:
status_message = f"Failed to download '{model_id}' ({feature_name}): {e.original_error}"
logger.error(status_message)
download_statuses[model_id] = "failed"
# The progress_callback is already called within the specific download functions on failure
except Exception as e:
status_message = f"An unexpected error occurred while downloading '{model_id}' ({feature_name}): {e}"
logger.error(status_message, exc_info=True)
download_statuses[model_id] = "failed"
if progress_callback:
progress_callback(model_id, "failed", 0.0, status_message)
logger.info("Finished attempting to download all required models.")
return download_statuses