isat commited on
Commit
38264cb
·
verified ·
1 Parent(s): 7ea451b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -181
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import sys
3
  import cv2
@@ -6,48 +8,162 @@ import torch
6
  import gradio as gr
7
  from PIL import Image, ImageFilter, ImageDraw
8
 
9
- os.environ["HF_HUB_DISABLE_XET"] = "1"
10
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
 
12
 
13
- from huggingface_hub import snapshot_download
14
- from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
15
- import math
16
- from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
17
 
18
- import os,sys
19
- os.system("python -m pip install -e segment_anything")
20
- os.system("python -m pip install -e GroundingDINO")
21
  sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
22
  sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
23
- os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth")
24
- os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
25
- import torchvision
26
- from GroundingDINO.groundingdino.util.inference import load_model
27
- from segment_anything import build_sam, SamPredictor
28
- import spaces
29
- import GroundingDINO.groundingdino.datasets.transforms as T
30
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
31
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
33
 
34
  # GroundingDINO config and checkpoint
35
  GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py"
36
- GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swinb_cogcoor.pth"
37
 
38
  # Segment-Anything checkpoint
39
  SAM_ENCODER_VERSION = "vit_h"
40
- SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Building GroundingDINO inference model
43
- groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device="cuda")
44
- # Building SAM Model and SAM Predictor
45
  sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
46
  sam.to(device="cuda")
47
  sam_predictor = SamPredictor(sam)
48
 
