cora / visualization /image_utils.py
armikaeili's picture
code added
79c5088
import torch
import numpy as np
import os
import re
import jsonc as json
from PIL import Image
def img_list_to_pil(img_list, cond_image = None, seperation = 10):
if cond_image is not None:
img_list.append(cond_image)
widths, heights = zip(*(i.size for i in img_list))
total_width = sum(widths) + seperation * len(img_list)
max_height = max(heights)
new_im = Image.new('RGB', (total_width, max_height))
x_offset = 0
for im in img_list:
new_im.paste(im, (x_offset, 0))
x_offset += im.size[0] + seperation
return new_im
def grid_image_visualize(images, row_size):
widths, heights = zip(*(i.size for i in images))
total_width = max(widths) * row_size + 10 * (row_size - 1)
max_height = max(heights) * ((len(images) + row_size - 1) // row_size)
new_im = Image.new('RGB', (total_width, max_height))
x_offset = 0
y_offset = 0
for i, im in enumerate(images):
new_im.paste(im, (x_offset, y_offset))
x_offset += im.size[0] + 10
if (i + 1) % row_size == 0:
x_offset = 0
y_offset += im.size[1]
return new_im
def process_images(images, res=512):
res_images = []
for image in images:
crop_size = min(image.size)
left = (image.size[0] - crop_size) // 2
top = (image.size[1] - crop_size) // 2
right = (image.size[0] + crop_size) // 2
bottom = (image.size[1] + crop_size) // 2
image = image.crop((left, top, right, bottom))
image = image.resize((res, res), Image.BILINEAR)
res_images.append(image)
return res_images
def sanitize_prompt(prompt: str, max_len: int = 50) -> str:
sanitized = re.sub(r'[^a-zA-Z0-9_\-]+', '_', prompt)
return sanitized[:max_len].strip("_")
def get_next_index(folder_path: str) -> int:
if not os.path.exists(folder_path):
return 0
pattern = re.compile(r'.*_(\d+)\.(?:png|json)$')
max_index = -1
for filename in os.listdir(folder_path):
match = pattern.match(filename)
if match:
idx = int(match.group(1))
if idx > max_index:
max_index = idx
return max_index + 1
def save_results(
args,
source_prompt: str,
target_prompt: str,
images: Image.Image,
):
src_name = sanitize_prompt(source_prompt)
tgt_name = sanitize_prompt(target_prompt)
folder_name = f"{src_name}#{tgt_name}"
output_dir = os.path.join(args.output_dir, folder_name)
os.makedirs(output_dir, exist_ok=True)
next_idx = get_next_index(output_dir)
concated_image = img_list_to_pil([images[0], images[-1]], cond_image=None, seperation=10)
concated_image.save(os.path.join(output_dir, f"concat_{next_idx}.png"))
images[0].save(os.path.join(output_dir, f"input_{next_idx}.png"))
images[-1].save(os.path.join(output_dir, f"output_{next_idx}.png"))
args_filename = f"args_{next_idx}.json"
args_path = os.path.join(output_dir, args_filename)
with open(args_path, "w") as f:
json.dump(vars(args), f, indent=4)
print(f"Saved image to {output_dir} and args to {args_path}")