rootglitch commited on
Commit
dbafd34
·
1 Parent(s): 2649278
Files changed (1) hide show
  1. app.py +184 -285
app.py CHANGED
@@ -5,6 +5,7 @@ import warnings
5
  import random
6
  import time
7
  import logging
 
8
  import base64
9
  import numpy as np
10
  import math
@@ -53,6 +54,10 @@ CONFIG_FILE = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
53
  GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth"
54
  SAM_CHECKPOINT = 'sam_hq_vit_l.pth'
55
  OUTPUT_DIR = "outputs"
 
 
 
 
56
 
57
  # Global variables for model caching
58
  _models = {
@@ -106,6 +111,7 @@ class ModelManager:
106
  _models['sam_predictor'] = SamPredictor(sam)
107
 
108
  logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds")
 
109
 
110
  except Exception as e:
111
  logger.error(f"Error loading {model_name} model: {e}")
@@ -193,7 +199,9 @@ def get_grounding_output(
193
 
194
  def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None:
195
  """Draw mask on image"""
 
196
  color = (255, 255, 255, 255)
 
197
  nonzero_coords = np.transpose(np.nonzero(mask))
198
  for coord in nonzero_coords:
199
  draw.point(coord[::-1], fill=color)
@@ -230,8 +238,8 @@ def run_grounded_sam(input_image):
230
  # Process input image
231
  if isinstance(input_image, dict):
232
  # Input from gradio sketch component
233
- scribble = np.array(input_image["mask"]) if "mask" in input_image else None
234
- image_pil = input_image["image"].convert("RGB") if "image" in input_image else None
235
  else:
236
  # Direct image input
237
  image_pil = input_image.convert("RGB") if input_image else None
@@ -239,7 +247,7 @@ def run_grounded_sam(input_image):
239
 
240
  if image_pil is None:
241
  logger.error("No input image provided")
242
- return Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))
243
 
244
  # Transform image for GroundingDINO
245
  transformed_image = transform_image(image_pil)
@@ -254,32 +262,20 @@ def run_grounded_sam(input_image):
254
  transformed_image, text_prompt, box_threshold, text_threshold
255
  )
256
 
257
- # Fix: Handle case when no boxes are detected
258
- if len(boxes_filt) == 0 or boxes_filt.nelement() == 0:
259
- logger.warning("No boxes detected")
260
- # Create a simple fallback mask - a circle in the center
261
- mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
262
- mask_draw = ImageDraw.Draw(mask_image)
263
- center_x, center_y = W // 2, H // 2
264
- radius = min(W, H) // 4
265
- mask_draw.ellipse((center_x - radius, center_y - radius,
266
- center_x + radius, center_y + radius),
267
- fill=(255, 255, 255, 255))
268
- return mask_image
269
-
270
- # Scale boxes to image dimensions
271
- for i in range(boxes_filt.size(0)):
272
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
273
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
274
- boxes_filt[i][2:] += boxes_filt[i][:2]
275
-
276
- # Apply non-maximum suppression if we have multiple boxes
277
- if boxes_filt.size(0) > 1:
278
- logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
279
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
280
- boxes_filt = boxes_filt[nms_idx]
281
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
282
- logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
283
 
284
  # Load SAM model
285
  ModelManager.load_model('sam')
@@ -289,18 +285,11 @@ def run_grounded_sam(input_image):
289
  image = np.array(image_pil)
290
  sam_predictor.set_image(image)
291
 
292
- # Run SAM
 
293
  if boxes_filt.size(0) == 0:
294
- logger.warning("No boxes detected after NMS")
295
- # Create a simple fallback mask
296
- mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
297
- mask_draw = ImageDraw.Draw(mask_image)
298
- center_x, center_y = W // 2, H // 2
299
- radius = min(W, H) // 4
300
- mask_draw.ellipse((center_x - radius, center_y - radius,
301
- center_x + radius, center_y + radius),
302
- fill=(255, 255, 255, 255))
303
- return mask_image
304
 
