isat commited on
Commit
12b4ecb
·
verified ·
1 Parent(s): ecf3916

Update app.py

Browse files

Acelerating inference

Files changed (1) hide show
  1. app.py +138 -45
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — storage-safe + HF Hub friendly + SAM import guard
2
 
3
  import os
4
 
@@ -19,12 +19,15 @@ os.environ["OMP_NUM_THREADS"] = omp_val # must be a positive integer string
19
  os.environ.setdefault("HF_HOME", "/data/.huggingface")
20
  os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
21
  os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
22
- # (TRANSFORMERS_CACHE is deprecated; rely on HF_HOME) # https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache
23
 
24
  # Disable Xet path, enable fast transfer
25
  os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
26
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
27
 
 
 
 
28
  # ---------- NOW safe to import heavy libs ----------
29
  import sys
30
  import cv2
@@ -33,12 +36,37 @@ import torch
33
  import gradio as gr
34
  from PIL import Image, ImageFilter, ImageDraw
35
 
 
36
  try:
37
  torch.set_num_threads(int(omp_val))
38
  torch.set_num_interop_threads(1)
39
  except Exception:
40
  pass
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # ---------- HUB IMPORTS ----------
43
  from huggingface_hub import snapshot_download, hf_hub_download
44
  from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
@@ -92,14 +120,13 @@ GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py"
92
  GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "groundingdino_swinb_cogcoor.pth")
93
 
94
  # Segment-Anything checkpoint
95
- SAM_ENCODER_VERSION = "vit_h"
96
  SAM_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "sam_vit_h_4b8939.pth")
97
 
98
  # ---------- AUTH TOKEN ----------
99
  hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
100
 
101
  # ---------- DOWNLOAD CHECKPOINTS (single files) ----------
102
- # Use hf_hub_download for single files, which returns a cached path. Keep files under /data. # https://huggingface.co/docs/huggingface_hub/en/guides/download
103
  if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
