Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
52b5462
1
Parent(s):
d04cb87
Add mask and blending modules
Browse files
app.py
CHANGED
@@ -225,9 +225,7 @@ def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> N
|
|
225 |
draw.text((box[0], box[1]), str(label), fill="white")
|
226 |
|
227 |
|
228 |
-
def run_grounded_sam(
|
229 |
-
input_image
|
230 |
-
) -> List[Image.Image]:
|
231 |
"""Main function to run GroundingDINO and SAM-HQ"""
|
232 |
try:
|
233 |
# Create output directory
|
@@ -319,7 +317,7 @@ def run_grounded_sam(
|
|
319 |
for box, label in zip(boxes_filt, pred_phrases):
|
320 |
draw_box(box, image_draw, label)
|
321 |
|
322 |
-
return
|
323 |
|
324 |
except Exception as e:
|
325 |
logger.error(f"Error in run_grounded_sam: {e}")
|
@@ -331,14 +329,85 @@ def run_grounded_sam(
|
|
331 |
else:
|
332 |
return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
|
333 |
|
334 |
-
def
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
|
|
339 |
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
def encode_image(image):
|
344 |
buffer = BytesIO()
|
@@ -368,10 +437,85 @@ def generate_ai_bg(input_img, prompt):
|
|
368 |
|
369 |
return ic_light_img
|
370 |
|
371 |
-
def
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
-
return [
|
375 |
|
376 |
def create_ui():
|
377 |
"""Create Gradio UI for CarViz demo"""
|
@@ -383,6 +527,7 @@ def create_ui():
|
|
383 |
with gr.Row():
|
384 |
with gr.Column():
|
385 |
input_image = gr.Image(type="pil", label="image")
|
|
|
386 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
|
387 |
run_button = gr.Button(value='Run')
|
388 |
|
@@ -396,6 +541,7 @@ def create_ui():
|
|
396 |
fn=generate_image,
|
397 |
inputs=[
|
398 |
input_image,
|
|
|
399 |
prompt
|
400 |
],
|
401 |
outputs=gallery
|
|
|
225 |
draw.text((box[0], box[1]), str(label), fill="white")
|
226 |
|
227 |
|
228 |
+
def run_grounded_sam(input_image):
|
|
|
|
|
229 |
"""Main function to run GroundingDINO and SAM-HQ"""
|
230 |
try:
|
231 |
# Create output directory
|
|
|
317 |
for box, label in zip(boxes_filt, pred_phrases):
|
318 |
draw_box(box, image_draw, label)
|
319 |
|
320 |
+
return mask_image
|
321 |
|
322 |
except Exception as e:
|
323 |
logger.error(f"Error in run_grounded_sam: {e}")
|
|
|
329 |
else:
|
330 |
return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
|
331 |
|
332 |
+
def image_gaussian_blur(image: torch.Tensor, radius: float) -> torch.Tensor:
|
333 |
+
if image.ndim == 4: # Remove batch dimension if present
|
334 |
+
image = image.squeeze(0)
|
335 |
+
pil_image = tensor2pil(image)
|
336 |
+
blurred_pil_image = pil_image.filter(ImageFilter.GaussianBlur(radius))
|
337 |
+
return pil2tensor(blurred_pil_image).squeeze(0)
|
338 |
|
339 |
+
def load_image(image_path: str) -> torch.Tensor:
|
340 |
+
image = Image.open(image_path).convert("RGBA")
|
341 |
+
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
|
342 |
+
return image_tensor
|
343 |
+
|
344 |
+
def split_image_with_alpha(image: torch.Tensor):
|
345 |
+
out_images = image[:3, :, :]
|
346 |
+
out_alphas = image[3, :, :] if image.shape[0] > 3 else torch.ones_like(image[0, :, :])
|
347 |
+
result = (out_images.unsqueeze(0), 1.0 - out_alphas.unsqueeze(0))
|
348 |
+
return result
|
349 |
+
|
350 |
+
def pil2numpy(image: Image.Image):
|
351 |
+
return np.array(image).astype(np.float32) / 255.0
|
352 |
+
|
353 |
+
def numpy2pil(image: np.ndarray, mode=None):
|
354 |
+
return Image.fromarray(np.clip(255.0 * image, 0, 255).astype(np.uint8), mode)
|
355 |
+
|
356 |
+
def pil2tensor(image: Image.Image):
|
357 |
+
return torch.from_numpy(pil2numpy(image)).unsqueeze(0)
|
358 |
+
|
359 |
+
def invert(image):
|
360 |
+
s = 1.0 - image
|
361 |
+
return s
|
362 |
+
|
363 |
+
def tensor2pil(image: torch.Tensor, mode=None):
|
364 |
+
if image.ndim == 2: # Grayscale image
|
365 |
+
image = image.unsqueeze(0) # Add channel dimension
|
366 |
+
|
367 |
+
if image.ndim != 3 or image.shape[1:] == (0, 0):
|
368 |
+
raise ValueError(f"Invalid tensor dimensions: {image.shape}")
|
369 |
+
|
370 |
+
if image.shape[0] == 1: # Single channel, replicate to 3 channels
|
371 |
+
image = image.repeat(3, 1, 1)
|
372 |
+
elif image.shape[0] != 3:
|
373 |
+
raise ValueError("Unexpected number of channels in the image tensor")
|
374 |
+
|
375 |
+
return numpy2pil(image.cpu().numpy().transpose(1, 2, 0), mode=mode)
|
376 |
+
|
377 |
+
def extract_high_frequency(image: torch.Tensor, blur_radius: float = 5.0) -> torch.Tensor:
|
378 |
+
"""Extract high-frequency details by subtracting the blurred image from the original."""
|
379 |
+
if image.ndim == 4:
|
380 |
+
image = image.squeeze(0)
|
381 |
+
|
382 |
+
blurred = image_gaussian_blur(image, blur_radius)
|
383 |
+
|
384 |
+
if blurred.ndim == 4:
|
385 |
+
blurred = blurred.squeeze(0)
|
386 |
+
elif blurred.ndim == 3 and blurred.shape[0] != 3:
|
387 |
+
blurred = blurred.permute(2, 0, 1)
|
388 |
+
|
389 |
+
high_freq = image - blurred
|
390 |
+
return high_freq
|
391 |
+
|
392 |
+
def image_blend_mask(image_a, image_b, mask, blend_percentage):
|
393 |
+
|
394 |
+
# Convert images to PIL
|
395 |
+
img_a = tensor2pil(image_a)
|
396 |
+
img_b = tensor2pil(image_b)
|
397 |
+
mask = ImageOps.invert(tensor2pil(mask).convert('L'))
|
398 |
+
|
399 |
+
# Mask image
|
400 |
+
masked_img = Image.composite(img_a, img_b, mask.resize(img_a.size))
|
401 |
+
|
402 |
+
# Blend image
|
403 |
+
blend_mask = Image.new(mode="L", size=img_a.size,
|
404 |
+
color=(round(blend_percentage * 255)))
|
405 |
+
blend_mask = ImageOps.invert(blend_mask)
|
406 |
+
img_result = Image.composite(img_a, masked_img, blend_mask)
|
407 |
+
|
408 |
+
del img_a, img_b, blend_mask, mask
|
409 |
+
|
410 |
+
return (pil2tensor(img_result), )
|
411 |
|
412 |
def encode_image(image):
|
413 |
buffer = BytesIO()
|
|
|
437 |
|
438 |
return ic_light_img
|
439 |
|
440 |
+
def blend_details(input_image, relit_image, masked_image):
|
441 |
+
with torch.inference_mode():
|
442 |
+
# Load and resize images
|
443 |
+
# input_image = load_image(input_image_path)
|
444 |
+
# relit_image = load_image(relit_image_path)
|
445 |
+
# masked_image = load_image(masked_image_path)
|
446 |
+
|
447 |
+
# Resize input image
|
448 |
+
input_image = torch.nn.functional.interpolate(
|
449 |
+
input_image.unsqueeze(0),
|
450 |
+
size=(1024, 1024),
|
451 |
+
mode="bicubic",
|
452 |
+
align_corners=False
|
453 |
+
).squeeze(0)
|
454 |
+
|
455 |
+
# Resize relit image
|
456 |
+
relit_image = torch.nn.functional.interpolate(
|
457 |
+
relit_image.unsqueeze(0),
|
458 |
+
size=(1024, 1024),
|
459 |
+
mode="bicubic",
|
460 |
+
align_corners=False
|
461 |
+
).squeeze(0)
|
462 |
+
|
463 |
+
# Resize masked image
|
464 |
+
masked_image = torch.nn.functional.interpolate(
|
465 |
+
masked_image.unsqueeze(0),
|
466 |
+
size=(1024, 1024),
|
467 |
+
mode="bicubic",
|
468 |
+
align_corners=False
|
469 |
+
).squeeze(0)
|
470 |
+
|
471 |
+
# Split images and get RGB channels
|
472 |
+
input_image_rgb = split_image_with_alpha(input_image)[0].squeeze(0)
|
473 |
+
relit_image_rgb = split_image_with_alpha(relit_image)[0].squeeze(0)
|
474 |
+
|
475 |
+
# Use masked image RGB channels as segmentation mask (average of RGB channels)
|
476 |
+
segmentation_mask = masked_image[:3].mean(dim=0) # Average RGB channels to get grayscale mask
|
477 |
+
|
478 |
+
print(f"segmentation_mask shape: {segmentation_mask.shape}")
|
479 |
+
|
480 |
+
# Extract high-frequency details from input image
|
481 |
+
high_freq_details = extract_high_frequency(input_image_rgb, blur_radius=3.0)
|
482 |
+
|
483 |
+
# Print shapes for debugging
|
484 |
+
print(f"high_freq_details shape: {high_freq_details.shape}")
|
485 |
+
print(f"segmentation_mask shape: {segmentation_mask.shape}")
|
486 |
+
print(f"relit_image_rgb shape: {relit_image_rgb.shape}")
|
487 |
+
|
488 |
+
# Apply high-frequency details only in masked areas
|
489 |
+
detail_strength = 0.5
|
490 |
+
segmentation_mask = segmentation_mask.unsqueeze(0).repeat(3, 1, 1) # Expand mask to match RGB channels
|
491 |
+
masked_details = high_freq_details * segmentation_mask
|
492 |
+
# final_image = relit_image_rgb + (masked_details * detail_strength)
|
493 |
+
# final_image = image_blend_mask(relit_image_rgb, masked_details, mask, blend_percentage)
|
494 |
+
final_image = relit_image_rgb + masked_details
|
495 |
+
print('final_image shape:', final_image.shape)
|
496 |
+
|
497 |
+
# Normalize to [0, 1] range
|
498 |
+
final_image = torch.clamp(final_image, 0, 1)
|
499 |
+
|
500 |
+
# Save intermediate results for debugging
|
501 |
+
tensor2pil(segmentation_mask).save("output/segmentation_mask.png")
|
502 |
+
tensor2pil(high_freq_details).save("output/high_freq_details.png")
|
503 |
+
tensor2pil(masked_details).save("output/masked_details.png")
|
504 |
+
|
505 |
+
# Save final result
|
506 |
+
final_image_pil = tensor2pil(final_image)
|
507 |
+
# final_image_pil.save("output/output_image.png")
|
508 |
+
return [final_image_pil]
|
509 |
+
|
510 |
+
def generate_image(input_img, ai_gen_image, prompt):
|
511 |
+
|
512 |
+
# ai_gen_image = generate_ai_bg(input_img, prompt)
|
513 |
+
|
514 |
+
mask_input_image = run_grounded_sam(input_img)
|
515 |
+
|
516 |
+
final_image = blend_details(input_img, ai_gen_image, mask_input_image)
|
517 |
|
518 |
+
return [final_image]
|
519 |
|
520 |
def create_ui():
|
521 |
"""Create Gradio UI for CarViz demo"""
|
|
|
527 |
with gr.Row():
|
528 |
with gr.Column():
|
529 |
input_image = gr.Image(type="pil", label="image")
|
530 |
+
ai_image = gr.Image(type="pil", label="image")
|
531 |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
|
532 |
run_button = gr.Button(value='Run')
|
533 |
|
|
|
541 |
fn=generate_image,
|
542 |
inputs=[
|
543 |
input_image,
|
544 |
+
ai_image,
|
545 |
prompt
|
546 |
],
|
547 |
outputs=gallery
|