305
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
306
 
@@ -330,47 +319,30 @@ def run_grounded_sam(input_image):
330
 
331
  except Exception as e:
332
  logger.error(f"Error in run_grounded_sam: {e}")
333
- # Return transparent image on error
334
  if isinstance(input_image, dict) and "image" in input_image:
335
- return Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))
336
  elif isinstance(input_image, Image.Image):
337
- return Image.new('RGBA', input_image.size, color=(0, 0, 0, 0))
338
  else:
339
- return Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))
340
-
341
 
342
  def split_image_with_alpha(image):
343
- """Ensure image is RGB and return a copy"""
344
- if image.mode == 'RGBA':
345
- # Create a white background
346
- background = Image.new('RGB', image.size, (255, 255, 255))
347
- # Composite the image with alpha onto the background
348
- return Image.alpha_composite(background.convert('RGBA'), image).convert('RGB')
349
- else:
350
- return image.convert("RGB")
351
-
352
 
353
  def gaussian_blur(image, radius=10):
354
  """Apply Gaussian blur to image."""
355
- # Ensure image is in RGB mode
356
- image = image.convert("RGB")
357
- blurred = image.filter(ImageFilter.GaussianBlur(radius=radius))
358
  return blurred
359
 
360
-
361
  def invert_image(image):
362
- """Invert image colors"""
363
- # Ensure image is in RGB mode for inversion
364
- image = image.convert("RGB")
365
  img_inverted = ImageOps.invert(image)
366
  return img_inverted
367
 
368
-
369
  def expand_mask(mask, expand, tapered_corners):
370
- """Expand or contract a mask with proper error handling"""
371
- # Ensure mask is in grayscale mode 'L'
372
- if mask.mode != 'L':
373
- mask = mask.convert("L")
374
 
375
  # Convert to NumPy array
376
  mask_np = np.array(mask)
@@ -392,45 +364,31 @@ def expand_mask(mask, expand, tapered_corners):
392
  # Convert back to PIL image
393
  return Image.fromarray(mask_np, mode="L")
394
 
395
-
396
  def image_blend_by_mask(image_a, image_b, mask, blend_percentage):
397
- """Blend two images using a mask with proper error handling"""
398
- # Ensure both images are in RGB mode
399
- image_a = image_a.convert('RGB')
400
- image_b = image_b.convert('RGB')
401
-
402
- # Ensure mask is in grayscale mode 'L'
403
- if mask.mode != 'L':
404
- mask = mask.convert('L')
405
-
406
- # Invert mask for proper blending
407
- mask = ImageOps.invert(mask)
408
 
409
  # Mask image
410
  masked_img = Image.composite(image_a, image_b, mask)
411
 
412
  # Blend image
413
  blend_mask = Image.new(mode="L", size=image_a.size,
414
- color=(round(blend_percentage * 255)))
415
  blend_mask = ImageOps.invert(blend_mask)
416
  img_result = Image.composite(image_a, masked_img, blend_mask)
417
 
418
- return img_result
419
 
 
420
 
421
  def blend_images(image_a, image_b, blend_percentage):
422
- """Blend two images with proper format handling"""
423
- # Ensure both images are in RGBA mode
424
  img_a = image_a.convert("RGBA")
425
  img_b = image_b.convert("RGBA")
426
-
427
- # Fix: Check if sizes match and resize if needed
428
- if img_a.size != img_b.size:
429
- logger.warning(f"Image sizes don't match: {img_a.size} vs {img_b.size}. Resizing second image.")
430
- img_b = img_b.resize(img_a.size, Image.LANCZOS)
431
 
432
  # Blend img_b over img_a using alpha_composite (normal blend mode)
433
  out_image = Image.alpha_composite(img_a, img_b)
 
434
  out_image = out_image.convert("RGB")
435
 
436
  # Create blend mask
