import numpy as np import os import PIL import pickle import torch import argparse import json from PIL import Image import torch.nn as nn import torch from transformers import AutoProcessor, AutoModelForImageTextToText from src.data_loader import DataCollatorForSupervisedDataset, get_dataset from src.data_processing import tensor_to_pil from src.model_processing import get_model from PIL import Image from accelerate import Accelerator from torch.utils.data import DataLoader from tqdm import tqdm from concurrent.futures import ThreadPoolExecutor parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, default="chameleon") parser.add_argument("--model_path", type=str, default=None) parser.add_argument("--dataset_name", type=str, default="task3-movie-posters") parser.add_argument("--split_name", type=str, default="test") parser.add_argument("--batch_size", default=8, type=int) parser.add_argument("--output_dir", type=str, default=None) parser.add_argument("--begin_id", default=0, type=int) parser.add_argument("--n_take", default=-1, type=int) args = parser.parse_args() batch_size = args.batch_size output_dir = args.output_dir accelerator = Accelerator() if accelerator.is_main_process and output_dir is not None: os.makedirs(output_dir, exist_ok=True) os.makedirs(f"{output_dir}/original_images", exist_ok=True) os.makedirs(f"{output_dir}/reconstructed_images", exist_ok=True) os.makedirs(f"{output_dir}/results", exist_ok=True) model, data_params = get_model(args.model_path, args.model_name) dataset = get_dataset(args.dataset_name, args.split_name, None if args.n_take <= 0 else args.n_take) data_collator = DataCollatorForSupervisedDataset(args.dataset_name, **data_params) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=0, collate_fn=data_collator ) model, dataloader = accelerator.prepare(model, dataloader) print("Model prepared...") def save_results( pixel_values, reconstructed_image, idx, output_dir, data_params ): if reconstructed_image is None: return ori_img = tensor_to_pil(pixel_values, **data_params) rec_img = tensor_to_pil(reconstructed_image, **data_params) ori_img.save(f"{output_dir}/original_images/{idx:08d}.png") rec_img.save(f"{output_dir}/reconstructed_images/{idx:08d}.png") result = { "ori_img": ori_img, "rec_img": rec_img, } with open(f"{output_dir}/results/{idx:08d}.pickle", "wb") as fw: pickle.dump(result, fw) executor = ThreadPoolExecutor(max_workers=16) with torch.no_grad(): print("Begin data loading...") for batch in tqdm(dataloader): pixel_values = batch["image"] reconstructed_images = model(pixel_values) if isinstance(reconstructed_images, tuple): reconstructed_images = reconstructed_images[0] if output_dir is not None: idx_list = batch["idx"] original_images = pixel_values.detach().cpu() if not isinstance(reconstructed_images, list): reconstructed_images = reconstructed_images.detach().cpu() for i in range(pixel_values.shape[0]): executor.submit( save_results, original_images[i], reconstructed_images[i], idx_list[i], output_dir, data_params, ) executor.shutdown(wait=True)