isat commited on
Commit
8f1b38f
·
verified ·
1 Parent(s): 38264cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -25
app.py CHANGED
@@ -1,33 +1,46 @@
1
  # app.py — storage-safe + HF Hub friendly
2
 
3
  import os
4
- import sys
5
- import cv2
6
- import numpy as np
7
- import torch
8
- import gradio as gr
9
- from PIL import Image, ImageFilter, ImageDraw
10
 
11
- # ---------- ENV & THREADS ----------
12
- # Map a Spaces variable (no underscores allowed) to the real OpenMP var.
13
- omp_val = os.getenv("OMP-NUM-THREADS") or os.getenv("OMPNUMTHREADS") or "2"
14
- os.environ["OMP_NUM_THREADS"] = omp_val
 
 
 
 
15
  try:
16
- torch.set_num_threads(int(omp_val))
17
- torch.set_num_interop_threads(1)
18
  except Exception:
19
- pass
 
20
 
21
  # Send all caches to persistent storage
22
  os.environ.setdefault("HF_HOME", "/data/.huggingface")
23
  os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
24
- os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.huggingface/transformers")
25
  os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
 
26
 
27
  # Disable Xet path, enable fast transfer
28
  os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
29
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # ---------- HUB IMPORTS ----------
32
  from huggingface_hub import snapshot_download, hf_hub_download # noqa: E402
33
  from diffusers import FluxFillPipeline, FluxPriorReduxPipeline # noqa: E402
@@ -37,7 +50,7 @@ from utils.utils import ( # noqa: E402
37
  get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
38
  )
39
 
40
- # Optional editable installs ONLY if import fails (use requirements.txt ideally)
41
  def _ensure_local_editable(pkg_name, rel_path):
42
  try:
43
  __import__(pkg_name)
@@ -52,7 +65,8 @@ sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
52
 
53
  import torchvision # noqa: E402
54
  from GroundingDINO.groundingdino.util.inference import load_model # noqa: E402
55
- from segment_anything import build_sam, SamPredictor # noqa: E402
 
56
  import spaces # noqa: E402
57
  import GroundingDINO.groundingdino.datasets.transforms as T # noqa: E402
58
  from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
@@ -78,26 +92,25 @@ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
78
  # ---------- DOWNLOAD CHECKPOINTS (single files) ----------
79
  # GroundingDINO ckpt (single file)
80
  if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
81
- G_DINO_FILE = hf_hub_download(
82
  repo_id="ShilongLiu/GroundingDINO",
83
  filename="groundingdino_swinb_cogcoor.pth",
84
  local_dir=CKPT_DIR,
85
  token=hf_token,
86
  )
87
- if G_DINO_FILE != GROUNDING_DINO_CHECKPOINT_PATH:
88
- # Ensure the expected path exists for later code
89
- os.replace(G_DINO_FILE, GROUNDING_DINO_CHECKPOINT_PATH)
90
 
91
  # SAM ckpt (single file)
92
  if not os.path.exists(SAM_CHECKPOINT_PATH):
93
- SAM_FILE = hf_hub_download(
94
  repo_id="spaces/mrtlive/segment-anything-model",
95
  filename="sam_vit_h_4b8939.pth",
96
  local_dir=CKPT_DIR,
97
  token=hf_token,
98
  )
99
- if SAM_FILE != SAM_CHECKPOINT_PATH:
100
- os.replace(SAM_FILE, SAM_CHECKPOINT_PATH)
101
 
102
  # ---------- DOWNLOAD MODELS (filtered snapshots into /data) ----------
103
  FILL_DIR = os.path.join(MODELS_DIR, "FLUX.1-Fill-dev")
@@ -142,8 +155,8 @@ groundingdino_model = load_model(
142
  device="cuda"
143
  )
144
 
145
- # SAM + Predictor
146
- sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
147
  sam.to(device="cuda")
148
  sam_predictor = SamPredictor(sam)
149
 
 
1
  # app.py — storage-safe + HF Hub friendly
2
 
