Spaces:
Sleeping
Sleeping
# ----------- START download_models.py ----------- | |
import os | |
import logging | |
# Configure logging similar to app.py | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger("ModelDownloader") | |
# --- Model IDs (MUST MATCH app.py) --- | |
ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement" | |
SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Or the exact one used in app.py | |
# Get cache directory from environment or use default (MUST MATCH app.py/Dockerfile) | |
HF_CACHE_DIR = os.environ.get("HF_HOME", "/app/hf_cache") | |
os.makedirs(HF_CACHE_DIR, exist_ok=True) | |
logger.info(f"Using Hugging Face cache directory: {HF_CACHE_DIR}") | |
def download_model(model_id: str): | |
"""Attempts to download a model using transformers or other relevant libraries.""" | |
logger.info(f"--- Attempting to download model: {model_id} ---") | |
try: | |
# Try using AutoProcessor and AutoModel first, common for many HF models | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
try: | |
logger.info(f"Downloading processor for {model_id}...") | |
AutoProcessor.from_pretrained(model_id, cache_dir=HF_CACHE_DIR) | |
logger.info(f"Processor download attempt finished for {model_id}.") | |
except Exception as proc_err: | |
logger.warning(f"Could not download processor using AutoProcessor for {model_id} (might be normal): {proc_err}") | |
try: | |
logger.info(f"Downloading model for {model_id}...") | |
# Use the class expected by app.py if known (e.g., AutoModelForSpeechSeq2Seq) | |
# Or a generic AutoModel as a fallback attempt | |
AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, | |
cache_dir=HF_CACHE_DIR | |
# Add trust_remote_code=True if model requires custom code from HF Hub | |
) | |
logger.info(f"Model download attempt finished for {model_id}.") | |
except Exception as model_err: | |
logger.warning(f"Could not download model using AutoModel* for {model_id}: {model_err}") | |
logger.warning("This might be okay if the model requires a different loading method (e.g., SpeechBrain or Demucs library).") | |
# Add specific download/loading logic if needed | |
# Example for SpeechBrain (if library installed): | |
# if "speechbrain" in model_id: | |
# try: | |
# from speechbrain.pretrained import SepformerEnhancement | |
# logger.info(f"Attempting SpeechBrain specific download for {model_id}...") | |
# SepformerEnhancement.from_hparams( | |
# source=model_id, | |
# savedir=os.path.join(HF_CACHE_DIR, "speechbrain", model_id.split('/')[-1]), | |
# # Don't specify run_opts here, just download | |
# ) | |
# logger.info(f"SpeechBrain download attempt finished for {model_id}.") | |
# except Exception as sb_err: | |
# logger.error(f"Failed SpeechBrain specific download for {model_id}: {sb_err}") | |
# Example for Demucs (if library installed): | |
# if "demucs" in model_id: | |
# try: | |
# import demucs.separate | |
# logger.info(f"Attempting Demucs specific download for {model_id}...") | |
# # This might involve loading the model which triggers download | |
# demucs.apply.load_model(model_id) # Check correct function | |
# logger.info(f"Demucs download attempt finished for {model_id}.") | |
# except Exception as demucs_err: | |
# logger.error(f"Failed Demucs specific download for {model_id}: {demucs_err}") | |
logger.info(f"--- Finished download attempt for model: {model_id} ---") | |
except ImportError: | |
logger.error("Transformers library not found. Cannot download models.") | |
except Exception as e: | |
logger.error(f"An unexpected error occurred during download attempt for {model_id}: {e}", exc_info=True) | |
if __name__ == "__main__": | |
logger.info("Starting pre-download of Hugging Face models...") | |
# List of models to download | |
models_to_download = [ | |
ENHANCEMENT_MODEL_ID, | |
SEPARATION_MODEL_ID, | |
# Add any other models your app uses | |
] | |
for model_id in models_to_download: | |
if model_id: # Ensure model ID is not empty | |
download_model(model_id) | |
else: | |
logger.warning("Skipping empty model ID.") | |
logger.info("Model pre-download process finished.") | |
# ----------- END download_models.py ----------- |