@@ -438,9 +396,13 @@ def blend_images(image_a, image_b, blend_percentage):
438
  blend_mask = ImageOps.invert(blend_mask) # Invert the mask
439
 
440
  # Apply composite blend
441
- result = Image.composite(image_a.convert("RGB"), out_image, blend_mask)
442
  return result
443
 
 
 
 
 
444
 
445
  class AdjustLevels:
446
  def __init__(self, min_level, mid_level, max_level):
@@ -449,238 +411,175 @@ class AdjustLevels:
449
  self.max_level = max_level
450
 
451
  def adjust(self, im):
452
- """Adjust image levels with proper error handling"""
453
- # Ensure image is in RGB mode
454
- im = im.convert("RGB")
455
-
456
- # Convert to numpy array
457
  im_arr = np.array(im).astype(np.float32)
458
-
459
- # Apply levels adjustment
460
  im_arr[im_arr < self.min_level] = self.min_level
461
- im_arr = (im_arr - self.min_level) * (255 / (self.max_level - self.min_level))
 
462
  im_arr = np.clip(im_arr, 0, 255)
463
 
464
- # Apply mid-level adjustment (gamma correction)
465
- # Fix: Add error handling for potential division by zero or log errors
466
- try:
467
- gamma_dividend = (self.mid_level - self.min_level)
468
- gamma_divisor = (self.max_level - self.min_level)
469
-
470
- # Avoid division by zero
471
- if gamma_divisor == 0:
472
- gamma = 1.0
473
- elif gamma_dividend <= 0:
474
- gamma = 1.0
475
- else:
476
- gamma = math.log(0.5) / math.log(gamma_dividend / gamma_divisor)
477
-
478
- # Ensure gamma is reasonable
479
- gamma = max(0.1, min(5.0, gamma))
480
-
481
- im_arr = np.power(im_arr / 255, gamma) * 255
482
- except Exception as e:
483
- logger.error(f"Error in gamma calculation: {e}")
484
- # Fall back to no gamma adjustment
485
- pass
486
 
487
  im_arr = im_arr.astype(np.uint8)
488
- im = Image.fromarray(im_arr)
489
- return im
490
-
491
 
492
- def apply_image_levels(image, black_level, mid_level, white_level):
493
- """Apply levels adjustment to an image"""
494
- levels = AdjustLevels(black_level, mid_level, white_level)
495
- adjusted_image = levels.adjust(image)
496
- return adjusted_image
497
 
 
498
 
499
  def resize_image(image, scaling_factor=1):
500
- """Resize image with error handling"""
501
- if scaling_factor <= 0:
502
- logger.warning(f"Invalid scaling factor: {scaling_factor}, using 1.0 instead")
503
- scaling_factor = 1.0
504
-
505
- try:
506
- new_width = int(image.width * scaling_factor)
507
- new_height = int(image.height * scaling_factor)
508
-
509
- # Ensure minimum size
510
- new_width = max(1, new_width)
511
- new_height = max(1, new_height)
512
-
513
- resized = image.resize((new_width, new_height), Image.LANCZOS)
514
- return resized
515
- except Exception as e:
516
- logger.error(f"Error resizing image: {e}")
517
- return image
518
-
519
 
520
  def resize_to_square(image, size=1024):
521
- """Resize image to a square canvas while maintaining aspect ratio"""
522
- try:
523
- # Convert to RGBA if needed
524
- img = image.convert("RGBA") if isinstance(image, Image.Image) else Image.open(image).convert("RGBA")
525
 
526
- # Resize while maintaining aspect ratio
527
- img.thumbnail((size, size), Image.LANCZOS)
 
 
 
 
 
 
528
 
529
- # Create a transparent square canvas
530
- square_img = Image.new("RGBA", (size, size), (0, 0, 0, 0))
531
 
532
- # Calculate the position to paste the resized image (centered)
533
- x_offset = (size - img.width) // 2
534
- y_offset = (size - img.height) // 2
535
 
