imagier / src /inference_paired.py
gaur3009's picture
Upload 10 files
a96e5dd verified
raw
history blame
3.17 kB
import os
import argparse
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as F
from pix2pix_turbo import Pix2Pix_Turbo
from image_prep import canny_from_pil
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
parser.add_argument('--prompt', type=str, required=True, help='the prompt to be used')
parser.add_argument('--model_name', type=str, default='', help='name of the pretrained model to be used')
parser.add_argument('--model_path', type=str, default='', help='path to a model state dict to be used')
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
parser.add_argument('--low_threshold', type=int, default=100, help='Canny low threshold')
parser.add_argument('--high_threshold', type=int, default=200, help='Canny high threshold')
parser.add_argument('--gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')
parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
args = parser.parse_args()
# only one of model_name and model_path should be provided
if args.model_name == '' != args.model_path == '':
raise ValueError('Either model_name or model_path should be provided')
os.makedirs(args.output_dir, exist_ok=True)
# initialize the model
model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
model.set_eval()
# make sure that the input image is a multiple of 8
input_image = Image.open(args.input_image).convert('RGB')
new_width = input_image.width - input_image.width % 8
new_height = input_image.height - input_image.height % 8
input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
bname = os.path.basename(args.input_image)
# translate the image
with torch.no_grad():
if args.model_name == 'edge_to_image':
canny = canny_from_pil(input_image, args.low_threshold, args.high_threshold)
canny_viz_inv = Image.fromarray(255 - np.array(canny))
canny_viz_inv.save(os.path.join(args.output_dir, bname.replace('.png', '_canny.png')))
c_t = F.to_tensor(canny).unsqueeze(0).cuda()
output_image = model(c_t, args.prompt)
elif args.model_name == 'sketch_to_image_stochastic':
image_t = F.to_tensor(input_image) < 0.5
c_t = image_t.unsqueeze(0).cuda().float()
torch.manual_seed(args.seed)
B, C, H, W = c_t.shape
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
output_image = model(c_t, args.prompt, deterministic=False, r=args.gamma, noise_map=noise)
else:
c_t = F.to_tensor(input_image).unsqueeze(0).cuda()
output_image = model(c_t, args.prompt)
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
# save the output image
output_pil.save(os.path.join(args.output_dir, bname))