Visualizr / src /visualizr /dataset.py
MH0386's picture
Upload folder using huggingface_hub
3e165b2 verified
import os
import random
from typing import Dict
import librosa
import numpy as np
import python_speech_features
import torchvision
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
class LatentDataLoader(object):
def __init__(
self,
window_size,
frame_jpgs,
lmd_feats_prefix,
audio_prefix,
raw_audio_prefix,
motion_latents_prefix,
pose_prefix,
db_name,
video_fps=25,
audio_hz=50,
size=256,
mfcc_mode=False,
):
self.window_size = window_size
self.lmd_feats_prefix = lmd_feats_prefix
self.audio_prefix = audio_prefix
self.pose_prefix = pose_prefix
self.video_fps = video_fps
self.audio_hz = audio_hz
self.db_name = db_name
self.raw_audio_prefix = raw_audio_prefix
self.mfcc_mode = mfcc_mode
self.transform = torchvision.transforms.Compose(
[
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
]
)
self.data = []
for db_name in ["VoxCeleb2", "HDTF"]:
db_png_path = os.path.join(frame_jpgs, db_name)
for clip_name in tqdm(os.listdir(db_png_path)):
item_dict: Dict = {}
item_dict["clip_name"] = clip_name
item_dict["frame_count"] = len(
list(os.listdir(os.path.join(frame_jpgs, db_name, clip_name)))
)
item_dict["hubert_path"] = os.path.join(
audio_prefix, db_name, clip_name + ".npy"
)
item_dict["wav_path"] = os.path.join(
raw_audio_prefix, db_name, clip_name + ".wav"
)
item_dict["yaw_pitch_roll_path"] = os.path.join(
pose_prefix,
db_name,
"raw_videos_pose_yaw_pitch_roll",
clip_name + ".npy",
)
if not os.path.exists(item_dict["yaw_pitch_roll_path"]):
print(f"{db_name}'s {clip_name} miss yaw_pitch_roll_path")
continue
item_dict["yaw_pitch_roll"] = np.load(item_dict["yaw_pitch_roll_path"])
item_dict["yaw_pitch_roll"] = (
np.clip(item_dict["yaw_pitch_roll"], -90, 90) / 90.0
)
if not os.path.exists(item_dict["wav_path"]):
print(f"{db_name}'s {clip_name} miss wav_path")
continue
if not os.path.exists(item_dict["hubert_path"]):
print(f"{db_name}'s {clip_name} miss hubert_path")
continue
if self.mfcc_mode:
wav, sr = librosa.load(item_dict["wav_path"], sr=16000)
input_values = python_speech_features.mfcc(
signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01
)
d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
input_values = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
item_dict["hubert_obj"] = input_values
else:
item_dict["hubert_obj"] = np.load(
item_dict["hubert_path"], mmap_mode="r"
)
item_dict["lmd_path"] = os.path.join(
lmd_feats_prefix, db_name, clip_name + ".txt"
)
item_dict["lmd_obj_full"] = self.read_landmark_info(
item_dict["lmd_path"], upper_face=False
)
motion_start_path = os.path.join(
motion_latents_prefix, db_name, "motions", clip_name + ".npy"
)
motion_direction_path = os.path.join(
motion_latents_prefix, db_name, "directions", clip_name + ".npy"
)
if not os.path.exists(motion_start_path):
print(f"{db_name}'s {clip_name} miss motion_start_path")
continue
if not os.path.exists(motion_direction_path):
print(f"{db_name}'s {clip_name} miss motion_direction_path")
continue
item_dict["motion_start_obj"] = np.load(motion_start_path)
item_dict["motion_direction_obj"] = np.load(motion_direction_path)
if self.mfcc_mode:
min_len = min(
item_dict["lmd_obj_full"].shape[0],
item_dict["yaw_pitch_roll"].shape[0],
item_dict["motion_start_obj"].shape[0],
item_dict["motion_direction_obj"].shape[0],
int(item_dict["hubert_obj"].shape[0] / 4),
item_dict["frame_count"],
)
item_dict["frame_count"] = min_len
item_dict["hubert_obj"] = item_dict["hubert_obj"][: min_len * 4, :]
else:
min_len = min(
item_dict["lmd_obj_full"].shape[0],
item_dict["yaw_pitch_roll"].shape[0],
item_dict["motion_start_obj"].shape[0],
item_dict["motion_direction_obj"].shape[0],
int(item_dict["hubert_obj"].shape[1] / 2),
item_dict["frame_count"],
)
item_dict["frame_count"] = min_len
item_dict["hubert_obj"] = item_dict["hubert_obj"][
:, : min_len * 2, :
]
if min_len < self.window_size * self.video_fps + 5:
continue
print("Db count:", len(self.data))
def get_single_image(self, image_path):
img_source = Image.open(image_path).convert("RGB")
img_source = self.transform(img_source)
return img_source
def get_multiple_ranges(self, lists, multi_ranges):
# Ensure that multi_ranges is a list of tuples
if not all(isinstance(item, tuple) and len(item) == 2 for item in multi_ranges):
raise ValueError(
"multi_ranges must be a list of (start, end) tuples with exactly two elements each"
)
extracted_elements = [lists[start:end] for start, end in multi_ranges]
return [item for sublist in extracted_elements for item in sublist]
def read_landmark_info(self, lmd_path, upper_face=True):
with open(lmd_path, "r") as file:
lmd_lines = file.readlines()
lmd_lines.sort()
total_lmd_obj = []
for i, line in enumerate(lmd_lines):
# Split the coordinates and filter out any empty strings
coords = [c for c in line.strip().split(" ") if c]
coords = coords[1:] # do not include the file name in the first row
lmd_obj = []
if upper_face:
# Ensure that the coordinates are parsed as integers
for coord_pair in self.get_multiple_ranges(
coords, [(0, 3), (14, 27), (36, 48)]
): # 28个
x, y = coord_pair.split("_")
lmd_obj.append((int(x) / 512, int(y) / 512))
else:
for coord_pair in coords:
x, y = coord_pair.split("_")
lmd_obj.append((int(x) / 512, int(y) / 512))
total_lmd_obj.append(lmd_obj)
return np.array(total_lmd_obj, dtype=np.float32)
def calculate_face_height(self, landmarks):
forehead_center = (landmarks[:, 21, :] + landmarks[:, 22, :]) / 2
chin_bottom = landmarks[:, 8, :]
distances = np.linalg.norm(forehead_center - chin_bottom, axis=1, keepdims=True)
return distances
def __getitem__(self, index):
data_item = self.data[index]
hubert_obj = data_item["hubert_obj"]
frame_count = data_item["frame_count"]
lmd_obj_full = data_item["lmd_obj_full"]
yaw_pitch_roll = data_item["yaw_pitch_roll"]
motion_start_obj = data_item["motion_start_obj"]
motion_direction_obj = data_item["motion_direction_obj"]
frame_end_index = random.randint(
self.window_size * self.video_fps + 1, frame_count - 1
)
frame_start_index = frame_end_index - self.window_size * self.video_fps
frame_hint_index = frame_start_index - 1
audio_start_index = int(frame_start_index * (self.audio_hz / self.video_fps))
audio_end_index = int(frame_end_index * (self.audio_hz / self.video_fps))
if self.mfcc_mode:
audio_feats = hubert_obj[audio_start_index:audio_end_index, :]
else:
audio_feats = hubert_obj[:, audio_start_index:audio_end_index, :]
lmd_obj_full = lmd_obj_full[frame_hint_index:frame_end_index, :]
yaw_pitch_roll = yaw_pitch_roll[frame_start_index:frame_end_index, :]
motion_start = motion_start_obj[frame_hint_index]
motion_direction_start = motion_direction_obj[frame_hint_index]
motion_direction = motion_direction_obj[frame_start_index:frame_end_index, :]
return {
"motion_start": motion_start,
"motion_direction": motion_direction,
"audio_feats": audio_feats,
# '1:' means taking the first frame as the driven frame.
# '30' is the noise location,
# '0' means x coordinate
"face_location": lmd_obj_full[1:, 30, 0],
"face_scale": self.calculate_face_height(lmd_obj_full[1:, :, :]),
"yaw_pitch_roll": yaw_pitch_roll,
"motion_direction_start": motion_direction_start,
}
def __len__(self):
return len(self.data)