104
  g_dino_file = hf_hub_download(
105
  repo_id="ShilongLiu/GroundingDINO",
@@ -164,13 +191,14 @@ groundingdino_model = load_model(
164
  device="cuda"
165
  )
166
 
167
- # SAM + Predictor (registry API from official SAM) # https://github.com/facebookresearch/segment-anything
168
  sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
169
  sam.to(device="cuda")
170
  sam_predictor = SamPredictor(sam)
171
 
172
  # Diffusers (Flux)
173
- dtype = torch.bfloat16
 
174
  size = (768, 768)
175
 
176
  pipe = FluxFillPipeline.from_pretrained(
@@ -178,17 +206,57 @@ pipe = FluxFillPipeline.from_pretrained(
178
  torch_dtype=dtype
179
  ).to("cuda")
180
 
 
181
  pipe.load_lora_weights(
182
  os.path.join(LORA_DIR, "20250321_steps5000_pytorch_lora_weights.safetensors")
183
  )
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  redux = FluxPriorReduxPipeline.from_pretrained(REDUX_DIR).to(dtype=dtype).to("cuda")
 
 
 
 
 
 
 
 
 
 
186
 
187
- # ---------- APP LOGIC ----------
188
  def transform_image(image_pil):
 
189
  transform = T.Compose(
190
  [
191
- T.RandomResize([800], max_size=1333),
192
  T.ToTensor(),
193
  T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
194
  ]
@@ -196,31 +264,42 @@ def transform_image(image_pil):
196
  image, _ = transform(image_pil, None) # 3, h, w
197
  return image
198
 
199
-
200
  def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
201
  caption = caption.lower().strip()
202
  if not caption.endswith("."):
203
  caption = caption + "."
204
- with torch.no_grad():
 
 
 
 
205
  outputs = model(image[None], captions=[caption])
206
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
207
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
 
208
 
209
  # filter output
210
  filt_mask = logits.max(dim=1)[0] > box_threshold
211
  logits_filt = logits[filt_mask]
212
  boxes_filt = boxes[filt_mask]
213
 
214
- # get phrase
 
 
 
 
 
 
 
 
215
  tokenlizer = model.tokenizer
216
  tokenized = tokenlizer(caption)
217
- pred_phrases, scores = [], []
218
- for logit, box in zip(logits_filt, boxes_filt):
219
  pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
220
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})" if with_logits else pred_phrase)
221
- scores.append(logit.max().item())
222
- return boxes_filt, torch.Tensor(scores), pred_phrases
223
 
 
224
 
225
  def get_mask(image, label):
226
  global groundingdino_model, sam_predictor
@@ -236,27 +315,25 @@ def get_mask(image, label):
236
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
237
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
238
  boxes_filt[i][2:] += boxes_filt[i][:2]
239
- boxes_filt = boxes_filt.cpu()
240
-
241
- nms_idx = torchvision.ops.nms(boxes_filt, scores, 0.8).numpy().tolist()
242
- boxes_filt = boxes_filt[nms_idx]
243
-
244
  image_np = np.array(image_pil)
245
  sam_predictor.set_image(image_np)
 
246
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(
247
  boxes_filt, image_np.shape[:2]
248
  ).to("cuda")
249
 
250
- masks, _, _ = sam_predictor.predict_torch(
251
- point_coords=None,
252
- point_labels=None,
253
- boxes=transformed_boxes,
254
- multimask_output=False,
255
- )
256
- result_mask = masks[0][0].cpu().numpy()
 
 
257
  return Image.fromarray(result_mask)
258
 
259
-
260
  def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
261
  if mask_np.max() <= 1.0:
262
  mask_np = (mask_np * 255).astype(np.uint8)
@@ -267,6 +344,15 @@ def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
267
  result[mask_bool] = (1 - alpha) * image_float[mask_bool] + alpha * gray_overlay[mask_bool]
268
  return result.astype(np.uint8)
269
 
 
 
 
 
 
 
 
 
 
270
 
271
  # ---------- EXAMPLES ----------
272
  ref_dir = './examples/ref_image'
@@ -334,9 +420,8 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
334
 
335
  masked_ref_image = pad_to_square(masked_ref_image, pad_value=255, random=False)
336
 
337
- kernel = np.ones((7, 7), np.uint8)
338
  iterations = 2
339
- tar_mask = cv2.dilate(tar_mask, kernel, iterations=iterations)
340
 
341
  # zoom in
342
  tar_box_yyxx = get_bbox_from_mask(tar_mask)
@@ -355,8 +440,10 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
355
  tar_mask = pad_to_square(tar_mask, pad_value=0)
356
  tar_mask = cv2.resize(tar_mask, size)
357
 
 
358
  masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
359
- pipe_prior_output = redux(Image.fromarray(masked_ref_image))
 
360
 
361
  tar_image = pad_to_square(tar_image, pad_value=255)
362
  H2, W2 = tar_image.shape[0], tar_image.shape[1]
@@ -374,16 +461,22 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
374
  mask_diptych[mask_diptych == 1] = 255
375
  mask_diptych = Image.fromarray(mask_diptych)
376
 
377
- generator = torch.Generator("cuda").manual_seed(seed)
378
- edited_image = pipe(
379
- image=diptych_ref_tar,
380
- mask_image=mask_diptych,
381
- height=mask_diptych.size[1],
382
- width=mask_diptych.size[0],
383
- max_sequence_length=512,
384
- generator=generator,
385
- **pipe_prior_output,
386
- ).images[0]
 
 
 
 
 
 
387
 
388
  width, height = edited_image.size
389
  left = width // 2
@@ -471,4 +564,4 @@ with gr.Blocks() as demo:
471
  inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
472
  outputs=[baseline_gallery]
473
  )
474
- demo.launch()
 
1
+ # app.py — storage-safe + HF Hub friendly + speed-optimized
2
 
3
  import os
4
 
 
19
  os.environ.setdefault("HF_HOME", "/data/.huggingface")
20
  os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
21
  os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
22
+ # (TRANSFORMERS_CACHE is deprecated; rely on HF_HOME)
23
 
24
  # Disable Xet path, enable fast transfer
25
  os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
26
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
27
 
28
+ # Faster + smoother CUDA memory behavior
29
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
30
+
31
  # ---------- NOW safe to import heavy libs ----------
32
  import sys
33
  import cv2
 
36
  import gradio as gr
37
  from PIL import Image, ImageFilter, ImageDraw
38
 
39
+ # Global torch speed knobs
40
  try:
41
  torch.set_num_threads(int(omp_val))
42
  torch.set_num_interop_threads(1)
43
  except Exception:
44
  pass
45
 
46
+ # Use TF32 when available (Ampere/Ada)
47
+ try:
48
+ torch.backends.cuda.matmul.allow_tf32 = True
49
+ torch.backends.cudnn.allow_tf32 = True
50
+ except Exception:
51
+ pass
52
+ try:
53
+ # PyTorch 2.x matmul precision hint
54
+ torch.set_float32_matmul_precision("high")
55
+ except Exception:
56
+ pass
57
+
58
+ # Pick fastest cudnn convs once shapes are known
59
+ try:
60
+ torch.backends.cudnn.benchmark = True
61
+ except Exception:
62
+ pass
63
+
64
+ # No autograd for inference-only app
65
+ torch.set_grad_enabled(False)
66
+
67
+ # SDPA availability flag
68
+ USE_SDPA = hasattr(torch.nn.functional, "scaled_dot_product_attention")
69
+
70
  # ---------- HUB IMPORTS ----------
