import os import cv2 import math import json import torch import random import librosa import traceback import torchvision import numpy as np import pandas as pd from PIL import Image from einops import rearrange from torch.utils.data import Dataset from decord import VideoReader, cpu from transformers import CLIPImageProcessor import torchvision.transforms as transforms from torchvision.transforms import ToPILImage def get_audio_feature(feature_extractor, audio_path): audio_input, sampling_rate = librosa.load(audio_path, sr=16000) assert sampling_rate == 16000 audio_features = [] window = 750*640 for i in range(0, len(audio_input), window): audio_feature = feature_extractor(audio_input[i:i+window], sampling_rate=sampling_rate, return_tensors="pt", ).input_features audio_features.append(audio_feature) audio_features = torch.cat(audio_features, dim=-1) return audio_features, len(audio_input) // 640 class VideoAudioTextLoaderVal(Dataset): def __init__( self, image_size: int, meta_file: str, **kwargs, ): super().__init__() self.meta_file = meta_file self.image_size = image_size self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder self.feature_extractor = kwargs.get("feature_extractor", None) self.meta_files = [] csv_data = pd.read_csv(meta_file) for idx in range(len(csv_data)): self.meta_files.append( { "videoid": str(csv_data["videoid"][idx]), "image_path": str(csv_data["image"][idx]), "audio_path": str(csv_data["audio"][idx]), "prompt": str(csv_data["prompt"][idx]), "fps": float(csv_data["fps"][idx]) } ) self.llava_transform = transforms.Compose( [ transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), ] ) self.clip_image_processor = CLIPImageProcessor() self.device = torch.device("cuda") self.weight_dtype = torch.float16 def __len__(self): return len(self.meta_files) @staticmethod def get_text_tokens(text_encoder, description, dtype_encode="video"): text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode) text_ids = text_inputs["input_ids"].squeeze(0) text_mask = text_inputs["attention_mask"].squeeze(0) return text_ids, text_mask def get_batch_data(self, idx): meta_file = self.meta_files[idx] videoid = meta_file["videoid"] image_path = meta_file["image_path"] audio_path = meta_file["audio_path"] prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"] fps = meta_file["fps"] img_size = self.image_size ref_image = Image.open(image_path).convert('RGB') # Resize reference image w, h = ref_image.size scale = img_size / min(w, h) new_w = round(w * scale / 64) * 64 new_h = round(h * scale / 64) * 64 if img_size == 704: img_size_long = 1216 if new_w * new_h > img_size * img_size_long: import math scale = math.sqrt(img_size * img_size_long / w / h) new_w = round(w * scale / 64) * 64 new_h = round(h * scale / 64) * 64 ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS) ref_image = np.array(ref_image) ref_image = torch.from_numpy(ref_image) audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path) audio_prompts = audio_input[0] motion_bucket_id_heads = np.array([25] * 4) motion_bucket_id_exps = np.array([30] * 4) motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads) motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps) fps = torch.from_numpy(np.array(fps)) to_pil = ToPILImage() pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w) pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref] pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0) pixel_value_ref_clip = self.clip_image_processor( images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)), return_tensors="pt" ).pixel_values[0] pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0) # Encode text prompts text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt) text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt) # Output batch batch = { "text_prompt": prompt, # "videoid": videoid, "pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255) "pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围 "pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围 "audio_prompts": audio_prompts.to(dtype=torch.float16), "motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype), "motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype), "fps": fps.to(dtype=torch.float16), "text_ids": text_ids.clone(), # 对应llava_text_encoder "text_mask": text_mask.clone(), # 对应llava_text_encoder "text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder "text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder "audio_len": audio_len, "image_path": image_path, "audio_path": audio_path, } return batch def __getitem__(self, idx): return self.get_batch_data(idx)