Spaces:
Running
Running
Commit
·
375dbf3
1
Parent(s):
67b29cd
Add download model simplifier
Browse files- app.py +9 -24
- download_model.py +12 -13
app.py
CHANGED
@@ -58,31 +58,16 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
58 |
# os.makedirs(project_dir, exist_ok=True)
|
59 |
# huggingface_cache_dir = os.path.join(project_dir, "huggingface_models")
|
60 |
# os.environ["HF_HOME"] = huggingface_cache_dir # Use this folder for HF cache
|
61 |
-
#
|
62 |
-
|
63 |
-
#
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
os.environ["SENTENCE_TRANSFORMERS_HOME"] = hf_cache_dir
|
68 |
-
# 3. Download (or load from cache) the SentenceTransformer model
|
69 |
-
from huggingface_hub import snapshot_download
|
70 |
-
print("⏳ Checking or downloading the all-MiniLM-L6-v2 model from huggingface_hub...")
|
71 |
-
# st.write("⏳ Checking or downloading the all-MiniLM-L6-v2 model from huggingface_hub...")
|
72 |
-
# a) First, try loading from our copied cache
|
73 |
-
if os.path.exists(model_cache_dir) and os.listdir(model_cache_dir): # Check if model folder exists and is not empty
|
74 |
-
print(f"✅ Found cached model at {model_cache_dir}")
|
75 |
-
model_loc = model_cache_dir
|
76 |
-
# b) Else, try loading backup from snapshot_download
|
77 |
else:
|
78 |
-
print(f"❌ Model not found in {
|
79 |
-
|
80 |
-
|
81 |
-
repo_id="sentence-transformers/all-MiniLM-L6-v2",
|
82 |
-
cache_dir=hf_cache_dir,
|
83 |
-
local_files_only=True # Change to `False` for fallback to online download
|
84 |
-
)
|
85 |
-
# 4. Load the model to application
|
86 |
from sentence_transformers import SentenceTransformer
|
87 |
print("📥 **Loading Embedding Model...**")
|
88 |
# st.write("📥 **Loading Embedding Model...**")
|
|
|
58 |
# os.makedirs(project_dir, exist_ok=True)
|
59 |
# huggingface_cache_dir = os.path.join(project_dir, "huggingface_models")
|
60 |
# os.environ["HF_HOME"] = huggingface_cache_dir # Use this folder for HF cache
|
61 |
+
# 2. Setup Hugging Face Cloud project model cache
|
62 |
+
MODEL_CACHE_DIR = "/app/model_cache"
|
63 |
+
# Check if the required model files exist
|
64 |
+
if os.path.exists(os.path.join(MODEL_CACHE_DIR, "config.json")):
|
65 |
+
print(f"✅ Found cached model at {MODEL_CACHE_DIR}")
|
66 |
+
model_loc = MODEL_CACHE_DIR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
else:
|
68 |
+
print(f"❌ Model not found in {MODEL_CACHE_DIR}. Critical error!")
|
69 |
+
exit(1) # Exit since the model is missing
|
70 |
+
# 3. Load the model to application
|
|
|
|
|
|
|
|
|
|
|
71 |
from sentence_transformers import SentenceTransformer
|
72 |
print("📥 **Loading Embedding Model...**")
|
73 |
# st.write("📥 **Loading Embedding Model...**")
|
download_model.py
CHANGED
@@ -2,20 +2,19 @@ import os
|
|
2 |
import shutil
|
3 |
from huggingface_hub import snapshot_download
|
4 |
|
5 |
-
#
|
|
|
6 |
MODEL_CACHE_DIR = "/app/model_cache"
|
7 |
|
8 |
-
# Download
|
9 |
print("⏳ Downloading the SentenceTransformer model...")
|
10 |
-
model_path = snapshot_download(repo_id=
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
if os.path.exists(
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
shutil.move(os.path.join(snapshot_dir, filename), MODEL_CACHE_DIR)
|
21 |
-
print(f"✅ Model downloaded and stored in {MODEL_CACHE_DIR}")
|
|
|
2 |
import shutil
|
3 |
from huggingface_hub import snapshot_download
|
4 |
|
5 |
+
# Dir setup
|
6 |
+
MODEL_REPO = "sentence-transformers/all-MiniLM-L6-v2"
|
7 |
MODEL_CACHE_DIR = "/app/model_cache"
|
8 |
|
9 |
+
# Download snapshots
|
10 |
print("⏳ Downloading the SentenceTransformer model...")
|
11 |
+
model_path = snapshot_download(repo_id=MODEL_REPO, cache_dir=MODEL_CACHE_DIR)
|
12 |
|
13 |
+
# Ensure the model structure is correct
|
14 |
+
snapshot_folders = os.path.join(model_path, "snapshots")
|
15 |
+
if os.path.exists(snapshot_folders):
|
16 |
+
snapshot_dir = os.path.join(snapshot_folders, os.listdir(snapshot_folders)[0]) # Get first snapshot folder
|
17 |
+
for filename in os.listdir(snapshot_dir):
|
18 |
+
shutil.move(os.path.join(snapshot_dir, filename), MODEL_CACHE_DIR) # Move files to /app/model_cache
|
19 |
+
# Complete
|
20 |
+
print(f"✅ Model downloaded and extracted to {MODEL_CACHE_DIR}")
|
|
|
|