File size: 3,460 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import numpy as np
import os
import PIL
import pickle
import torch
import argparse
import json
from PIL import Image
import torch.nn as nn
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from src.data_loader import DataCollatorForSupervisedDataset, get_dataset
from src.data_processing import tensor_to_pil
from src.model_processing import get_model
from PIL import Image
from accelerate import Accelerator
from torch.utils.data import DataLoader
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="chameleon")
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--dataset_name", type=str, default="task3-movie-posters")
parser.add_argument("--split_name", type=str, default="test")
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--begin_id", default=0, type=int)
parser.add_argument("--n_take", default=-1, type=int)
args = parser.parse_args()
batch_size = args.batch_size
output_dir = args.output_dir
accelerator = Accelerator()
if accelerator.is_main_process and output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/original_images", exist_ok=True)
os.makedirs(f"{output_dir}/reconstructed_images", exist_ok=True)
os.makedirs(f"{output_dir}/results", exist_ok=True)
model, data_params = get_model(args.model_path, args.model_name)
dataset = get_dataset(args.dataset_name, args.split_name, None if args.n_take <= 0 else args.n_take)
data_collator = DataCollatorForSupervisedDataset(args.dataset_name, **data_params)
dataloader = DataLoader(
dataset, batch_size=batch_size, num_workers=0, collate_fn=data_collator
)
model, dataloader = accelerator.prepare(model, dataloader)
print("Model prepared...")
def save_results(
pixel_values, reconstructed_image, idx, output_dir, data_params
):
if reconstructed_image is None:
return
ori_img = tensor_to_pil(pixel_values, **data_params)
rec_img = tensor_to_pil(reconstructed_image, **data_params)
ori_img.save(f"{output_dir}/original_images/{idx:08d}.png")
rec_img.save(f"{output_dir}/reconstructed_images/{idx:08d}.png")
result = {
"ori_img": ori_img,
"rec_img": rec_img,
}
with open(f"{output_dir}/results/{idx:08d}.pickle", "wb") as fw:
pickle.dump(result, fw)
executor = ThreadPoolExecutor(max_workers=16)
with torch.no_grad():
print("Begin data loading...")
for batch in tqdm(dataloader):
pixel_values = batch["image"]
reconstructed_images = model(pixel_values)
if isinstance(reconstructed_images, tuple):
reconstructed_images = reconstructed_images[0]
if output_dir is not None:
idx_list = batch["idx"]
original_images = pixel_values.detach().cpu()
if not isinstance(reconstructed_images, list):
reconstructed_images = reconstructed_images.detach().cpu()
for i in range(pixel_values.shape[0]):
executor.submit(
save_results,
original_images[i],
reconstructed_images[i],
idx_list[i],
output_dir,
data_params,
)
executor.shutdown(wait=True)
|