AhmadMustafa commited on
Commit
0c6b4cf
·
1 Parent(s): 675f40a

add: 16:9 crops

Browse files
Files changed (2) hide show
  1. crop_utils.py +603 -152
  2. prompts.py +8 -2
crop_utils.py CHANGED
@@ -13,6 +13,8 @@ from ultralytics import YOLO
13
 
14
  from prompts import remove_unwanted_prompt
15
 
 
 
16
 
17
  def get_middle_thumbnail(input_image: Image, grid_size=(10, 10), padding=3):
18
  """
@@ -57,129 +59,6 @@ def get_middle_thumbnail(input_image: Image, grid_size=(10, 10), padding=3):
57
  return middle_thumb
58
 
59
 
60
- def get_person_bbox(frame, model):
61
- """Detect person and return the largest bounding box"""
62
- results = model(frame, classes=[0]) # class 0 is person in COCO
63
-
64
- if not results or len(results[0].boxes) == 0:
65
- return None
66
-
67
- # Get all person boxes
68
- boxes = results[0].boxes.xyxy.cpu().numpy()
69
- # Calculate areas to find the largest person
70
- areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
71
- largest_idx = np.argmax(areas)
72
-
73
- return boxes[largest_idx]
74
-
75
-
76
- def generate_crops(frame):
77
- """Generate both 16:9 and 9:16 crops based on person detection"""
78
- # Load YOLO model
79
- model = YOLO("yolo11n.pt")
80
-
81
- # Convert PIL Image to cv2 format if needed
82
- if isinstance(frame, Image.Image):
83
- frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
84
-
85
- original_height, original_width = frame.shape[:2]
86
- bbox = get_person_bbox(frame, model)
87
-
88
- if bbox is None:
89
- return None, None
90
-
91
- # Extract coordinates
92
- x1, y1, x2, y2 = map(int, bbox)
93
- person_height = y2 - y1
94
- person_width = x2 - x1
95
- person_center_x = (x1 + x2) // 2
96
- person_center_y = (y1 + y2) // 2
97
-
98
- # Generate 16:9 crop (focus on upper body)
99
- aspect_ratio_16_9 = 16 / 9
100
- crop_width_16_9 = min(original_width, int(person_height * aspect_ratio_16_9))
101
- crop_height_16_9 = min(original_height, int(crop_width_16_9 / aspect_ratio_16_9))
102
-
103
- # For 16:9, center horizontally and align top with person's top
104
- x1_16_9 = max(0, person_center_x - crop_width_16_9 // 2)
105
- x2_16_9 = min(original_width, x1_16_9 + crop_width_16_9)
106
- y1_16_9 = max(0, y1) # Start from person's top
107
- y2_16_9 = min(original_height, y1_16_9 + crop_height_16_9)
108
-
109
- # Adjust if exceeding boundaries
110
- if x2_16_9 > original_width:
111
- x1_16_9 = original_width - crop_width_16_9
112
- x2_16_9 = original_width
113
- if y2_16_9 > original_height:
114
- y1_16_9 = original_height - crop_height_16_9
115
- y2_16_9 = original_height
116
-
117
- # Generate 9:16 crop (full body)
118
- aspect_ratio_9_16 = 9 / 16
119
- crop_width_9_16 = min(original_width, int(person_height * aspect_ratio_9_16))
120
- crop_height_9_16 = min(original_height, int(crop_width_9_16 / aspect_ratio_9_16))
121
-
122
- # For 9:16, center both horizontally and vertically
123
- x1_9_16 = max(0, person_center_x - crop_width_9_16 // 2)
124
- x2_9_16 = min(original_width, x1_9_16 + crop_width_9_16)
125
- y1_9_16 = max(0, person_center_y - crop_height_9_16 // 2)
126
- y2_9_16 = min(original_height, y1_9_16 + crop_height_9_16)
127
-
128
- # Adjust if exceeding boundaries
129
- if x2_9_16 > original_width:
130
- x1_9_16 = original_width - crop_width_9_16
131
- x2_9_16 = original_width
132
- if y2_9_16 > original_height:
133
- y1_9_16 = original_height - crop_height_9_16
134
- y2_9_16 = original_height
135
-
136
- # Create crops
137
- crop_16_9 = frame[y1_16_9:y2_16_9, x1_16_9:x2_16_9]
138
- crop_9_16 = frame[y1_9_16:y2_9_16, x1_9_16:x2_9_16]
139
-
140
- # Resize to standard dimensions
141
- crop_16_9 = cv2.resize(crop_16_9, (426, 240)) # 16:9 aspect ratio
142
- crop_9_16 = cv2.resize(crop_9_16, (240, 426)) # 9:16 aspect ratio
143
-
144
- return crop_16_9, crop_9_16
145
-
146
-
147
- def visualize_crops(image, bbox, crops_info):
148
- """
149
- Visualize original bbox and calculated crops
150
- bbox: [x1, y1, x2, y2]
151
- crops_info: dict with 'crop_16_9' and 'crop_9_16' coordinates
152
- """
153
- viz = image.copy()
154
-
155
- # Draw original person bbox in blue
156
- cv2.rectangle(
157
- viz, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 0, 0), 2
158
- )
159
-
160
- # Draw 16:9 crop in green
161
- crop_16_9 = crops_info["crop_16_9"]
162
- cv2.rectangle(
163
- viz,
164
- (int(crop_16_9["x1"]), int(crop_16_9["y1"])),
165
- (int(crop_16_9["x2"]), int(crop_16_9["y2"])),
166
- (0, 255, 0),
167
- 2,
168
- )
169
-
170
- # Draw 9:16 crop in red
171
- crop_9_16 = crops_info["crop_9_16"]
172
- cv2.rectangle(
173
- viz,
174
- (int(crop_9_16["x1"]), int(crop_9_16["y1"])),
175
- (int(crop_9_16["x2"]), int(crop_9_16["y2"])),
176
- (0, 0, 255),
177
- 2,
178
- )
179
-
180
- return viz
181
-
182
-
183
  def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str:
184
  """
