|
import base64 |
|
import os |
|
from io import BytesIO |
|
|
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import pyrebase |
|
import requests |
|
from openai import OpenAI |
|
from PIL import Image, ImageDraw, ImageFont |
|
from ultralytics import YOLO |
|
|
|
from prompts import remove_unwanted_prompt |
|
|
|
|
|
def get_middle_thumbnail(input_image: Image, grid_size=(10, 10), padding=3): |
|
""" |
|
Extract the middle thumbnail from a sprite sheet, handling different aspect ratios |
|
and removing padding. |
|
|
|
Args: |
|
input_image: PIL Image |
|
grid_size: Tuple of (columns, rows) |
|
padding: Number of padding pixels on each side (default 3) |
|
|
|
Returns: |
|
PIL.Image: The middle thumbnail image with padding removed |
|
""" |
|
sprite_sheet = input_image |
|
|
|
|
|
sprite_width, sprite_height = sprite_sheet.size |
|
thumb_width_with_padding = sprite_width // grid_size[0] |
|
thumb_height_with_padding = sprite_height // grid_size[1] |
|
|
|
|
|
thumb_width = thumb_width_with_padding - (2 * padding) |
|
thumb_height = thumb_height_with_padding - (2 * padding) |
|
|
|
|
|
total_thumbs = grid_size[0] * grid_size[1] |
|
middle_index = total_thumbs // 2 |
|
|
|
|
|
middle_row = middle_index // grid_size[0] |
|
middle_col = middle_index % grid_size[0] |
|
|
|
|
|
left = (middle_col * thumb_width_with_padding) + padding |
|
top = (middle_row * thumb_height_with_padding) + padding |
|
right = left + thumb_width |
|
bottom = top + thumb_height |
|
|
|
|
|
middle_thumb = sprite_sheet.crop((left, top, right, bottom)) |
|
return middle_thumb |
|
|
|
|
|
def get_person_bbox(frame, model): |
|
"""Detect person and return the largest bounding box""" |
|
results = model(frame, classes=[0]) |
|
|
|
if not results or len(results[0].boxes) == 0: |
|
return None |
|
|
|
|
|
boxes = results[0].boxes.xyxy.cpu().numpy() |
|
|
|
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
|
largest_idx = np.argmax(areas) |
|
|
|
return boxes[largest_idx] |
|
|
|
|
|
def generate_crops(frame): |
|
"""Generate both 16:9 and 9:16 crops based on person detection""" |
|
|
|
model = YOLO("yolo11n.pt") |
|
|
|
|
|
if isinstance(frame, Image.Image): |
|
frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR) |
|
|
|
original_height, original_width = frame.shape[:2] |
|
bbox = get_person_bbox(frame, model) |
|
|
|
if bbox is None: |
|
return None, None |
|
|
|
|
|
x1, y1, x2, y2 = map(int, bbox) |
|
person_height = y2 - y1 |
|
person_width = x2 - x1 |
|
person_center_x = (x1 + x2) // 2 |
|
person_center_y = (y1 + y2) // 2 |
|
|
|
|
|
aspect_ratio_16_9 = 16 / 9 |
|
crop_width_16_9 = min(original_width, int(person_height * aspect_ratio_16_9)) |
|
crop_height_16_9 = min(original_height, int(crop_width_16_9 / aspect_ratio_16_9)) |
|
|
|
|
|
x1_16_9 = max(0, person_center_x - crop_width_16_9 // 2) |
|
x2_16_9 = min(original_width, x1_16_9 + crop_width_16_9) |
|
y1_16_9 = max(0, y1) |
|
y2_16_9 = min(original_height, y1_16_9 + crop_height_16_9) |
|
|
|
|
|
if x2_16_9 > original_width: |
|
x1_16_9 = original_width - crop_width_16_9 |
|
x2_16_9 = original_width |
|
if y2_16_9 > original_height: |
|
y1_16_9 = original_height - crop_height_16_9 |
|
y2_16_9 = original_height |
|
|
|
|
|
aspect_ratio_9_16 = 9 / 16 |
|
crop_width_9_16 = min(original_width, int(person_height * aspect_ratio_9_16)) |
|
crop_height_9_16 = min(original_height, int(crop_width_9_16 / aspect_ratio_9_16)) |
|
|
|
|
|
x1_9_16 = max(0, person_center_x - crop_width_9_16 // 2) |
|
x2_9_16 = min(original_width, x1_9_16 + crop_width_9_16) |
|
y1_9_16 = max(0, person_center_y - crop_height_9_16 // 2) |
|
y2_9_16 = min(original_height, y1_9_16 + crop_height_9_16) |
|
|
|
|
|
if x2_9_16 > original_width: |
|
x1_9_16 = original_width - crop_width_9_16 |
|
x2_9_16 = original_width |
|
if y2_9_16 > original_height: |
|
y1_9_16 = original_height - crop_height_9_16 |
|
y2_9_16 = original_height |
|
|
|
|
|
crop_16_9 = frame[y1_16_9:y2_16_9, x1_16_9:x2_16_9] |
|
crop_9_16 = frame[y1_9_16:y2_9_16, x1_9_16:x2_9_16] |
|
|
|
|
|
crop_16_9 = cv2.resize(crop_16_9, (426, 240)) |
|
crop_9_16 = cv2.resize(crop_9_16, (240, 426)) |
|
|
|
return crop_16_9, crop_9_16 |
|
|
|
|
|
def visualize_crops(image, bbox, crops_info): |
|
""" |
|
Visualize original bbox and calculated crops |
|
bbox: [x1, y1, x2, y2] |
|
crops_info: dict with 'crop_16_9' and 'crop_9_16' coordinates |
|
""" |
|
viz = image.copy() |
|
|
|
|
|
cv2.rectangle( |
|
viz, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 0, 0), 2 |
|
) |
|
|
|
|
|
crop_16_9 = crops_info["crop_16_9"] |
|
cv2.rectangle( |
|
viz, |
|
(int(crop_16_9["x1"]), int(crop_16_9["y1"])), |
|
(int(crop_16_9["x2"]), int(crop_16_9["y2"])), |
|
(0, 255, 0), |
|
2, |
|
) |
|
|
|
|
|
crop_9_16 = crops_info["crop_9_16"] |
|
cv2.rectangle( |
|
viz, |
|
(int(crop_9_16["x1"]), int(crop_9_16["y1"])), |
|
(int(crop_9_16["x2"]), int(crop_9_16["y2"])), |
|
(0, 0, 255), |
|
2, |
|
) |
|
|
|
return viz |
|
|
|
|
|
def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str: |
|
""" |
|
Convert a PIL image to a base64 string. |
|
|
|
Args: |
|
image: PIL Image object |
|
format: Image format to use for encoding (default: PNG) |
|
|
|
Returns: |
|
Base64 encoded string of the image |
|
""" |
|
buffered = BytesIO() |
|
image.save(buffered, format=format) |
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
def add_top_numbers( |
|
input_image, |
|
num_divisions=20, |
|
margin=90, |
|
font_size=120, |
|
dot_spacing=20, |
|
): |
|
""" |
|
Add numbered divisions across the top and bottom of any image with dotted vertical lines. |
|
|
|
Args: |
|
input_image (Image): PIL Image |
|
num_divisions (int): Number of divisions to create |
|
margin (int): Size of margin in pixels for numbers |
|
font_size (int): Font size for numbers |
|
dot_spacing (int): Spacing between dots in pixels |
|
""" |
|
|
|
original_image = input_image |
|
|
|
|
|
new_width = original_image.width |
|
new_height = original_image.height + ( |
|
2 * margin |
|
) |
|
new_image = Image.new("RGB", (new_width, new_height), "white") |
|
|
|
|
|
new_image.paste(original_image, (0, margin)) |
|
|
|
|
|
draw = ImageDraw.Draw(new_image) |
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", font_size) |
|
except OSError: |
|
print("Using default font") |
|
font = ImageFont.load_default(size=font_size) |
|
|
|
|
|
division_width = original_image.width / num_divisions |
|
|
|
|
|
for i in range(num_divisions): |
|
x = (i * division_width) + (division_width / 2) |
|
|
|
|
|
draw.text((x, margin // 2), str(i + 1), fill="black", font=font, anchor="mm") |
|
|
|
|
|
draw.text( |
|
(x, new_height - (margin // 2)), |
|
str(i + 1), |
|
fill="black", |
|
font=font, |
|
anchor="mm", |
|
) |
|
|
|
|
|
y_start = margin |
|
y_end = new_height - margin |
|
|
|
|
|
current_y = y_start |
|
while current_y < y_end: |
|
draw.circle( |
|
[x - 1, current_y - 1, x + 1, current_y + 1], |
|
fill="black", |
|
width=5, |
|
radius=3, |
|
) |
|
current_y += dot_spacing |
|
|
|
return new_image |
|
|
|
|
|
def crop_and_draw_divisions( |
|
input_image, |
|
left_division, |
|
right_division, |
|
num_divisions=20, |
|
line_color=(255, 0, 0), |
|
line_width=2, |
|
head_margin_percent=0.1, |
|
): |
|
""" |
|
Create both 9:16 and 16:9 crops and draw guide lines. |
|
|
|
Args: |
|
input_image (Image): PIL Image |
|
left_division (int): Left-side division number (1-20) |
|
right_division (int): Right-side division number (1-20) |
|
num_divisions (int): Total number of divisions (default=20) |
|
line_color (tuple): RGB color tuple for lines (default: red) |
|
line_width (int): Width of lines in pixels (default: 2) |
|
head_margin_percent (float): Percentage margin above head (default: 0.1) |
|
|
|
Returns: |
|
tuple: (cropped_image_16_9, image_with_lines, cropped_image_9_16) |
|
""" |
|
yolo_model = YOLO("yolo11n.pt") |
|
|
|
division_width = input_image.width / num_divisions |
|
left_boundary = (left_division - 1) * division_width |
|
right_boundary = right_division * division_width |
|
|
|
|
|
cropped_image_9_16 = input_image.crop( |
|
(left_boundary, 0, right_boundary, input_image.height) |
|
) |
|
|
|
|
|
bbox = yolo_model(cropped_image_9_16, classes=[0])[0].boxes.xyxy.cpu().numpy()[0] |
|
x1, y1, x2, y2 = bbox |
|
|
|
|
|
head_margin = (y2 - y1) * head_margin_percent |
|
top_boundary = max(0, y1 - head_margin) |
|
|
|
|
|
crop_width = right_boundary - left_boundary |
|
crop_height_16_9 = int(crop_width * 9 / 16) |
|
|
|
|
|
bottom_boundary = min(input_image.height, top_boundary + crop_height_16_9) |
|
|
|
|
|
cropped_image_16_9 = input_image.crop( |
|
(left_boundary, top_boundary, right_boundary, bottom_boundary) |
|
) |
|
|
|
|
|
image_with_lines = input_image.copy() |
|
draw = ImageDraw.Draw(image_with_lines) |
|
|
|
|
|
draw.line( |
|
[(left_boundary, 0), (left_boundary, input_image.height)], |
|
fill=line_color, |
|
width=line_width, |
|
) |
|
draw.line( |
|
[(right_boundary, 0), (right_boundary, input_image.height)], |
|
fill=line_color, |
|
width=line_width, |
|
) |
|
|
|
|
|
draw.line( |
|
[(left_boundary, top_boundary), (right_boundary, top_boundary)], |
|
fill=line_color, |
|
width=line_width, |
|
) |
|
draw.line( |
|
[(left_boundary, bottom_boundary), (right_boundary, bottom_boundary)], |
|
fill=line_color, |
|
width=line_width, |
|
) |
|
|
|
return cropped_image_16_9, image_with_lines, cropped_image_9_16 |
|
|
|
|
|
def analyze_image(numbered_input_image: Image, prompt, input_image): |
|
""" |
|
Perform inference on an image using GPT-4V. |
|
|
|
Args: |
|
numbered_input_image (Image): PIL Image |
|
prompt (str): The prompt/question about the image |
|
input_image (Image): input image without numbers |
|
|
|
Returns: |
|
str: The model's response |
|
""" |
|
client = OpenAI() |
|
base64_image = encode_image_to_base64(numbered_input_image, format="JPEG") |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
{ |
|
"type": "image_url", |
|
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, |
|
}, |
|
], |
|
} |
|
] |
|
|
|
response = client.chat.completions.create( |
|
model="gpt-4o", messages=messages, max_tokens=300 |
|
) |
|
|
|
messages.extend( |
|
[ |
|
{"role": "assistant", "content": response.choices[0].message.content}, |
|
{ |
|
"role": "user", |
|
"content": "please return the response in the json with keys left_row and right_row", |
|
}, |
|
], |
|
) |
|
|
|
response = ( |
|
client.chat.completions.create(model="gpt-4o", messages=messages) |
|
.choices[0] |
|
.message.content |
|
) |
|
|
|
left_index = response.find("{") |
|
right_index = response.rfind("}") |
|
|
|
try: |
|
if left_index != -1 and right_index != -1: |
|
response_json = eval(response[left_index : right_index + 1]) |
|
cropped_image_16_9, image_with_lines, cropped_image_9_16 = ( |
|
crop_and_draw_divisions( |
|
input_image=input_image, |
|
left_division=response_json["left_row"], |
|
right_division=response_json["right_row"], |
|
) |
|
) |
|
except Exception as e: |
|
print(e) |
|
return input_image, input_image, input_image |
|
|
|
return cropped_image_16_9, image_with_lines, cropped_image_9_16 |
|
|
|
|
|
def get_sprite_firebase(cid, rsid, uid): |
|
config = { |
|
"apiKey": f"{os.getenv('FIREBASE_API_KEY')}", |
|
"authDomain": f"{os.getenv('FIREBASE_AUTH_DOMAIN')}", |
|
"databaseURL": f"{os.getenv('FIREBASE_DATABASE_URL')}", |
|
"projectId": f"{os.getenv('FIREBASE_PROJECT_ID')}", |
|
"storageBucket": f"{os.getenv('FIREBASE_STORAGE_BUCKET')}", |
|
"messagingSenderId": f"{os.getenv('FIREBASE_MESSAGING_SENDER_ID')}", |
|
"appId": f"{os.getenv('FIREBASE_APP_ID')}", |
|
"measurementId": f"{os.getenv('FIREBASE_MEASUREMENT_ID')}", |
|
} |
|
|
|
firebase = pyrebase.initialize_app(config) |
|
db = firebase.database() |
|
account_id = os.getenv("ROLL_ACCOUNT") |
|
|
|
COLLAB_EDIT_LINK = "collab_sprite_link_handler" |
|
|
|
path = f"{account_id}/{COLLAB_EDIT_LINK}/{uid}/{cid}/{rsid}" |
|
|
|
data = db.child(path).get() |
|
return data.val() |
|
|
|
|
|
def get_image_crop(cid=None, rsid=None, uid=None): |
|
"""Function that returns both 16:9 and 9:16 crops""" |
|
image_paths = get_sprite_firebase(cid, rsid, uid) |
|
|
|
input_images = [] |
|
mid_images = [] |
|
cropped_image_16_9s = [] |
|
images_with_lines = [] |
|
cropped_image_9_16s = [] |
|
|
|
for image_path in image_paths: |
|
response = requests.get(image_path) |
|
|
|
input_image = Image.open(BytesIO(response.content)) |
|
input_images.append(input_image) |
|
|
|
|
|
mid_image = get_middle_thumbnail(input_image) |
|
mid_images.append(mid_image) |
|
|
|
numbered_mid_image = add_top_numbers( |
|
input_image=mid_image, |
|
num_divisions=20, |
|
margin=50, |
|
font_size=30, |
|
dot_spacing=20, |
|
) |
|
|
|
cropped_image_16_9, image_with_lines, cropped_image_9_16 = analyze_image( |
|
numbered_mid_image, remove_unwanted_prompt(2), mid_image |
|
) |
|
cropped_image_16_9s.append(cropped_image_16_9) |
|
images_with_lines.append(image_with_lines) |
|
cropped_image_9_16s.append(cropped_image_9_16) |
|
|
|
return gr.Gallery( |
|
[ |
|
*input_images, |
|
*mid_images, |
|
*cropped_image_16_9s, |
|
*images_with_lines, |
|
*cropped_image_9_16s, |
|
] |
|
) |
|
|