Spaces:
Running
Running
| from diffusers import AutoencoderKL | |
| from typing import Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput | |
| from diffusers.models.autoencoders.vae import DecoderOutput | |
| class PixelMixer(nn.Module): | |
| def __init__(self, in_channels, downscale_factor): | |
| super(PixelMixer, self).__init__() | |
| self.downscale_factor = downscale_factor | |
| self.in_channels = in_channels | |
| def forward(self, x): | |
| latent = self.encode(x) | |
| out = self.decode(latent) | |
| return out | |
| def encode(self, x): | |
| return torch.nn.PixelUnshuffle(self.downscale_factor)(x) | |
| def decode(self, x): | |
| return torch.nn.PixelShuffle(self.downscale_factor)(x) | |
| # for reference | |
| # none of this matters with llvae, but we need to match the interface (latent_channels might matter) | |
| class Config: | |
| in_channels = 3 | |
| out_channels = 3 | |
| down_block_types = ('1', '1', | |
| '1', '1') | |
| up_block_types = ('1', '1', | |
| '1', '1') | |
| block_out_channels = (1, 1, 1, 1) | |
| latent_channels = 192 # usually 4 | |
| norm_num_groups = 32 | |
| sample_size = 512 | |
| # scaling_factor = 1 | |
| # shift_factor = 0 | |
| scaling_factor = 1.8 | |
| shift_factor = -0.123 | |
| # VAE | |
| # - Mean: -0.12306906282901764 | |
| # - Std: 0.556016206741333 | |
| # Normalization parameters: | |
| # - Shift factor: -0.12306906282901764 | |
| # - Scaling factor: 1.7985087266803625 | |
| def __getitem__(cls, x): | |
| return getattr(cls, x) | |
| class AutoencoderPixelMixer(nn.Module): | |
| def __init__(self, in_channels=3, downscale_factor=8): | |
| super().__init__() | |
| self.mixer = PixelMixer(in_channels, downscale_factor) | |
| self._dtype = torch.float32 | |
| self._device = torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu") | |
| self.config = Config() | |
| if downscale_factor == 8: | |
| # we go by len of block out channels in code, so simulate it | |
| self.config.block_out_channels = (1, 1, 1, 1) | |
| self.config.latent_channels = 192 | |
| elif downscale_factor == 16: | |
| # we go by len of block out channels in code, so simulate it | |
| self.config.block_out_channels = (1, 1, 1, 1, 1) | |
| self.config.latent_channels = 768 | |
| else: | |
| raise ValueError( | |
| f"downscale_factor {downscale_factor} not supported") | |
| def dtype(self): | |
| return self._dtype | |
| def dtype(self, value): | |
| self._dtype = value | |
| def device(self): | |
| return self._device | |
| def device(self, value): | |
| self._device = value | |
| # mimic to from torch | |
| def to(self, *args, **kwargs): | |
| # pull out dtype and device if they exist | |
| if 'dtype' in kwargs: | |
| self._dtype = kwargs['dtype'] | |
| if 'device' in kwargs: | |
| self._device = kwargs['device'] | |
| return super().to(*args, **kwargs) | |
| def enable_xformers_memory_efficient_attention(self): | |
| pass | |
| # @apply_forward_hook | |
| def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: | |
| h = self.mixer.encode(x) | |
| # moments = self.quant_conv(h) | |
| # posterior = DiagonalGaussianDistribution(moments) | |
| if not return_dict: | |
| return (h,) | |
| class FakeDist: | |
| def __init__(self, x): | |
| self._sample = x | |
| def sample(self): | |
| return self._sample | |
| return AutoencoderKLOutput(latent_dist=FakeDist(h)) | |
| def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: | |
| dec = self.mixer.decode(z) | |
| if not return_dict: | |
| return (dec,) | |
| return DecoderOutput(sample=dec) | |
| # @apply_forward_hook | |
| def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: | |
| decoded = self._decode(z).sample | |
| if not return_dict: | |
| return (decoded,) | |
| return DecoderOutput(sample=decoded) | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| pass | |
| def enable_tiling(self, use_tiling: bool = True): | |
| pass | |
| def disable_tiling(self): | |
| pass | |
| def enable_slicing(self): | |
| pass | |
| def disable_slicing(self): | |
| pass | |
| def set_use_memory_efficient_attention_xformers(self, value: bool = True): | |
| pass | |
| def forward( | |
| self, | |
| sample: torch.FloatTensor, | |
| sample_posterior: bool = False, | |
| return_dict: bool = True, | |
| generator: Optional[torch.Generator] = None, | |
| ) -> Union[DecoderOutput, torch.FloatTensor]: | |
| x = sample | |
| posterior = self.encode(x).latent_dist | |
| if sample_posterior: | |
| z = posterior.sample(generator=generator) | |
| else: | |
| z = posterior.mode() | |
| dec = self.decode(z).sample | |
| if not return_dict: | |
| return (dec,) | |
| return DecoderOutput(sample=dec) | |
| # test it | |
| if __name__ == '__main__': | |
| import os | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| user_path = os.path.expanduser('~') | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dtype = torch.float32 | |
| input_path = os.path.join(user_path, "Pictures/test/test.jpg") | |
| output_path = os.path.join(user_path, "Pictures/test/test.jpg") | |
| img = Image.open(input_path) | |
| img_tensor = transforms.ToTensor()(img) | |
| img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) | |
| print("input_shape: ", list(img_tensor.shape)) | |
| vae = PixelMixer(in_channels=3, downscale_factor=8) | |
| latent = vae.encode(img_tensor) | |
| print("latent_shape: ", list(latent.shape)) | |
| out_tensor = vae.decode(latent) | |
| print("out_shape: ", list(out_tensor.shape)) | |
| mse_loss = nn.MSELoss() | |
| mse = mse_loss(img_tensor, out_tensor) | |
| print("roundtrip_loss: ", mse.item()) | |
| out_img = transforms.ToPILImage()(out_tensor.squeeze(0)) | |
| out_img.save(output_path) | |