49
- def transform_image(image_pil):
 
 
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
51
  transform = T.Compose(
52
  [
53
  T.RandomResize([800], max_size=1333),
@@ -60,80 +176,54 @@ def transform_image(image_pil):
60
 
61
 
62
  def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
63
- caption = caption.lower()
64
- caption = caption.strip()
65
  if not caption.endswith("."):
66
  caption = caption + "."
67
-
68
  with torch.no_grad():
69
  outputs = model(image[None], captions=[caption])
70
  logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
71
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
72
- logits.shape[0]
73
 
74
  # filter output
75
- logits_filt = logits.clone()
76
- boxes_filt = boxes.clone()
77
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
78
- logits_filt = logits_filt[filt_mask] # num_filt, 256
79
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
80
- logits_filt.shape[0]
81
 
82
  # get phrase
83
  tokenlizer = model.tokenizer
84
  tokenized = tokenlizer(caption)
85
- # build pred
86
- pred_phrases = []
87
- scores = []
88
  for logit, box in zip(logits_filt, boxes_filt):
89
- pred_phrase = get_phrases_from_posmap(
90
- logit > text_threshold, tokenized, tokenlizer)
91
- if with_logits:
92
- pred_phrases.append(
93
- pred_phrase + f"({str(logit.max().item())[:4]})")
94
- else:
95
- pred_phrases.append(pred_phrase)
96
  scores.append(logit.max().item())
97
-
98
  return boxes_filt, torch.Tensor(scores), pred_phrases
99
 
100
 
101
  def get_mask(image, label):
102
  global groundingdino_model, sam_predictor
103
-
104
-
105
  image_pil = image.convert("RGB")
106
  transformed_image = transform_image(image_pil)
107
 
108
-
109
  boxes_filt, scores, pred_phrases = get_grounding_output(
110
  groundingdino_model, transformed_image, label
111
  )
112
 
113
- size = image_pil.size
114
-
115
- # process boxes
116
- H, W = size[1], size[0]
117
  for i in range(boxes_filt.size(0)):
118
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
119
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
120
  boxes_filt[i][2:] += boxes_filt[i][:2]
121
-
122
  boxes_filt = boxes_filt.cpu()
123
 
124
- # nms
125
-
126
- nms_idx = torchvision.ops.nms(
127
- boxes_filt, scores, 0.8).numpy().tolist()
128
  boxes_filt = boxes_filt[nms_idx]
129
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
130
-
131
-
132
- image = np.array(image_pil)
133
- sam_predictor.set_image(image)
134
 
 
 
135
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(
136
- boxes_filt, image.shape[:2]).to("cuda")
 
137
 
138
  masks, _, _ = sam_predictor.predict_torch(
139
  point_coords=None,
@@ -142,80 +232,34 @@ def get_mask(image, label):
142
  multimask_output=False,
143
  )
144
  result_mask = masks[0][0].cpu().numpy()
 
145
 
146
- result_mask = Image.fromarray(result_mask)
147
-
148
- return result_mask
149
 
150
  def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
151
-
152
-
153
  if mask_np.max() <= 1.0:
154
  mask_np = (mask_np * 255).astype(np.uint8)
155
  mask_bool = mask_np > 128
156
-
157
  image_float = image_np.astype(np.float32)
158
-
159
- # 灰色图层
160
  gray_overlay = np.full_like(image_float, gray_value, dtype=np.float32)
161
-
162
- # 混合
163
  result = image_float.copy()
164
- result[mask_bool] = (
165
- (1 - alpha) * image_float[mask_bool] + alpha * gray_overlay[mask_bool]
166
- )
167
-
168
  return result.astype(np.uint8)
169
 
170
- hf_token = os.getenv("HF_TOKEN")
171
-
172
- snapshot_download(repo_id="black-forest-labs/FLUX.1-Fill-dev", local_dir="./FLUX.1-Fill-dev", token=hf_token)
173
- snapshot_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", local_dir="./FLUX.1-Redux-dev", token=hf_token)
174
- snapshot_download(repo_id="WensongSong/Insert-Anything", local_dir="./insertanything_model", token=hf_token)
175
 
 
 
 
 
 
176
 
177
- dtype = torch.bfloat16
178
- size = (768, 768)
179
-
180
- pipe = FluxFillPipeline.from_pretrained(
181
- "./FLUX.1-Fill-dev",
182
- torch_dtype=dtype
183
- ).to("cuda")
184
-
185
- pipe.load_lora_weights(
186
- "./insertanything_model/20250321_steps5000_pytorch_lora_weights.safetensors"
187
- )
188
-
189
-
190
- redux = FluxPriorReduxPipeline.from_pretrained("./FLUX.1-Redux-dev").to(dtype=dtype).to("cuda")
191
-
192
-
193
-
194
- ### example #####
195
- ref_dir='./examples/ref_image'
196
- ref_mask_dir='./examples/ref_mask'
197
- image_dir='./examples/source_image'
198
- image_mask_dir='./examples/source_mask'
199
-
200
- ref_list=[os.path.join(ref_dir,file) for file in os.listdir(ref_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file ]
201
- ref_list.sort()
202
-
203
- ref_mask_list=[os.path.join(ref_mask_dir,file) for file in os.listdir(ref_mask_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
204
- ref_mask_list.sort()
205
-
206
- image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file ]
207
- image_list.sort()
208
-
209
- image_mask_list=[os.path.join(image_mask_dir,file) for file in os.listdir(image_mask_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
210
- image_mask_list.sort()
211
- ### example #####
212
-
213
 
214
 
215
  @spaces.GPU
216
  def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt):
217
-
218
-
219
  if base_mask_option == "Draw Mask":
220
  tar_image = base_image["background"]
221
  tar_mask = base_image["layers"][0]
@@ -250,42 +294,37 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
250
 
251
  if tar_mask.sum() == 0:
252
  raise gr.Error('No mask for the background image.Please check mask button!')
253
-
254
  if ref_mask.sum() == 0:
255
  raise gr.Error('No mask for the reference image.Please check mask button!')
256
 
257
  ref_box_yyxx = get_bbox_from_mask(ref_mask)
258
- ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
259
- masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
260
- y1,y2,x1,x2 = ref_box_yyxx
261
- masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
262
- ref_mask = ref_mask[y1:y2,x1:x2]
263
  ratio = 1.3
264
  masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
265
 
266
-
267
- masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
268
 
269
  kernel = np.ones((7, 7), np.uint8)
270
  iterations = 2
271
  tar_mask = cv2.dilate(tar_mask, kernel, iterations=iterations)
272
 
273
- # zome in
274
  tar_box_yyxx = get_bbox_from_mask(tar_mask)
275
  tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=1.2)
276
 
277
- tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=2) #1.2 1.6
278
- tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
279
- y1,y2,x1,x2 = tar_box_yyxx_crop
280
-
281
 
282
  old_tar_image = tar_image.copy()
283
- tar_image = tar_image[y1:y2,x1:x2,:]
284
- tar_mask = tar_mask[y1:y2,x1:x2]
285
 
286
  H1, W1 = tar_image.shape[0], tar_image.shape[1]
287
- # zome in
288
-
289
 
290
  tar_mask = pad_to_square(tar_mask, pad_value=0)
291
  tar_mask = cv2.resize(tar_mask, size)
@@ -293,19 +332,15 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
293
  masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
294
  pipe_prior_output = redux(Image.fromarray(masked_ref_image))
295
 
296
-
297
  tar_image = pad_to_square(tar_image, pad_value=255)
298
-
299
  H2, W2 = tar_image.shape[0], tar_image.shape[1]
300
-
301
  tar_image = cv2.resize(tar_image, size)
302
  diptych_ref_tar = np.concatenate([masked_ref_image, tar_image], axis=1)
303
 
304
-
305
- tar_mask = np.stack([tar_mask,tar_mask,tar_mask],-1)
306
  mask_black = np.ones_like(tar_image) * 0
307
  mask_diptych = np.concatenate([mask_black, tar_mask], axis=1)
308
-
309
  show_diptych_ref_tar = create_highlighted_mask(diptych_ref_tar, mask_diptych)
310
  show_diptych_ref_tar = Image.fromarray(show_diptych_ref_tar)
311
 
@@ -313,8 +348,6 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
313
  mask_diptych[mask_diptych == 1] = 255
314
  mask_diptych = Image.fromarray(mask_diptych)
315
 
316
-
317
-
318
  generator = torch.Generator("cuda").manual_seed(seed)
319
  edited_image = pipe(
320
  image=diptych_ref_tar,
@@ -323,27 +356,22 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
323
  width=mask_diptych.size[0],
324
  max_sequence_length=512,
325
  generator=generator,
326
- **pipe_prior_output,
327
  ).images[0]
328
 
329
-
330
-
331
  width, height = edited_image.size
332
  left = width // 2
333
- right = width
334
- top = 0
335
- bottom = height
336
- edited_image = edited_image.crop((left, top, right, bottom))
337
-
338
 
339
  edited_image = np.array(edited_image)
340
- edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
341
  edited_image = Image.fromarray(edited_image)
342
 
343
  if ref_mask_option != "Label to Mask":
344
  return [show_diptych_ref_tar, edited_image]
345
  else:
346
- return [return_ref_mask, show_diptych_ref_tar, edited_image]
 
347
 
348
  def update_ui(option):
349
  if option == "Draw Mask":
@@ -353,8 +381,6 @@ def update_ui(option):
353
 
354
 
355
  with gr.Blocks() as demo:
356
-
357
-
358
  gr.Markdown("# Insert-Anything")
359
  gr.Markdown("### Make sure to select the correct mask button!!")
360
  gr.Markdown("### Click the output image to toggle between Diptych and final results!!")
@@ -362,42 +388,42 @@ with gr.Blocks() as demo:
362
  with gr.Row():
363
  with gr.Column(scale=1):
364
  with gr.Row():
365
- base_image = gr.ImageEditor(label="Background Image", sources="upload", type="pil", brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
366
- layers = False,
367
- interactive=True)
368
-
369
- base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil", layers = False, brush=False, eraser=False)
370
-
371
  with gr.Row():
372
- base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
 
373
 
374
  with gr.Row():
375
- ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil", brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
376
- layers = False,
377
- interactive=True)
378
-
379
- ref_mask = gr.ImageEditor(label="Reference Mask", sources="upload", type="pil", layers = False, brush=False, eraser=False)
380
 
381
  with gr.Row():
382
- ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Reference Mask Input Option", value="Upload with Mask")
383
-
384
  with gr.Row():
385
- text_prompt = gr.Textbox(label="Label", placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
 
386
 
387
  with gr.Column(scale=1):
388
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
389
  with gr.Accordion("Advanced Option", open=True):
390
- seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
391
  gr.Markdown("### Guidelines")
392
  gr.Markdown(" Users can try using different seeds. For example, seeds like 42 and 123456 may produce different effects.")
393
  gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
394
  gr.Markdown(" Upload with Mask means uploading a mask file.")
395
  gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
396
-
397
 
398
  run_local_button = gr.Button(value="Run")
399
 
400
- # #### example #####
401
  num_examples = len(image_list)
402
  for i in range(num_examples):
403
  with gr.Row():
@@ -413,10 +439,10 @@ with gr.Blocks() as demo:
413
  gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
414
  if i < num_examples - 1:
415
  gr.HTML("<hr>")
416
- # #### example #####
417
 
418
- run_local_button.click(fn=run_local,
419
- inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
420
- outputs=[baseline_gallery]
421
- )
422
- demo.launch()
 
 
1
+ # app.py — storage-safe + HF Hub friendly
2
+
3
  import os
4
  import sys
5
  import cv2
 
8
  import gradio as gr
9
  from PIL import Image, ImageFilter, ImageDraw
10
 
11
+ # ---------- ENV & THREADS ----------
12
+ # Map a Spaces variable (no underscores allowed) to the real OpenMP var.
13
+ omp_val = os.getenv("OMP-NUM-THREADS") or os.getenv("OMPNUMTHREADS") or "2"
14
+ os.environ["OMP_NUM_THREADS"] = omp_val
15
+ try:
16
+ torch.set_num_threads(int(omp_val))
17
+ torch.set_num_interop_threads(1)
18
+ except Exception:
19
+ pass
20
+
21
+ # Send all caches to persistent storage
22
+ os.environ.setdefault("HF_HOME", "/data/.huggingface")
23
+ os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
24
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.huggingface/transformers")
25
+ os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
26
+
27
+ # Disable Xet path, enable fast transfer
28
+ os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
29
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
30
+
31
+ # ---------- HUB IMPORTS ----------
32
+ from huggingface_hub import snapshot_download, hf_hub_download # noqa: E402
33
+ from diffusers import FluxFillPipeline, FluxPriorReduxPipeline # noqa: E402
34
+
35
+ import math # noqa: E402
36
+ from utils.utils import ( # noqa: E402
37
+ get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
38
+ )
39
 
40
+ # Optional editable installs ONLY if import fails (use requirements.txt ideally)
41
+ def _ensure_local_editable(pkg_name, rel_path):
42
+ try:
43
+ __import__(pkg_name)
44
+ except ImportError:
45
+ os.system(f"python -m pip install -e {rel_path}")
46
 
47
+ _ensure_local_editable("segment_anything", "segment_anything")
48
+ _ensure_local_editable("GroundingDINO", "GroundingDINO")
 
 
49
 
 
 
 
50
  sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
51
  sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
 
 
 
 
 
 
 
 
52
 
53
+ import torchvision # noqa: E402
54
+ from GroundingDINO.groundingdino.util.inference import load_model # noqa: E402
55
+ from segment_anything import build_sam, SamPredictor # noqa: E402
56
+ import spaces # noqa: E402
57
+ import GroundingDINO.groundingdino.datasets.transforms as T # noqa: E402
58
+ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
59
 
60
+ # ---------- PATHS ----------
61
+ PERSIST_ROOT = "/data"
62
+ MODELS_DIR = os.path.join(PERSIST_ROOT, "models")
63
+ CKPT_DIR = os.path.join(PERSIST_ROOT, "checkpoints")
64
+ os.makedirs(MODELS_DIR, exist_ok=True)
65
+ os.makedirs(CKPT_DIR, exist_ok=True)
66
 
67
  # GroundingDINO config and checkpoint
68
  GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py"
69
+ GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "groundingdino_swinb_cogcoor.pth")
70
 
71
  # Segment-Anything checkpoint
72
  SAM_ENCODER_VERSION = "vit_h"
73
+ SAM_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "sam_vit_h_4b8939.pth")
74
+
75
+ # ---------- AUTH TOKEN ----------
76
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
77
+
78
+ # ---------- DOWNLOAD CHECKPOINTS (single files) ----------
79
+ # GroundingDINO ckpt (single file)
80
+ if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
81
+ G_DINO_FILE = hf_hub_download(
82
+ repo_id="ShilongLiu/GroundingDINO",
83
+ filename="groundingdino_swinb_cogcoor.pth",
84
+ local_dir=CKPT_DIR,
85
+ token=hf_token,
86
+ )
87
+ if G_DINO_FILE != GROUNDING_DINO_CHECKPOINT_PATH:
88
+ # Ensure the expected path exists for later code
89
+ os.replace(G_DINO_FILE, GROUNDING_DINO_CHECKPOINT_PATH)
90
+
91
+ # SAM ckpt (single file)
92
+ if not os.path.exists(SAM_CHECKPOINT_PATH):
93
+ SAM_FILE = hf_hub_download(
94
+ repo_id="spaces/mrtlive/segment-anything-model",
95
+ filename="sam_vit_h_4b8939.pth",
96
+ local_dir=CKPT_DIR,
97
+ token=hf_token,
98
+ )
99
+ if SAM_FILE != SAM_CHECKPOINT_PATH:
100
+ os.replace(SAM_FILE, SAM_CHECKPOINT_PATH)
101
+
102
+ # ---------- DOWNLOAD MODELS (filtered snapshots into /data) ----------
103
+ FILL_DIR = os.path.join(MODELS_DIR, "FLUX.1-Fill-dev")
104
+ REDUX_DIR = os.path.join(MODELS_DIR, "FLUX.1-Redux-dev")
105
+ LORA_DIR = os.path.join(MODELS_DIR, "insertanything_model")
106
+ for path in (FILL_DIR, REDUX_DIR, LORA_DIR):
107
+ os.makedirs(path, exist_ok=True)
108
+
109
+ # Only pull what we need (weights/configs). Keep symlinks to avoid copies.
110
+ if not os.listdir(FILL_DIR):
111
+ snapshot_download(
112
+ repo_id="black-forest-labs/FLUX.1-Fill-dev",
113
+ local_dir=FILL_DIR,
114
+ local_dir_use_symlinks=True,
115
+ allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.txt", "*.py", "*.model"],
116
+ token=hf_token,
117
+ )
118
+
119
+ if not os.listdir(REDUX_DIR):
120
+ snapshot_download(
121
+ repo_id="black-forest-labs/FLUX.1-Redux-dev",
122
+ local_dir=REDUX_DIR,
123
+ local_dir_use_symlinks=True,
124
+ allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.txt", "*.py", "*.model"],
125
+ token=hf_token,
126
+ )
127
+
128
+ if not os.listdir(LORA_DIR):
129
+ snapshot_download(
130
+ repo_id="WensongSong/Insert-Anything",
131
+ local_dir=LORA_DIR,
132
+ local_dir_use_symlinks=True,
133
+ allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.txt"],
134
+ token=hf_token,
135
+ )
136
+
137
+ # ---------- BUILD MODELS ----------
138
+ # GroundingDINO
139
+ groundingdino_model = load_model(
140
+ model_config_path=GROUNDING_DINO_CONFIG_PATH,
141
+ model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH,
142
+ device="cuda"
143
+ )
144
 
145
+ # SAM + Predictor
 
 
146
  sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
147
  sam.to(device="cuda")
148
  sam_predictor = SamPredictor(sam)
149
 
150
+ # Diffusers
151
+ dtype = torch.bfloat16
152
+ size = (768, 768)
153
+
154
+ pipe = FluxFillPipeline.from_pretrained(
155
+ FILL_DIR,
156
+ torch_dtype=dtype
157
+ ).to("cuda")
158
+
159
+ pipe.load_lora_weights(
160
+ os.path.join(LORA_DIR, "20250321_steps5000_pytorch_lora_weights.safetensors")
161
+ )
162
 
163
+ redux = FluxPriorReduxPipeline.from_pretrained(REDUX_DIR).to(dtype=dtype).to("cuda")
164
+
165
+ # ---------- APP LOGIC ----------
166
+ def transform_image(image_pil):
167
  transform = T.Compose(
168
  [
169
  T.RandomResize([800], max_size=1333),
 
176
 
177
 
178
  def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
179
+ caption = caption.lower().strip()
 
180
  if not caption.endswith("."):
181
  caption = caption + "."
 
182
  with torch.no_grad():
183
  outputs = model(image[None], captions=[caption])
184
  logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
185
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
 
186
 
187
  # filter output
188
+ filt_mask = logits.max(dim=1)[0] > box_threshold
189
+ logits_filt = logits[filt_mask]
190
+ boxes_filt = boxes[filt_mask]
 
 
 
191
 
192
  # get phrase
193
  tokenlizer = model.tokenizer
194
  tokenized = tokenlizer(caption)
195
+ pred_phrases, scores = [], []
 
 
196
  for logit, box in zip(logits_filt, boxes_filt):
197
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
198
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})" if with_logits else pred_phrase)
 
 
 
 
 
199
  scores.append(logit.max().item())
 
200
  return boxes_filt, torch.Tensor(scores), pred_phrases
201
 
202
 
203
  def get_mask(image, label):
204
  global groundingdino_model, sam_predictor
 
 
205
  image_pil = image.convert("RGB")
206
  transformed_image = transform_image(image_pil)
207
 
 
208
  boxes_filt, scores, pred_phrases = get_grounding_output(
209
  groundingdino_model, transformed_image, label
210
  )
211
 
212
+ W, H = image_pil.size
 
 
 
213
  for i in range(boxes_filt.size(0)):
214
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
215
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
216
  boxes_filt[i][2:] += boxes_filt[i][:2]
 
217
  boxes_filt = boxes_filt.cpu()
218
 
219
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, 0.8).numpy().tolist()
 
 
 
