rootglitch commited on
Commit
2649278
·
1 Parent(s): 5045825

Major changes

Browse files
Files changed (1) hide show
  1. app.py +285 -184
app.py CHANGED
@@ -5,7 +5,6 @@ import warnings
5
  import random
6
  import time
7
  import logging
8
- import fal_client
9
  import base64
10
  import numpy as np
11
  import math
@@ -54,10 +53,6 @@ CONFIG_FILE = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
54
  GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth"
55
  SAM_CHECKPOINT = 'sam_hq_vit_l.pth'
56
  OUTPUT_DIR = "outputs"
57
- # FAL_KEY = os.getenv("FAL_KEY")
58
- # UPLOAD_DIR = "./tmp/images"
59
-
60
- # os.makedirs(UPLOAD_DIR, exist_ok=True)
61
 
62
  # Global variables for model caching
63
  _models = {
@@ -111,7 +106,6 @@ class ModelManager:
111
  _models['sam_predictor'] = SamPredictor(sam)
112
 
113
  logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds")
114
-
115
 
116
  except Exception as e:
117
  logger.error(f"Error loading {model_name} model: {e}")
@@ -199,9 +193,7 @@ def get_grounding_output(
199
 
200
  def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None:
201
  """Draw mask on image"""
202
-
203
  color = (255, 255, 255, 255)
204
-
205
  nonzero_coords = np.transpose(np.nonzero(mask))
206
  for coord in nonzero_coords:
207
  draw.point(coord[::-1], fill=color)
@@ -238,8 +230,8 @@ def run_grounded_sam(input_image):
238
  # Process input image
239
  if isinstance(input_image, dict):
240
  # Input from gradio sketch component
241
- scribble = np.array(input_image["mask"])
242
- image_pil = input_image["image"].convert("RGB")
243
  else:
244
  # Direct image input
245
  image_pil = input_image.convert("RGB") if input_image else None
@@ -247,7 +239,7 @@ def run_grounded_sam(input_image):
247
 
248
  if image_pil is None:
249
  logger.error("No input image provided")
250
- return [Image.new('RGB', (400, 300), color='gray')]
251
 
252
  # Transform image for GroundingDINO
253
  transformed_image = transform_image(image_pil)
@@ -262,20 +254,32 @@ def run_grounded_sam(input_image):
262
  transformed_image, text_prompt, box_threshold, text_threshold
263
  )
264
 
265
- if boxes_filt is not None:
266
- # Scale boxes to image dimensions
267
- for i in range(boxes_filt.size(0)):
268
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
269
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
270
- boxes_filt[i][2:] += boxes_filt[i][:2]
271
-
272
- # Apply non-maximum suppression if we have multiple boxes
273
- if boxes_filt.size(0) > 1:
274
- logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
275
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
276
- boxes_filt = boxes_filt[nms_idx]
277
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
278
- logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  # Load SAM model
281
  ModelManager.load_model('sam')
@@ -285,11 +289,18 @@ def run_grounded_sam(input_image):
285
  image = np.array(image_pil)
286
  sam_predictor.set_image(image)
287
 
288
- # Run SAM
289
- # Use boxes for these task types
290
  if boxes_filt.size(0) == 0:
291
- logger.warning("No boxes detected")
292
- return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
 
 
 
 
 
 
 
 
293
 
294
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
295
 
@@ -319,30 +330,47 @@ def run_grounded_sam(input_image):
319
 
320
  except Exception as e:
321
  logger.error(f"Error in run_grounded_sam: {e}")
322
- # Return original image on error
323
  if isinstance(input_image, dict) and "image" in input_image:
324
- return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
325
  elif isinstance(input_image, Image.Image):
326
- return [input_image, Image.new('RGBA', input_image.size, color=(0, 0, 0, 0))]
327
  else:
328
- return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
 
329
 
330
  def split_image_with_alpha(image):
331
- image = image.convert("RGB")
332
- return image
 
 
 
 
 
 
 
333
 
334
  def gaussian_blur(image, radius=10):
335
  """Apply Gaussian blur to image."""
336
- blurred = image.filter(ImageFilter.GaussianBlur(radius=10))
 
 
337
  return blurred
338
 
 
339
  def invert_image(image):
 
 
 
340
  img_inverted = ImageOps.invert(image)
341
  return img_inverted
342
 
 
343
  def expand_mask(mask, expand, tapered_corners):
344
- # Ensure mask is in grayscale (mode 'L')
345
- mask = mask.convert("L")
 
 
346
 
347
  # Convert to NumPy array
348
  mask_np = np.array(mask)
@@ -364,31 +392,45 @@ def expand_mask(mask, expand, tapered_corners):
364
  # Convert back to PIL image
365
  return Image.fromarray(mask_np, mode="L")
366
 
367
- def image_blend_by_mask(image_a, image_b, mask, blend_percentage):
368
 
369
- mask = ImageOps.invert(mask.convert('L'))
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  # Mask image
372
  masked_img = Image.composite(image_a, image_b, mask)
373
 
374
  # Blend image
375
  blend_mask = Image.new(mode="L", size=image_a.size,
376
- color=(round(blend_percentage * 255)))
377
  blend_mask = ImageOps.invert(blend_mask)
378
  img_result = Image.composite(image_a, masked_img, blend_mask)
379
 
380
- del image_a, image_b, blend_mask, mask
381
-
382
  return img_result
383
 
 
384
  def blend_images(image_a, image_b, blend_percentage):
385
- """Blend img_b over image_a using the normal mode with a blend percentage."""
 
386
  img_a = image_a.convert("RGBA")
387
  img_b = image_b.convert("RGBA")
 
 
 
 
 
388
 
389
  # Blend img_b over img_a using alpha_composite (normal blend mode)
390
  out_image = Image.alpha_composite(img_a, img_b)
391
-
392
  out_image = out_image.convert("RGB")
393
 
394
  # Create blend mask
@@ -396,13 +438,9 @@ def blend_images(image_a, image_b, blend_percentage):
396
  blend_mask = ImageOps.invert(blend_mask) # Invert the mask
397
 
398
  # Apply composite blend
399
- result = Image.composite(image_a, out_image, blend_mask)
400
  return result
401
 
402
- def apply_image_levels(image, black_level, mid_level, white_level):
403
- levels = AdjustLevels(black_level, mid_level, white_level)
404
- adjusted_image = levels.adjust(image)
405
- return adjusted_image
406
 
407
  class AdjustLevels:
408
  def __init__(self, min_level, mid_level, max_level):
@@ -411,175 +449,238 @@ class AdjustLevels:
411
  self.max_level = max_level
412
 
413
  def adjust(self, im):
414
-
 
 
 
 
415
  im_arr = np.array(im).astype(np.float32)
 
 
416
  im_arr[im_arr < self.min_level] = self.min_level
417
- im_arr = (im_arr - self.min_level) * \
418
- (255 / (self.max_level - self.min_level))
419
  im_arr = np.clip(im_arr, 0, 255)
420
 
421
- # mid-level adjustment
422
- gamma = math.log(0.5) / math.log((self.mid_level - self.min_level) / (self.max_level - self.min_level))
423
- im_arr = np.power(im_arr / 255, gamma) * 255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  im_arr = im_arr.astype(np.uint8)
426
-
427
  im = Image.fromarray(im_arr)
428
-
429
  return im
430
 
 
 
 
 
 
 
 
 
431
  def resize_image(image, scaling_factor=1):
432
- image = image.resize((int(image.width * scaling_factor),
433
- int(image.height * scaling_factor)))
434
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
- def resize_to_square(image, size=1024):
437
 
438
- # Load image if a file path is provided
439
- if isinstance(image, str):
440
- img = Image.open(image).convert("RGBA")
441
- else:
442
- img = image.convert("RGBA") # If already an Image object
443
 
444
- # Resize while maintaining aspect ratio
445
- img.thumbnail((size, size), Image.LANCZOS)
446
 
447
- # Create a transparent square canvas
448
- square_img = Image.new("RGBA", (size, size), (0, 0, 0, 0))
449
 
450
- # Calculate the position to paste the resized image (centered)
451
- x_offset = (size - img.width) // 2
452
- y_offset = (size - img.height) // 2
453
 
454
- # Extract the alpha channel as a mask
455
- mask = img.split()[3] if img.mode == "RGBA" else None
456
 
457
- # Paste the resized image onto the square canvas with the correct transparency mask
458
- square_img.paste(img, (x_offset, y_offset), mask)
459
 
460
- return square_img
 
 
 
 
461
 
462
 
463
  def encode_image(image):
 
464
  buffer = BytesIO()
 
 
 
465
  image.save(buffer, format="PNG")
466
  encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
467
  return f"data:image/png;base64,{encoded_image}"
468
 
469
- def generate_ai_bg(input_img, prompt):
470
- # input_img = resize_image(input_img, 0.01)
471
- # hf_input_img = encode_image(input_img)
472
-
473
- # handler = fal_client.submit(
474
- # "fal-ai/iclight-v2",
475
- # arguments={
476
- # "prompt": prompt,
477
- # "image_url": hf_input_img
478
- # },
479
- # webhook_url="https://optional.webhook.url/for/results",
480
- # )
481
-
482
- # request_id = handler.request_id
483
-
484
- # status = fal_client.status("fal-ai/iclight-v2", request_id, with_logs=True)
485
-
486
- # result = fal_client.result("fal-ai/iclight-v2", request_id)
487
-
488
- # relight_img_path = result['images'][0]['url']
489
-
490
- # response = requests.get(relight_img_path, stream=True)
491
-
492
- # relight_img = Image.open(BytesIO(response.content)).convert("RGBA")
493
-
494
- from gradio_client import Client, handle_file
495
-
496
- client = Client("lllyasviel/iclight-v2-vary")
497
- result = client.predict(
498
- input_fg=handle_file(input_img),
499
- bg_source="None",
500
- prompt=prompt,
501
- image_width=1024,
502
- image_height=1024,
503
- num_samples=1,
504
- seed=12345,
505
- steps=25,
506
- n_prompt="lowres, bad anatomy, bad hands, cropped, worst quality",
507
- cfg=2,
508
- gs=5,
509
- enable_hr_fix=True,
510
- hr_downscale=0.5,
511
- lowres_denoise=0.8,
512
- highres_denoise=0.99,
513
- api_name="/process"
514
- )
515
- print(result)
516
-
517
- relight_img_path = result[1]
518
-
519
- # response = requests.get(relight_img_path, stream=True)
520
- relight_img = Image.open(relight_img_path).convert("RGBA")
521
 
522
- # relight_img = Image.open(BytesIO(response.content)).convert("RGBA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
- return relight_img
525
 
526
  def blend_details(input_image, relit_image, masked_image, scaling_factor=1):
527
-
528
- # input_image = resize_image(input_image)
529
-
530
- # relit_image = resize_image(relit_image)
531
-
532
- # masked_image = resize_image(masked_image)
533
-
534
- masked_image_rgb = split_image_with_alpha(masked_image)
535
- masked_image_blurred = gaussian_blur(masked_image_rgb, radius=10)
536
- grow_mask = expand_mask(masked_image_blurred, -15, True)
537
-
538
- # grow_mask.save("output/grow_mask.png")
539
-
540
- # Split images and get RGB channels
541
- input_image_rgb = split_image_with_alpha(input_image)
542
- input_blurred = gaussian_blur(input_image_rgb, radius=10)
543
- input_inverted = invert_image(input_image_rgb)
544
-
545
- # input_blurred.save("output/input_blurred.png")
546
- # input_inverted.save("output/input_inverted.png")
547
 
548
- # Add blurred and inverted images
549
- input_blend_1 = blend_images(input_inverted, input_blurred, blend_percentage=0.5)
550
- input_blend_1_inverted = invert_image(input_blend_1)
551
- input_blend_2 = blend_images(input_blurred, input_blend_1_inverted, blend_percentage=1.0)
552
-
553
- # input_blend_2.save("output/input_blend_2.png")
554
-
555
- # Process relit image
556
- relit_image_rgb = split_image_with_alpha(relit_image)
557
- relit_blurred = gaussian_blur(relit_image_rgb, radius=10)
558
- relit_inverted = invert_image(relit_image_rgb)
559
-
560
- # relit_blurred.save("output/relit_blurred.png")
561
- # relit_inverted.save("output/relit_inverted.png")
562
-
563
- # Add blurred and inverted relit images
564
- relit_blend_1 = blend_images(relit_inverted, relit_blurred, blend_percentage=0.5)
565
- relit_blend_1_inverted = invert_image(relit_blend_1)
566
- relit_blend_2 = blend_images(relit_blurred, relit_blend_1_inverted, blend_percentage=1.0)
567
-
568
- # relit_blend_2.save("output/relit_blend_2.png")
569
-
570
- high_freq_comp = image_blend_by_mask(relit_blend_2, input_blend_2, grow_mask, blend_percentage=1.0)
571
-
572
- # high_freq_comp.save("output/high_freq_comp.png")
573
-
574
- comped_image = blend_images(relit_blurred, high_freq_comp, blend_percentage=0.65)
575
 
576
- # comped_image.save("output/comped_image.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
- final_image = apply_image_levels(comped_image, black_level=83, mid_level=128, white_level=172)
 
 
 
 
 
 
 
579
 
580
- # final_image.save("output/final_image.png")
 
 
 
 
 
581
 
582
- return final_image
583
 
584
  @spaces.GPU
585
  def generate_image(input_img, prompt):
 
5
  import random
6
  import time
7
  import logging
 
8
  import base64
9
  import numpy as np
10
  import math
 
53
  GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth"
54
  SAM_CHECKPOINT = 'sam_hq_vit_l.pth'
55
  OUTPUT_DIR = "outputs"
 
 
 
 
56
 
57
  # Global variables for model caching
58
  _models = {
 
106
  _models['sam_predictor'] = SamPredictor(sam)
107
 
108
  logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds")
 
109
 
110
  except Exception as e:
111
  logger.error(f"Error loading {model_name} model: {e}")
 
193
 
194
  def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None:
195
  """Draw mask on image"""
 
196
  color = (255, 255, 255, 255)
 
197
  nonzero_coords = np.transpose(np.nonzero(mask))
198
  for coord in nonzero_coords:
199
  draw.point(coord[::-1], fill=color)
 
230
  # Process input image
231
  if isinstance(input_image, dict):
232
  # Input from gradio sketch component
233
+ scribble = np.array(input_image["mask"]) if "mask" in input_image else None
234
+ image_pil = input_image["image"].convert("RGB") if "image" in input_image else None
235
  else:
236
  # Direct image input
237
  image_pil = input_image.convert("RGB") if input_image else None
 
239
 
240
  if image_pil is None:
241
  logger.error("No input image provided")
242
+ return Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))
243
 
