LiamKhoaLe commited on
Commit
375dbf3
·
1 Parent(s): 67b29cd

Add download model simplifier

Browse files
Files changed (2) hide show
  1. app.py +9 -24
  2. download_model.py +12 -13
app.py CHANGED
@@ -58,31 +58,16 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
58
  # os.makedirs(project_dir, exist_ok=True)
59
  # huggingface_cache_dir = os.path.join(project_dir, "huggingface_models")
60
  # os.environ["HF_HOME"] = huggingface_cache_dir # Use this folder for HF cache
61
- # 2b) Setup Hugging Face Cloud project model cache
62
- hf_cache_dir = "/home/user/.cache/huggingface"
63
- # Model storage location
64
- hf_cache_dir = "/home/user/.cache/huggingface"
65
- model_cache_dir = "/app/model_cache"
66
- os.environ["HF_HOME"] = hf_cache_dir
67
- os.environ["SENTENCE_TRANSFORMERS_HOME"] = hf_cache_dir
68
- # 3. Download (or load from cache) the SentenceTransformer model
69
- from huggingface_hub import snapshot_download
70
- print("⏳ Checking or downloading the all-MiniLM-L6-v2 model from huggingface_hub...")
71
- # st.write("⏳ Checking or downloading the all-MiniLM-L6-v2 model from huggingface_hub...")
72
- # a) First, try loading from our copied cache
73
- if os.path.exists(model_cache_dir) and os.listdir(model_cache_dir): # Check if model folder exists and is not empty
74
- print(f"✅ Found cached model at {model_cache_dir}")
75
- model_loc = model_cache_dir
76
- # b) Else, try loading backup from snapshot_download
77
  else:
78
- print(f"❌ Model not found in {model_cache_dir}. This should not happen!")
79
- print("⚠️ Retrying with snapshot_download...")
80
- model_loc = snapshot_download(
81
- repo_id="sentence-transformers/all-MiniLM-L6-v2",
82
- cache_dir=hf_cache_dir,
83
- local_files_only=True # Change to `False` for fallback to online download
84
- )
85
- # 4. Load the model to application
86
  from sentence_transformers import SentenceTransformer
87
  print("📥 **Loading Embedding Model...**")
88
  # st.write("📥 **Loading Embedding Model...**")
 
58
  # os.makedirs(project_dir, exist_ok=True)
59
  # huggingface_cache_dir = os.path.join(project_dir, "huggingface_models")
60
  # os.environ["HF_HOME"] = huggingface_cache_dir # Use this folder for HF cache
61
+ # 2. Setup Hugging Face Cloud project model cache
62
+ MODEL_CACHE_DIR = "/app/model_cache"
63
+ # Check if the required model files exist
64
+ if os.path.exists(os.path.join(MODEL_CACHE_DIR, "config.json")):
65
+ print(f"✅ Found cached model at {MODEL_CACHE_DIR}")
66
+ model_loc = MODEL_CACHE_DIR
 
 
 
 
 
 
 
 
 
 
67
  else:
68
+ print(f"❌ Model not found in {MODEL_CACHE_DIR}. Critical error!")
69
+ exit(1) # Exit since the model is missing
70
+ # 3. Load the model to application
 
 
 
 
 
71
  from sentence_transformers import SentenceTransformer
72
  print("📥 **Loading Embedding Model...**")
73
  # st.write("📥 **Loading Embedding Model...**")
download_model.py CHANGED
@@ -2,20 +2,19 @@ import os
2
  import shutil
3
  from huggingface_hub import snapshot_download
4
 
5
- # Define the target cache directory
 
6
  MODEL_CACHE_DIR = "/app/model_cache"
7
 
8
- # Download model
9
  print("⏳ Downloading the SentenceTransformer model...")
10
- model_path = snapshot_download(repo_id="sentence-transformers/all-MiniLM-L6-v2", cache_dir=MODEL_CACHE_DIR)
11
 
12
- # Find the snapshot folder
13
- snapshots_dir = os.path.join(model_path, "snapshots")
14
- if os.path.exists(snapshots_dir):
15
- snapshot_subdirs = os.listdir(snapshots_dir)
16
- if snapshot_subdirs:
17
- snapshot_dir = os.path.join(snapshots_dir, snapshot_subdirs[0])
18
- # Move all files to the main model cache directory
19
- for filename in os.listdir(snapshot_dir):
20
- shutil.move(os.path.join(snapshot_dir, filename), MODEL_CACHE_DIR)
21
- print(f"✅ Model downloaded and stored in {MODEL_CACHE_DIR}")
 
2
  import shutil
3
  from huggingface_hub import snapshot_download
4
 
5
+ # Dir setup
6
+ MODEL_REPO = "sentence-transformers/all-MiniLM-L6-v2"
7
  MODEL_CACHE_DIR = "/app/model_cache"
8
 
9
+ # Download snapshots
10
  print("⏳ Downloading the SentenceTransformer model...")
11
+ model_path = snapshot_download(repo_id=MODEL_REPO, cache_dir=MODEL_CACHE_DIR)
12
 
13
+ # Ensure the model structure is correct
14
+ snapshot_folders = os.path.join(model_path, "snapshots")
15
+ if os.path.exists(snapshot_folders):
16
+ snapshot_dir = os.path.join(snapshot_folders, os.listdir(snapshot_folders)[0]) # Get first snapshot folder
17
+ for filename in os.listdir(snapshot_dir):
18
+ shutil.move(os.path.join(snapshot_dir, filename), MODEL_CACHE_DIR) # Move files to /app/model_cache
19
+ # Complete
20
+ print(f"✅ Model downloaded and extracted to {MODEL_CACHE_DIR}")