eawolf2357-git / utils /old_encode_video.py
seawolf2357's picture
Upload folder using huggingface_hub
321d89c verified
raw
history blame contribute delete
814 Bytes
def encode_video(video):
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
#print(video.shape)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
image = video[:, :, :1].clone()
with torch.no_grad():
latent_dist = vae.encode(video).latent_dist
#为什么训练的时候不引入参考图像呢?而是使用纯噪声,并且这个噪声的均值是-3,方差是0.5,例如tensor([0.0549])
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]
image_latent_dist = vae.encode(noisy_image).latent_dist
return latent_dist, image_latent_dist