Spaces:
Build error
Build error
File size: 7,013 Bytes
357c94c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import cv2
import math
import json
import torch
import random
import librosa
import traceback
import torchvision
import numpy as np
import pandas as pd
from PIL import Image
from einops import rearrange
from torch.utils.data import Dataset
from decord import VideoReader, cpu
from transformers import CLIPImageProcessor
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
def get_audio_feature(feature_extractor, audio_path):
audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
assert sampling_rate == 16000
audio_features = []
window = 750*640
for i in range(0, len(audio_input), window):
audio_feature = feature_extractor(audio_input[i:i+window],
sampling_rate=sampling_rate,
return_tensors="pt",
).input_features
audio_features.append(audio_feature)
audio_features = torch.cat(audio_features, dim=-1)
return audio_features, len(audio_input) // 640
class VideoAudioTextLoaderVal(Dataset):
def __init__(
self,
image_size: int,
meta_file: str,
**kwargs,
):
super().__init__()
self.meta_file = meta_file
self.image_size = image_size
self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder
self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder
self.feature_extractor = kwargs.get("feature_extractor", None)
self.meta_files = []
csv_data = pd.read_csv(meta_file)
for idx in range(len(csv_data)):
self.meta_files.append(
{
"videoid": str(csv_data["videoid"][idx]),
"image_path": str(csv_data["image"][idx]),
"audio_path": str(csv_data["audio"][idx]),
"prompt": str(csv_data["prompt"][idx]),
"fps": float(csv_data["fps"][idx])
}
)
self.llava_transform = transforms.Compose(
[
transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
]
)
self.clip_image_processor = CLIPImageProcessor()
self.device = torch.device("cuda")
self.weight_dtype = torch.float16
def __len__(self):
return len(self.meta_files)
@staticmethod
def get_text_tokens(text_encoder, description, dtype_encode="video"):
text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
text_ids = text_inputs["input_ids"].squeeze(0)
text_mask = text_inputs["attention_mask"].squeeze(0)
return text_ids, text_mask
def get_batch_data(self, idx):
meta_file = self.meta_files[idx]
videoid = meta_file["videoid"]
image_path = meta_file["image_path"]
audio_path = meta_file["audio_path"]
prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
fps = meta_file["fps"]
img_size = self.image_size
ref_image = Image.open(image_path).convert('RGB')
# Resize reference image
w, h = ref_image.size
scale = img_size / min(w, h)
new_w = round(w * scale / 64) * 64
new_h = round(h * scale / 64) * 64
if img_size == 704:
img_size_long = 1216
if new_w * new_h > img_size * img_size_long:
import math
scale = math.sqrt(img_size * img_size_long / w / h)
new_w = round(w * scale / 64) * 64
new_h = round(h * scale / 64) * 64
ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
ref_image = np.array(ref_image)
ref_image = torch.from_numpy(ref_image)
audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
audio_prompts = audio_input[0]
motion_bucket_id_heads = np.array([25] * 4)
motion_bucket_id_exps = np.array([30] * 4)
motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
fps = torch.from_numpy(np.array(fps))
to_pil = ToPILImage()
pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
pixel_value_ref_clip = self.clip_image_processor(
images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)),
return_tensors="pt"
).pixel_values[0]
pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
# Encode text prompts
text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
# Output batch
batch = {
"text_prompt": prompt, #
"videoid": videoid,
"pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255)
"pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围
"pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围
"audio_prompts": audio_prompts.to(dtype=torch.float16),
"motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype),
"motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype),
"fps": fps.to(dtype=torch.float16),
"text_ids": text_ids.clone(), # 对应llava_text_encoder
"text_mask": text_mask.clone(), # 对应llava_text_encoder
"text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder
"text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder
"audio_len": audio_len,
"image_path": image_path,
"audio_path": audio_path,
}
return batch
def __getitem__(self, idx):
return self.get_batch_data(idx)
|