220
  boxes_filt = boxes_filt[nms_idx]
 
 
 
 
 
221
 
222
+ image_np = np.array(image_pil)
223
+ sam_predictor.set_image(image_np)
224
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(
225
+ boxes_filt, image_np.shape[:2]
226
+ ).to("cuda")
227
 
228
  masks, _, _ = sam_predictor.predict_torch(
229
  point_coords=None,
 
232
  multimask_output=False,
233
  )
234
  result_mask = masks[0][0].cpu().numpy()
235
+ return Image.fromarray(result_mask)
236
 
 
 
 
237
 
238
  def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
 
 
239
  if mask_np.max() <= 1.0:
240
  mask_np = (mask_np * 255).astype(np.uint8)
241
  mask_bool = mask_np > 128
 
242
  image_float = image_np.astype(np.float32)
 
 
243
  gray_overlay = np.full_like(image_float, gray_value, dtype=np.float32)
 
 
244
  result = image_float.copy()
245
+ result[mask_bool] = (1 - alpha) * image_float[mask_bool] + alpha * gray_overlay[mask_bool]
 
 
 
246
  return result.astype(np.uint8)
247
 
 
 
 
 
 
248
 
249
+ # ---------- EXAMPLES ----------
250
+ ref_dir = './examples/ref_image'
251
+ ref_mask_dir = './examples/ref_mask'
252
+ image_dir = './examples/source_image'
253
+ image_mask_dir = './examples/source_mask'
254
 
