Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,33 +1,46 @@
|
|
1 |
# app.py — storage-safe + HF Hub friendly
|
2 |
|
3 |
import os
|
4 |
-
import sys
|
5 |
-
import cv2
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import gradio as gr
|
9 |
-
from PIL import Image, ImageFilter, ImageDraw
|
10 |
|
11 |
-
# ---------- ENV & THREADS ----------
|
12 |
-
#
|
13 |
-
omp_val =
|
14 |
-
os.
|
|
|
|
|
|
|
|
|
15 |
try:
|
16 |
-
|
17 |
-
torch.set_num_interop_threads(1)
|
18 |
except Exception:
|
19 |
-
|
|
|
20 |
|
21 |
# Send all caches to persistent storage
|
22 |
os.environ.setdefault("HF_HOME", "/data/.huggingface")
|
23 |
os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
|
24 |
-
os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.huggingface/transformers")
|
25 |
os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
|
|
|
26 |
|
27 |
# Disable Xet path, enable fast transfer
|
28 |
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
29 |
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
# ---------- HUB IMPORTS ----------
|
32 |
from huggingface_hub import snapshot_download, hf_hub_download # noqa: E402
|
33 |
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline # noqa: E402
|
@@ -37,7 +50,7 @@ from utils.utils import ( # noqa: E402
|
|
37 |
get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
|
38 |
)
|
39 |
|
40 |
-
# Optional editable installs ONLY if import fails (
|
41 |
def _ensure_local_editable(pkg_name, rel_path):
|
42 |
try:
|
43 |
__import__(pkg_name)
|
@@ -52,7 +65,8 @@ sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
|
|
52 |
|
53 |
import torchvision # noqa: E402
|
54 |
from GroundingDINO.groundingdino.util.inference import load_model # noqa: E402
|
55 |
-
|
|
|
56 |
import spaces # noqa: E402
|
57 |
import GroundingDINO.groundingdino.datasets.transforms as T # noqa: E402
|
58 |
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
|
@@ -78,26 +92,25 @@ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
|
78 |
# ---------- DOWNLOAD CHECKPOINTS (single files) ----------
|
79 |
# GroundingDINO ckpt (single file)
|
80 |
if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
|
81 |
-
|
82 |
repo_id="ShilongLiu/GroundingDINO",
|
83 |
filename="groundingdino_swinb_cogcoor.pth",
|
84 |
local_dir=CKPT_DIR,
|
85 |
token=hf_token,
|
86 |
)
|
87 |
-
if
|
88 |
-
|
89 |
-
os.replace(G_DINO_FILE, GROUNDING_DINO_CHECKPOINT_PATH)
|
90 |
|
91 |
# SAM ckpt (single file)
|
92 |
if not os.path.exists(SAM_CHECKPOINT_PATH):
|
93 |
-
|
94 |
repo_id="spaces/mrtlive/segment-anything-model",
|
95 |
filename="sam_vit_h_4b8939.pth",
|
96 |
local_dir=CKPT_DIR,
|
97 |
token=hf_token,
|
98 |
)
|
99 |
-
if
|
100 |
-
os.replace(
|
101 |
|
102 |
# ---------- DOWNLOAD MODELS (filtered snapshots into /data) ----------
|
103 |
FILL_DIR = os.path.join(MODELS_DIR, "FLUX.1-Fill-dev")
|
@@ -142,8 +155,8 @@ groundingdino_model = load_model(
|
|
142 |
device="cuda"
|
143 |
)
|
144 |
|
145 |
-
# SAM + Predictor
|
146 |
-
sam =
|
147 |
sam.to(device="cuda")
|
148 |
sam_predictor = SamPredictor(sam)
|
149 |
|
|
|
1 |
# app.py — storage-safe + HF Hub friendly
|
2 |
|
3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
# ---------- ENV & THREADS (set BEFORE importing numpy/torch) ----------
|
6 |
+
# Accept any of these names from Space Settings, prefer the standard one:
|
7 |
+
omp_val = (
|
8 |
+
os.getenv("OMP_NUM_THREADS")
|
9 |
+
or os.getenv("OMP-NUM-THREADS")
|
10 |
+
or os.getenv("OMPNUMTHREADS")
|
11 |
+
or "2"
|
12 |
+
)
|
13 |
try:
|
14 |
+
omp_val = str(int(omp_val))
|
|
|
15 |
except Exception:
|
16 |
+
omp_val = "2"
|
17 |
+
os.environ["OMP_NUM_THREADS"] = omp_val # must be a positive integer string
|
18 |
|
19 |
# Send all caches to persistent storage
|
20 |
os.environ.setdefault("HF_HOME", "/data/.huggingface")
|
21 |
os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
|
|
|
22 |
os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
|
23 |
+
# NOTE: TRANSFORMERS_CACHE is deprecated; using HF_HOME instead.
|
24 |
|
25 |
# Disable Xet path, enable fast transfer
|
26 |
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
27 |
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
28 |
|
29 |
+
# ---------- NOW safe to import heavy libs ----------
|
30 |
+
import sys
|
31 |
+
import cv2
|
32 |
+
import numpy as np
|
33 |
+
import torch
|
34 |
+
import gradio as gr
|
35 |
+
from PIL import Image, ImageFilter, ImageDraw
|
36 |
+
|
37 |
+
# Optional: align PyTorch thread pools with OMP setting
|
38 |
+
try:
|
39 |
+
torch.set_num_threads(int(omp_val))
|
40 |
+
torch.set_num_interop_threads(1)
|
41 |
+
except Exception:
|
42 |
+
pass
|
43 |
+
|
44 |
# ---------- HUB IMPORTS ----------
|
45 |
from huggingface_hub import snapshot_download, hf_hub_download # noqa: E402
|
46 |
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline # noqa: E402
|
|
|
50 |
get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
|
51 |
)
|
52 |
|
53 |
+
# Optional editable installs ONLY if import fails (prefer requirements.txt)
|
54 |
def _ensure_local_editable(pkg_name, rel_path):
|
55 |
try:
|
56 |
__import__(pkg_name)
|
|
|
65 |
|
66 |
import torchvision # noqa: E402
|
67 |
from GroundingDINO.groundingdino.util.inference import load_model # noqa: E402
|
68 |
+
# Use the stable SAM API (avoids build_sam import error)
|
69 |
+
from segment_anything import sam_model_registry, SamPredictor # noqa: E402
|
70 |
import spaces # noqa: E402
|
71 |
import GroundingDINO.groundingdino.datasets.transforms as T # noqa: E402
|
72 |
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
|
|
|
92 |
# ---------- DOWNLOAD CHECKPOINTS (single files) ----------
|
93 |
# GroundingDINO ckpt (single file)
|
94 |
if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
|
95 |
+
g_dino_file = hf_hub_download(
|
96 |
repo_id="ShilongLiu/GroundingDINO",
|
97 |
filename="groundingdino_swinb_cogcoor.pth",
|
98 |
local_dir=CKPT_DIR,
|
99 |
token=hf_token,
|
100 |
)
|
101 |
+
if g_dino_file != GROUNDING_DINO_CHECKPOINT_PATH:
|
102 |
+
os.replace(g_dino_file, GROUNDING_DINO_CHECKPOINT_PATH)
|
|
|
103 |
|
104 |
# SAM ckpt (single file)
|
105 |
if not os.path.exists(SAM_CHECKPOINT_PATH):
|
106 |
+
sam_file = hf_hub_download(
|
107 |
repo_id="spaces/mrtlive/segment-anything-model",
|
108 |
filename="sam_vit_h_4b8939.pth",
|
109 |
local_dir=CKPT_DIR,
|
110 |
token=hf_token,
|
111 |
)
|
112 |
+
if sam_file != SAM_CHECKPOINT_PATH:
|
113 |
+
os.replace(sam_file, SAM_CHECKPOINT_PATH)
|
114 |
|
115 |
# ---------- DOWNLOAD MODELS (filtered snapshots into /data) ----------
|
116 |
FILL_DIR = os.path.join(MODELS_DIR, "FLUX.1-Fill-dev")
|
|
|
155 |
device="cuda"
|
156 |
)
|
157 |
|
158 |
+
# SAM + Predictor (registry API)
|
159 |
+
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
|
160 |
sam.to(device="cuda")
|
161 |
sam_predictor = SamPredictor(sam)
|
162 |
|