244
  # Transform image for GroundingDINO
245
  transformed_image = transform_image(image_pil)
 
254
  transformed_image, text_prompt, box_threshold, text_threshold
255
  )
256
 
257
+ # Fix: Handle case when no boxes are detected
258
+ if len(boxes_filt) == 0 or boxes_filt.nelement() == 0:
259
+ logger.warning("No boxes detected")
260
+ # Create a simple fallback mask - a circle in the center
261
+ mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
262
+ mask_draw = ImageDraw.Draw(mask_image)
263
+ center_x, center_y = W // 2, H // 2
264
+ radius = min(W, H) // 4
265
+ mask_draw.ellipse((center_x - radius, center_y - radius,
266
+ center_x + radius, center_y + radius),
267
+ fill=(255, 255, 255, 255))
268
+ return mask_image
269
+
270
+ # Scale boxes to image dimensions
271
+ for i in range(boxes_filt.size(0)):
272
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
273
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
274
+ boxes_filt[i][2:] += boxes_filt[i][:2]
275
+
276
+ # Apply non-maximum suppression if we have multiple boxes
277
+ if boxes_filt.size(0) > 1:
278
+ logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
279
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
280
+ boxes_filt = boxes_filt[nms_idx]
281
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
282
+ logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
283
 
284
  # Load SAM model