255
+ ref_list = sorted([os.path.join(ref_dir, f) for f in os.listdir(ref_dir) if f.lower().endswith((".jpg", ".png", ".jpeg"))])
256
+ ref_mask_list = sorted([os.path.join(ref_mask_dir, f) for f in os.listdir(ref_mask_dir) if f.lower().endswith((".jpg", ".png", ".jpeg"))])
257
+ image_list = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith((".jpg", ".png", ".jpeg"))])
258
+ image_mask_list = sorted([os.path.join(image_mask_dir, f) for f in os.listdir(image_mask_dir) if f.lower().endswith((".jpg", ".png", ".jpeg"))])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
 
261
  @spaces.GPU
262
  def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt):
 
 
263
  if base_mask_option == "Draw Mask":
264
  tar_image = base_image["background"]
265
  tar_mask = base_image["layers"][0]
 
294
 
295
  if tar_mask.sum() == 0:
296
  raise gr.Error('No mask for the background image.Please check mask button!')
 
297
  if ref_mask.sum() == 0:
298
  raise gr.Error('No mask for the reference image.Please check mask button!')
299
 
300
  ref_box_yyxx = get_bbox_from_mask(ref_mask)
301
+ ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
302
+ masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1 - ref_mask_3)
303
+ y1, y2, x1, x2 = ref_box_yyxx
304
+ masked_ref_image = masked_ref_image[y1:y2, x1:x2, :]
305
+ ref_mask = ref_mask[y1:y2, x1:x2]
306
  ratio = 1.3
