simswapNSCR / run_faceswap.py
NeoSerge's picture
Upload run_faceswap.py
57d5944 verified
import os
import torch
from types import SimpleNamespace
from SimSwap.models.models import create_model
from PIL import Image
from torchvision import transforms
def run_faceswap(image_path_A, image_path_B, output_path):
opt = SimpleNamespace()
opt.name = "people"
opt.checkpoints_dir = "./checkpoints"
opt.gpu_ids = []
opt.isTrain = False
opt.resize_or_crop = 'none'
opt.crop_size = 224
opt.which_epoch = "latest"
# Añadir atributos requeridos para netG
opt.input_nc = 3
opt.output_nc = 3
opt.ngf = 64
opt.netG = 'unet_128'
model = create_model(opt)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
img_A = Image.open(image_path_A).convert("RGB")
img_B = Image.open(image_path_B).convert("RGB")
tensor_A = transform(img_A).unsqueeze(0)
tensor_B = transform(img_B).unsqueeze(0)
with torch.no_grad():
output = model(tensor_A, tensor_B)
output_img = transforms.ToPILImage()(output.squeeze(0))
output_img.save(output_path)