Visualizr / src /visualizr /LIA_Model.py
MH0386's picture
Upload folder using huggingface_hub
3e165b2 verified
from torch import load, nn
from visualizr import logger
from visualizr.networks.encoder import Encoder
from visualizr.networks.styledecoder import Synthesis
class LIA_Model(nn.Module):
def __init__(
self,
size=256,
style_dim=512,
motion_dim=20,
channel_multiplier=1,
blur_kernel=[1, 3, 3, 1],
fusion_type="",
):
super().__init__()
self.enc = Encoder(size, style_dim, motion_dim, fusion_type)
self.dec = Synthesis(
size, style_dim, motion_dim, blur_kernel, channel_multiplier
)
def get_start_direction_code(self, x_start, x_target, x_face, x_aug):
enc_dic = self.enc(x_start, x_target, x_face, x_aug)
wa, alpha, feats = enc_dic["h_source"], enc_dic["h_motion"], enc_dic["feats"]
return wa, alpha, feats
def render(self, start, direction, feats):
return self.dec(start, direction, feats)
def load_lightning_model(self, lia_pretrained_model_path):
selfState = self.state_dict()
state = load(lia_pretrained_model_path, map_location="cpu")
for name, param in state.items():
origName = name
if name not in selfState:
name = name.replace("lia.", "")
if name not in selfState:
logger.exception("%s is not in the model." % origName)
# You can ignore those errors as some parameters are only used for training
continue
if selfState[name].size() != state[origName].size():
logger.exception(
"Wrong parameter length: %s, model: %s, loaded: %s"
% (origName, selfState[name].size(), state[origName].size())
)
continue
selfState[name].copy_(param)