LiamKhoaLe commited on
Commit
4f5341e
·
1 Parent(s): a8b5cb5

Lazy load models

Browse files
Files changed (4) hide show
  1. Dockerfile +5 -1
  2. download_model.py +24 -26
  3. memory.py +1 -3
  4. warmup.py +10 -1
Dockerfile CHANGED
@@ -23,7 +23,11 @@ ENV MEDGEMMA_HOME="/home/user/.cache/huggingface/sentence-transformers"
23
  RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformers && \
24
  chown -R user:user /app/model_cache /home/user/.cache/huggingface
25
 
26
- # Pre-load model in a separate script
 
 
 
 
27
  RUN python /app/download_model.py && python /app/warmup.py
28
 
29
  # Ensure ownership and permissions remain intact
 
23
  RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformers && \
24
  chown -R user:user /app/model_cache /home/user/.cache/huggingface
25
 
26
+ # Control preloading to avoid exhausting build disk on HF Spaces
27
+ ENV PRELOAD_TRANSLATORS="0"
28
+ ENV EMBEDDING_HALF="0"
29
+
30
+ # Pre-load model in a separate script (translation preload disabled by default)
31
  RUN python /app/download_model.py && python /app/warmup.py
32
 
33
  # Ensure ownership and permissions remain intact
download_model.py CHANGED
@@ -7,34 +7,21 @@ from huggingface_hub import snapshot_download
7
  # Set up paths
8
  MODEL_REPO = "sentence-transformers/all-MiniLM-L6-v2"
9
  MODEL_CACHE_DIR = "/app/model_cache"
 
10
 
11
  print("⏳ Downloading the SentenceTransformer model...")
12
- model_path = snapshot_download(repo_id=MODEL_REPO, cache_dir=MODEL_CACHE_DIR)
 
 
 
 
 
 
13
 
14
  print("Model path: ", model_path)
15
-
16
- # Ensure the directory exists
17
  if not os.path.exists(MODEL_CACHE_DIR):
18
  os.makedirs(MODEL_CACHE_DIR)
19
 
20
- # Move all contents from the snapshot folder
21
- if os.path.exists(model_path):
22
- print(f"📂 Moving model files from {model_path} to {MODEL_CACHE_DIR}...")
23
-
24
- for item in os.listdir(model_path):
25
- source = os.path.join(model_path, item)
26
- destination = os.path.join(MODEL_CACHE_DIR, item)
27
-
28
- if os.path.isdir(source):
29
- shutil.copytree(source, destination, dirs_exist_ok=True)
30
- else:
31
- shutil.copy2(source, destination)
32
-
33
- print(f"✅ Model extracted and flattened in {MODEL_CACHE_DIR}")
34
- else:
35
- print("❌ No snapshot directory found!")
36
- exit(1)
37
-
38
  # Verify structure after moving
39
  print("\n📂 LLM Model Structure (Build Level):")
40
  for root, dirs, files in os.walk(MODEL_CACHE_DIR):
@@ -44,8 +31,19 @@ for root, dirs, files in os.walk(MODEL_CACHE_DIR):
44
 
45
 
46
  ### --- B. translation modules ---
47
- from transformers import pipeline
48
- print("⏬ Downloading Vietnamese–English translator...")
49
- _ = pipeline("translation", model="VietAI/envit5-translation", src_lang="vi", tgt_lang="en")
50
- print(" Downloading Chinese–English translator...")
51
- _ = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Set up paths
8
  MODEL_REPO = "sentence-transformers/all-MiniLM-L6-v2"
9
  MODEL_CACHE_DIR = "/app/model_cache"
10
+ HF_CACHE_DIR = os.getenv("HF_HOME", "/home/user/.cache/huggingface")
11
 
12
  print("⏳ Downloading the SentenceTransformer model...")
13
+ # Download directly into /app/model_cache to avoid duplicating files from HF cache
14
+ model_path = snapshot_download(
15
+ repo_id=MODEL_REPO,
16
+ cache_dir=HF_CACHE_DIR, # Store HF cache in user cache dir
17
+ local_dir=MODEL_CACHE_DIR, # Place usable model here
18
+ local_dir_use_symlinks=False # Copy files into local_dir (no symlinks)
19
+ )
20
 
21
  print("Model path: ", model_path)
 
 
22
  if not os.path.exists(MODEL_CACHE_DIR):
23
  os.makedirs(MODEL_CACHE_DIR)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Verify structure after moving
26
  print("\n📂 LLM Model Structure (Build Level):")
27
  for root, dirs, files in os.walk(MODEL_CACHE_DIR):
 
31
 
32
 
33
  ### --- B. translation modules ---
34
+ # Optional pre-download of translation models. These can be very large and
35
+ # may exceed build storage limits on constrained environments (e.g., HF Spaces).
36
+ # Control with env var PRELOAD_TRANSLATORS ("1" to enable; default: disabled).
37
+ PRELOAD_TRANSLATORS = os.getenv("PRELOAD_TRANSLATORS", "0")
38
+ if PRELOAD_TRANSLATORS == "1":
39
+ try:
40
+ from transformers import pipeline
41
+ print("⏬ Pre-downloading Vietnamese–English translator...")
42
+ _ = pipeline("translation", model="VietAI/envit5-translation", src_lang="vi", tgt_lang="en", device=-1)
43
+ print("⏬ Pre-downloading Chinese–English translator...")
44
+ _ = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en", device=-1)
45
+ print("✅ Translators preloaded.")
46
+ except Exception as e:
47
+ print(f"⚠️ Skipping translator preload due to error: {e}")
48
+ else:
49
+ print("ℹ️ Skipping translator pre-download (PRELOAD_TRANSLATORS != '1'). They will lazy-load at runtime.")
memory.py CHANGED
@@ -421,6 +421,4 @@ class MemoryManager:
421
  first = " ".join(words[:16])
422
  # ensure capitalized
423
  return first.strip().rstrip(':')
424
- return topic
425
-
426
-
 
421
  first = " ".join(words[:16])
422
  # ensure capitalized
423
  return first.strip().rstrip(':')
424
+ return topic
 
 
warmup.py CHANGED
@@ -1,8 +1,17 @@
1
  from sentence_transformers import SentenceTransformer
2
  import torch
 
3
 
4
  print("🚀 Warming up model...")
5
  embedding_model = SentenceTransformer("/app/model_cache", device="cpu")
6
- embedding_model = embedding_model.half() # Reduce memory
 
 
 
 
 
 
 
 
7
  embedding_model.to(torch.device("cpu"))
8
  print("✅ Model warm-up complete!")
 
1
  from sentence_transformers import SentenceTransformer
2
  import torch
3
+ import os
4
 
5
  print("🚀 Warming up model...")
6
  embedding_model = SentenceTransformer("/app/model_cache", device="cpu")
7
+
8
+ # Some CPU backends on HF Spaces fail on .half(); make it configurable
9
+ USE_HALF = os.getenv("EMBEDDING_HALF", "1") == "1"
10
+ try:
11
+ if USE_HALF and torch.cuda.is_available():
12
+ embedding_model = embedding_model.half()
13
+ except Exception as e:
14
+ print(f"⚠️ Skipping half precision due to: {e}")
15
+
16
  embedding_model.to(torch.device("cpu"))
17
  print("✅ Model warm-up complete!")