185
  Convert a PIL image to a base64 string.
@@ -421,9 +300,15 @@ def analyze_image(numbered_input_image: Image, prompt, input_image):
421
  )
422
  except Exception as e:
423
  print(e)
424
- return input_image, input_image, input_image
425
-
426
- return cropped_image_16_9, image_with_lines, cropped_image_9_16
 
 
 
 
 
 
427
 
428
 
429
  def get_sprite_firebase(cid, rsid, uid):
@@ -450,26 +335,548 @@ def get_sprite_firebase(cid, rsid, uid):
450
  return data.val()
451
 
452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  def get_image_crop(cid=None, rsid=None, uid=None):
454
- """Function that returns both 16:9 and 9:16 crops"""
455
- image_paths = get_sprite_firebase(cid, rsid, uid)
456
 
457
- input_images = []
458
- mid_images = []
459
- cropped_image_16_9s = []
460
- images_with_lines = []
461
- cropped_image_9_16s = []
462
 
463
- for image_path in image_paths:
464
- response = requests.get(image_path)
465
 
466
- input_image = Image.open(BytesIO(response.content))
467
- input_images.append(input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  # Get the middle thumbnail
470
  mid_image = get_middle_thumbnail(input_image)
471
- mid_images.append(mid_image)
472
 
 
473
  numbered_mid_image = add_top_numbers(
474
  input_image=mid_image,
475
  num_divisions=20,
@@ -478,19 +885,63 @@ def get_image_crop(cid=None, rsid=None, uid=None):
478
  dot_spacing=20,
479
  )
480
 
481
- cropped_image_16_9, image_with_lines, cropped_image_9_16 = analyze_image(
482
- numbered_mid_image, remove_unwanted_prompt(2), mid_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  )
484
- cropped_image_16_9s.append(cropped_image_16_9)
485
- images_with_lines.append(image_with_lines)
486
- cropped_image_9_16s.append(cropped_image_9_16)
487
 
488
- return gr.Gallery(
489
- [
490
- *input_images,
491
- *mid_images,
492
- *cropped_image_16_9s,
493
- *images_with_lines,
494
- *cropped_image_9_16s,
495
- ]
496
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  from prompts import remove_unwanted_prompt
15
 
16
+ model = YOLO("yolo11n.pt")
17
+
18
 
19
  def get_middle_thumbnail(input_image: Image, grid_size=(10, 10), padding=3):
20
  """
 
59
  return middle_thumb
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str:
63
  """
64
  Convert a PIL image to a base64 string.
 
300
  )
301
  except Exception as e:
302
  print(e)
303
+ return input_image, input_image, input_image, 0, 20
304
+
305
+ return (
306
+ cropped_image_16_9,
307
+ image_with_lines,
308
+ cropped_image_9_16,
309
+ response_json["left_row"],
310
+ response_json["right_row"],
311
+ )
312
 
313
 
314
  def get_sprite_firebase(cid, rsid, uid):
 
335
  return data.val()
336
 
337
 
338
+ def find_persons_center(image):
339
+ """
340
+ Find the center point of all persons in the image.
341
+ If multiple persons are detected, merge all bounding boxes and find the center.
342
+
343
+ Args:
344
+ image: CV2/numpy array image
345
+
346
+ Returns:
347
+ int: x-coordinate of the center point of all persons
348
+ """
349
+ # Detect persons (class 0 in COCO dataset)
350
+ results = model(image, classes=[0])
351
+
352
+ if not results or len(results[0].boxes) == 0:
353
+ # If no persons detected, return center of image
354
+ return image.shape[1] // 2
355
+
356
+ # Get all person boxes
357
+ boxes = results[0].boxes.xyxy.cpu().numpy()
358
+
359
+ # Print the number of persons detected (for debugging)
360
+ print(f"Detected {len(boxes)} persons in the image")
361
+
362
+ if len(boxes) == 1:
363
+ # If only one person, return center of their bounding box
364
+ x1, _, x2, _ = boxes[0]
365
+ center_x = int((x1 + x2) // 2)
366
+ print(f"Single person detected at center x: {center_x}")
367
+ return center_x
368
+ else:
369
+ # Multiple persons - create a merged bounding box
370
+ left_x = min(box[0] for box in boxes)
371
+ right_x = max(box[2] for box in boxes)
372
+ merged_center_x = int((left_x + right_x) // 2)
373
+
374
+ print(f"Multiple persons merged bounding box center x: {merged_center_x}")
375
+ print(f"Merged bounds: left={left_x}, right={right_x}")
376
+
377
+ return merged_center_x
378
+
379
+
380
+ def create_layouts(image, left_division, right_division):
381
+ """
382
+ Create different layout variations of the image using half, one-third, and two-thirds width.
383
+ All layout variations will be centered on detected persons, including 16:9 and 9:16 crops.
384
+
385
+ Args:
386
+ image: PIL Image
387
+ left_division: Left division index (1-20)
388
+ right_division: Right division index (1-20)
389
+
390
+ Returns:
391
+ tuple: (list of layout variations, cutout_image, cutout_16_9, cutout_9_16)
392
+ """
393
+ # Convert PIL Image to cv2 format
394
+ if isinstance(image, Image.Image):
395
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
396
+ else:
397
+ image_cv = image.copy()
398
+
399
+ # Get image dimensions
400
+ height, width = image_cv.shape[:2]
401
+
402
+ # Calculate division width and crop boundaries
403
+ division_width = width / 20 # Assuming 20 divisions
404
+ left_boundary = int((left_division - 1) * division_width)
405
+ right_boundary = int(right_division * division_width)
406
+
407
+ # 1. Create cutout image based on divisions
408
+ cutout_image = image_cv[:, left_boundary:right_boundary].copy()
409
+ cutout_width = right_boundary - left_boundary
410
+ cutout_height = cutout_image.shape[0]
411
+
412
+ # 2. Run YOLO on cutout to get person bounding box and center
413
+ results = model(cutout_image, classes=[0])
414
+
415
+ # Default center if no detection
416
+ cutout_center_x = cutout_image.shape[1] // 2
417
+ cutout_center_y = cutout_height // 2
418
+
419
+ # Default values for bounding box
420
+ person_top = 0.0
421
+ person_height = float(cutout_height)
422
+
423
+ if results and len(results[0].boxes) > 0:
424
+ # Get person detection
425
+ boxes = results[0].boxes.xyxy.cpu().numpy()
426
+
427
+ if len(boxes) == 1:
428
+ # Single person
429
+ x1, y1, x2, y2 = boxes[0]
430
+ cutout_center_x = int((x1 + x2) // 2)
431
+ cutout_center_y = int((y1 + y2) // 2)
432
+ person_top = y1
433
+ person_height = y2 - y1
434
+ else:
435
+ # Multiple persons - merge bounding boxes
436
+ left_x = min(box[0] for box in boxes)
437
+ right_x = max(box[2] for box in boxes)
438
+ top_y = min(box[1] for box in boxes) # Top of highest person
439
+ bottom_y = max(box[3] for box in boxes) # Bottom of lowest person
440
+
441
+ cutout_center_x = int((left_x + right_x) // 2)
442
+ cutout_center_y = int((top_y + bottom_y) // 2)
443
+ person_top = top_y
444
+ person_height = bottom_y - top_y
445
+
446
+ # 3. Create 16:9 and 9:16 versions with person properly framed
447
+ aspect_16_9 = 16 / 9
448
+ aspect_9_16 = 9 / 16
449
+
450
+ # For 16:9 version (with 20% margin above person)
451
+ target_height_16_9 = int(cutout_width / aspect_16_9)
452
+ if target_height_16_9 <= cutout_height:
453
+ # Calculate 20% of person height for top margin
454
+ top_margin = int(person_height * 0.2)
455
+
456
+ # Start 20% above the person's top
457
+ y_start = int(max(0, person_top - top_margin))
458
+
459
+ # If this would make the crop exceed the bottom, adjust y_start
460
+ if y_start + target_height_16_9 > cutout_height:
461
+ y_start = int(max(0, cutout_height - target_height_16_9))
462
+
463
+ y_end = int(min(cutout_height, y_start + target_height_16_9))
464
+ cutout_16_9 = cutout_image[y_start:y_end, :].copy()
465
+ else:
466
+ # Handle rare case where we need to adjust width (not expected with normal images)
467
+ new_width = int(cutout_height * aspect_16_9)
468
+ x_start = max(
469
+ 0, min(cutout_width - new_width, cutout_center_x - new_width // 2)
470
+ )
471
+ x_end = min(cutout_width, x_start + new_width)
472
+ cutout_16_9 = cutout_image[:, x_start:x_end].copy()
473
+
474
+ # For 9:16 version (centered on person)
475
+ target_width_9_16 = int(cutout_height * aspect_9_16)
476
+ if target_width_9_16 <= cutout_width:
477
+ # Center horizontally around person
478
+ x_start = int(
479
+ max(
480
+ 0,
481
+ min(
482
+ cutout_width - target_width_9_16,
483
+ cutout_center_x - target_width_9_16 // 2,
484
+ ),
485
+ )
486
+ )
487
+ x_end = int(min(cutout_width, x_start + target_width_9_16))
488
+ cutout_9_16 = cutout_image[:, x_start:x_end].copy()
489
+ else:
490
+ # Handle rare case where we need to adjust height
491
+ new_height = int(cutout_width / aspect_9_16)
492
+ y_start = int(
493
+ max(0, min(cutout_height - new_height, cutout_center_y - new_height // 2))
494
+ )
495
+ y_end = int(min(cutout_height, y_start + new_height))
496
+ cutout_9_16 = cutout_image[y_start:y_end, :].copy()
497
+
498
+ # 4. Scale the center back to original image coordinates
499
+ original_center_x = left_boundary + cutout_center_x
500
+
501
+ # 5. Create layout variations on the original image centered on persons
502
+ # Half width layout
503
+ half_width = width // 2
504
+ half_left_x = max(0, min(width - half_width, original_center_x - half_width // 2))
505
+ half_right_x = half_left_x + half_width
506
+ half_width_crop = image_cv[:, half_left_x:half_right_x].copy()
507
+
508
+ # Third width layout
509
+ third_width = width // 3
510
+ third_left_x = max(
511
+ 0, min(width - third_width, original_center_x - third_width // 2)
512
+ )
513
+ third_right_x = third_left_x + third_width
514
+ third_width_crop = image_cv[:, third_left_x:third_right_x].copy()
515
+
516
+ # Two-thirds width layout
517
+ two_thirds_width = (width * 2) // 3
518
+ two_thirds_left_x = max(
519
+ 0, min(width - two_thirds_width, original_center_x - two_thirds_width // 2)
520
+ )
521
+ two_thirds_right_x = two_thirds_left_x + two_thirds_width
522
+ two_thirds_crop = image_cv[:, two_thirds_left_x:two_thirds_right_x].copy()
523
+
524
+ # Add labels to all crops
525
+ font = cv2.FONT_HERSHEY_SIMPLEX
526
+ label_settings = {
527
+ "fontScale": 1.0,
528
+ "fontFace": 1,
529
+ "thickness": 2,
530
+ }
531
+
532
+ # Draw label backgrounds for better visibility
533
+ def add_label(img, label):
534
+ # Draw background for text
535
+ text_size = cv2.getTextSize(
536
+ label, **{k: v for k, v in label_settings.items() if k != "color"}
537
+ )
538
+ cv2.rectangle(
539
+ img,
540
+ (10, 10),
541
+ (10 + text_size[0][0] + 10, 10 + text_size[0][1] + 10),
542
+ (0, 0, 0),
543
+ -1,
544
+ ) # Black background
545
+ # Draw text
546
+ cv2.putText(
547
+ img,
548
+ label,
549
+ (15, 15 + text_size[0][1]),
550
+ **label_settings,
551
+ color=(255, 255, 255),
552
+ lineType=cv2.LINE_AA,
553
+ )
554
+ return img
555
+
556
+ cutout_image = add_label(cutout_image, "Cutout")
557
+ cutout_16_9 = add_label(cutout_16_9, "16:9")
558
+ cutout_9_16 = add_label(cutout_9_16, "9:16")
559
+ half_width_crop = add_label(half_width_crop, "Half Width")
560
+ third_width_crop = add_label(third_width_crop, "Third Width")
561
+ two_thirds_crop = add_label(two_thirds_crop, "Two-Thirds Width")
562
+
563
+ # Convert all output images to PIL format
564
+ layout_crops = []
565
+ for layout, label in [
566
+ (half_width_crop, "Half Width"),
567
+ (third_width_crop, "Third Width"),
568
+ (two_thirds_crop, "Two-Thirds Width"),
569
+ ]:
570
+ pil_layout = Image.fromarray(cv2.cvtColor(layout, cv2.COLOR_BGR2RGB))
571
+ layout_crops.append(pil_layout)
572
+
573
+ cutout_pil = Image.fromarray(cv2.cvtColor(cutout_image, cv2.COLOR_BGR2RGB))
574
+ cutout_16_9_pil = Image.fromarray(cv2.cvtColor(cutout_16_9, cv2.COLOR_BGR2RGB))
575
+ cutout_9_16_pil = Image.fromarray(cv2.cvtColor(cutout_9_16, cv2.COLOR_BGR2RGB))
576
+
577
+ return layout_crops, cutout_pil, cutout_16_9_pil, cutout_9_16_pil
578
+
579
+
580
+ def draw_all_crops_on_original(image, left_division, right_division):
581
+ """
582
+ Create a visualization showing all crop regions overlaid on the original image.
583
+ Each crop region is outlined with a different color and labeled.
584
+ All crops are centered on the person's center point.
585
+
586
+ Args:
587
+ image: PIL Image
588
+ left_division: Left division index (1-20)
589
+ right_division: Right division index (1-20)
590
+
591
+ Returns:
592
+ PIL Image: Original image with all crop regions visualized
593
+ """
594
+ # Convert PIL Image to cv2 format
595
+ if isinstance(image, Image.Image):
596
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
597
+ else:
598
+ image_cv = image.copy()
599
+
600
+ # Get a clean copy for drawing
601
+ visualization = image_cv.copy()
602
+
603
+ # Get image dimensions
604
+ height, width = image_cv.shape[:2]
605
+
606
+ # Calculate division width and crop boundaries
607
+ division_width = width / 20 # Assuming 20 divisions
608
+ left_boundary = int((left_division - 1) * division_width)
609
+ right_boundary = int(right_division * division_width)
610
+
611
+ # Find person bounding box and center in cutout
612
+ cutout_image = image_cv[:, left_boundary:right_boundary].copy()
613
+
614
+ # Get YOLO detections for person bounding box
615
+ results = model(cutout_image, classes=[0])
616
+
617
+ # Default values
618
+ cutout_center_x = cutout_image.shape[1] // 2
619
+ cutout_center_y = cutout_image.shape[0] // 2
620
+ person_top = 0.0
621
+ person_height = float(cutout_image.shape[0])
622
+
623
+ if results and len(results[0].boxes) > 0:
624
+ # Get person detection
625
+ boxes = results[0].boxes.xyxy.cpu().numpy()
626
+
627
+ if len(boxes) == 1:
628
+ # Single person
629
+ x1, y1, x2, y2 = boxes[0]
630
+ cutout_center_x = int((x1 + x2) // 2)
631
+ cutout_center_y = int((y1 + y2) // 2)
632
+ person_top = y1
633
+ person_height = y2 - y1
634
+ else:
635
+ # Multiple persons - merge bounding boxes
636
+ left_x = min(box[0] for box in boxes)
637
+ right_x = max(box[2] for box in boxes)
638
+ top_y = min(box[1] for box in boxes) # Top of highest person
639
+ bottom_y = max(box[3] for box in boxes) # Bottom of lowest person
640
+
641
+ cutout_center_x = int((left_x + right_x) // 2)
642
+ cutout_center_y = int((top_y + bottom_y) // 2)
643
+ person_top = top_y
644
+ person_height = bottom_y - top_y
645
+
646
+ # Scale back to original image
647
+ original_center_x = left_boundary + cutout_center_x
648
+ original_center_y = cutout_center_y
649
+ original_person_top = (
650
+ person_top # Already in original image space since we didn't crop vertically
651
+ )
652
+ original_person_height = person_height # Same in original space
653
+
654
+ # Define colors for different crops (BGR format)
655
+ colors = {
656
+ "cutout": (0, 165, 255), # Orange
657
+ "16:9": (0, 255, 0), # Green
658
+ "9:16": (255, 0, 0), # Blue
659
+ "half": (255, 255, 0), # Cyan
660
+ "third": (255, 0, 255), # Magenta
661
+ "two_thirds": (0, 255, 255), # Yellow
662
+ }
663
+
664
+ # Define line thickness and font
665
+ thickness = 3
666
+ font = cv2.FONT_HERSHEY_SIMPLEX
667
+ font_scale = 0.8
668
+ font_thickness = 2
669
+
670
+ # 1. Draw cutout region (original divisions)
671
+ cv2.rectangle(
672
+ visualization,
673
+ (left_boundary, 0),
674
+ (right_boundary, height),
675
+ colors["cutout"],
676
+ thickness,
677
+ )
678
+ cv2.putText(
679
+ visualization,
680
+ "Cutout",
681
+ (left_boundary + 5, 30),
682
+ font,
683
+ font_scale,
684
+ colors["cutout"],
685
+ font_thickness,
686
+ )
687
+
688
+ # 2. Create 16:9 and 9:16 versions of the cutout - CENTERED on person
689
+ cutout_width = right_boundary - left_boundary
690
+ cutout_height = height
691
+
692
+ # For 16:9 version with 20% margin above person
693
+ aspect_16_9 = 16 / 9
694
+ target_height_16_9 = int(cutout_width / aspect_16_9)
695
+ if target_height_16_9 <= height:
696
+ # Calculate 20% of person height for top margin
697
+ top_margin = int(original_person_height * 0.2)
698
+
699
+ # Start 20% above the person's top
700
+ y_start = int(max(0, original_person_top - top_margin))
701
+
702
+ # If this would make the crop exceed the bottom, adjust y_start
703
+ if y_start + target_height_16_9 > height:
704
+ y_start = int(max(0, height - target_height_16_9))
705
+
706
+ y_end = int(min(height, y_start + target_height_16_9))
707
+
708
+ cv2.rectangle(
709
+ visualization,
710
+ (left_boundary, y_start),
711
+ (right_boundary, y_end),
712
+ colors["16:9"],
713
+ thickness,
714
+ )
715
+ cv2.putText(
716
+ visualization,
717
+ "16:9",
718
+ (left_boundary + 5, y_start + 30),
719
+ font,
720
+ font_scale,
721
+ colors["16:9"],
722
+ font_thickness,
723
+ )
724
+
725
+ # For 9:16 version centered on person
726
+ aspect_9_16 = 9 / 16
727
+ target_width_9_16 = int(cutout_height * aspect_9_16)
728
+ if target_width_9_16 <= cutout_width:
729
+ # Center horizontally around person
730
+ x_start = max(
731
+ 0,
732
+ min(
733
+ left_boundary + cutout_width - target_width_9_16,
734
+ original_center_x - target_width_9_16 // 2,
735
+ ),
736
+ )
737
+ x_end = x_start + target_width_9_16
738
+ cv2.rectangle(
739
+ visualization, (x_start, 0), (x_end, height), colors["9:16"], thickness
740
+ )
741
+ cv2.putText(
742
+ visualization,
743
+ "9:16",
744
+ (x_start + 5, 60),
745
+ font,
746
+ font_scale,
747
+ colors["9:16"],
748
+ font_thickness,
749
+ )
750
+
751
+ # 3. Draw centered layout variations
752
+ # Half width layout
753
+ half_width = width // 2
754
+ half_left_x = max(0, min(width - half_width, original_center_x - half_width // 2))
755
+ half_right_x = half_left_x + half_width
756
+ cv2.rectangle(
757
+ visualization,
758
+ (half_left_x, 0),
759
+ (half_right_x, height),
760
+ colors["half"],
761
+ thickness,
762
+ )
763
+ cv2.putText(
764
+ visualization,
765
+ "Half Width",
766
+ (half_left_x + 5, 90),
767
+ font,
768
+ font_scale,
769
+ colors["half"],
770
+ font_thickness,
771
+ )
772
+
773
+ # Third width layout
774
+ third_width = width // 3
775
+ third_left_x = max(
776
+ 0, min(width - third_width, original_center_x - third_width // 2)
777
+ )
778
+ third_right_x = third_left_x + third_width
779
+ cv2.rectangle(
780
+ visualization,
781
+ (third_left_x, 0),
782
+ (third_right_x, height),
783
+ colors["third"],
784
+ thickness,
785
+ )
786
+ cv2.putText(
787
+ visualization,
788
+ "Third Width",
789
+ (third_left_x + 5, 120),
790
+ font,
791
+ font_scale,
792
+ colors["third"],
793
+ font_thickness,
794
+ )
795
+
796
+ # Two-thirds width layout
797
+ two_thirds_width = (width * 2) // 3
798
+ two_thirds_left_x = max(
799
+ 0, min(width - two_thirds_width, original_center_x - two_thirds_width // 2)
800
+ )
801
+ two_thirds_right_x = two_thirds_left_x + two_thirds_width
802
+ cv2.rectangle(
803
+ visualization,
804
+ (two_thirds_left_x, 0),
805
+ (two_thirds_right_x, height),
806
+ colors["two_thirds"],
807
+ thickness,
808
+ )
809
+ cv2.putText(
810
+ visualization,
811
+ "Two-Thirds Width",
812
+ (two_thirds_left_x + 5, 150),
813
+ font,
814
+ font_scale,
815
+ colors["two_thirds"],
816
+ font_thickness,
817
+ )
818
+
819
+ # 4. Draw center point of person(s)
820
+ center_radius = 8
821
+ cv2.circle(
822
+ visualization,
823
+ (original_center_x, height // 2),
824
+ center_radius,
825
+ (255, 255, 255),
826
+ -1,
827
+ )
828
+ cv2.circle(
829
+ visualization, (original_center_x, height // 2), center_radius, (0, 0, 0), 2
830
+ )
831
+ cv2.putText(
832
+ visualization,
833
+ "Person Center",
834
+ (original_center_x + 10, height // 2),
835
+ font,
836
+ font_scale,
837
+ (255, 255, 255),
838
+ font_thickness,
839
+ )
840
+
841
+ # Convert back to PIL format
842
+ visualization_pil = Image.fromarray(cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB))
843
+
844
+ return visualization_pil
845
+
846
+
847
  def get_image_crop(cid=None, rsid=None, uid=None):
848
+ """
849
+ Function that returns both 16:9 and 9:16 crops and layout variations for visualization.
850
 
851
+ Returns:
852
+ gr.Gallery: Gallery of all generated images
853
+ """
854
+ # Uncomment this line when using Firebase
855
+ # image_paths = get_sprite_firebase(cid, rsid, uid)
856
 
857
+ # For testing, use a local image path
858
+ image_paths = ["sprite1.jpg", "sprite2.jpg"]
859
 
860
+ # Lists to store all images
861
+ all_images = []
862
+ all_captions = []
863
+
864
+ for image_path in image_paths:
865
+ # Load image (from local file or URL)
866
+ try:
867
+ if image_path.startswith(("http://", "https://")):
868
+ response = requests.get(image_path)
869
+ input_image = Image.open(BytesIO(response.content))
870
+ else:
871
+ input_image = Image.open(image_path)
872
+ except Exception as e:
873
+ print(f"Error loading image {image_path}: {e}")
874
+ continue
875
 
876
  # Get the middle thumbnail
877
  mid_image = get_middle_thumbnail(input_image)
 
878
 
879
+ # Add numbered divisions for GPT-4V analysis
880
  numbered_mid_image = add_top_numbers(
881
  input_image=mid_image,
882
  num_divisions=20,
 
885
  dot_spacing=20,
886
  )
887
 
888
+ # Analyze the image to get optimal crop divisions
889
+ # This uses GPT-4V to identify the optimal crop points
890
+ (
891
+ _,
892
+ _,
893
+ _,
894
+ left_division,
895
+ right_division,
896
+ ) = analyze_image(numbered_mid_image, remove_unwanted_prompt(2), mid_image)
897
+
898
+ # Safety check for divisions
899
+ if left_division <= 0:
900
+ left_division = 1
901
+ if right_division > 20:
902
+ right_division = 20
903
+ if left_division >= right_division:
904
+ left_division = 1
905
+ right_division = 20
906
+
907
+ print(f"Using divisions: left={left_division}, right={right_division}")
908
+
909
+ # Create layouts and cutouts
910
+ layouts, cutout_image, cutout_16_9, cutout_9_16 = create_layouts(
911
+ mid_image, left_division, right_division
912
  )
 
 
 
913
 
914
+ # Create the visualization with all crops overlaid on original
915
+ all_crops_visualization = draw_all_crops_on_original(
916
+ mid_image, left_division, right_division
917
+ )
918
+
919
+ # Start with the visualization showing all crops
920
+ all_images.append(all_crops_visualization)
921
+ all_captions.append(f"All Crops Visualization {all_crops_visualization.size}")
922
+
923
+ # Add input and middle image to gallery
924
+ all_images.append(input_image)
925
+ all_captions.append(f"Input Image {input_image.size}")
926
+
927
+ all_images.append(mid_image)
928
+ all_captions.append(f"Middle Thumbnail {mid_image.size}")
929
+
930
+ # Add cutout images to gallery
931
+ all_images.append(cutout_image)
932
+ all_captions.append(f"Cutout Image {cutout_image.size}")
933
+
934
+ all_images.append(cutout_16_9)
935
+ all_captions.append(f"16:9 Crop {cutout_16_9.size}")
936
+
937
+ all_images.append(cutout_9_16)
938
+ all_captions.append(f"9:16 Crop {cutout_9_16.size}")
939
+
940
+ # Add layout variations
941
+ for i, layout in enumerate(layouts):
942
+ label = ["Half Width", "Third Width", "Two-Thirds Width"][i]
943
+ all_images.append(layout)
944
+ all_captions.append(f"{label} {layout.size}")
945
+
946
+ # Return gallery with all images
947
+ return gr.Gallery(value=list(zip(all_images, all_captions)))
prompts.py CHANGED
@@ -153,5 +153,11 @@ If the user provides the correct call type, use the correct_call_type function t
153
 
154
  def remove_unwanted_prompt(number_of_speakers: int):
155
  if number_of_speakers == 2:
156
- return """I want to crop this image such that no unwanted or Partial Object or Partial Human is in the image.
157
- Please analyze the image such that you tell me the row number on both the left and right sides of the image inside which there is the no unwanted partial object."""
 
 
 
 
 
 
 
153
 
154
  def remove_unwanted_prompt(number_of_speakers: int):
155
  if number_of_speakers == 2:
156
+ return """I want to crop this image only when absolutely necessary to remove partial objects or humans.
157
+
158
+ Please analyze the image and tell me:
159
+ 1. The column number (1-20) on the left side where I should start the crop. Only suggest cropping (columns 1-4) if there are clear partial objects or humans that need removal. If no cropping is needed on the left, return 1.
160
+
161
+ 2. The column number (1-20) on the right side where I should end the crop. Only suggest cropping (columns 17-20) if there are clear partial objects or humans that need removal. If no cropping is needed on the right, return 20.
162
+
163
+ I'm looking for minimal cropping - only cut when absolutely necessary to remove distracting partial elements."""