307
  masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
308
 
309
+ masked_ref_image = pad_to_square(masked_ref_image, pad_value=255, random=False)
 
310
 
311
  kernel = np.ones((7, 7), np.uint8)
312
  iterations = 2
313
  tar_mask = cv2.dilate(tar_mask, kernel, iterations=iterations)
314
 
315
+ # zoom in
316
  tar_box_yyxx = get_bbox_from_mask(tar_mask)
317
  tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=1.2)
318
 
319
+ tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=2)
320
+ tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
321
+ y1, y2, x1, x2 = tar_box_yyxx_crop
 
322
 
323
  old_tar_image = tar_image.copy()
324
+ tar_image = tar_image[y1:y2, x1:x2, :]
325
+ tar_mask = tar_mask[y1:y2, x1:x2]
326
 
327
  H1, W1 = tar_image.shape[0], tar_image.shape[1]
 
 
328
 
329
  tar_mask = pad_to_square(tar_mask, pad_value=0)
330
  tar_mask = cv2.resize(tar_mask, size)
 
332
  masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
333
  pipe_prior_output = redux(Image.fromarray(masked_ref_image))
334
 
 
335
  tar_image = pad_to_square(tar_image, pad_value=255)
 
336
  H2, W2 = tar_image.shape[0], tar_image.shape[1]
 
