from typing import Dict import torch import torch.nn as nn from einops import repeat, rearrange from ...util import append_dims, instantiate_from_config from .denoiser_scaling import DenoiserScaling class DenoiserDub(nn.Module): def __init__(self, scaling_config: Dict, mask_input: bool = True): super().__init__() self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) self.mask_input = mask_input def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: return sigma def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: return c_noise def forward( self, network: nn.Module, input: torch.Tensor, sigma: torch.Tensor, cond: Dict, num_overlap_frames: int = 1, num_frames: int = 14, n_skips: int = 1, chunk_size: int = None, **additional_model_inputs, ) -> torch.Tensor: sigma = self.possibly_quantize_sigma(sigma) if input.ndim == 5: T = input.shape[2] input = rearrange(input, "b c t h w -> (b t) c h w") if sigma.shape[0] != input.shape[0]: sigma = repeat(sigma, "b ... -> b t ...", t=T) sigma = rearrange(sigma, "b t ... -> (b t) ...") sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) gt = cond.get("gt", torch.Tensor([]).type_as(input)) if gt.dim() == 5: gt = rearrange(gt, "b c t h w -> (b t) c h w") masks = cond.get("masks", None) if masks.dim() == 5: masks = rearrange(masks, "b c t h w -> (b t) c h w") if self.mask_input: input = input * masks + gt * (1.0 - masks) if chunk_size is not None: assert chunk_size % num_frames == 0, ( "Chunk size should be multiple of num_frames" ) out = chunk_network( network, input, c_in, c_noise, cond, additional_model_inputs, chunk_size, num_frames=num_frames, ) else: out = network(input * c_in, c_noise, cond, **additional_model_inputs) out = out * c_out + input * c_skip out = out * masks + gt * (1.0 - masks) return out class DenoiserTemporalMultiDiffusion(nn.Module): def __init__(self, scaling_config: Dict, is_dub: bool = False): super().__init__() self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) self.is_dub = is_dub def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: return sigma def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: return c_noise def forward( self, network: nn.Module, input: torch.Tensor, sigma: torch.Tensor, cond: Dict, num_overlap_frames: int, num_frames: int, n_skips: int, chunk_size: int = None, **additional_model_inputs, ) -> torch.Tensor: """ Args: network: Denoising network input: Noisy input sigma: Noise level cond: Dictionary containing additional information num_overlap_frames: Number of overlapping frames additional_model_inputs: Additional inputs for the denoising network Returns: out: Denoised output This function assumes the input is of shape (B, C, T, H, W) with the B dimension being the number of segments in video. The num_overlap_frames is the number of overlapping frames between the segments to be able to handle the temporal overlap. """ sigma = self.possibly_quantize_sigma(sigma) T = num_frames if input.ndim == 5: T = input.shape[2] input = rearrange(input, "b c t h w -> (b t) c h w") if sigma.shape[0] != input.shape[0]: sigma = repeat(sigma, "b ... -> b t ...", t=T) sigma = rearrange(sigma, "b t ... -> (b t) ...") n_skips = n_skips * input.shape[0] // T sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) if self.is_dub: gt = cond.get("gt", torch.Tensor([]).type_as(input)) if gt.dim() == 5: gt = rearrange(gt, "b c t h w -> (b t) c h w") masks = cond.get("masks", None) if masks.dim() == 5: masks = rearrange(masks, "b c t h w -> (b t) c h w") input = input * masks + gt * (1.0 - masks) # Now we want to find the overlapping frames and average them input = rearrange(input, "(b t) c h w -> b c t h w", t=T) # Overlapping frames are at begining and end of each segment and given by num_overlap_frames for i in range(input.shape[0] - n_skips): average_frame = torch.stack( [ input[i, :, -num_overlap_frames:], input[i + 1, :, :num_overlap_frames], ] ).mean(0) input[i, :, -num_overlap_frames:] = average_frame input[i + n_skips, :, :num_overlap_frames] = average_frame input = rearrange(input, "b c t h w -> (b t) c h w") if chunk_size is not None: assert chunk_size % num_frames == 0, ( "Chunk size should be multiple of num_frames" ) out = chunk_network( network, input, c_in, c_noise, cond, additional_model_inputs, chunk_size, num_frames=num_frames, ) else: out = network(input * c_in, c_noise, cond, **additional_model_inputs) out = out * c_out + input * c_skip if self.is_dub: out = out * masks + gt * (1.0 - masks) return out def chunk_network( network, input, c_in, c_noise, cond, additional_model_inputs, chunk_size, num_frames=1, ): out = [] for i in range(0, input.shape[0], chunk_size): start_idx = i end_idx = i + chunk_size input_chunk = input[start_idx:end_idx] c_in_chunk = ( c_in[start_idx:end_idx] if c_in.shape[0] == input.shape[0] else c_in[start_idx // num_frames : end_idx // num_frames] ) c_noise_chunk = ( c_noise[start_idx:end_idx] if c_noise.shape[0] == input.shape[0] else c_noise[start_idx // num_frames : end_idx // num_frames] ) cond_chunk = {} for k, v in cond.items(): if isinstance(v, torch.Tensor) and v.shape[0] == input.shape[0]: cond_chunk[k] = v[start_idx:end_idx] elif isinstance(v, torch.Tensor): cond_chunk[k] = v[start_idx // num_frames : end_idx // num_frames] else: cond_chunk[k] = v additional_model_inputs_chunk = {} for k, v in additional_model_inputs.items(): if isinstance(v, torch.Tensor): or_size = v.shape[0] additional_model_inputs_chunk[k] = repeat( v, "b c -> (b t) c", t=(input_chunk.shape[0] // num_frames // or_size) + 1, )[: cond_chunk["concat"].shape[0]] else: additional_model_inputs_chunk[k] = v out.append( network( input_chunk * c_in_chunk, c_noise_chunk, cond_chunk, **additional_model_inputs_chunk, ) ) return torch.cat(out, dim=0) class KarrasTemporalMultiDiffusion(DenoiserTemporalMultiDiffusion): def __init__(self, scaling_config: Dict): super().__init__(scaling_config) self.bad_network = None def set_bad_network(self, bad_network: nn.Module): self.bad_network = bad_network def split_inputs( self, input: torch.Tensor, cond: Dict, additional_model_inputs ) -> torch.Tensor: half_input = input.shape[0] // 2 first_cond_half = {} second_cond_half = {} for k, v in cond.items(): if isinstance(v, torch.Tensor): half_cond = v.shape[0] // 2 first_cond_half[k] = v[:half_cond] second_cond_half[k] = v[half_cond:] elif isinstance(v, list): half_add = v[0].shape[0] // 2 first_cond_half[k] = [v[i][:half_add] for i in range(len(v))] second_cond_half[k] = [v[i][half_add:] for i in range(len(v))] else: first_cond_half[k] = v second_cond_half[k] = v add_good = {} add_bad = {} for k, v in additional_model_inputs.items(): if isinstance(v, torch.Tensor): half_add = v.shape[0] // 2 add_good[k] = v[:half_add] add_bad[k] = v[half_add:] elif isinstance(v, list): half_add = v[0].shape[0] // 2 add_good[k] = [v[i][:half_add] for i in range(len(v))] add_bad[k] = [v[i][half_add:] for i in range(len(v))] else: add_good[k] = v add_bad[k] = v return ( input[:half_input], input[half_input:], first_cond_half, second_cond_half, add_good, add_bad, ) def forward( self, network: nn.Module, input: torch.Tensor, sigma: torch.Tensor, cond: Dict, num_overlap_frames: int, num_frames: int, n_skips: int, chunk_size: int = None, **additional_model_inputs, ) -> torch.Tensor: """ Args: network: Denoising network input: Noisy input sigma: Noise level cond: Dictionary containing additional information num_overlap_frames: Number of overlapping frames additional_model_inputs: Additional inputs for the denoising network Returns: out: Denoised output This function assumes the input is of shape (B, C, T, H, W) with the B dimension being the number of segments in video. The num_overlap_frames is the number of overlapping frames between the segments to be able to handle the temporal overlap. """ sigma = self.possibly_quantize_sigma(sigma) T = num_frames if input.ndim == 5: T = input.shape[2] input = rearrange(input, "b c t h w -> (b t) c h w") if sigma.shape[0] != input.shape[0]: sigma = repeat(sigma, "b ... -> b t ...", t=T) sigma = rearrange(sigma, "b t ... -> (b t) ...") n_skips = n_skips * input.shape[0] // T sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) if self.is_dub: gt = cond.get("gt", torch.Tensor([]).type_as(input)) if gt.dim() == 5: gt = rearrange(gt, "b c t h w -> (b t) c h w") masks = cond.get("masks", None) if masks.dim() == 5: masks = rearrange(masks, "b c t h w -> (b t) c h w") input = input * masks + gt * (1.0 - masks) # Now we want to find the overlapping frames and average them input = rearrange(input, "(b t) c h w -> b c t h w", t=T) # Overlapping frames are at begining and end of each segment and given by num_overlap_frames for i in range(input.shape[0] - n_skips): average_frame = torch.stack( [ input[i, :, -num_overlap_frames:], input[i + 1, :, :num_overlap_frames], ] ).mean(0) input[i, :, -num_overlap_frames:] = average_frame input[i + n_skips, :, :num_overlap_frames] = average_frame input = rearrange(input, "b c t h w -> (b t) c h w") half = c_in.shape[0] // 2 in_bad, in_good, cond_bad, cond_good, add_inputs_good, add_inputs_bad = ( self.split_inputs(input, cond, additional_model_inputs) ) if chunk_size is not None: assert chunk_size % num_frames == 0, ( "Chunk size should be multiple of num_frames" ) out = chunk_network( network, in_good, c_in[half:], c_noise[half:], cond_good, add_inputs_good, chunk_size, num_frames=num_frames, ) bad_out = chunk_network( self.bad_network, in_bad, c_in[:half], c_noise[:half], cond_bad, add_inputs_bad, chunk_size, num_frames=num_frames, ) else: out = network( in_good * c_in[half:], c_noise[half:], cond_good, **add_inputs_good ) bad_out = self.bad_network( in_bad * c_in[:half], c_noise[:half], cond_bad, **add_inputs_bad ) out = torch.cat([bad_out, out], dim=0) out = out * c_out + input * c_skip if self.is_dub: out = out * masks + gt * (1.0 - masks) return out