536
- # Extract the alpha channel as a mask
537
- mask = img.split()[3] if img.mode == "RGBA" else None
538
 
539
- # Paste the resized image onto the square canvas with the correct transparency mask
540
- square_img.paste(img, (x_offset, y_offset), mask)
541
 
542
- return square_img
543
- except Exception as e:
544
- logger.error(f"Error creating square image: {e}")
545
- # Return the original image in case of an error
546
- return image
547
 
548
 
549
  def encode_image(image):
550
- """Encode image to base64 for API requests"""
551
  buffer = BytesIO()
552
- # Ensure image is in proper format
553
- if image.mode not in ["RGB", "RGBA"]:
554
- image = image.convert("RGBA")
555
  image.save(buffer, format="PNG")
556
  encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
557
  return f"data:image/png;base64,{encoded_image}"
558
 
559
-
560
  def generate_ai_bg(input_img, prompt):
561
- """Generate AI background using external service"""
562
- try:
563
- # Make sure the prompt is not empty
564
- if not prompt or prompt.strip() == "":
565
- prompt = "realistic automotive photography, professional lighting"
566
- logger.info("Using default prompt for AI background generation")
567
-
568
- # Use gradio_client for the API call
569
- from gradio_client import Client, handle_file
570
-
571
- try:
572
- client = Client("lllyasviel/iclight-v2-vary")
573
- result = client.predict(
574
- input_fg=handle_file(input_img),
575
- bg_source="None",
576
- prompt=prompt,
577
- image_width=1024,
578
- image_height=1024,
579
- num_samples=1,
580
- seed=12345,
581
- steps=25,
582
- n_prompt="lowres, bad anatomy, bad hands, cropped, worst quality",
583
- cfg=2,
584
- gs=5,
585
- enable_hr_fix=True,
586
- hr_downscale=0.5,
587
- lowres_denoise=0.8,
588
- highres_denoise=0.99,
589
- api_name="/process"
590
- )
591
-
592
- logger.info(f"AI background generation result: {result}")
593
- relight_img_path = result[1]
594
-
595
- # Load the generated image
596
- relight_img = Image.open(relight_img_path).convert("RGBA")
597
-
598
- # Ensure sizes match
599
- if relight_img.size != input_img.size:
600
- logger.info(f"Resizing generated image from {relight_img.size} to {input_img.size}")
601
- relight_img = relight_img.resize(input_img.size, Image.LANCZOS)
602
-
603
- return relight_img
604
-
605
- except Exception as e:
606
- logger.error(f"Error using gradio_client API: {e}")
607
- # Fall back to a simpler method - just use the input image with a color filter
608
- relight_img = input_img.copy()
609
- color_overlay = Image.new("RGBA", input_img.size, (100, 150, 200, 128))
610
- relight_img = Image.alpha_composite(relight_img.convert("RGBA"), color_overlay)
611
- logger.info("Using fallback method for image generation")
612
- return relight_img
613
-
614
- except Exception as e:
615
- logger.error(f"Error in AI background generation: {e}")
616
- # Return the original image in case of an error
617
- return input_img.copy()
618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
 
620
  def blend_details(input_image, relit_image, masked_image, scaling_factor=1):
621
- """Blend original and relit images using mask with detailed processing"""
622
- try:
623
- # Ensure all inputs are proper images
624
- if input_image is None or relit_image is None or masked_image is None:
625
- logger.error("Missing input for blend_details")
626
- return input_image if input_image is not None else Image.new("RGB", (800, 600), (0, 0, 0))
627
-
628
- # Ensure all images have the same size
629
- if input_image.size != relit_image.size:
630
- logger.warning(f"Relit image size ({relit_image.size}) doesn't match input image size ({input_image.size})")
631
- relit_image = relit_image.resize(input_image.size, Image.LANCZOS)
632
-
633
- if input_image.size != masked_image.size:
634
- logger.warning(f"Mask image size ({masked_image.size}) doesn't match input image size ({input_image.size})")
635
- masked_image = masked_image.resize(input_image.size, Image.LANCZOS)
636
-
637
- # Process masked image
638
- masked_image_rgb = split_image_with_alpha(masked_image)
639
- masked_image_blurred = gaussian_blur(masked_image_rgb, radius=10)
640
-
641
- # Fix: Add error handling for mask expansion
642
- try:
643
- grow_mask = expand_mask(masked_image_blurred, -15, True)
644
- except Exception as e:
645
- logger.error(f"Error expanding mask: {e}")
646
- grow_mask = masked_image_blurred.convert("L")
647
 
