isat commited on
Commit
776b8aa
·
verified ·
1 Parent(s): 1f6d141

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -25
app.py CHANGED
@@ -1,9 +1,8 @@
1
- # app.py — storage-safe + HF Hub friendly
2
 
3
  import os
4
 
5
  # ---------- ENV & THREADS (set BEFORE importing numpy/torch) ----------
6
- # Accept any of these names from Space Settings, prefer the standard one:
7
  omp_val = (
8
  os.getenv("OMP_NUM_THREADS")
9
  or os.getenv("OMP-NUM-THREADS")
@@ -16,11 +15,11 @@ except Exception:
16
  omp_val = "2"
17
  os.environ["OMP_NUM_THREADS"] = omp_val # must be a positive integer string
18
 
19
- # Send all caches to persistent storage
20
  os.environ.setdefault("HF_HOME", "/data/.huggingface")
21
  os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
22
  os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
23
- # NOTE: TRANSFORMERS_CACHE is deprecated; using HF_HOME instead.
24
 
25
  # Disable Xet path, enable fast transfer
26
  os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
@@ -34,7 +33,6 @@ import torch
34
  import gradio as gr
35
  from PIL import Image, ImageFilter, ImageDraw
36
 
37
- # Optional: align PyTorch thread pools with OMP setting
38
  try:
39
  torch.set_num_threads(int(omp_val))
40
  torch.set_num_interop_threads(1)
@@ -42,34 +40,45 @@ except Exception:
42
  pass
43
 
44
  # ---------- HUB IMPORTS ----------
45
- from huggingface_hub import snapshot_download, hf_hub_download # noqa: E402
46
- from diffusers import FluxFillPipeline, FluxPriorReduxPipeline # noqa: E402
47
 
48
- import math # noqa: E402
49
- from utils.utils import ( # noqa: E402
50
  get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
51
  )
52
 
53
- # Optional editable installs ONLY if import fails (prefer requirements.txt)
54
  def _ensure_local_editable(pkg_name, rel_path):
55
  try:
56
  __import__(pkg_name)
57
  except ImportError:
58
- os.system(f"python -m pip install -e {rel_path}")
59
 
60
- _ensure_local_editable("segment_anything", "segment_anything")
61
  _ensure_local_editable("GroundingDINO", "GroundingDINO")
62
 
63
- sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
64
- sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
 
 
 
 
 
 
 
 
65
 
66
- import torchvision # noqa: E402
67
- from GroundingDINO.groundingdino.util.inference import load_model # noqa: E402
68
- # Use the stable SAM API (avoids build_sam import error)
69
- from segment_anything import sam_model_registry, SamPredictor # noqa: E402
70
- import spaces # noqa: E402
71
- import GroundingDINO.groundingdino.datasets.transforms as T # noqa: E402
72
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
 
 
 
73
 
74
  # ---------- PATHS ----------
75
  PERSIST_ROOT = "/data"
@@ -90,7 +99,7 @@ SAM_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "sam_vit_h_4b8939.pth")
90
  hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
91
 
92
  # ---------- DOWNLOAD CHECKPOINTS (single files) ----------
93
- # GroundingDINO ckpt (single file)
94
  if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