285
  ModelManager.load_model('sam')
 
289
  image = np.array(image_pil)
290
  sam_predictor.set_image(image)
291
 
292
+ # Run SAM
 
293
  if boxes_filt.size(0) == 0:
294
+ logger.warning("No boxes detected after NMS")
295
+ # Create a simple fallback mask
296
+ mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
297
+ mask_draw = ImageDraw.Draw(mask_image)
298
+ center_x, center_y = W // 2, H // 2
299
+ radius = min(W, H) // 4
300
+ mask_draw.ellipse((center_x - radius, center_y - radius,
301
+ center_x + radius, center_y + radius),
302
+ fill=(255, 255, 255, 255))
303
+ return mask_image
304
 
305
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
306
 
 
330
 
331
  except Exception as e:
332
  logger.error(f"Error in run_grounded_sam: {e}")
333
+ # Return transparent image on error
334
  if isinstance(input_image, dict) and "image" in input_image:
335
+ return Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))
336
  elif isinstance(input_image, Image.Image):
337
+ return Image.new('RGBA', input_image.size, color=(0, 0, 0, 0))
338
  else:
339
+ return Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))
340
+
341
 
342
  def split_image_with_alpha(image):
343
+ """Ensure image is RGB and return a copy"""
344
+ if image.mode == 'RGBA':
345
+ # Create a white background
346
+ background = Image.new('RGB', image.size, (255, 255, 255))
347
+ # Composite the image with alpha onto the background
348
+ return Image.alpha_composite(background.convert('RGBA'), image).convert('RGB')
349
+ else:
350
+ return image.convert("RGB")
351
+
352
 