648
- # Process input image
649
- input_image_rgb = split_image_with_alpha(input_image)
650
- input_blurred = gaussian_blur(input_image_rgb, radius=10)
651
- input_inverted = invert_image(input_image_rgb)
652
-
653
- # Add blurred and inverted images
654
- input_blend_1 = blend_images(input_inverted, input_blurred, blend_percentage=0.5)
655
- input_blend_1_inverted = invert_image(input_blend_1)
656
- input_blend_2 = blend_images(input_blurred, input_blend_1_inverted, blend_percentage=1.0)
657
-
658
- # Process relit image
659
- relit_image_rgb = split_image_with_alpha(relit_image)
660
- relit_blurred = gaussian_blur(relit_image_rgb, radius=10)
661
- relit_inverted = invert_image(relit_image_rgb)
662
-
663
- # Add blurred and inverted relit images
664
- relit_blend_1 = blend_images(relit_inverted, relit_blurred, blend_percentage=0.5)
665
- relit_blend_1_inverted = invert_image(relit_blend_1)
666
- relit_blend_2 = blend_images(relit_blurred, relit_blend_1_inverted, blend_percentage=1.0)
667
 
668
- # Blend high frequency components
669
- high_freq_comp = image_blend_by_mask(relit_blend_2, input_blend_2, grow_mask, blend_percentage=1.0)
670
-
671
- # Final compositing
672
- comped_image = blend_images(relit_blurred, high_freq_comp, blend_percentage=0.65)
673
-
674
- # Apply levels adjustment
675
- final_image = apply_image_levels(comped_image, black_level=83, mid_level=128, white_level=172)
 
676
 
677
- return final_image
 
 
 
 
 
 
678
 
679
- except Exception as e:
680
- logger.error(f"Error in blend_details: {e}")
681
- # Return the relit image on error as a fallback
682
- return relit_image if relit_image is not None else input_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
 
685
  @spaces.GPU
686
  def generate_image(input_img, prompt):
 
5
  import random
6
  import time
7
  import logging
8
+ import fal_client
9
  import base64
10
  import numpy as np
11
  import math
 
54
  GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth"
55
  SAM_CHECKPOINT = 'sam_hq_vit_l.pth'
56
  OUTPUT_DIR = "outputs"
57
+ # FAL_KEY = os.getenv("FAL_KEY")
58
+ # UPLOAD_DIR = "./tmp/images"
59
+
60
+ # os.makedirs(UPLOAD_DIR, exist_ok=True)
61
 
62
  # Global variables for model caching
63
  _models = {
 
111
  _models['sam_predictor'] = SamPredictor(sam)
112
 
113
  logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds")
114
+
115
 
116
  except Exception as e:
117
  logger.error(f"Error loading {model_name} model: {e}")
 
199
 
200
  def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None:
201
  """Draw mask on image"""
202
+
203
  color = (255, 255, 255, 255)
204
+
205
  nonzero_coords = np.transpose(np.nonzero(mask))
206
  for coord in nonzero_coords:
207
  draw.point(coord[::-1], fill=color)
 
238
  # Process input image
239
  if isinstance(input_image, dict):
240
  # Input from gradio sketch component
241
+ scribble = np.array(input_image["mask"])
242
+ image_pil = input_image["image"].convert("RGB")
243
  else:
244
  # Direct image input
245
  image_pil = input_image.convert("RGB") if input_image else None
 
247
 
248
  if image_pil is None:
249
  logger.error("No input image provided")
250
+ return [Image.new('RGB', (400, 300), color='gray')]
251
 
252
  # Transform image for GroundingDINO
253
  transformed_image = transform_image(image_pil)
 
262
  transformed_image, text_prompt, box_threshold, text_threshold
263
  )
