Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
# app.py — storage-safe + HF Hub friendly
|
2 |
|
3 |
import os
|
4 |
|
5 |
# ---------- ENV & THREADS (set BEFORE importing numpy/torch) ----------
|
6 |
-
# Accept any of these names from Space Settings, prefer the standard one:
|
7 |
omp_val = (
|
8 |
os.getenv("OMP_NUM_THREADS")
|
9 |
or os.getenv("OMP-NUM-THREADS")
|
@@ -16,11 +15,11 @@ except Exception:
|
|
16 |
omp_val = "2"
|
17 |
os.environ["OMP_NUM_THREADS"] = omp_val # must be a positive integer string
|
18 |
|
19 |
-
#
|
20 |
os.environ.setdefault("HF_HOME", "/data/.huggingface")
|
21 |
os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
|
22 |
os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
|
23 |
-
#
|
24 |
|
25 |
# Disable Xet path, enable fast transfer
|
26 |
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
@@ -34,7 +33,6 @@ import torch
|
|
34 |
import gradio as gr
|
35 |
from PIL import Image, ImageFilter, ImageDraw
|
36 |
|
37 |
-
# Optional: align PyTorch thread pools with OMP setting
|
38 |
try:
|
39 |
torch.set_num_threads(int(omp_val))
|
40 |
torch.set_num_interop_threads(1)
|
@@ -42,34 +40,45 @@ except Exception:
|
|
42 |
pass
|
43 |
|
44 |
# ---------- HUB IMPORTS ----------
|
45 |
-
from huggingface_hub import snapshot_download, hf_hub_download
|
46 |
-
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
|
47 |
|
48 |
-
import math
|
49 |
-
from utils.utils import (
|
50 |
get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
|
51 |
)
|
52 |
|
53 |
-
#
|
54 |
def _ensure_local_editable(pkg_name, rel_path):
|
55 |
try:
|
56 |
__import__(pkg_name)
|
57 |
except ImportError:
|
58 |
-
os.system(f"
|
59 |
|
60 |
-
|
61 |
_ensure_local_editable("GroundingDINO", "GroundingDINO")
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
#
|
69 |
-
|
70 |
-
import
|
71 |
-
|
72 |
-
from
|
|
|
|
|
|
|
73 |
|
74 |
# ---------- PATHS ----------
|
75 |
PERSIST_ROOT = "/data"
|
@@ -90,7 +99,7 @@ SAM_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "sam_vit_h_4b8939.pth")
|
|
90 |
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
91 |
|
92 |
# ---------- DOWNLOAD CHECKPOINTS (single files) ----------
|
93 |
-
#
|
94 |
if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
|
95 |
g_dino_file = hf_hub_download(
|
96 |
repo_id="ShilongLiu/GroundingDINO",
|
@@ -101,7 +110,6 @@ if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
|
|
101 |
if g_dino_file != GROUNDING_DINO_CHECKPOINT_PATH:
|
102 |
os.replace(g_dino_file, GROUNDING_DINO_CHECKPOINT_PATH)
|
103 |
|
104 |
-
# SAM ckpt (single file)
|
105 |
if not os.path.exists(SAM_CHECKPOINT_PATH):
|
106 |
sam_file = hf_hub_download(
|
107 |
repo_id="spaces/mrtlive/segment-anything-model",
|
@@ -155,12 +163,12 @@ groundingdino_model = load_model(
|
|
155 |
device="cuda"
|
156 |
)
|
157 |
|
158 |
-
# SAM + Predictor (registry API)
|
159 |
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
|
160 |
sam.to(device="cuda")
|
161 |
sam_predictor = SamPredictor(sam)
|
162 |
|
163 |
-
# Diffusers
|
164 |
dtype = torch.bfloat16
|
165 |
size = (768, 768)
|
166 |
|
|
|
1 |
+
# app.py — storage-safe + HF Hub friendly + SAM import guard
|
2 |
|
3 |
import os
|
4 |
|
5 |
# ---------- ENV & THREADS (set BEFORE importing numpy/torch) ----------
|
|
|
6 |
omp_val = (
|
7 |
os.getenv("OMP_NUM_THREADS")
|
8 |
or os.getenv("OMP-NUM-THREADS")
|
|
|
15 |
omp_val = "2"
|
16 |
os.environ["OMP_NUM_THREADS"] = omp_val # must be a positive integer string
|
17 |
|
18 |
+
# Persistent caches
|
19 |
os.environ.setdefault("HF_HOME", "/data/.huggingface")
|
20 |
os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
|
21 |
os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
|
22 |
+
# (TRANSFORMERS_CACHE is deprecated; rely on HF_HOME) # https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache
|
23 |
|
24 |
# Disable Xet path, enable fast transfer
|
25 |
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
|
|
33 |
import gradio as gr
|
34 |
from PIL import Image, ImageFilter, ImageDraw
|
35 |
|
|
|
36 |
try:
|
37 |
torch.set_num_threads(int(omp_val))
|
38 |
torch.set_num_interop_threads(1)
|
|
|
40 |
pass
|
41 |
|
42 |
# ---------- HUB IMPORTS ----------
|
43 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
44 |
+
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
|
45 |
|
46 |
+
import math
|
47 |
+
from utils.utils import (
|
48 |
get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
|
49 |
)
|
50 |
|
51 |
+
# ---------- Ensure GroundingDINO & SAM are the right ones ----------
|
52 |
def _ensure_local_editable(pkg_name, rel_path):
|
53 |
try:
|
54 |
__import__(pkg_name)
|
55 |
except ImportError:
|
56 |
+
os.system(f"{sys.executable} -m pip install -e {rel_path}")
|
57 |
|
58 |
+
# GroundingDINO (local editable if present)
|
59 |
_ensure_local_editable("GroundingDINO", "GroundingDINO")
|
60 |
|
61 |
+
# SAM: verify the real package; fix automatically if a wrong one is installed
|
62 |
+
def _ensure_official_sam():
|
63 |
+
try:
|
64 |
+
import segment_anything as sa
|
65 |
+
if not hasattr(sa, "sam_model_registry"):
|
66 |
+
raise ImportError("Found 'segment_anything' without sam_model_registry")
|
67 |
+
except Exception:
|
68 |
+
# Nuke imposters and install the official repo
|
69 |
+
os.system(f"{sys.executable} -m pip uninstall -y segment-anything segment_anything")
|
70 |
+
os.system(f"{sys.executable} -m pip install -U git+https://github.com/facebookresearch/segment-anything.git")
|
71 |
|
72 |
+
_ensure_official_sam()
|
73 |
+
|
74 |
+
# Now import
|
75 |
+
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
|
76 |
+
import torchvision
|
77 |
+
from GroundingDINO.groundingdino.util.inference import load_model
|
78 |
+
from segment_anything import sam_model_registry, SamPredictor # official API
|
79 |
+
import spaces
|
80 |
+
import GroundingDINO.groundingdino.datasets.transforms as T
|
81 |
+
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
82 |
|
83 |
# ---------- PATHS ----------
|
84 |
PERSIST_ROOT = "/data"
|
|
|
99 |
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
100 |
|
101 |
# ---------- DOWNLOAD CHECKPOINTS (single files) ----------
|
102 |
+
# Use hf_hub_download for single files, which returns a cached path. Keep files under /data. # https://huggingface.co/docs/huggingface_hub/en/guides/download
|
103 |
if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
|
104 |
g_dino_file = hf_hub_download(
|
105 |
repo_id="ShilongLiu/GroundingDINO",
|
|
|
110 |
if g_dino_file != GROUNDING_DINO_CHECKPOINT_PATH:
|
111 |
os.replace(g_dino_file, GROUNDING_DINO_CHECKPOINT_PATH)
|
112 |
|
|
|
113 |
if not os.path.exists(SAM_CHECKPOINT_PATH):
|
114 |
sam_file = hf_hub_download(
|
115 |
repo_id="spaces/mrtlive/segment-anything-model",
|
|
|
163 |
device="cuda"
|
164 |
)
|
165 |
|
166 |
+
# SAM + Predictor (registry API from official SAM) # https://github.com/facebookresearch/segment-anything
|
167 |
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
|
168 |
sam.to(device="cuda")
|
169 |
sam_predictor = SamPredictor(sam)
|
170 |
|
171 |
+
# Diffusers (Flux)
|
172 |
dtype = torch.bfloat16
|
173 |
size = (768, 768)
|
174 |
|