71
  from huggingface_hub import snapshot_download, hf_hub_download
72
  from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
 
120
  GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "groundingdino_swinb_cogcoor.pth")
121
 
122
  # Segment-Anything checkpoint
123
+ SAM_ENCODER_VERSION = "vit_h" # consider "vit_l" or "vit_b" for more speed
124
  SAM_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "sam_vit_h_4b8939.pth")
125
 
126
  # ---------- AUTH TOKEN ----------
127
  hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
128
 
129
  # ---------- DOWNLOAD CHECKPOINTS (single files) ----------
 
130
  if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
131
  g_dino_file = hf_hub_download(
132
  repo_id="ShilongLiu/GroundingDINO",
 
191
  device="cuda"
192
  )
193
 
194
+ # SAM + Predictor (registry API from official SAM)
195
  sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
196
  sam.to(device="cuda")
197
  sam_predictor = SamPredictor(sam)
198
 
199
  # Diffusers (Flux)
200
+ # Prefer float16 for speed; change to bfloat16 if you hit NaNs on your GPU/drivers.
201
+ dtype = torch.float16
202
  size = (768, 768)
203
 
204
  pipe = FluxFillPipeline.from_pretrained(
 
206
  torch_dtype=dtype
207
  ).to("cuda")
208
 
209
+ # Load LoRA
210
  pipe.load_lora_weights(
211
  os.path.join(LORA_DIR, "20250321_steps5000_pytorch_lora_weights.safetensors")
212
  )
213
 
214
+ # Speed features
215
+ try:
216
+ if USE_SDPA and hasattr(pipe, "enable_sdpa"):
217
+ pipe.enable_sdpa()
218
+ elif hasattr(pipe, "enable_xformers_memory_efficient_attention"):
219
+ pipe.enable_xformers_memory_efficient_attention()
220
+ except Exception:
221
+ pass
222
+
223
+ try:
224
+ pipe.enable_vae_tiling()
225
+ except Exception:
226
+ pass
227
+
228
+ # Compile hot paths (PyTorch 2.0+)
229
+ try:
230
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
231
+ if hasattr(pipe.vae, "decode"):
232
+ pipe.vae.decode = torch.compile(pipe.vae.decode, mode="reduce-overhead")
233
+ except Exception:
234
+ pass
235
+
236
+ # Disable progress bars for tiny perf win
237
+ try:
238
+ pipe.set_progress_bar_config(disable=True)
239
+ except Exception:
240
+ pass
241
+
242
  redux = FluxPriorReduxPipeline.from_pretrained(REDUX_DIR).to(dtype=dtype).to("cuda")
243
+ try:
244
+ if USE_SDPA and hasattr(redux, "enable_sdpa"):
245
+ redux.enable_sdpa()
246
+ except Exception:
247
+ pass
248
+ try:
249
+ if hasattr(redux, "image_encoder"):
250
+ redux.image_encoder = torch.compile(redux.image_encoder, mode="reduce-overhead")
251
+ except Exception:
252
+ pass
253
 
254
+ # ---------- GLOBAL UTILS ----------
255
  def transform_image(image_pil):
256
+ # Smaller resize for faster DINO (was 800/max 1333)
257
  transform = T.Compose(
258
  [
259
+ T.RandomResize([640], max_size=1024),
260
  T.ToTensor(),
261
  T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
262
  ]
 
264
  image, _ = transform(image_pil, None) # 3, h, w
265
  return image
266
 
 
267
  def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
268
  caption = caption.lower().strip()
269
  if not caption.endswith("."):
270
  caption = caption + "."
271
+ device = next(model.parameters()).device
272
+ image = image.to(device, non_blocking=True)
273
+
274
+ # DINO forward in fp16 for speed
275
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
276
  outputs = model(image[None], captions=[caption])
277
+
278
+ logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) CUDA
279
+ boxes = outputs["pred_boxes"][0] # (nq, 4) CUDA
280
 
281
  # filter output
282
  filt_mask = logits.max(dim=1)[0] > box_threshold
283
  logits_filt = logits[filt_mask]
284
  boxes_filt = boxes[filt_mask]
285
 