264
 
265
+ if boxes_filt is not None:
266
+ # Scale boxes to image dimensions
267
+ for i in range(boxes_filt.size(0)):
268
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
269
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
270
+ boxes_filt[i][2:] += boxes_filt[i][:2]
271
+
272
+ # Apply non-maximum suppression if we have multiple boxes
273
+ if boxes_filt.size(0) > 1:
274
+ logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
275
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
276
+ boxes_filt = boxes_filt[nms_idx]
277
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
278
+ logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  # Load SAM model
281
  ModelManager.load_model('sam')
 
285
  image = np.array(image_pil)
286
  sam_predictor.set_image(image)
287
 
288
+ # Run SAM
289
+ # Use boxes for these task types
290
  if boxes_filt.size(0) == 0:
291
+ logger.warning("No boxes detected")
292
+ return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
 
 
 
 
 
 
 
 
293
 
294
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
295
 
 
319
 
320
  except Exception as e:
321
  logger.error(f"Error in run_grounded_sam: {e}")
322
+ # Return original image on error
323
  if isinstance(input_image, dict) and "image" in input_image:
324
+ return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
325
  elif isinstance(input_image, Image.Image):
326
+ return [input_image, Image.new('RGBA', input_image.size, color=(0, 0, 0, 0))]
327
  else:
328
+ return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
 
329
 
330
  def split_image_with_alpha(image):
331
+ image = image.convert("RGB")
332
+ return image
 
 
 
 
 
 
 
333
 
334
  def gaussian_blur(image, radius=10):
335
  """Apply Gaussian blur to image."""
336
+ blurred = image.filter(ImageFilter.GaussianBlur(radius=10))
 
 
337
  return blurred
338
 
 
339
  def invert_image(image):
 
 
 
340
  img_inverted = ImageOps.invert(image)
341
  return img_inverted
342
 
 
343
  def expand_mask(mask, expand, tapered_corners):
344
+ # Ensure mask is in grayscale (mode 'L')
345
+ mask = mask.convert("L")
 
 
346
 
347
  # Convert to NumPy array
348
  mask_np = np.array(mask)
 
364
  # Convert back to PIL image
365
  return Image.fromarray(mask_np, mode="L")
366
 
 
367
  def image_blend_by_mask(image_a, image_b, mask, blend_percentage):
368
+
369
+ mask = ImageOps.invert(mask.convert('L'))
 
 
 
 
 
 
 
 
 
370
 
371
  # Mask image
372
  masked_img = Image.composite(image_a, image_b, mask)
373
 
374
  # Blend image
375
  blend_mask = Image.new(mode="L", size=image_a.size,
376
+ color=(round(blend_percentage * 255)))
377
  blend_mask = ImageOps.invert(blend_mask)
378
  img_result = Image.composite(image_a, masked_img, blend_mask)
379
 
380
+ del image_a, image_b, blend_mask, mask
381
 
382
+ return img_result
383
 
384
  def blend_images(image_a, image_b, blend_percentage):
385
+ """Blend img_b over image_a using the normal mode with a blend percentage."""
 
386
  img_a = image_a.convert("RGBA")
387
  img_b = image_b.convert("RGBA")
 
 
 
 
 
388
 
389
  # Blend img_b over img_a using alpha_composite (normal blend mode)
390
  out_image = Image.alpha_composite(img_a, img_b)
391
+
392
  out_image = out_image.convert("RGB")
393
 
394
  # Create blend mask
 
