File size: 16,062 Bytes
72d1759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
import base64
import os
from io import BytesIO

import cv2
import gradio as gr
import numpy as np
import pyrebase
import requests
from openai import OpenAI
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO

from prompts import remove_unwanted_prompt


def get_middle_thumbnail(input_image: Image, grid_size=(10, 10), padding=3):
    """
    Extract the middle thumbnail from a sprite sheet, handling different aspect ratios
    and removing padding.

    Args:
        input_image: PIL Image
        grid_size: Tuple of (columns, rows)
        padding: Number of padding pixels on each side (default 3)

    Returns:
        PIL.Image: The middle thumbnail image with padding removed
    """
    sprite_sheet = input_image

    # Calculate thumbnail dimensions based on actual sprite sheet size
    sprite_width, sprite_height = sprite_sheet.size
    thumb_width_with_padding = sprite_width // grid_size[0]
    thumb_height_with_padding = sprite_height // grid_size[1]

    # Remove padding to get actual image dimensions
    thumb_width = thumb_width_with_padding - (2 * padding)  # 726 - 6 = 720
    thumb_height = thumb_height_with_padding - (2 * padding)  # varies based on input

    # Calculate the middle position
    total_thumbs = grid_size[0] * grid_size[1]
    middle_index = total_thumbs // 2

    # Calculate row and column of middle thumbnail
    middle_row = middle_index // grid_size[0]
    middle_col = middle_index % grid_size[0]

    # Calculate pixel coordinates for cropping, including padding offset
    left = (middle_col * thumb_width_with_padding) + padding
    top = (middle_row * thumb_height_with_padding) + padding
    right = left + thumb_width  # Don't add padding here
    bottom = top + thumb_height  # Don't add padding here

    # Crop and return the middle thumbnail
    middle_thumb = sprite_sheet.crop((left, top, right, bottom))
    return middle_thumb


def get_person_bbox(frame, model):
    """Detect person and return the largest bounding box"""
    results = model(frame, classes=[0])  # class 0 is person in COCO

    if not results or len(results[0].boxes) == 0:
        return None

    # Get all person boxes
    boxes = results[0].boxes.xyxy.cpu().numpy()
    # Calculate areas to find the largest person
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    largest_idx = np.argmax(areas)

    return boxes[largest_idx]


