Spaces:
Build error
Build error
| import os | |
| from os.path import join as opj | |
| import omegaconf | |
| import cv2 | |
| import einops | |
| import torch | |
| import torch as th | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from ldm.models.diffusion.ddpm import LatentDiffusion | |
| from ldm.util import instantiate_from_config | |
| class ControlLDM(LatentDiffusion): | |
| def __init__( | |
| self, | |
| control_stage_config, | |
| validation_config, | |
| control_key, | |
| only_mid_control, | |
| use_VAEdownsample=False, | |
| config_name="", | |
| control_scales=None, | |
| use_pbe_weight=False, | |
| u_cond_percent=0.0, | |
| img_H=512, | |
| img_W=384, | |
| always_learnable_param=False, | |
| *args, | |
| **kwargs | |
| ): | |
| self.control_stage_config = control_stage_config | |
| self.use_pbe_weight = use_pbe_weight | |
| self.u_cond_percent = u_cond_percent | |
| self.img_H = img_H | |
| self.img_W = img_W | |
| self.config_name = config_name | |
| self.always_learnable_param = always_learnable_param | |
| super().__init__(*args, **kwargs) | |
| control_stage_config.params["use_VAEdownsample"] = use_VAEdownsample | |
| self.control_model = instantiate_from_config(control_stage_config) | |
| self.control_key = control_key | |
| self.only_mid_control = only_mid_control | |
| if control_scales is None: | |
| self.control_scales = [1.0] * 13 | |
| else: | |
| self.control_scales = control_scales | |
| self.first_stage_key_cond = kwargs.get("first_stage_key_cond", None) | |
| self.valid_config = validation_config | |
| self.use_VAEDownsample = use_VAEdownsample | |
| def get_input(self, batch, k, bs=None, *args, **kwargs): | |
| x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs) | |
| if isinstance(self.control_key, omegaconf.listconfig.ListConfig): | |
| control_lst = [] | |
| for key in self.control_key: | |
| control = batch[key] | |
| if bs is not None: | |
| control = control[:bs] | |
| control = control.to(self.device) | |
| control = einops.rearrange(control, 'b h w c -> b c h w') | |
| control = control.to(memory_format=torch.contiguous_format).float() | |
| control_lst.append(control) | |
| control = control_lst | |
| else: | |
| control = batch[self.control_key] | |
| if bs is not None: | |
| control = control[:bs] | |
| control = control.to(self.device) | |
| control = einops.rearrange(control, 'b h w c -> b c h w') | |
| control = control.to(memory_format=torch.contiguous_format).float() | |
| control = [control] | |
| cond_dict = dict(c_crossattn=[c], c_concat=control) | |
| if self.first_stage_key_cond is not None: | |
| first_stage_cond = [] | |
| for key in self.first_stage_key_cond: | |
| if not "mask" in key: | |
| cond, _ = super().get_input(batch, key, *args, **kwargs) | |
| else: | |
| cond, _ = super().get_input(batch, key, no_latent=True, *args, **kwargs) | |
| first_stage_cond.append(cond) | |
| first_stage_cond = torch.cat(first_stage_cond, dim=1) | |
| cond_dict["first_stage_cond"] = first_stage_cond | |
| return x, cond_dict | |
| def apply_model(self, x_noisy, t, cond, *args, **kwargs): | |
| assert isinstance(cond, dict) | |
| diffusion_model = self.model.diffusion_model | |
| cond_txt = torch.cat(cond["c_crossattn"], 1) | |
| if self.proj_out is not None: | |
| if cond_txt.shape[-1] == 1024: | |
| cond_txt = self.proj_out(cond_txt) # [BS x 1 x 768] | |
| if self.always_learnable_param: | |
| cond_txt = self.get_unconditional_conditioning(cond_txt.shape[0]) | |
| if cond['c_concat'] is None: | |
| eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control) | |
| else: | |
| if "first_stage_cond" in cond: | |
| x_noisy = torch.cat([x_noisy, cond["first_stage_cond"]], dim=1) | |
| if not self.use_VAEDownsample: | |
| hint = cond["c_concat"] | |
| else: | |
| hint = [] | |
| for h in cond["c_concat"]: | |
| if h.shape[2] == self.img_H and h.shape[3] == self.img_W: | |
| h = self.encode_first_stage(h) | |
| h = self.get_first_stage_encoding(h).detach() | |
| hint.append(h) | |
| hint = torch.cat(hint, dim=1) | |
| control, _ = self.control_model(x=x_noisy, hint=hint, timesteps=t, context=cond_txt, only_mid_control=self.only_mid_control) | |
| if len(control) == len(self.control_scales): | |
| control = [c * scale for c, scale in zip(control, self.control_scales)] | |
| eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control) | |
| return eps, None | |
| def get_unconditional_conditioning(self, N): | |
| if not self.kwargs["use_imageCLIP"]: | |
| return self.get_learned_conditioning([""] * N) | |
| else: | |
| return self.learnable_vector.repeat(N,1,1) | |
| def low_vram_shift(self, is_diffusing): | |
| if is_diffusing: | |
| self.model = self.model.cuda() | |
| self.control_model = self.control_model.cuda() | |
| self.first_stage_model = self.first_stage_model.cpu() | |
| self.cond_stage_model = self.cond_stage_model.cpu() | |
| else: | |
| self.model = self.model.cpu() | |
| self.control_model = self.control_model.cpu() | |
| self.first_stage_model = self.first_stage_model.cuda() | |
| self.cond_stage_model = self.cond_stage_model.cuda() |