396
  blend_mask = ImageOps.invert(blend_mask) # Invert the mask
397
 
398
  # Apply composite blend
399
+ result = Image.composite(image_a, out_image, blend_mask)
400
  return result
401
 
402
+ def apply_image_levels(image, black_level, mid_level, white_level):
403
+ levels = AdjustLevels(black_level, mid_level, white_level)
404
+ adjusted_image = levels.adjust(image)
405
+ return adjusted_image
406
 
407
  class AdjustLevels:
408
  def __init__(self, min_level, mid_level, max_level):
 
411
  self.max_level = max_level
412
 
413
  def adjust(self, im):
414
+
 
 
 
 
415
  im_arr = np.array(im).astype(np.float32)
 
 
416
  im_arr[im_arr < self.min_level] = self.min_level
417
+ im_arr = (im_arr - self.min_level) * \
418
+ (255 / (self.max_level - self.min_level))
419
  im_arr = np.clip(im_arr, 0, 255)
420
 
421
+ # mid-level adjustment
422
+ gamma = math.log(0.5) / math.log((self.mid_level - self.min_level) / (self.max_level - self.min_level))
423
+ im_arr = np.power(im_arr / 255, gamma) * 255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  im_arr = im_arr.astype(np.uint8)
 
 
 
426
 
427
+ im = Image.fromarray(im_arr)
 
 
 
 
428
 
429
+ return im
430
 
431
  def resize_image(image, scaling_factor=1):
432
+ image = image.resize((int(image.width * scaling_factor),
433
+ int(image.height * scaling_factor)))
434
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
  def resize_to_square(image, size=1024):
 
 
 
 
437
 
438
+ # Load image if a file path is provided
439
+ if isinstance(image, str):
440
+ img = Image.open(image).convert("RGBA")
441
+ else:
442
+ img = image.convert("RGBA") # If already an Image object
443
+
444
+ # Resize while maintaining aspect ratio
445
+ img.thumbnail((size, size), Image.LANCZOS)
446
 
447
+ # Create a transparent square canvas
448
+ square_img = Image.new("RGBA", (size, size), (0, 0, 0, 0))
449
 
450
+ # Calculate the position to paste the resized image (centered)
451
+ x_offset = (size - img.width) // 2
452
+ y_offset = (size - img.height) // 2
453
 
454
+ # Extract the alpha channel as a mask
455
+ mask = img.split()[3] if img.mode == "RGBA" else None
456
 
457
+ # Paste the resized image onto the square canvas with the correct transparency mask
458
+ square_img.paste(img, (x_offset, y_offset), mask)
459
 
460
+ return square_img
 
 
 
 
461
 
462
 
463
  def encode_image(image):
 
464
  buffer = BytesIO()
 
 
 
465
  image.save(buffer, format="PNG")
466
  encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
467
  return f"data:image/png;base64,{encoded_image}"
468
 
 
469
  def generate_ai_bg(input_img, prompt):
470
+ # input_img = resize_image(input_img, 0.01)
471
+ # hf_input_img = encode_image(input_img)
472
+
473
+ # handler = fal_client.submit(
474
+ # "fal-ai/iclight-v2",
475
+ # arguments={
476
+ # "prompt": prompt,
477
+ # "image_url": hf_input_img
478
+ # },
479
+ # webhook_url="https://optional.webhook.url/for/results",
480
+ # )
481
+
482
+ # request_id = handler.request_id
483
+
484
+ # status = fal_client.status("fal-ai/iclight-v2", request_id, with_logs=True)
485
+
486
+ # result = fal_client.result("fal-ai/iclight-v2", request_id)
487
+
488
+ # relight_img_path = result['images'][0]['url']
489
+
490
+ # response = requests.get(relight_img_path, stream=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
+ # relight_img = Image.open(BytesIO(response.content)).convert("RGBA")
493
+
494
+ from gradio_client import Client, handle_file
495
+
496
+ client = Client("lllyasviel/iclight-v2-vary")
497
+ result = client.predict(
498
+ input_fg=handle_file(input_img),
499
+ bg_source="None",
500
+ prompt=prompt,
501
+ image_width=1024,
502
+ image_height=1024,
503
+ num_samples=1,
504
+ seed=12345,
505
+ steps=25,
506
+ n_prompt="lowres, bad anatomy, bad hands, cropped, worst quality",
507
+ cfg=2,
508
+ gs=5,
509
+ enable_hr_fix=True,
510
+ hr_downscale=0.5,
511
+ lowres_denoise=0.8,
512
+ highres_denoise=0.99,
513
+ api_name="/process"
514
+ )
515
+ print(result)
516
+
517
+ relight_img_path = result[1]
518
+
519
+ # response = requests.get(relight_img_path, stream=True)
520
+ relight_img = Image.open(relight_img_path).convert("RGBA")
521
+
522
+ # relight_img = Image.open(BytesIO(response.content)).convert("RGBA")
523
+
524
+ return relight_img
525
 