286
+ # scores for NMS
287
+ scores = logits_filt.max(dim=1).values
288
+ # NMS on GPU
289
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, 0.8)
290
+
291
+ # Move minimal tensors to CPU for tokenizer phrase mapping
292
+ boxes_filt_cpu = boxes_filt[nms_idx].detach().cpu()
293
+ scores_cpu = scores[nms_idx].detach().cpu()
294
+
295
  tokenlizer = model.tokenizer
296
  tokenized = tokenlizer(caption)
297
+ pred_phrases = []
298
+ for logit in logits_filt[nms_idx].detach().cpu():
299
  pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
300
+ pred_phrases.append(pred_phrase + f"({float(logit.max()):.2f})" if with_logits else pred_phrase)
 
 
301
 
302
+ return boxes_filt_cpu, scores_cpu, pred_phrases
303
 
304
  def get_mask(image, label):
305
  global groundingdino_model, sam_predictor
 
315
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
316
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
317
  boxes_filt[i][2:] += boxes_filt[i][:2]
318
+ # keep CPU for transform, then CUDA for SAM
 
 
 
 
319
  image_np = np.array(image_pil)
320
  sam_predictor.set_image(image_np)
321
+
322
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(
323
  boxes_filt, image_np.shape[:2]
324
  ).to("cuda")
325
 
326
+ # SAM forward (fp16 autocast for speed; switch to fp32 if masks degrade)
327
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
328
+ masks, _, _ = sam_predictor.predict_torch(
329
+ point_coords=None,
330
+ point_labels=None,
331
+ boxes=transformed_boxes,
332
+ multimask_output=False,
333
+ )
334
+ result_mask = masks[0][0].detach().cpu().numpy()
335
  return Image.fromarray(result_mask)
336
 
 
337
  def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
338
  if mask_np.max() <= 1.0:
339
  mask_np = (mask_np * 255).astype(np.uint8)
 
344
  result[mask_bool] = (1 - alpha) * image_float[mask_bool] + alpha * gray_overlay[mask_bool]
345
  return result.astype(np.uint8)
346
 
347
+ # Pre-allocated kernel to avoid repeated allocs
348
+ KERNEL_7x7 = np.ones((7, 7), np.uint8)
349
+
350
+ # Reusable CUDA generator (seedable)
351
+ GLOBAL_GEN = torch.Generator(device="cuda")
352
+ def make_gen(seed):
353
+ if seed is None or seed < 0:
354
+ return GLOBAL_GEN
355
+ return torch.Generator(device="cuda").manual_seed(int(seed))
356
 
357
  # ---------- EXAMPLES ----------
358
  ref_dir = './examples/ref_image'
 
420
 
421
  masked_ref_image = pad_to_square(masked_ref_image, pad_value=255, random=False)
422
 
 
423
  iterations = 2
424
+ tar_mask = cv2.dilate(tar_mask, KERNEL_7x7, iterations=iterations)
425
 
426
  # zoom in
427
  tar_box_yyxx = get_bbox_from_mask(tar_mask)
 
440
  tar_mask = pad_to_square(tar_mask, pad_value=0)
441
  tar_mask = cv2.resize(tar_mask, size)
442
 
443
+ # --- Redux (prior) ---
444
  masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
445
+ with torch.inference_mode(), torch.autocast("cuda", dtype=dtype):
446
+ pipe_prior_output = redux(Image.fromarray(masked_ref_image))
447
 
448
  tar_image = pad_to_square(tar_image, pad_value=255)
449
  H2, W2 = tar_image.shape[0], tar_image.shape[1]
 
461
  mask_diptych[mask_diptych == 1] = 255
462
  mask_diptych = Image.fromarray(mask_diptych)
463
 
464
+ # Reusable CUDA generator
465
+ generator = make_gen(seed)
466
+
467
+ # --- Flux Fill ---
468
+ with torch.inference_mode(), torch.autocast("cuda", dtype=dtype):
469
+ edited_image = pipe(
470
+ image=diptych_ref_tar,
471
+ mask_image=mask_diptych,
472
+ height=mask_diptych.size[1],
473
+ width=mask_diptych.size[0],
474
+ max_sequence_length=512,
475
+ generator=generator,
476
+ num_inference_steps=18, # tune 12–24 for quality/speed tradeoff
477
+ guidance_scale=3.5, # lower often faster and still good
478
+ **pipe_prior_output,
479
+ ).images[0]
480
 
481
  width, height = edited_image.size
482
  left = width // 2
 
564
  inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
565
  outputs=[baseline_gallery]
566
  )
567
+ demo.launch()