353
  def gaussian_blur(image, radius=10):
354
  """Apply Gaussian blur to image."""
355
+ # Ensure image is in RGB mode
356
+ image = image.convert("RGB")
357
+ blurred = image.filter(ImageFilter.GaussianBlur(radius=radius))
358
  return blurred
359
 
360
+
361
  def invert_image(image):
362
+ """Invert image colors"""
363
+ # Ensure image is in RGB mode for inversion
364
+ image = image.convert("RGB")
365
  img_inverted = ImageOps.invert(image)
366
  return img_inverted
367
 
368
+
369
  def expand_mask(mask, expand, tapered_corners):
370
+ """Expand or contract a mask with proper error handling"""
371
+ # Ensure mask is in grayscale mode 'L'
372
+ if mask.mode != 'L':
373
+ mask = mask.convert("L")
374
 
375
  # Convert to NumPy array
376
  mask_np = np.array(mask)
 
392
  # Convert back to PIL image
393
  return Image.fromarray(mask_np, mode="L")
394
 
 
395
 
396
+ def image_blend_by_mask(image_a, image_b, mask, blend_percentage):
397
+ """Blend two images using a mask with proper error handling"""
398
+ # Ensure both images are in RGB mode
399
+ image_a = image_a.convert('RGB')
400
+ image_b = image_b.convert('RGB')
401
+
402
+ # Ensure mask is in grayscale mode 'L'
403
+ if mask.mode != 'L':
404
+ mask = mask.convert('L')
405
+
406
+ # Invert mask for proper blending
407
+ mask = ImageOps.invert(mask)
408
 