337
  tar_image = cv2.resize(tar_image, size)
338
  diptych_ref_tar = np.concatenate([masked_ref_image, tar_image], axis=1)
339
 
340
+ tar_mask = np.stack([tar_mask, tar_mask, tar_mask], -1)
 
341
  mask_black = np.ones_like(tar_image) * 0
342
  mask_diptych = np.concatenate([mask_black, tar_mask], axis=1)
343
+
344
  show_diptych_ref_tar = create_highlighted_mask(diptych_ref_tar, mask_diptych)
345
  show_diptych_ref_tar = Image.fromarray(show_diptych_ref_tar)
346
 
 
348
  mask_diptych[mask_diptych == 1] = 255
349
  mask_diptych = Image.fromarray(mask_diptych)
350
 
 
 
351
  generator = torch.Generator("cuda").manual_seed(seed)
352
  edited_image = pipe(
353
  image=diptych_ref_tar,
 
356
  width=mask_diptych.size[0],
357
  max_sequence_length=512,
358
  generator=generator,
359
+ **pipe_prior_output,
360
  ).images[0]
361
 
 
 
362
  width, height = edited_image.size
363
  left = width // 2
364
+ edited_image = edited_image.crop((left, 0, width, height))
 
 
 
 
365
 
366
  edited_image = np.array(edited_image)
