Aiaudio / download_models.py
Athspi's picture
Create download_models.py
6ab4ced verified
raw
history blame
4.64 kB
# ----------- START download_models.py -----------
import os
import logging
# Configure logging similar to app.py
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("ModelDownloader")
# --- Model IDs (MUST MATCH app.py) ---
ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement"
SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Or the exact one used in app.py
# Get cache directory from environment or use default (MUST MATCH app.py/Dockerfile)
HF_CACHE_DIR = os.environ.get("HF_HOME", "/app/hf_cache")
os.makedirs(HF_CACHE_DIR, exist_ok=True)
logger.info(f"Using Hugging Face cache directory: {HF_CACHE_DIR}")
def download_model(model_id: str):
"""Attempts to download a model using transformers or other relevant libraries."""
logger.info(f"--- Attempting to download model: {model_id} ---")
try:
# Try using AutoProcessor and AutoModel first, common for many HF models
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
try:
logger.info(f"Downloading processor for {model_id}...")
AutoProcessor.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
logger.info(f"Processor download attempt finished for {model_id}.")
except Exception as proc_err:
logger.warning(f"Could not download processor using AutoProcessor for {model_id} (might be normal): {proc_err}")
try:
logger.info(f"Downloading model for {model_id}...")
# Use the class expected by app.py if known (e.g., AutoModelForSpeechSeq2Seq)
# Or a generic AutoModel as a fallback attempt
AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
cache_dir=HF_CACHE_DIR
# Add trust_remote_code=True if model requires custom code from HF Hub
)
logger.info(f"Model download attempt finished for {model_id}.")
except Exception as model_err:
logger.warning(f"Could not download model using AutoModel* for {model_id}: {model_err}")
logger.warning("This might be okay if the model requires a different loading method (e.g., SpeechBrain or Demucs library).")
# Add specific download/loading logic if needed
# Example for SpeechBrain (if library installed):
# if "speechbrain" in model_id:
# try:
# from speechbrain.pretrained import SepformerEnhancement
# logger.info(f"Attempting SpeechBrain specific download for {model_id}...")
# SepformerEnhancement.from_hparams(
# source=model_id,
# savedir=os.path.join(HF_CACHE_DIR, "speechbrain", model_id.split('/')[-1]),
# # Don't specify run_opts here, just download
# )
# logger.info(f"SpeechBrain download attempt finished for {model_id}.")
# except Exception as sb_err:
# logger.error(f"Failed SpeechBrain specific download for {model_id}: {sb_err}")
# Example for Demucs (if library installed):
# if "demucs" in model_id:
# try:
# import demucs.separate
# logger.info(f"Attempting Demucs specific download for {model_id}...")
# # This might involve loading the model which triggers download
# demucs.apply.load_model(model_id) # Check correct function
# logger.info(f"Demucs download attempt finished for {model_id}.")
# except Exception as demucs_err:
# logger.error(f"Failed Demucs specific download for {model_id}: {demucs_err}")
logger.info(f"--- Finished download attempt for model: {model_id} ---")
except ImportError:
logger.error("Transformers library not found. Cannot download models.")
except Exception as e:
logger.error(f"An unexpected error occurred during download attempt for {model_id}: {e}", exc_info=True)
if __name__ == "__main__":
logger.info("Starting pre-download of Hugging Face models...")
# List of models to download
models_to_download = [
ENHANCEMENT_MODEL_ID,
SEPARATION_MODEL_ID,
# Add any other models your app uses
]
for model_id in models_to_download:
if model_id: # Ensure model ID is not empty
download_model(model_id)
else:
logger.warning("Skipping empty model ID.")
logger.info("Model pre-download process finished.")
# ----------- END download_models.py -----------