Spaces:
Build error
Build error
import torch | |
import numpy as np | |
def faceshifter_batch(source_emb: torch.tensor, | |
target: torch.tensor, | |
G: torch.nn.Module) -> np.ndarray: | |
""" | |
Apply faceshifter model for batch of target images | |
""" | |
bs = target.shape[0] | |
assert target.ndim == 4, "target should have 4 dimentions -- B x C x H x W" | |
if bs > 1: | |
source_emb = torch.cat([source_emb]*bs) | |
with torch.no_grad(): | |
Y_st, _ = G(target, source_emb) | |
Y_st = (Y_st.permute(0, 2, 3, 1)*0.5 + 0.5)*255 | |
Y_st = Y_st[:, :, :, [2,1,0]].type(torch.uint8) | |
Y_st = Y_st.cpu().detach().numpy() | |
return Y_st |