seawolf2357's picture
Upload folder using huggingface_hub
684943d verified
import numpy as np
import torch
import cv2
from torchvision.transforms import Resize, InterpolationMode, ToTensor, Compose, CenterCrop
from einops import rearrange
import glob
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.utils import load_image
from natsort import natsorted
def read_mask(mask_dir):
transform = Compose([
Resize((512, 512), interpolation=InterpolationMode.BILINEAR, antialias=True),
# CenterCrop((512, 512)),
ToTensor()])
mask_paths = glob.glob(mask_dir + '/*.png')
mask_paths = natsorted(mask_paths)
mask_list = []
for mask_path in mask_paths:
mask = load_image(mask_path)
mask_torch = transform(mask).bool().unsqueeze(0) # torch.Size([1, 3, 512, 512]) -1~1
mask_list.append(mask_torch)
return mask_list
def read_rgb(rgb_dir):
transform = Compose([
Resize((512, 512), interpolation=InterpolationMode.BILINEAR, antialias=True),
# CenterCrop((512, 512)),
ToTensor()])
rgb_paths = sorted(glob.glob(rgb_dir + '/*.jpg'))
rgb_list = []
rgb_frame = []
for rgb_path in rgb_paths:
rgb = load_image(rgb_path);
width, height = rgb.size
file_name = rgb_path.split('/')[-1]
frame_number = int(file_name.split('_')[1].split('.')[0].lstrip('0') or '0')
rgb_frame.append(frame_number)
rgb_torch = transform(rgb).unsqueeze(0) # torch.Size([1, 3, 512, 512])
rgb_list.append(rgb_torch)
return rgb_list, (width, height), rgb_frame
def read_depth2disparity(depth_dir):
depth_paths = sorted(glob.glob(depth_dir + '/*.npy'))
disparity_list = []
for depth_path in depth_paths:
depth = np.load(depth_path)
depth = cv2.resize(depth, (512, 512)).reshape((512, 512, 1)) # [512,512,1]
# depth = CenterCrop((512, 512))(torch.from_numpy(depth))[..., None].numpy() # [512,512,1]
disparity = 1 / (depth + 1e-5)
disparity_map = disparity / np.max(disparity) # 0.00233~1
# disparity_map = disparity_map.astype(np.uint8)[:,:,0]
disparity_map = np.concatenate([disparity_map, disparity_map, disparity_map], axis=2)
disparity_list.append(torch.from_numpy(disparity_map[None]).permute(0, 3, 1, 2).float()) # [1,512,512,3]
return disparity_list
def compute_attn(attn, query, key, value, video_length, ref_frame_index, attention_mask):
key_ref_cross = rearrange(key, "(b f) d c -> b f d c", f=video_length)
key_ref_cross = key_ref_cross[:, ref_frame_index]
key_ref_cross = rearrange(key_ref_cross, "b f d c -> (b f) d c")
value_ref_cross = rearrange(value, "(b f) d c -> b f d c", f=video_length)
value_ref_cross = value_ref_cross[:, ref_frame_index]
value_ref_cross = rearrange(value_ref_cross, "b f d c -> (b f) d c")
key_ref_cross = attn.head_to_batch_dim(key_ref_cross)
value_ref_cross = attn.head_to_batch_dim(value_ref_cross)
attention_probs = attn.get_attention_scores(query, key_ref_cross, attention_mask)
hidden_states_ref_cross = torch.bmm(attention_probs, value_ref_cross)
return hidden_states_ref_cross
class CrossViewAttnProcessor:
def __init__(self, self_attn_coeff, unet_chunk_size=2):
self.unet_chunk_size = unet_chunk_size
self.self_attn_coeff = self_attn_coeff
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0, ):
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
# Sparse Attention
if not is_cross_attention:
################## Perform self attention
key_self = attn.head_to_batch_dim(key)
value_self = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key_self, attention_mask)
hidden_states_self = torch.bmm(attention_probs, value_self)
#######################################
video_length = key.size()[0] // self.unet_chunk_size
ref0_frame_index = [0] * video_length
ref1_frame_index = [1] * video_length
ref2_frame_index = [2] * video_length
ref3_frame_index = [3] * video_length
hidden_states_ref0 = compute_attn(attn, query, key, value, video_length, ref0_frame_index, attention_mask)
hidden_states_ref1 = compute_attn(attn, query, key, value, video_length, ref1_frame_index, attention_mask)
hidden_states_ref2 = compute_attn(attn, query, key, value, video_length, ref2_frame_index, attention_mask)
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
key = key[:, ref3_frame_index]
key = rearrange(key, "b f d c -> (b f) d c")
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
value = value[:, ref3_frame_index]
value = rearrange(value, "b f d c -> (b f) d c")
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states_ref3 = torch.bmm(attention_probs, value)
hidden_states = self.self_attn_coeff * hidden_states_self + (1 - self.self_attn_coeff) * torch.mean(
torch.stack([hidden_states_ref0, hidden_states_ref1, hidden_states_ref2, hidden_states_ref3]),
dim=0) if not is_cross_attention else hidden_states_ref3
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states