409
  # Mask image
410
  masked_img = Image.composite(image_a, image_b, mask)
411
 
412
  # Blend image
413
  blend_mask = Image.new(mode="L", size=image_a.size,
414
+ color=(round(blend_percentage * 255)))
415
  blend_mask = ImageOps.invert(blend_mask)
416
  img_result = Image.composite(image_a, masked_img, blend_mask)
417
 
 
 
418
  return img_result
419
 
420
+
421
  def blend_images(image_a, image_b, blend_percentage):
422
+ """Blend two images with proper format handling"""
423
+ # Ensure both images are in RGBA mode
424
  img_a = image_a.convert("RGBA")
425
  img_b = image_b.convert("RGBA")
426
+
427
+ # Fix: Check if sizes match and resize if needed
428
+ if img_a.size != img_b.size:
429
+ logger.warning(f"Image sizes don't match: {img_a.size} vs {img_b.size}. Resizing second image.")
430
+ img_b = img_b.resize(img_a.size, Image.LANCZOS)
431
 
432
  # Blend img_b over img_a using alpha_composite (normal blend mode)
433
  out_image = Image.alpha_composite(img_a, img_b)
 
434
  out_image = out_image.convert("RGB")
435
 
436
  # Create blend mask
 
438
  blend_mask = ImageOps.invert(blend_mask) # Invert the mask
