import torch import numpy as np import os import re import jsonc as json from PIL import Image def img_list_to_pil(img_list, cond_image = None, seperation = 10): if cond_image is not None: img_list.append(cond_image) widths, heights = zip(*(i.size for i in img_list)) total_width = sum(widths) + seperation * len(img_list) max_height = max(heights) new_im = Image.new('RGB', (total_width, max_height)) x_offset = 0 for im in img_list: new_im.paste(im, (x_offset, 0)) x_offset += im.size[0] + seperation return new_im def grid_image_visualize(images, row_size): widths, heights = zip(*(i.size for i in images)) total_width = max(widths) * row_size + 10 * (row_size - 1) max_height = max(heights) * ((len(images) + row_size - 1) // row_size) new_im = Image.new('RGB', (total_width, max_height)) x_offset = 0 y_offset = 0 for i, im in enumerate(images): new_im.paste(im, (x_offset, y_offset)) x_offset += im.size[0] + 10 if (i + 1) % row_size == 0: x_offset = 0 y_offset += im.size[1] return new_im def process_images(images, res=512): res_images = [] for image in images: crop_size = min(image.size) left = (image.size[0] - crop_size) // 2 top = (image.size[1] - crop_size) // 2 right = (image.size[0] + crop_size) // 2 bottom = (image.size[1] + crop_size) // 2 image = image.crop((left, top, right, bottom)) image = image.resize((res, res), Image.BILINEAR) res_images.append(image) return res_images def sanitize_prompt(prompt: str, max_len: int = 50) -> str: sanitized = re.sub(r'[^a-zA-Z0-9_\-]+', '_', prompt) return sanitized[:max_len].strip("_") def get_next_index(folder_path: str) -> int: if not os.path.exists(folder_path): return 0 pattern = re.compile(r'.*_(\d+)\.(?:png|json)$') max_index = -1 for filename in os.listdir(folder_path): match = pattern.match(filename) if match: idx = int(match.group(1)) if idx > max_index: max_index = idx return max_index + 1 def save_results( args, source_prompt: str, target_prompt: str, images: Image.Image, ): src_name = sanitize_prompt(source_prompt) tgt_name = sanitize_prompt(target_prompt) folder_name = f"{src_name}#{tgt_name}" output_dir = os.path.join(args.output_dir, folder_name) os.makedirs(output_dir, exist_ok=True) next_idx = get_next_index(output_dir) concated_image = img_list_to_pil([images[0], images[-1]], cond_image=None, seperation=10) concated_image.save(os.path.join(output_dir, f"concat_{next_idx}.png")) images[0].save(os.path.join(output_dir, f"input_{next_idx}.png")) images[-1].save(os.path.join(output_dir, f"output_{next_idx}.png")) args_filename = f"args_{next_idx}.json" args_path = os.path.join(output_dir, args_filename) with open(args_path, "w") as f: json.dump(vars(args), f, indent=4) print(f"Saved image to {output_dir} and args to {args_path}")