3
  import os
 
 
 
 
 
 
4
 
5
+ # ---------- ENV & THREADS (set BEFORE importing numpy/torch) ----------
6
+ # Accept any of these names from Space Settings, prefer the standard one:
7
+ omp_val = (
8
+ os.getenv("OMP_NUM_THREADS")
9
+ or os.getenv("OMP-NUM-THREADS")
10
+ or os.getenv("OMPNUMTHREADS")
11
+ or "2"
12
+ )
13
  try:
14
+ omp_val = str(int(omp_val))
 
15
  except Exception:
16
+ omp_val = "2"
17
+ os.environ["OMP_NUM_THREADS"] = omp_val # must be a positive integer string
18
 
19
  # Send all caches to persistent storage
20
  os.environ.setdefault("HF_HOME", "/data/.huggingface")
21
  os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
 
22
  os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
23
+ # NOTE: TRANSFORMERS_CACHE is deprecated; using HF_HOME instead.
24
 
25
  # Disable Xet path, enable fast transfer
26
  os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
27
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
28
 
29
+ # ---------- NOW safe to import heavy libs ----------
30
+ import sys
31
+ import cv2
32
+ import numpy as np
33
+ import torch
34
+ import gradio as gr
35
+ from PIL import Image, ImageFilter, ImageDraw
36
+
37
+ # Optional: align PyTorch thread pools with OMP setting
38
+ try:
39
+ torch.set_num_threads(int(omp_val))
40
+ torch.set_num_interop_threads(1)
41
+ except Exception:
42
+ pass
43
+
44
  # ---------- HUB IMPORTS ----------
45
  from huggingface_hub import snapshot_download, hf_hub_download # noqa: E402
46
  from diffusers import FluxFillPipeline, FluxPriorReduxPipeline # noqa: E402
 
50
  get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
51
  )
52
 
53
+ # Optional editable installs ONLY if import fails (prefer requirements.txt)
54
  def _ensure_local_editable(pkg_name, rel_path):
55
  try:
56
  __import__(pkg_name)
 
65
 
66
  import torchvision # noqa: E402
67
  from GroundingDINO.groundingdino.util.inference import load_model # noqa: E402
68
+ # Use the stable SAM API (avoids build_sam import error)
69
+ from segment_anything import sam_model_registry, SamPredictor # noqa: E402
70
  import spaces # noqa: E402
71
  import GroundingDINO.groundingdino.datasets.transforms as T # noqa: E402
72
  from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
 
92
  # ---------- DOWNLOAD CHECKPOINTS (single files) ----------
93
  # GroundingDINO ckpt (single file)
94
  if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
95
+ g_dino_file = hf_hub_download(
96
  repo_id="ShilongLiu/GroundingDINO",
97
  filename="groundingdino_swinb_cogcoor.pth",
98
  local_dir=CKPT_DIR,
99
  token=hf_token,
100
  )
101
+ if g_dino_file != GROUNDING_DINO_CHECKPOINT_PATH:
102
+ os.replace(g_dino_file, GROUNDING_DINO_CHECKPOINT_PATH)
 
103
 
104
  # SAM ckpt (single file)
105
  if not os.path.exists(SAM_CHECKPOINT_PATH):
106
+ sam_file = hf_hub_download(
107
  repo_id="spaces/mrtlive/segment-anything-model",
108
  filename="sam_vit_h_4b8939.pth",
109
  local_dir=CKPT_DIR,
110
  token=hf_token,
111
  )
112
+ if sam_file != SAM_CHECKPOINT_PATH:
113
+ os.replace(sam_file, SAM_CHECKPOINT_PATH)
114
 
115
  # ---------- DOWNLOAD MODELS (filtered snapshots into /data) ----------
116
  FILL_DIR = os.path.join(MODELS_DIR, "FLUX.1-Fill-dev")
 
155
  device="cuda"
156
  )
157
 
158
+ # SAM + Predictor (registry API)
159
+ sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
160
  sam.to(device="cuda")
161
  sam_predictor = SamPredictor(sam)
162