Athspi commited on
Commit
6ab4ced
·
verified ·
1 Parent(s): 3ef3c9e

Create download_models.py

Browse files
Files changed (1) hide show
  1. download_models.py +98 -0
download_models.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------- START download_models.py -----------
2
+ import os
3
+ import logging
4
+
5
+ # Configure logging similar to app.py
6
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
7
+ logger = logging.getLogger("ModelDownloader")
8
+
9
+ # --- Model IDs (MUST MATCH app.py) ---
10
+ ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement"
11
+ SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Or the exact one used in app.py
12
+
13
+ # Get cache directory from environment or use default (MUST MATCH app.py/Dockerfile)
14
+ HF_CACHE_DIR = os.environ.get("HF_HOME", "/app/hf_cache")
15
+ os.makedirs(HF_CACHE_DIR, exist_ok=True)
16
+ logger.info(f"Using Hugging Face cache directory: {HF_CACHE_DIR}")
17
+
18
+
19
+ def download_model(model_id: str):
20
+ """Attempts to download a model using transformers or other relevant libraries."""
21
+ logger.info(f"--- Attempting to download model: {model_id} ---")
22
+ try:
23
+ # Try using AutoProcessor and AutoModel first, common for many HF models
24
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
25
+
26
+ try:
27
+ logger.info(f"Downloading processor for {model_id}...")
28
+ AutoProcessor.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
29
+ logger.info(f"Processor download attempt finished for {model_id}.")
30
+ except Exception as proc_err:
31
+ logger.warning(f"Could not download processor using AutoProcessor for {model_id} (might be normal): {proc_err}")
32
+
33
+ try:
34
+ logger.info(f"Downloading model for {model_id}...")
35
+ # Use the class expected by app.py if known (e.g., AutoModelForSpeechSeq2Seq)
36
+ # Or a generic AutoModel as a fallback attempt
37
+ AutoModelForSpeechSeq2Seq.from_pretrained(
38
+ model_id,
39
+ cache_dir=HF_CACHE_DIR
40
+ # Add trust_remote_code=True if model requires custom code from HF Hub
41
+ )
42
+ logger.info(f"Model download attempt finished for {model_id}.")
43
+ except Exception as model_err:
44
+ logger.warning(f"Could not download model using AutoModel* for {model_id}: {model_err}")
45
+ logger.warning("This might be okay if the model requires a different loading method (e.g., SpeechBrain or Demucs library).")
46
+
47
+ # Add specific download/loading logic if needed
48
+ # Example for SpeechBrain (if library installed):
49
+ # if "speechbrain" in model_id:
50
+ # try:
51
+ # from speechbrain.pretrained import SepformerEnhancement
52
+ # logger.info(f"Attempting SpeechBrain specific download for {model_id}...")
53
+ # SepformerEnhancement.from_hparams(
54
+ # source=model_id,
55
+ # savedir=os.path.join(HF_CACHE_DIR, "speechbrain", model_id.split('/')[-1]),
56
+ # # Don't specify run_opts here, just download
57
+ # )
58
+ # logger.info(f"SpeechBrain download attempt finished for {model_id}.")
59
+ # except Exception as sb_err:
60
+ # logger.error(f"Failed SpeechBrain specific download for {model_id}: {sb_err}")
61
+
62
+ # Example for Demucs (if library installed):
63
+ # if "demucs" in model_id:
64
+ # try:
65
+ # import demucs.separate
66
+ # logger.info(f"Attempting Demucs specific download for {model_id}...")
67
+ # # This might involve loading the model which triggers download
68
+ # demucs.apply.load_model(model_id) # Check correct function
69
+ # logger.info(f"Demucs download attempt finished for {model_id}.")
70
+ # except Exception as demucs_err:
71
+ # logger.error(f"Failed Demucs specific download for {model_id}: {demucs_err}")
72
+
73
+ logger.info(f"--- Finished download attempt for model: {model_id} ---")
74
+
75
+ except ImportError:
76
+ logger.error("Transformers library not found. Cannot download models.")
77
+ except Exception as e:
78
+ logger.error(f"An unexpected error occurred during download attempt for {model_id}: {e}", exc_info=True)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ logger.info("Starting pre-download of Hugging Face models...")
83
+
84
+ # List of models to download
85
+ models_to_download = [
86
+ ENHANCEMENT_MODEL_ID,
87
+ SEPARATION_MODEL_ID,
88
+ # Add any other models your app uses
89
+ ]
90
+
91
+ for model_id in models_to_download:
92
+ if model_id: # Ensure model ID is not empty
93
+ download_model(model_id)
94
+ else:
95
+ logger.warning("Skipping empty model ID.")
96
+
97
+ logger.info("Model pre-download process finished.")
98
+ # ----------- END download_models.py -----------