def generate_crops(frame):
    """Generate both 16:9 and 9:16 crops based on person detection"""
    # Load YOLO model
    model = YOLO("yolo11n.pt")

    # Convert PIL Image to cv2 format if needed
    if isinstance(frame, Image.Image):
        frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)

    original_height, original_width = frame.shape[:2]
    bbox = get_person_bbox(frame, model)

    if bbox is None:
        return None, None

    # Extract coordinates
    x1, y1, x2, y2 = map(int, bbox)
    person_height = y2 - y1
    person_width = x2 - x1
    person_center_x = (x1 + x2) // 2
    person_center_y = (y1 + y2) // 2

    # Generate 16:9 crop (focus on upper body)
    aspect_ratio_16_9 = 16 / 9
    crop_width_16_9 = min(original_width, int(person_height * aspect_ratio_16_9))
    crop_height_16_9 = min(original_height, int(crop_width_16_9 / aspect_ratio_16_9))

    # For 16:9, center horizontally and align top with person's top
    x1_16_9 = max(0, person_center_x - crop_width_16_9 // 2)
    x2_16_9 = min(original_width, x1_16_9 + crop_width_16_9)
    y1_16_9 = max(0, y1)  # Start from person's top
    y2_16_9 = min(original_height, y1_16_9 + crop_height_16_9)

    # Adjust if exceeding boundaries
    if x2_16_9 > original_width:
        x1_16_9 = original_width - crop_width_16_9
        x2_16_9 = original_width
    if y2_16_9 > original_height:
        y1_16_9 = original_height - crop_height_16_9
        y2_16_9 = original_height

    # Generate 9:16 crop (full body)
    aspect_ratio_9_16 = 9 / 16
    crop_width_9_16 = min(original_width, int(person_height * aspect_ratio_9_16))
    crop_height_9_16 = min(original_height, int(crop_width_9_16 / aspect_ratio_9_16))

    # For 9:16, center both horizontally and vertically
    x1_9_16 = max(0, person_center_x - crop_width_9_16 // 2)
    x2_9_16 = min(original_width, x1_9_16 + crop_width_9_16)
    y1_9_16 = max(0, person_center_y - crop_height_9_16 // 2)
    y2_9_16 = min(original_height, y1_9_16 + crop_height_9_16)

    # Adjust if exceeding boundaries
    if x2_9_16 > original_width:
        x1_9_16 = original_width - crop_width_9_16
        x2_9_16 = original_width
    if y2_9_16 > original_height:
        y1_9_16 = original_height - crop_height_9_16
        y2_9_16 = original_height

    # Create crops
    crop_16_9 = frame[y1_16_9:y2_16_9, x1_16_9:x2_16_9]
    crop_9_16 = frame[y1_9_16:y2_9_16, x1_9_16:x2_9_16]

    # Resize to standard dimensions
    crop_16_9 = cv2.resize(crop_16_9, (426, 240))  # 16:9 aspect ratio
    crop_9_16 = cv2.resize(crop_9_16, (240, 426))  # 9:16 aspect ratio

    return crop_16_9, crop_9_16


def visualize_crops(image, bbox, crops_info):
    """
    Visualize original bbox and calculated crops
    bbox: [x1, y1, x2, y2]
    crops_info: dict with 'crop_16_9' and 'crop_9_16' coordinates
    """
    viz = image.copy()

    # Draw original person bbox in blue
    cv2.rectangle(
        viz, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 0, 0), 2
    )

    # Draw 16:9 crop in green
    crop_16_9 = crops_info["crop_16_9"]
    cv2.rectangle(
        viz,
        (int(crop_16_9["x1"]), int(crop_16_9["y1"])),
        (int(crop_16_9["x2"]), int(crop_16_9["y2"])),
        (0, 255, 0),
        2,
    )

    # Draw 9:16 crop in red
    crop_9_16 = crops_info["crop_9_16"]
    cv2.rectangle(
        viz,
        (int(crop_9_16["x1"]), int(crop_9_16["y1"])),
        (int(crop_9_16["x2"]), int(crop_9_16["y2"])),
        (0, 0, 255),
        2,
    )

    return viz


def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str:
    """
    Convert a PIL image to a base64 string.

    Args:
        image: PIL Image object
        format: Image format to use for encoding (default: PNG)

    Returns:
        Base64 encoded string of the image
    """
    buffered = BytesIO()
    image.save(buffered, format=format)
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def add_top_numbers(
    input_image,
    num_divisions=20,
    margin=90,
    font_size=120,
    dot_spacing=20,
):
    """
    Add numbered divisions across the top and bottom of any image with dotted vertical lines.

    Args:
        input_image (Image): PIL Image
        num_divisions (int): Number of divisions to create
        margin (int): Size of margin in pixels for numbers
        font_size (int): Font size for numbers
        dot_spacing (int): Spacing between dots in pixels
    """
    # Load the image
    original_image = input_image

    # Create new image with extra space for numbers on top and bottom
    new_width = original_image.width
    new_height = original_image.height + (
        2 * margin
    )  # Add margin to both top and bottom
    new_image = Image.new("RGB", (new_width, new_height), "white")

    # Paste original image in the middle
    new_image.paste(original_image, (0, margin))

    # Initialize drawing context
    draw = ImageDraw.Draw(new_image)

    try:
        font = ImageFont.truetype("arial.ttf", font_size)
    except OSError:
        print("Using default font")
        font = ImageFont.load_default(size=font_size)

    # Calculate division width
    division_width = original_image.width / num_divisions

    # Draw division numbers and dotted lines
    for i in range(num_divisions):
        x = (i * division_width) + (division_width / 2)

        # Draw number at top
        draw.text((x, margin // 2), str(i + 1), fill="black", font=font, anchor="mm")

        # Draw number at bottom
        draw.text(
            (x, new_height - (margin // 2)),
            str(i + 1),
            fill="black",
            font=font,
            anchor="mm",
        )

        # Draw dotted line from top margin to bottom margin
        y_start = margin
        y_end = new_height - margin

        # Draw dots with specified spacing
        current_y = y_start
        while current_y < y_end:
            draw.circle(
                [x - 1, current_y - 1, x + 1, current_y + 1],
                fill="black",
                width=5,
                radius=3,
            )
            current_y += dot_spacing

    return new_image


def crop_and_draw_divisions(
    input_image,
    left_division,
    right_division,
    num_divisions=20,
    line_color=(255, 0, 0),
    line_width=2,
    head_margin_percent=0.1,
):
    """
    Create both 9:16 and 16:9 crops and draw guide lines.

    Args:
        input_image (Image): PIL Image
        left_division (int): Left-side division number (1-20)
        right_division (int): Right-side division number (1-20)
        num_divisions (int): Total number of divisions (default=20)
        line_color (tuple): RGB color tuple for lines (default: red)
        line_width (int): Width of lines in pixels (default: 2)
        head_margin_percent (float): Percentage margin above head (default: 0.1)

    Returns:
        tuple: (cropped_image_16_9, image_with_lines, cropped_image_9_16)
    """
    yolo_model = YOLO("yolo11n.pt")
    # Calculate division width and boundaries
    division_width = input_image.width / num_divisions
    left_boundary = (left_division - 1) * division_width
    right_boundary = right_division * division_width

    # First get the 9:16 crop
    cropped_image_9_16 = input_image.crop(
        (left_boundary, 0, right_boundary, input_image.height)
    )

    # Run YOLO on the 9:16 crop to get person bbox
    bbox = yolo_model(cropped_image_9_16, classes=[0])[0].boxes.xyxy.cpu().numpy()[0]
    x1, y1, x2, y2 = bbox

    # Calculate top boundary with head margin
    head_margin = (y2 - y1) * head_margin_percent
    top_boundary = max(0, y1 - head_margin)

    # Calculate 16:9 dimensions based on the width between divisions
    crop_width = right_boundary - left_boundary
    crop_height_16_9 = int(crop_width * 9 / 16)

    # Calculate bottom boundary for 16:9
    bottom_boundary = min(input_image.height, top_boundary + crop_height_16_9)

    # Create 16:9 crop from original image
    cropped_image_16_9 = input_image.crop(
        (left_boundary, top_boundary, right_boundary, bottom_boundary)
    )

    # Draw guide lines for both crops on original image
    image_with_lines = input_image.copy()
    draw = ImageDraw.Draw(image_with_lines)

    # Draw vertical lines (for both crops)
    draw.line(
        [(left_boundary, 0), (left_boundary, input_image.height)],
        fill=line_color,
        width=line_width,
    )
    draw.line(
        [(right_boundary, 0), (right_boundary, input_image.height)],
        fill=line_color,
        width=line_width,
    )

    # Draw horizontal lines (for 16:9 crop)
    draw.line(
        [(left_boundary, top_boundary), (right_boundary, top_boundary)],
        fill=line_color,
        width=line_width,
    )
    draw.line(
        [(left_boundary, bottom_boundary), (right_boundary, bottom_boundary)],
        fill=line_color,
        width=line_width,
    )

    return cropped_image_16_9, image_with_lines, cropped_image_9_16


def analyze_image(numbered_input_image: Image, prompt, input_image):
    """
    Perform inference on an image using GPT-4V.

    Args:
        numbered_input_image (Image): PIL Image
        prompt (str): The prompt/question about the image
        input_image (Image): input image without numbers

    Returns:
        str: The model's response
    """
    client = OpenAI()
    base64_image = encode_image_to_base64(numbered_input_image, format="JPEG")

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                },
            ],
        }
    ]

    response = client.chat.completions.create(
        model="gpt-4o", messages=messages, max_tokens=300
    )

    messages.extend(
        [
            {"role": "assistant", "content": response.choices[0].message.content},
            {
                "role": "user",
                "content": "please return the response in the json with keys left_row and right_row",
            },
        ],
    )

    response = (
        client.chat.completions.create(model="gpt-4o", messages=messages)
        .choices[0]
        .message.content
    )

    left_index = response.find("{")
    right_index = response.rfind("}")

    try:
        if left_index != -1 and right_index != -1:
            response_json = eval(response[left_index : right_index + 1])
        cropped_image_16_9, image_with_lines, cropped_image_9_16 = (
            crop_and_draw_divisions(
                input_image=input_image,
                left_division=response_json["left_row"],
                right_division=response_json["right_row"],
            )
        )
    except Exception as e:
        print(e)
        return input_image, input_image, input_image

    return cropped_image_16_9, image_with_lines, cropped_image_9_16


def get_sprite_firebase(cid, rsid, uid):
    config = {
        "apiKey": f"{os.getenv('FIREBASE_API_KEY')}",
        "authDomain": f"{os.getenv('FIREBASE_AUTH_DOMAIN')}",
        "databaseURL": f"{os.getenv('FIREBASE_DATABASE_URL')}",
        "projectId": f"{os.getenv('FIREBASE_PROJECT_ID')}",
        "storageBucket": f"{os.getenv('FIREBASE_STORAGE_BUCKET')}",
        "messagingSenderId": f"{os.getenv('FIREBASE_MESSAGING_SENDER_ID')}",
        "appId": f"{os.getenv('FIREBASE_APP_ID')}",
        "measurementId": f"{os.getenv('FIREBASE_MEASUREMENT_ID')}",
    }
    config = {
        "apiKey": "AIzaSyB4n2UpGtWsTPj2qd9zChzLevhFkLPliXI",
        "authDomain": "roll-dev-7c14a.firebaseapp.com",
        "databaseURL": "https://roll-dev-7c14a-default-rtdb.firebaseio.com",
        "projectId": "roll-dev-7c14a",
        "storageBucket": "roll-dev-7c14a.firebasestorage.app",
        "messagingSenderId": "556047642295",
        "appId": "1:556047642295:web:be8714a223d3763efa2732",
        "measurementId": "G-RE6ZGE7DGG",
    }
    firebase = pyrebase.initialize_app(config)
    db = firebase.database()
    account_id = "roll-dev-account"  # os.getenv('ROLL_ACCOUNT')

    COLLAB_EDIT_LINK = "collab_sprite_link_handler"

    path = f"{account_id}/{COLLAB_EDIT_LINK}/{uid}/{cid}/{rsid}"

    data = db.child(path).get()
    return data.val()


def get_image_crop(cid=None, rsid=None, uid=None):
    """Function that returns both 16:9 and 9:16 crops"""
    image_paths = get_sprite_firebase(cid, rsid, uid)

    input_images = []
    mid_images = []
    cropped_image_16_9s = []
    images_with_lines = []
    cropped_image_9_16s = []

    for image_path in image_paths:
        response = requests.get(image_path)

        input_image = Image.open(BytesIO(response.content))
        input_images.append(input_image)

        # Get the middle thumbnail
        mid_image = get_middle_thumbnail(input_image)
        mid_images.append(mid_image)

        numbered_mid_image = add_top_numbers(
            input_image=mid_image,
            num_divisions=20,
            margin=50,
            font_size=30,
            dot_spacing=20,
        )

        cropped_image_16_9, image_with_lines, cropped_image_9_16 = analyze_image(
            numbered_mid_image, remove_unwanted_prompt(2), mid_image
        )
        cropped_image_16_9s.append(cropped_image_16_9)
        images_with_lines.append(image_with_lines)
        cropped_image_9_16s.append(cropped_image_9_16)

    return gr.Gallery(
        [
            *input_images,
            *mid_images,
            *cropped_image_16_9s,
            *images_with_lines,
            *cropped_image_9_16s,
        ]
    )