import numpy as np import torch import cv2 from torchvision.transforms import Resize, InterpolationMode, ToTensor, Compose, CenterCrop from einops import rearrange import glob from diffusers.utils import USE_PEFT_BACKEND from diffusers.utils import load_image from natsort import natsorted def read_mask(mask_dir): transform = Compose([ Resize((512, 512), interpolation=InterpolationMode.BILINEAR, antialias=True), # CenterCrop((512, 512)), ToTensor()]) mask_paths = glob.glob(mask_dir + '/*.png') mask_paths = natsorted(mask_paths) mask_list = [] for mask_path in mask_paths: mask = load_image(mask_path) mask_torch = transform(mask).bool().unsqueeze(0) # torch.Size([1, 3, 512, 512]) -1~1 mask_list.append(mask_torch) return mask_list def read_rgb(rgb_dir): transform = Compose([ Resize((512, 512), interpolation=InterpolationMode.BILINEAR, antialias=True), # CenterCrop((512, 512)), ToTensor()]) rgb_paths = sorted(glob.glob(rgb_dir + '/*.jpg')) rgb_list = [] rgb_frame = [] for rgb_path in rgb_paths: rgb = load_image(rgb_path); width, height = rgb.size file_name = rgb_path.split('/')[-1] frame_number = int(file_name.split('_')[1].split('.')[0].lstrip('0') or '0') rgb_frame.append(frame_number) rgb_torch = transform(rgb).unsqueeze(0) # torch.Size([1, 3, 512, 512]) rgb_list.append(rgb_torch) return rgb_list, (width, height), rgb_frame def read_depth2disparity(depth_dir): depth_paths = sorted(glob.glob(depth_dir + '/*.npy')) disparity_list = [] for depth_path in depth_paths: depth = np.load(depth_path) depth = cv2.resize(depth, (512, 512)).reshape((512, 512, 1)) # [512,512,1] # depth = CenterCrop((512, 512))(torch.from_numpy(depth))[..., None].numpy() # [512,512,1] disparity = 1 / (depth + 1e-5) disparity_map = disparity / np.max(disparity) # 0.00233~1 # disparity_map = disparity_map.astype(np.uint8)[:,:,0] disparity_map = np.concatenate([disparity_map, disparity_map, disparity_map], axis=2) disparity_list.append(torch.from_numpy(disparity_map[None]).permute(0, 3, 1, 2).float()) # [1,512,512,3] return disparity_list def compute_attn(attn, query, key, value, video_length, ref_frame_index, attention_mask): key_ref_cross = rearrange(key, "(b f) d c -> b f d c", f=video_length) key_ref_cross = key_ref_cross[:, ref_frame_index] key_ref_cross = rearrange(key_ref_cross, "b f d c -> (b f) d c") value_ref_cross = rearrange(value, "(b f) d c -> b f d c", f=video_length) value_ref_cross = value_ref_cross[:, ref_frame_index] value_ref_cross = rearrange(value_ref_cross, "b f d c -> (b f) d c") key_ref_cross = attn.head_to_batch_dim(key_ref_cross) value_ref_cross = attn.head_to_batch_dim(value_ref_cross) attention_probs = attn.get_attention_scores(query, key_ref_cross, attention_mask) hidden_states_ref_cross = torch.bmm(attention_probs, value_ref_cross) return hidden_states_ref_cross class CrossViewAttnProcessor: def __init__(self, self_attn_coeff, unet_chunk_size=2): self.unet_chunk_size = unet_chunk_size self.self_attn_coeff = self_attn_coeff def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, scale=1.0, ): residual = hidden_states args = () if USE_PEFT_BACKEND else (scale,) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states, *args) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) query = attn.head_to_batch_dim(query) # Sparse Attention if not is_cross_attention: ################## Perform self attention key_self = attn.head_to_batch_dim(key) value_self = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key_self, attention_mask) hidden_states_self = torch.bmm(attention_probs, value_self) ####################################### video_length = key.size()[0] // self.unet_chunk_size ref0_frame_index = [0] * video_length ref1_frame_index = [1] * video_length ref2_frame_index = [2] * video_length ref3_frame_index = [3] * video_length hidden_states_ref0 = compute_attn(attn, query, key, value, video_length, ref0_frame_index, attention_mask) hidden_states_ref1 = compute_attn(attn, query, key, value, video_length, ref1_frame_index, attention_mask) hidden_states_ref2 = compute_attn(attn, query, key, value, video_length, ref2_frame_index, attention_mask) key = rearrange(key, "(b f) d c -> b f d c", f=video_length) key = key[:, ref3_frame_index] key = rearrange(key, "b f d c -> (b f) d c") value = rearrange(value, "(b f) d c -> b f d c", f=video_length) value = value[:, ref3_frame_index] value = rearrange(value, "b f d c -> (b f) d c") key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states_ref3 = torch.bmm(attention_probs, value) hidden_states = self.self_attn_coeff * hidden_states_self + (1 - self.self_attn_coeff) * torch.mean( torch.stack([hidden_states_ref0, hidden_states_ref1, hidden_states_ref2, hidden_states_ref3]), dim=0) if not is_cross_attention else hidden_states_ref3 hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states