367
+ edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
368
  edited_image = Image.fromarray(edited_image)
369
 
370
  if ref_mask_option != "Label to Mask":
371
  return [show_diptych_ref_tar, edited_image]
372
  else:
373
+ return [return_ref_mask, show_diptych_ref_tar, edited_image]
374
+
375
 
376
  def update_ui(option):
377
  if option == "Draw Mask":
 
381
 
382
 
383
  with gr.Blocks() as demo:
 
 
384
  gr.Markdown("# Insert-Anything")
385
  gr.Markdown("### Make sure to select the correct mask button!!")
386
  gr.Markdown("### Click the output image to toggle between Diptych and final results!!")
 
388
  with gr.Row():
389
  with gr.Column(scale=1):
390
  with gr.Row():
391
+ base_image = gr.ImageEditor(label="Background Image", sources="upload", type="pil",
392
+ brush=gr.Brush(colors=["#FFFFFF"], default_size=30, color_mode="fixed"),
393
+ layers=False, interactive=True)
394
+ base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil",
395
+ layers=False, brush=False, eraser=False)
 
396
  with gr.Row():
397
+ base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option",
398
+ value="Upload with Mask")
399
 
400
  with gr.Row():
401
+ ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil",
402
+ brush=gr.Brush(colors=["#FFFFFF"], default_size=30, color_mode="fixed"),
403
+ layers=False, interactive=True)
404
+ ref_mask = gr.ImageEditor(label="Reference Mask", sources="upload", type="pil",
405
+ layers=False, brush=False, eraser=False)
406
 
407
  with gr.Row():
408
+ ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"],
409
+ label="Reference Mask Input Option", value="Upload with Mask")
410
  with gr.Row():
411
+ text_prompt = gr.Textbox(label="Label",
412
+ placeholder="Enter the category of the reference object, e.g., car, dress, toy, etc.")
413
 
414
  with gr.Column(scale=1):
415
  baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=695, columns=1)
416
  with gr.Accordion("Advanced Option", open=True):
417
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=999_999_999, step=1, value=666)
418
  gr.Markdown("### Guidelines")
419
  gr.Markdown(" Users can try using different seeds. For example, seeds like 42 and 123456 may produce different effects.")
420
  gr.Markdown(" Draw Mask means manually drawing a mask on the original image.")
421
  gr.Markdown(" Upload with Mask means uploading a mask file.")
422
  gr.Markdown(" Label to Mask means simply inputting a label to automatically extract the mask and obtain the result.")
 
423
 
424
  run_local_button = gr.Button(value="Run")
425
 
426
+ # examples
427
  num_examples = len(image_list)
428
  for i in range(num_examples):
429
  with gr.Row():
 
439
  gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
440
  if i < num_examples - 1:
441
  gr.HTML("<hr>")
 
442
 
443
+ run_local_button.click(
444
+ fn=run_local,
445
+ inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
446
+ outputs=[baseline_gallery]
447
+ )
448
+ demo.launch()