439
 
440
  # Apply composite blend
441
+ result = Image.composite(image_a.convert("RGB"), out_image, blend_mask)
442
  return result
443
 
 
 
 
 
444
 
445
  class AdjustLevels:
446
  def __init__(self, min_level, mid_level, max_level):
 
449
  self.max_level = max_level
450
 
451
  def adjust(self, im):
452
+ """Adjust image levels with proper error handling"""
453
+ # Ensure image is in RGB mode
454
+ im = im.convert("RGB")
455
+
456
+ # Convert to numpy array
457
  im_arr = np.array(im).astype(np.float32)
458
+
459
+ # Apply levels adjustment
460
  im_arr[im_arr < self.min_level] = self.min_level
461
+ im_arr = (im_arr - self.min_level) * (255 / (self.max_level - self.min_level))
 
462
  im_arr = np.clip(im_arr, 0, 255)
463
 
464
+ # Apply mid-level adjustment (gamma correction)
465
+ # Fix: Add error handling for potential division by zero or log errors
466
+ try:
467
+ gamma_dividend = (self.mid_level - self.min_level)
468
+ gamma_divisor = (self.max_level - self.min_level)
469
+
470
+ # Avoid division by zero
471
+ if gamma_divisor == 0:
472
+ gamma = 1.0
473
+ elif gamma_dividend <= 0:
474
+ gamma = 1.0
475
+ else:
476
+ gamma = math.log(0.5) / math.log(gamma_dividend / gamma_divisor)
477
+
478
+ # Ensure gamma is reasonable
479
+ gamma = max(0.1, min(5.0, gamma))
480
+
481
+ im_arr = np.power(im_arr / 255, gamma) * 255
482
+ except Exception as e:
483
+ logger.error(f"Error in gamma calculation: {e}")
484
+ # Fall back to no gamma adjustment
485
+ pass
486
 
487
  im_arr = im_arr.astype(np.uint8)
 
488
  im = Image.fromarray(im_arr)
 
489
  return im
490
 
491
+
492
+ def apply_image_levels(image, black_level, mid_level, white_level):
493
+ """Apply levels adjustment to an image"""
494
+ levels = AdjustLevels(black_level, mid_level, white_level)
495
+ adjusted_image = levels.adjust(image)
496
+ return adjusted_image
497
+
498
+
499
  def resize_image(image, scaling_factor=1):
500
+ """Resize image with error handling"""
501
+ if scaling_factor <= 0:
502
+ logger.warning(f"Invalid scaling factor: {scaling_factor}, using 1.0 instead")
503
+ scaling_factor = 1.0
504
+
505
+ try:
506
+ new_width = int(image.width * scaling_factor)
507
+ new_height = int(image.height * scaling_factor)
508
+
509
+ # Ensure minimum size
510
+ new_width = max(1, new_width)
511
+ new_height = max(1, new_height)
512
+
513
+ resized = image.resize((new_width, new_height), Image.LANCZOS)
514
+ return resized
515
+ except Exception as e:
516
+ logger.error(f"Error resizing image: {e}")
517
+ return image
518
 
 
519
 
