File size: 4,636 Bytes
6ab4ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# ----------- START download_models.py -----------
import os
import logging

# Configure logging similar to app.py
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("ModelDownloader")

# --- Model IDs (MUST MATCH app.py) ---
ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement"
SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Or the exact one used in app.py

# Get cache directory from environment or use default (MUST MATCH app.py/Dockerfile)
HF_CACHE_DIR = os.environ.get("HF_HOME", "/app/hf_cache")
os.makedirs(HF_CACHE_DIR, exist_ok=True)
logger.info(f"Using Hugging Face cache directory: {HF_CACHE_DIR}")


def download_model(model_id: str):
    """Attempts to download a model using transformers or other relevant libraries."""
    logger.info(f"--- Attempting to download model: {model_id} ---")
    try:
        # Try using AutoProcessor and AutoModel first, common for many HF models
        from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

        try:
            logger.info(f"Downloading processor for {model_id}...")
            AutoProcessor.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
            logger.info(f"Processor download attempt finished for {model_id}.")
        except Exception as proc_err:
            logger.warning(f"Could not download processor using AutoProcessor for {model_id} (might be normal): {proc_err}")

        try:
             logger.info(f"Downloading model for {model_id}...")
             # Use the class expected by app.py if known (e.g., AutoModelForSpeechSeq2Seq)
             # Or a generic AutoModel as a fallback attempt
             AutoModelForSpeechSeq2Seq.from_pretrained(
                 model_id,
                 cache_dir=HF_CACHE_DIR
                 # Add trust_remote_code=True if model requires custom code from HF Hub
             )
             logger.info(f"Model download attempt finished for {model_id}.")
        except Exception as model_err:
             logger.warning(f"Could not download model using AutoModel* for {model_id}: {model_err}")
             logger.warning("This might be okay if the model requires a different loading method (e.g., SpeechBrain or Demucs library).")

        # Add specific download/loading logic if needed
        # Example for SpeechBrain (if library installed):
        # if "speechbrain" in model_id:
        #     try:
        #         from speechbrain.pretrained import SepformerEnhancement
        #         logger.info(f"Attempting SpeechBrain specific download for {model_id}...")
        #         SepformerEnhancement.from_hparams(
        #             source=model_id,
        #             savedir=os.path.join(HF_CACHE_DIR, "speechbrain", model_id.split('/')[-1]),
        #             # Don't specify run_opts here, just download
        #         )
        #         logger.info(f"SpeechBrain download attempt finished for {model_id}.")
        #     except Exception as sb_err:
        #         logger.error(f"Failed SpeechBrain specific download for {model_id}: {sb_err}")

        # Example for Demucs (if library installed):
        # if "demucs" in model_id:
        #    try:
        #         import demucs.separate
        #         logger.info(f"Attempting Demucs specific download for {model_id}...")
        #         # This might involve loading the model which triggers download
        #         demucs.apply.load_model(model_id) # Check correct function
        #         logger.info(f"Demucs download attempt finished for {model_id}.")
        #    except Exception as demucs_err:
        #        logger.error(f"Failed Demucs specific download for {model_id}: {demucs_err}")

        logger.info(f"--- Finished download attempt for model: {model_id} ---")

    except ImportError:
        logger.error("Transformers library not found. Cannot download models.")
    except Exception as e:
        logger.error(f"An unexpected error occurred during download attempt for {model_id}: {e}", exc_info=True)


if __name__ == "__main__":
    logger.info("Starting pre-download of Hugging Face models...")

    # List of models to download
    models_to_download = [
        ENHANCEMENT_MODEL_ID,
        SEPARATION_MODEL_ID,
        # Add any other models your app uses
    ]

    for model_id in models_to_download:
        if model_id: # Ensure model ID is not empty
            download_model(model_id)
        else:
             logger.warning("Skipping empty model ID.")

    logger.info("Model pre-download process finished.")
# ----------- END download_models.py -----------