95
  g_dino_file = hf_hub_download(
96
  repo_id="ShilongLiu/GroundingDINO",
@@ -101,7 +110,6 @@ if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
101
  if g_dino_file != GROUNDING_DINO_CHECKPOINT_PATH:
102
  os.replace(g_dino_file, GROUNDING_DINO_CHECKPOINT_PATH)
103
 
104
- # SAM ckpt (single file)
105
  if not os.path.exists(SAM_CHECKPOINT_PATH):
106
  sam_file = hf_hub_download(
107
  repo_id="spaces/mrtlive/segment-anything-model",
@@ -155,12 +163,12 @@ groundingdino_model = load_model(
155
  device="cuda"
156
  )
157
 
158
- # SAM + Predictor (registry API)
159
  sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
160
  sam.to(device="cuda")
161
  sam_predictor = SamPredictor(sam)
162
 
163
- # Diffusers
164
  dtype = torch.bfloat16
165
  size = (768, 768)
166
 
 
1
+ # app.py — storage-safe + HF Hub friendly + SAM import guard
2
 
3
  import os
4
 
5
  # ---------- ENV & THREADS (set BEFORE importing numpy/torch) ----------
 
6
  omp_val = (
7
  os.getenv("OMP_NUM_THREADS")
8
  or os.getenv("OMP-NUM-THREADS")
 
15
  omp_val = "2"
16
  os.environ["OMP_NUM_THREADS"] = omp_val # must be a positive integer string
17
 
18
+ # Persistent caches
19
  os.environ.setdefault("HF_HOME", "/data/.huggingface")
20
  os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
21
  os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
22
+ # (TRANSFORMERS_CACHE is deprecated; rely on HF_HOME) # https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache
23
 
24
  # Disable Xet path, enable fast transfer
25
  os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
 
33
  import gradio as gr
34
  from PIL import Image, ImageFilter, ImageDraw
35
 
 
36
  try:
37
  torch.set_num_threads(int(omp_val))
38
  torch.set_num_interop_threads(1)
 
40
  pass
41
 
42
  # ---------- HUB IMPORTS ----------
43
+ from huggingface_hub import snapshot_download, hf_hub_download
44
+ from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
45
 
46
+ import math
47
+ from utils.utils import (
48
  get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
49
  )
50
 
51
+ # ---------- Ensure GroundingDINO & SAM are the right ones ----------
52
  def _ensure_local_editable(pkg_name, rel_path):
53
  try:
54
  __import__(pkg_name)
55
  except ImportError:
56
+ os.system(f"{sys.executable} -m pip install -e {rel_path}")
57
 
58
+ # GroundingDINO (local editable if present)
59
  _ensure_local_editable("GroundingDINO", "GroundingDINO")
60
 
61
+ # SAM: verify the real package; fix automatically if a wrong one is installed
62
+ def _ensure_official_sam():
63
+ try:
64
+ import segment_anything as sa
65
+ if not hasattr(sa, "sam_model_registry"):
66
+ raise ImportError("Found 'segment_anything' without sam_model_registry")
67
+ except Exception:
68
+ # Nuke imposters and install the official repo
69
+ os.system(f"{sys.executable} -m pip uninstall -y segment-anything segment_anything")
70
+ os.system(f"{sys.executable} -m pip install -U git+https://github.com/facebookresearch/segment-anything.git")
71
 
72
+ _ensure_official_sam()
73
+
74
+ # Now import
75
+ sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
76
+ import torchvision
77
+ from GroundingDINO.groundingdino.util.inference import load_model
78
+ from segment_anything import sam_model_registry, SamPredictor # official API
79
+ import spaces
80
+ import GroundingDINO.groundingdino.datasets.transforms as T
81
+ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
82
 
83
  # ---------- PATHS ----------
84
  PERSIST_ROOT = "/data"
 
99
  hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
100
 
101
  # ---------- DOWNLOAD CHECKPOINTS (single files) ----------
102
+ # Use hf_hub_download for single files, which returns a cached path. Keep files under /data. # https://huggingface.co/docs/huggingface_hub/en/guides/download
103
  if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
104
  g_dino_file = hf_hub_download(
105
  repo_id="ShilongLiu/GroundingDINO",
 
110
  if g_dino_file != GROUNDING_DINO_CHECKPOINT_PATH:
111
  os.replace(g_dino_file, GROUNDING_DINO_CHECKPOINT_PATH)
112
 
 
113
  if not os.path.exists(SAM_CHECKPOINT_PATH):
114
  sam_file = hf_hub_download(
115
  repo_id="spaces/mrtlive/segment-anything-model",
 
163
  device="cuda"
164
  )
165
 
166
+ # SAM + Predictor (registry API from official SAM) # https://github.com/facebookresearch/segment-anything
167
  sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
168
  sam.to(device="cuda")
169
  sam_predictor = SamPredictor(sam)
170
 
171
+ # Diffusers (Flux)
172
  dtype = torch.bfloat16
173
  size = (768, 768)
174