520
+ def resize_to_square(image, size=1024):
521
+ """Resize image to a square canvas while maintaining aspect ratio"""
522
+ try:
523
+ # Convert to RGBA if needed
524
+ img = image.convert("RGBA") if isinstance(image, Image.Image) else Image.open(image).convert("RGBA")
525
 
526
+ # Resize while maintaining aspect ratio
527
+ img.thumbnail((size, size), Image.LANCZOS)
528
 
529
+ # Create a transparent square canvas
530
+ square_img = Image.new("RGBA", (size, size), (0, 0, 0, 0))
531
 
532
+ # Calculate the position to paste the resized image (centered)
533
+ x_offset = (size - img.width) // 2
534
+ y_offset = (size - img.height) // 2
535
 
536
+ # Extract the alpha channel as a mask
537
+ mask = img.split()[3] if img.mode == "RGBA" else None
538
 
539
+ # Paste the resized image onto the square canvas with the correct transparency mask
540
+ square_img.paste(img, (x_offset, y_offset), mask)
541
 
542
+ return square_img
543
+ except Exception as e:
544
+ logger.error(f"Error creating square image: {e}")
545
+ # Return the original image in case of an error
546
+ return image
547
 
548
 
549
  def encode_image(image):
550
+ """Encode image to base64 for API requests"""
551
  buffer = BytesIO()
552
+ # Ensure image is in proper format
553
+ if image.mode not in ["RGB", "RGBA"]:
554
+ image = image.convert("RGBA")
555
  image.save(buffer, format="PNG")
556
  encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
557
  return f"data:image/png;base64,{encoded_image}"
558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
 
560
+ def generate_ai_bg(input_img, prompt):
561
+ """Generate AI background using external service"""
562
+ try:
563
+ # Make sure the prompt is not empty
564
+ if not prompt or prompt.strip() == "":
565
+ prompt = "realistic automotive photography, professional lighting"
566
+ logger.info("Using default prompt for AI background generation")
567
+
568
+ # Use gradio_client for the API call
569
+ from gradio_client import Client, handle_file
570
+
571
+ try:
572
+ client = Client("lllyasviel/iclight-v2-vary")
573
+ result = client.predict(
574
+ input_fg=handle_file(input_img),
575
+ bg_source="None",
576
+ prompt=prompt,
577
+ image_width=1024,
578
+ image_height=1024,
579
+ num_samples=1,
580
+ seed=12345,
581
+ steps=25,
582
+ n_prompt="lowres, bad anatomy, bad hands, cropped, worst quality",
583
+ cfg=2,
584
+ gs=5,
585
+ enable_hr_fix=True,
586
+ hr_downscale=0.5,
587
+ lowres_denoise=0.8,
588
+ highres_denoise=0.99,
589
+ api_name="/process"
590
+ )
591
+
592
+ logger.info(f"AI background generation result: {result}")
593
+ relight_img_path = result[1]
594
+
595
+ # Load the generated image
596
+ relight_img = Image.open(relight_img_path).convert("RGBA")
597
+
598
+ # Ensure sizes match
599
+ if relight_img.size != input_img.size:
600
+ logger.info(f"Resizing generated image from {relight_img.size} to {input_img.size}")
601
+ relight_img = relight_img.resize(input_img.size, Image.LANCZOS)
602
+
603
+ return relight_img
604
+
605
+ except Exception as e:
606
+ logger.error(f"Error using gradio_client API: {e}")
607
+ # Fall back to a simpler method - just use the input image with a color filter
608
+ relight_img = input_img.copy()
609
+ color_overlay = Image.new("RGBA", input_img.size, (100, 150, 200, 128))
610
+ relight_img = Image.alpha_composite(relight_img.convert("RGBA"), color_overlay)
611
+ logger.info("Using fallback method for image generation")
612
+ return relight_img
613
+
614
+ except Exception as e:
615
+ logger.error(f"Error in AI background generation: {e}")
616
+ # Return the original image in case of an error
617
+ return input_img.copy()
618
 
 
619
 