526
  def blend_details(input_image, relit_image, masked_image, scaling_factor=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
+ # input_image = resize_image(input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
+ # relit_image = resize_image(relit_image)
531
+
532
+ # masked_image = resize_image(masked_image)
533
+
534
+ masked_image_rgb = split_image_with_alpha(masked_image)
535
+ masked_image_blurred = gaussian_blur(masked_image_rgb, radius=10)
536
+ grow_mask = expand_mask(masked_image_blurred, -15, True)
537
+
538
+ # grow_mask.save("output/grow_mask.png")
539
 
540
+ # Split images and get RGB channels
541
+ input_image_rgb = split_image_with_alpha(input_image)
542
+ input_blurred = gaussian_blur(input_image_rgb, radius=10)
543
+ input_inverted = invert_image(input_image_rgb)
544
+
545
+ # input_blurred.save("output/input_blurred.png")
546
+ # input_inverted.save("output/input_inverted.png")
547
 
548
+ # Add blurred and inverted images
549
+ input_blend_1 = blend_images(input_inverted, input_blurred, blend_percentage=0.5)
550
+ input_blend_1_inverted = invert_image(input_blend_1)
551
+ input_blend_2 = blend_images(input_blurred, input_blend_1_inverted, blend_percentage=1.0)
552
+
553
+ # input_blend_2.save("output/input_blend_2.png")
554
+
555
+ # Process relit image
556
+ relit_image_rgb = split_image_with_alpha(relit_image)
557
+ relit_blurred = gaussian_blur(relit_image_rgb, radius=10)
558
+ relit_inverted = invert_image(relit_image_rgb)
559
+
560
+ # relit_blurred.save("output/relit_blurred.png")
561
+ # relit_inverted.save("output/relit_inverted.png")
562
+
563
+ # Add blurred and inverted relit images
564
+ relit_blend_1 = blend_images(relit_inverted, relit_blurred, blend_percentage=0.5)
565
+ relit_blend_1_inverted = invert_image(relit_blend_1)
566
+ relit_blend_2 = blend_images(relit_blurred, relit_blend_1_inverted, blend_percentage=1.0)
567
+
568
+ # relit_blend_2.save("output/relit_blend_2.png")
569
 
570
+ high_freq_comp = image_blend_by_mask(relit_blend_2, input_blend_2, grow_mask, blend_percentage=1.0)
571
+
572
+ # high_freq_comp.save("output/high_freq_comp.png")
573
+
574
+ comped_image = blend_images(relit_blurred, high_freq_comp, blend_percentage=0.65)
575
+
576
+ # comped_image.save("output/comped_image.png")
577
+
578
+ final_image = apply_image_levels(comped_image, black_level=83, mid_level=128, white_level=172)
579
+
580
+ # final_image.save("output/final_image.png")
581
+
582
+ return final_image
583
 
584
  @spaces.GPU
585
  def generate_image(input_img, prompt):