Spaces:
Paused
Paused
Update app.py
Browse filesAcelerating inference
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py — storage-safe + HF Hub friendly +
|
2 |
|
3 |
import os
|
4 |
|
@@ -19,12 +19,15 @@ os.environ["OMP_NUM_THREADS"] = omp_val # must be a positive integer string
|
|
19 |
os.environ.setdefault("HF_HOME", "/data/.huggingface")
|
20 |
os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
|
21 |
os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
|
22 |
-
# (TRANSFORMERS_CACHE is deprecated; rely on HF_HOME)
|
23 |
|
24 |
# Disable Xet path, enable fast transfer
|
25 |
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
26 |
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
27 |
|
|
|
|
|
|
|
28 |
# ---------- NOW safe to import heavy libs ----------
|
29 |
import sys
|
30 |
import cv2
|
@@ -33,12 +36,37 @@ import torch
|
|
33 |
import gradio as gr
|
34 |
from PIL import Image, ImageFilter, ImageDraw
|
35 |
|
|
|
36 |
try:
|
37 |
torch.set_num_threads(int(omp_val))
|
38 |
torch.set_num_interop_threads(1)
|
39 |
except Exception:
|
40 |
pass
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# ---------- HUB IMPORTS ----------
|
43 |
from huggingface_hub import snapshot_download, hf_hub_download
|
44 |
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
|
@@ -92,14 +120,13 @@ GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py"
|
|
92 |
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "groundingdino_swinb_cogcoor.pth")
|
93 |
|
94 |
# Segment-Anything checkpoint
|
95 |
-
SAM_ENCODER_VERSION = "vit_h"
|
96 |
SAM_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "sam_vit_h_4b8939.pth")
|
97 |
|
98 |
# ---------- AUTH TOKEN ----------
|
99 |
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
100 |
|
101 |
# ---------- DOWNLOAD CHECKPOINTS (single files) ----------
|
102 |
-
# Use hf_hub_download for single files, which returns a cached path. Keep files under /data. # https://huggingface.co/docs/huggingface_hub/en/guides/download
|
103 |
if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
|
104 |
g_dino_file = hf_hub_download(
|
105 |
repo_id="ShilongLiu/GroundingDINO",
|
@@ -164,13 +191,14 @@ groundingdino_model = load_model(
|
|
164 |
device="cuda"
|
165 |
)
|
166 |
|
167 |
-
# SAM + Predictor (registry API from official SAM)
|
168 |
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
|
169 |
sam.to(device="cuda")
|
170 |
sam_predictor = SamPredictor(sam)
|
171 |
|
172 |
# Diffusers (Flux)
|
173 |
-
|
|
|
174 |
size = (768, 768)
|
175 |
|
176 |
pipe = FluxFillPipeline.from_pretrained(
|
@@ -178,17 +206,57 @@ pipe = FluxFillPipeline.from_pretrained(
|
|
178 |
torch_dtype=dtype
|
179 |
).to("cuda")
|
180 |
|
|
|
181 |
pipe.load_lora_weights(
|
182 |
os.path.join(LORA_DIR, "20250321_steps5000_pytorch_lora_weights.safetensors")
|
183 |
)
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
redux = FluxPriorReduxPipeline.from_pretrained(REDUX_DIR).to(dtype=dtype).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
-
# ----------
|
188 |
def transform_image(image_pil):
|
|
|
189 |
transform = T.Compose(
|
190 |
[
|
191 |
-
T.RandomResize([
|
192 |
T.ToTensor(),
|
193 |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
194 |
]
|
@@ -196,31 +264,42 @@ def transform_image(image_pil):
|
|
196 |
image, _ = transform(image_pil, None) # 3, h, w
|
197 |
return image
|
198 |
|
199 |
-
|
200 |
def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
|
201 |
caption = caption.lower().strip()
|
202 |
if not caption.endswith("."):
|
203 |
caption = caption + "."
|
204 |
-
|
|
|
|
|
|
|
|
|
205 |
outputs = model(image[None], captions=[caption])
|
206 |
-
|
207 |
-
|
|
|
208 |
|
209 |
# filter output
|
210 |
filt_mask = logits.max(dim=1)[0] > box_threshold
|
211 |
logits_filt = logits[filt_mask]
|
212 |
boxes_filt = boxes[filt_mask]
|
213 |
|
214 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
tokenlizer = model.tokenizer
|
216 |
tokenized = tokenlizer(caption)
|
217 |
-
pred_phrases
|
218 |
-
for logit
|
219 |
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
220 |
-
pred_phrases.append(pred_phrase + f"({
|
221 |
-
scores.append(logit.max().item())
|
222 |
-
return boxes_filt, torch.Tensor(scores), pred_phrases
|
223 |
|
|
|
224 |
|
225 |
def get_mask(image, label):
|
226 |
global groundingdino_model, sam_predictor
|
@@ -236,27 +315,25 @@ def get_mask(image, label):
|
|
236 |
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
237 |
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
238 |
boxes_filt[i][2:] += boxes_filt[i][:2]
|
239 |
-
|
240 |
-
|
241 |
-
nms_idx = torchvision.ops.nms(boxes_filt, scores, 0.8).numpy().tolist()
|
242 |
-
boxes_filt = boxes_filt[nms_idx]
|
243 |
-
|
244 |
image_np = np.array(image_pil)
|
245 |
sam_predictor.set_image(image_np)
|
|
|
246 |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
247 |
boxes_filt, image_np.shape[:2]
|
248 |
).to("cuda")
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
257 |
return Image.fromarray(result_mask)
|
258 |
|
259 |
-
|
260 |
def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
|
261 |
if mask_np.max() <= 1.0:
|
262 |
mask_np = (mask_np * 255).astype(np.uint8)
|
@@ -267,6 +344,15 @@ def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
|
|
267 |
result[mask_bool] = (1 - alpha) * image_float[mask_bool] + alpha * gray_overlay[mask_bool]
|
268 |
return result.astype(np.uint8)
|
269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
# ---------- EXAMPLES ----------
|
272 |
ref_dir = './examples/ref_image'
|
@@ -334,9 +420,8 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
334 |
|
335 |
masked_ref_image = pad_to_square(masked_ref_image, pad_value=255, random=False)
|
336 |
|
337 |
-
kernel = np.ones((7, 7), np.uint8)
|
338 |
iterations = 2
|
339 |
-
tar_mask = cv2.dilate(tar_mask,
|
340 |
|
341 |
# zoom in
|
342 |
tar_box_yyxx = get_bbox_from_mask(tar_mask)
|
@@ -355,8 +440,10 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
355 |
tar_mask = pad_to_square(tar_mask, pad_value=0)
|
356 |
tar_mask = cv2.resize(tar_mask, size)
|
357 |
|
|
|
358 |
masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
|
359 |
-
|
|
|
360 |
|
361 |
tar_image = pad_to_square(tar_image, pad_value=255)
|
362 |
H2, W2 = tar_image.shape[0], tar_image.shape[1]
|
@@ -374,16 +461,22 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
374 |
mask_diptych[mask_diptych == 1] = 255
|
375 |
mask_diptych = Image.fromarray(mask_diptych)
|
376 |
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
width, height = edited_image.size
|
389 |
left = width // 2
|
@@ -471,4 +564,4 @@ with gr.Blocks() as demo:
|
|
471 |
inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
|
472 |
outputs=[baseline_gallery]
|
473 |
)
|
474 |
-
demo.launch()
|
|
|
1 |
+
# app.py — storage-safe + HF Hub friendly + speed-optimized
|
2 |
|
3 |
import os
|
4 |
|
|
|
19 |
os.environ.setdefault("HF_HOME", "/data/.huggingface")
|
20 |
os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
|
21 |
os.environ.setdefault("HF_DATASETS_CACHE", "/data/.huggingface/datasets")
|
22 |
+
# (TRANSFORMERS_CACHE is deprecated; rely on HF_HOME)
|
23 |
|
24 |
# Disable Xet path, enable fast transfer
|
25 |
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
|
26 |
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
27 |
|
28 |
+
# Faster + smoother CUDA memory behavior
|
29 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
30 |
+
|
31 |
# ---------- NOW safe to import heavy libs ----------
|
32 |
import sys
|
33 |
import cv2
|
|
|
36 |
import gradio as gr
|
37 |
from PIL import Image, ImageFilter, ImageDraw
|
38 |
|
39 |
+
# Global torch speed knobs
|
40 |
try:
|
41 |
torch.set_num_threads(int(omp_val))
|
42 |
torch.set_num_interop_threads(1)
|
43 |
except Exception:
|
44 |
pass
|
45 |
|
46 |
+
# Use TF32 when available (Ampere/Ada)
|
47 |
+
try:
|
48 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
49 |
+
torch.backends.cudnn.allow_tf32 = True
|
50 |
+
except Exception:
|
51 |
+
pass
|
52 |
+
try:
|
53 |
+
# PyTorch 2.x matmul precision hint
|
54 |
+
torch.set_float32_matmul_precision("high")
|
55 |
+
except Exception:
|
56 |
+
pass
|
57 |
+
|
58 |
+
# Pick fastest cudnn convs once shapes are known
|
59 |
+
try:
|
60 |
+
torch.backends.cudnn.benchmark = True
|
61 |
+
except Exception:
|
62 |
+
pass
|
63 |
+
|
64 |
+
# No autograd for inference-only app
|
65 |
+
torch.set_grad_enabled(False)
|
66 |
+
|
67 |
+
# SDPA availability flag
|
68 |
+
USE_SDPA = hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
69 |
+
|
70 |
# ---------- HUB IMPORTS ----------
|
71 |
from huggingface_hub import snapshot_download, hf_hub_download
|
72 |
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
|
|
|
120 |
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "groundingdino_swinb_cogcoor.pth")
|
121 |
|
122 |
# Segment-Anything checkpoint
|
123 |
+
SAM_ENCODER_VERSION = "vit_h" # consider "vit_l" or "vit_b" for more speed
|
124 |
SAM_CHECKPOINT_PATH = os.path.join(CKPT_DIR, "sam_vit_h_4b8939.pth")
|
125 |
|
126 |
# ---------- AUTH TOKEN ----------
|
127 |
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
128 |
|
129 |
# ---------- DOWNLOAD CHECKPOINTS (single files) ----------
|
|
|
130 |
if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
|
131 |
g_dino_file = hf_hub_download(
|
132 |
repo_id="ShilongLiu/GroundingDINO",
|
|
|
191 |
device="cuda"
|
192 |
)
|
193 |
|
194 |
+
# SAM + Predictor (registry API from official SAM)
|
195 |
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
|
196 |
sam.to(device="cuda")
|
197 |
sam_predictor = SamPredictor(sam)
|
198 |
|
199 |
# Diffusers (Flux)
|
200 |
+
# Prefer float16 for speed; change to bfloat16 if you hit NaNs on your GPU/drivers.
|
201 |
+
dtype = torch.float16
|
202 |
size = (768, 768)
|
203 |
|
204 |
pipe = FluxFillPipeline.from_pretrained(
|
|
|
206 |
torch_dtype=dtype
|
207 |
).to("cuda")
|
208 |
|
209 |
+
# Load LoRA
|
210 |
pipe.load_lora_weights(
|
211 |
os.path.join(LORA_DIR, "20250321_steps5000_pytorch_lora_weights.safetensors")
|
212 |
)
|
213 |
|
214 |
+
# Speed features
|
215 |
+
try:
|
216 |
+
if USE_SDPA and hasattr(pipe, "enable_sdpa"):
|
217 |
+
pipe.enable_sdpa()
|
218 |
+
elif hasattr(pipe, "enable_xformers_memory_efficient_attention"):
|
219 |
+
pipe.enable_xformers_memory_efficient_attention()
|
220 |
+
except Exception:
|
221 |
+
pass
|
222 |
+
|
223 |
+
try:
|
224 |
+
pipe.enable_vae_tiling()
|
225 |
+
except Exception:
|
226 |
+
pass
|
227 |
+
|
228 |
+
# Compile hot paths (PyTorch 2.0+)
|
229 |
+
try:
|
230 |
+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
231 |
+
if hasattr(pipe.vae, "decode"):
|
232 |
+
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="reduce-overhead")
|
233 |
+
except Exception:
|
234 |
+
pass
|
235 |
+
|
236 |
+
# Disable progress bars for tiny perf win
|
237 |
+
try:
|
238 |
+
pipe.set_progress_bar_config(disable=True)
|
239 |
+
except Exception:
|
240 |
+
pass
|
241 |
+
|
242 |
redux = FluxPriorReduxPipeline.from_pretrained(REDUX_DIR).to(dtype=dtype).to("cuda")
|
243 |
+
try:
|
244 |
+
if USE_SDPA and hasattr(redux, "enable_sdpa"):
|
245 |
+
redux.enable_sdpa()
|
246 |
+
except Exception:
|
247 |
+
pass
|
248 |
+
try:
|
249 |
+
if hasattr(redux, "image_encoder"):
|
250 |
+
redux.image_encoder = torch.compile(redux.image_encoder, mode="reduce-overhead")
|
251 |
+
except Exception:
|
252 |
+
pass
|
253 |
|
254 |
+
# ---------- GLOBAL UTILS ----------
|
255 |
def transform_image(image_pil):
|
256 |
+
# Smaller resize for faster DINO (was 800/max 1333)
|
257 |
transform = T.Compose(
|
258 |
[
|
259 |
+
T.RandomResize([640], max_size=1024),
|
260 |
T.ToTensor(),
|
261 |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
262 |
]
|
|
|
264 |
image, _ = transform(image_pil, None) # 3, h, w
|
265 |
return image
|
266 |
|
|
|
267 |
def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
|
268 |
caption = caption.lower().strip()
|
269 |
if not caption.endswith("."):
|
270 |
caption = caption + "."
|
271 |
+
device = next(model.parameters()).device
|
272 |
+
image = image.to(device, non_blocking=True)
|
273 |
+
|
274 |
+
# DINO forward in fp16 for speed
|
275 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
276 |
outputs = model(image[None], captions=[caption])
|
277 |
+
|
278 |
+
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) CUDA
|
279 |
+
boxes = outputs["pred_boxes"][0] # (nq, 4) CUDA
|
280 |
|
281 |
# filter output
|
282 |
filt_mask = logits.max(dim=1)[0] > box_threshold
|
283 |
logits_filt = logits[filt_mask]
|
284 |
boxes_filt = boxes[filt_mask]
|
285 |
|
286 |
+
# scores for NMS
|
287 |
+
scores = logits_filt.max(dim=1).values
|
288 |
+
# NMS on GPU
|
289 |
+
nms_idx = torchvision.ops.nms(boxes_filt, scores, 0.8)
|
290 |
+
|
291 |
+
# Move minimal tensors to CPU for tokenizer phrase mapping
|
292 |
+
boxes_filt_cpu = boxes_filt[nms_idx].detach().cpu()
|
293 |
+
scores_cpu = scores[nms_idx].detach().cpu()
|
294 |
+
|
295 |
tokenlizer = model.tokenizer
|
296 |
tokenized = tokenlizer(caption)
|
297 |
+
pred_phrases = []
|
298 |
+
for logit in logits_filt[nms_idx].detach().cpu():
|
299 |
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
300 |
+
pred_phrases.append(pred_phrase + f"({float(logit.max()):.2f})" if with_logits else pred_phrase)
|
|
|
|
|
301 |
|
302 |
+
return boxes_filt_cpu, scores_cpu, pred_phrases
|
303 |
|
304 |
def get_mask(image, label):
|
305 |
global groundingdino_model, sam_predictor
|
|
|
315 |
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
316 |
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
317 |
boxes_filt[i][2:] += boxes_filt[i][:2]
|
318 |
+
# keep CPU for transform, then CUDA for SAM
|
|
|
|
|
|
|
|
|
319 |
image_np = np.array(image_pil)
|
320 |
sam_predictor.set_image(image_np)
|
321 |
+
|
322 |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
323 |
boxes_filt, image_np.shape[:2]
|
324 |
).to("cuda")
|
325 |
|
326 |
+
# SAM forward (fp16 autocast for speed; switch to fp32 if masks degrade)
|
327 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
328 |
+
masks, _, _ = sam_predictor.predict_torch(
|
329 |
+
point_coords=None,
|
330 |
+
point_labels=None,
|
331 |
+
boxes=transformed_boxes,
|
332 |
+
multimask_output=False,
|
333 |
+
)
|
334 |
+
result_mask = masks[0][0].detach().cpu().numpy()
|
335 |
return Image.fromarray(result_mask)
|
336 |
|
|
|
337 |
def create_highlighted_mask(image_np, mask_np, alpha=0.5, gray_value=128):
|
338 |
if mask_np.max() <= 1.0:
|
339 |
mask_np = (mask_np * 255).astype(np.uint8)
|
|
|
344 |
result[mask_bool] = (1 - alpha) * image_float[mask_bool] + alpha * gray_overlay[mask_bool]
|
345 |
return result.astype(np.uint8)
|
346 |
|
347 |
+
# Pre-allocated kernel to avoid repeated allocs
|
348 |
+
KERNEL_7x7 = np.ones((7, 7), np.uint8)
|
349 |
+
|
350 |
+
# Reusable CUDA generator (seedable)
|
351 |
+
GLOBAL_GEN = torch.Generator(device="cuda")
|
352 |
+
def make_gen(seed):
|
353 |
+
if seed is None or seed < 0:
|
354 |
+
return GLOBAL_GEN
|
355 |
+
return torch.Generator(device="cuda").manual_seed(int(seed))
|
356 |
|
357 |
# ---------- EXAMPLES ----------
|
358 |
ref_dir = './examples/ref_image'
|
|
|
420 |
|
421 |
masked_ref_image = pad_to_square(masked_ref_image, pad_value=255, random=False)
|
422 |
|
|
|
423 |
iterations = 2
|
424 |
+
tar_mask = cv2.dilate(tar_mask, KERNEL_7x7, iterations=iterations)
|
425 |
|
426 |
# zoom in
|
427 |
tar_box_yyxx = get_bbox_from_mask(tar_mask)
|
|
|
440 |
tar_mask = pad_to_square(tar_mask, pad_value=0)
|
441 |
tar_mask = cv2.resize(tar_mask, size)
|
442 |
|
443 |
+
# --- Redux (prior) ---
|
444 |
masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
|
445 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=dtype):
|
446 |
+
pipe_prior_output = redux(Image.fromarray(masked_ref_image))
|
447 |
|
448 |
tar_image = pad_to_square(tar_image, pad_value=255)
|
449 |
H2, W2 = tar_image.shape[0], tar_image.shape[1]
|
|
|
461 |
mask_diptych[mask_diptych == 1] = 255
|
462 |
mask_diptych = Image.fromarray(mask_diptych)
|
463 |
|
464 |
+
# Reusable CUDA generator
|
465 |
+
generator = make_gen(seed)
|
466 |
+
|
467 |
+
# --- Flux Fill ---
|
468 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=dtype):
|
469 |
+
edited_image = pipe(
|
470 |
+
image=diptych_ref_tar,
|
471 |
+
mask_image=mask_diptych,
|
472 |
+
height=mask_diptych.size[1],
|
473 |
+
width=mask_diptych.size[0],
|
474 |
+
max_sequence_length=512,
|
475 |
+
generator=generator,
|
476 |
+
num_inference_steps=18, # tune 12–24 for quality/speed tradeoff
|
477 |
+
guidance_scale=3.5, # lower often faster and still good
|
478 |
+
**pipe_prior_output,
|
479 |
+
).images[0]
|
480 |
|
481 |
width, height = edited_image.size
|
482 |
left = width // 2
|
|
|
564 |
inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
|
565 |
outputs=[baseline_gallery]
|
566 |
)
|
567 |
+
demo.launch()
|