620
  def blend_details(input_image, relit_image, masked_image, scaling_factor=1):
621
+ """Blend original and relit images using mask with detailed processing"""
622
+ try:
623
+ # Ensure all inputs are proper images
624
+ if input_image is None or relit_image is None or masked_image is None:
625
+ logger.error("Missing input for blend_details")
626
+ return input_image if input_image is not None else Image.new("RGB", (800, 600), (0, 0, 0))
627
+
628
+ # Ensure all images have the same size
629
+ if input_image.size != relit_image.size:
630
+ logger.warning(f"Relit image size ({relit_image.size}) doesn't match input image size ({input_image.size})")
631
+ relit_image = relit_image.resize(input_image.size, Image.LANCZOS)
632
+
633
+ if input_image.size != masked_image.size:
634
+ logger.warning(f"Mask image size ({masked_image.size}) doesn't match input image size ({input_image.size})")
635
+ masked_image = masked_image.resize(input_image.size, Image.LANCZOS)
636
+
637
+ # Process masked image
638
+ masked_image_rgb = split_image_with_alpha(masked_image)
639
+ masked_image_blurred = gaussian_blur(masked_image_rgb, radius=10)
 
640
 
641
+ # Fix: Add error handling for mask expansion
642
+ try:
643
+ grow_mask = expand_mask(masked_image_blurred, -15, True)
644
+ except Exception as e:
645
+ logger.error(f"Error expanding mask: {e}")
646
+ grow_mask = masked_image_blurred.convert("L")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
 
648
+ # Process input image
649
+ input_image_rgb = split_image_with_alpha(input_image)
650
+ input_blurred = gaussian_blur(input_image_rgb, radius=10)
651
+ input_inverted = invert_image(input_image_rgb)
652
+
653
+ # Add blurred and inverted images
654
+ input_blend_1 = blend_images(input_inverted, input_blurred, blend_percentage=0.5)
655
+ input_blend_1_inverted = invert_image(input_blend_1)
656
+ input_blend_2 = blend_images(input_blurred, input_blend_1_inverted, blend_percentage=1.0)
657
+
658
+ # Process relit image
659
+ relit_image_rgb = split_image_with_alpha(relit_image)
660
+ relit_blurred = gaussian_blur(relit_image_rgb, radius=10)
661
+ relit_inverted = invert_image(relit_image_rgb)
662
+
663
+ # Add blurred and inverted relit images
664
+ relit_blend_1 = blend_images(relit_inverted, relit_blurred, blend_percentage=0.5)
665
+ relit_blend_1_inverted = invert_image(relit_blend_1)
666
+ relit_blend_2 = blend_images(relit_blurred, relit_blend_1_inverted, blend_percentage=1.0)
667
 
668
+ # Blend high frequency components
669
+ high_freq_comp = image_blend_by_mask(relit_blend_2, input_blend_2, grow_mask, blend_percentage=1.0)
670
+
671
+ # Final compositing
672
+ comped_image = blend_images(relit_blurred, high_freq_comp, blend_percentage=0.65)
673
+
674
+ # Apply levels adjustment
675
+ final_image = apply_image_levels(comped_image, black_level=83, mid_level=128, white_level=172)
676
 
677
+ return final_image
678
+
679
+ except Exception as e:
680
+ logger.error(f"Error in blend_details: {e}")
681
+ # Return the relit image on error as a fallback
682
+ return relit_image if relit_image is not None else input_image
683
 
 
684
 
685
  @spaces.GPU
686
  def generate_image(input_img, prompt):