Hunyuan-Avatar / hymm_sp /data_kits /audio_dataset.py
rahul7star's picture
Upload 99 files
357c94c verified
import os
import cv2
import math
import json
import torch
import random
import librosa
import traceback
import torchvision
import numpy as np
import pandas as pd
from PIL import Image
from einops import rearrange
from torch.utils.data import Dataset
from decord import VideoReader, cpu
from transformers import CLIPImageProcessor
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
def get_audio_feature(feature_extractor, audio_path):
audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
assert sampling_rate == 16000
audio_features = []
window = 750*640
for i in range(0, len(audio_input), window):
audio_feature = feature_extractor(audio_input[i:i+window],
sampling_rate=sampling_rate,
return_tensors="pt",
).input_features
audio_features.append(audio_feature)
audio_features = torch.cat(audio_features, dim=-1)
return audio_features, len(audio_input) // 640
class VideoAudioTextLoaderVal(Dataset):
def __init__(
self,
image_size: int,
meta_file: str,
**kwargs,
):
super().__init__()
self.meta_file = meta_file
self.image_size = image_size
self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder
self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder
self.feature_extractor = kwargs.get("feature_extractor", None)
self.meta_files = []
csv_data = pd.read_csv(meta_file)
for idx in range(len(csv_data)):
self.meta_files.append(
{
"videoid": str(csv_data["videoid"][idx]),
"image_path": str(csv_data["image"][idx]),
"audio_path": str(csv_data["audio"][idx]),
"prompt": str(csv_data["prompt"][idx]),
"fps": float(csv_data["fps"][idx])
}
)
self.llava_transform = transforms.Compose(
[
transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
]
)
self.clip_image_processor = CLIPImageProcessor()
self.device = torch.device("cuda")
self.weight_dtype = torch.float16
def __len__(self):
return len(self.meta_files)
@staticmethod
def get_text_tokens(text_encoder, description, dtype_encode="video"):
text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
text_ids = text_inputs["input_ids"].squeeze(0)
text_mask = text_inputs["attention_mask"].squeeze(0)
return text_ids, text_mask
def get_batch_data(self, idx):
meta_file = self.meta_files[idx]
videoid = meta_file["videoid"]
image_path = meta_file["image_path"]
audio_path = meta_file["audio_path"]
prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
fps = meta_file["fps"]
img_size = self.image_size
ref_image = Image.open(image_path).convert('RGB')
# Resize reference image
w, h = ref_image.size
scale = img_size / min(w, h)
new_w = round(w * scale / 64) * 64
new_h = round(h * scale / 64) * 64
if img_size == 704:
img_size_long = 1216
if new_w * new_h > img_size * img_size_long:
import math
scale = math.sqrt(img_size * img_size_long / w / h)
new_w = round(w * scale / 64) * 64
new_h = round(h * scale / 64) * 64
ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
ref_image = np.array(ref_image)
ref_image = torch.from_numpy(ref_image)
audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
audio_prompts = audio_input[0]
motion_bucket_id_heads = np.array([25] * 4)
motion_bucket_id_exps = np.array([30] * 4)
motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
fps = torch.from_numpy(np.array(fps))
to_pil = ToPILImage()
pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
pixel_value_ref_clip = self.clip_image_processor(
images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)),
return_tensors="pt"
).pixel_values[0]
pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
# Encode text prompts
text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
# Output batch
batch = {
"text_prompt": prompt, #
"videoid": videoid,
"pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255)
"pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围
"pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围
"audio_prompts": audio_prompts.to(dtype=torch.float16),
"motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype),
"motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype),
"fps": fps.to(dtype=torch.float16),
"text_ids": text_ids.clone(), # 对应llava_text_encoder
"text_mask": text_mask.clone(), # 对应llava_text_encoder
"text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder
"text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder
"audio_len": audio_len,
"image_path": image_path,
"audio_path": audio_path,
}
return batch
def __getitem__(self, idx):
return self.get_batch_data(idx)