simswapNSCR / run_faceswap.py
NeoSerge's picture
Upload run_faceswap.py
5258830 verified
raw
history blame
2.17 kB
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from types import SimpleNamespace
from SimSwap.models.models import create_model
from SimSwap.arcface.model import Backbone
DEVICE = torch.device("cpu")
opt_dict = {
"isTrain": False,
"Arc_path": "arcface_model.tar",
"which_epoch": "latest",
"load_pretrain": "checkpoints/simswap",
"crop_size": 224,
"resize_or_crop": "none",
"gan_mode": "hinge",
"no_ganFeat_loss": True,
"no_vgg_loss": True,
"lambda_feat": 10.0,
"lambda_rec": 10.0,
"beta1": 0.5,
"lr": 0.0002,
"continue_train": False,
"name": "simswap",
"checkpoints_dir": "./checkpoints",
"use_mask": True,
"dataset_mode": "Swap",
"model": "fs_swap",
"gpu_ids": [] # Esto evitará llamadas internas a CUDA
}
opt = SimpleNamespace(**opt_dict)
model = create_model(opt)
model.setup(opt)
model.device = DEVICE
if hasattr(model, 'netG'):
model.netG = model.netG.to(DEVICE)
if hasattr(model, 'netArc'):
model.netArc = model.netArc.to(DEVICE)
model.eval()
arcface = Backbone(50, 0.6, 'ir_se')
arcface.load_state_dict(torch.load(opt.Arc_path, map_location=DEVICE))
arcface.eval().to(DEVICE)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def run_faceswap(base_image_pil, face_image_pil):
base_img = cv2.cvtColor(np.array(base_image_pil), cv2.COLOR_RGB2BGR)
id_img = face_image_pil.convert("RGB").resize((112, 112))
id_tensor = transform(id_img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
id_embedding = arcface(id_tensor)
id_embedding = id_embedding / id_embedding.norm(dim=-1, keepdim=True)
temp_input_path = 'temp_base.jpg'
cv2.imwrite(temp_input_path, base_img)
opt.pic_a_path = temp_input_path
opt.pic_b_path = None
model.set_input(opt, id_embedding)
model.test()
swapped = model.get_current_visuals()['synthesized_image']
result_np = swapped.squeeze().permute(1, 2, 0).cpu().numpy() * 255
result_np = result_np.astype(np.uint8)
return result_np