Spaces:
Running
on
Zero
Running
on
Zero
File size: 29,616 Bytes
b5ce381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 |
import random
import numpy as np
from functools import partial
from torch.utils.data import Dataset, WeightedRandomSampler
import torch.nn.functional as F
import torch
import math
import decord
from einops import rearrange
from more_itertools import sliding_window
from omegaconf import ListConfig
import torchaudio
import soundfile as sf
from torchvision.transforms import RandomHorizontalFlip
from audiomentations import Compose, AddGaussianNoise, PitchShift
from safetensors.torch import load_file
from tqdm import tqdm
import cv2
from sgm.data.data_utils import (
create_masks_from_landmarks_full_size,
create_face_mask_from_landmarks,
create_masks_from_landmarks_box,
create_masks_from_landmarks_mouth,
)
from sgm.data.mask import face_mask_cheeks_batch
torchaudio.set_audio_backend("sox_io")
decord.bridge.set_bridge("torch")
def exists(x):
return x is not None
def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
len_file = audio.shape[-1]
if max_len_sec or max_len_raw:
max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
if len_file < int(max_len):
extened_wav = torch.nn.functional.pad(
audio, (0, int(max_len) - len_file), "constant"
)
else:
extened_wav = audio[:, : int(max_len)]
else:
extened_wav = audio
return extened_wav
# Similar to regular video dataset but trades flexibility for speed
class VideoDataset(Dataset):
def __init__(
self,
filelist,
resize_size=None,
audio_folder="Audio",
video_folder="CroppedVideos",
emotions_folder="emotions",
landmarks_folder=None,
audio_emb_folder=None,
video_extension=".avi",
audio_extension=".wav",
audio_rate=16000,
latent_folder=None,
audio_in_video=False,
fps=25,
num_frames=5,
need_cond=True,
step=1,
mode="prediction",
scale_audio=False,
augment=False,
augment_audio=False,
use_latent=False,
latent_type="stable",
latent_scale=1, # For backwards compatibility
from_audio_embedding=False,
load_all_possible_indexes=False,
audio_emb_type="wavlm",
cond_noise=[-3.0, 0.5],
motion_id=255.0,
data_mean=None,
data_std=None,
use_latent_condition=False,
skip_frames=0,
get_separate_id=False,
virtual_increase=1,
filter_by_length=False,
select_randomly=False,
balance_datasets=True,
use_emotions=False,
get_original_frames=False,
add_extra_audio_emb=False,
expand_box=0.0,
nose_index=28,
what_mask="full",
get_masks=False,
):
self.audio_folder = audio_folder
self.from_audio_embedding = from_audio_embedding
self.audio_emb_type = audio_emb_type
self.cond_noise = cond_noise
self.latent_condition = use_latent_condition
precomputed_latent = latent_type
self.audio_emb_folder = (
audio_emb_folder if audio_emb_folder is not None else audio_folder
)
self.skip_frames = skip_frames
self.get_separate_id = get_separate_id
self.fps = fps
self.virtual_increase = virtual_increase
self.select_randomly = select_randomly
self.use_emotions = use_emotions
self.emotions_folder = emotions_folder
self.get_original_frames = get_original_frames
self.add_extra_audio_emb = add_extra_audio_emb
self.expand_box = expand_box
self.nose_index = nose_index
self.landmarks_folder = landmarks_folder
self.what_mask = what_mask
self.get_masks = get_masks
assert not (exists(data_mean) ^ exists(data_std)), (
"Both data_mean and data_std should be provided"
)
if data_mean is not None:
data_mean = rearrange(torch.as_tensor(data_mean), "c -> c () () ()")
data_std = rearrange(torch.as_tensor(data_std), "c -> c () () ()")
self.data_mean = data_mean
self.data_std = data_std
self.motion_id = motion_id
self.latent_folder = (
latent_folder if latent_folder is not None else video_folder
)
self.audio_in_video = audio_in_video
self.filelist = []
self.audio_filelist = []
self.landmark_filelist = [] if get_masks else None
with open(filelist, "r") as files:
for f in files.readlines():
f = f.rstrip()
audio_path = f.replace(video_folder, audio_folder).replace(
video_extension, audio_extension
)
self.filelist += [f]
self.audio_filelist += [audio_path]
if self.get_masks:
landmark_path = f.replace(video_folder, landmarks_folder).replace(
video_extension, ".npy"
)
self.landmark_filelist += [landmark_path]
self.resize_size = resize_size
if use_latent and not precomputed_latent:
self.resize_size *= 4 if latent_type in ["stable", "ldm"] else 8
self.scale_audio = scale_audio
self.step = step
self.use_latent = use_latent
self.precomputed_latent = precomputed_latent
self.latent_type = latent_type
self.latent_scale = latent_scale
self.video_ext = video_extension
self.video_folder = video_folder
self.augment = augment
self.maybe_augment = RandomHorizontalFlip(p=0.5) if augment else lambda x: x
self.maybe_augment_audio = (
Compose(
[
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.002, p=0.25),
# TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3),
PitchShift(min_semitones=-1, max_semitones=1, p=0.25),
# Shift(min_fraction=-0.5, max_fraction=0.5, p=0.333),
]
)
if augment_audio
else lambda x, sample_rate: x
)
self.maybe_augment_audio = partial(
self.maybe_augment_audio, sample_rate=audio_rate
)
self.mode = mode
if mode == "interpolation":
need_cond = False # Interpolation does not need condition as first and last frame becomes the condition
self.need_cond = need_cond # If need cond will extract one more frame than the number of frames
if get_separate_id:
self.need_cond = True
# It is used for the conditional model when the condition is not on the temporal dimension
num_frames = num_frames if not self.need_cond else num_frames + 1
vr = decord.VideoReader(self.filelist[0])
self.video_rate = math.ceil(vr.get_avg_fps())
print(f"Video rate: {self.video_rate}")
self.audio_rate = audio_rate
a2v_ratio = fps / float(self.audio_rate)
self.samples_per_frame = math.ceil(1 / a2v_ratio)
if get_separate_id:
assert mode == "prediction", (
"Separate identity frame is only supported for prediction mode"
)
# No need for extra frame if we are getting a separate identity frame
self.need_cond = True
num_frames -= 1
self.num_frames = num_frames
self.load_all_possible_indexes = load_all_possible_indexes
if load_all_possible_indexes:
self._indexes = self._get_indexes(
self.filelist, self.audio_filelist, self.landmark_filelist
)
else:
if filter_by_length:
self._indexes = self.filter_by_length(
self.filelist, self.audio_filelist, self.landmark_filelist
)
else:
if self.get_masks:
self._indexes = list(
zip(self.filelist, self.audio_filelist, self.landmark_filelist)
)
else:
self._indexes = list(
zip(
self.filelist,
self.audio_filelist,
[None] * len(self.filelist),
)
)
self.balance_datasets = balance_datasets
if self.balance_datasets:
self.weights = self._calculate_weights()
self.sampler = WeightedRandomSampler(
self.weights, num_samples=len(self._indexes), replacement=True
)
def __len__(self):
return len(self._indexes) * self.virtual_increase
def _load_landmarks(self, filename, original_size, target_size, indexes):
landmarks = np.load(filename, allow_pickle=True)[indexes, :]
if self.what_mask == "full":
mask = create_masks_from_landmarks_full_size(
landmarks,
original_size[0],
original_size[1],
offset=self.expand_box,
nose_index=self.nose_index,
)
elif self.what_mask == "box":
mask = create_masks_from_landmarks_box(
landmarks,
(original_size[0], original_size[1]),
box_expand=self.expand_box,
nose_index=self.nose_index,
)
elif self.what_mask == "heart":
mask = face_mask_cheeks_batch(
original_size, landmarks, box_expand=0.0, show_nose=True
)
elif self.what_mask == "mouth":
mask = create_masks_from_landmarks_mouth(
landmarks,
(original_size[0], original_size[1]),
box_expand=0.01,
nose_index=self.nose_index,
)
else:
mask = create_face_mask_from_landmarks(
landmarks, original_size[0], original_size[1], mask_expand=0.05
)
# Interpolate the mask to the target size
mask = F.interpolate(
mask.unsqueeze(1).float(), size=target_size, mode="nearest"
)
return mask, landmarks
def get_emotions(self, video_file, video_indexes):
emotions_path = video_file.replace(
self.video_folder, self.emotions_folder
).replace(self.video_ext, ".pt")
emotions = torch.load(emotions_path)
return (
emotions["valence"][video_indexes],
emotions["arousal"][video_indexes],
emotions["labels"][video_indexes],
)
def get_frame_indices(self, total_video_frames, select_randomly=False, start_idx=0):
if select_randomly:
# Randomly select self.num_frames indices from the available range
available_indices = list(range(start_idx, total_video_frames))
if len(available_indices) < self.num_frames:
raise ValueError(
"Not enough frames in the video to sample with given parameters."
)
indexes = random.sample(available_indices, self.num_frames)
return sorted(indexes) # Sort to maintain temporal order
else:
# Calculate the maximum possible start index
max_start_idx = total_video_frames - (
(self.num_frames - 1) * (self.skip_frames + 1) + 1
)
# Generate a random start index
if max_start_idx > 0:
start_idx = np.random.randint(start_idx, max_start_idx)
else:
raise ValueError(
"Not enough frames in the video to sample with given parameters."
)
# Generate the indices
indexes = [
start_idx + i * (self.skip_frames + 1) for i in range(self.num_frames)
]
return indexes
def _load_audio(self, filename, max_len_sec, start=None, indexes=None):
audio, sr = sf.read(
filename,
start=math.ceil(start * self.audio_rate),
frames=math.ceil(self.audio_rate * max_len_sec),
always_2d=True,
) # e.g (16000, 1)
audio = audio.T # (1, 16000)
assert sr == self.audio_rate, (
f"Audio rate is {sr} but should be {self.audio_rate}"
)
audio = audio.mean(0, keepdims=True)
audio = self.maybe_augment_audio(audio)
audio = torch.from_numpy(audio).float()
# audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=self.audio_rate)
audio = trim_pad_audio(audio, self.audio_rate, max_len_sec=max_len_sec)
return audio[0]
def ensure_shape(self, tensors):
target_length = self.samples_per_frame
processed_tensors = []
for tensor in tensors:
current_length = tensor.shape[1]
diff = current_length - target_length
assert abs(diff) <= 5, (
f"Expected shape {target_length}, but got {current_length}"
)
if diff < 0:
# Calculate how much padding is needed
padding_needed = target_length - current_length
# Pad the tensor
padded_tensor = F.pad(tensor, (0, padding_needed))
processed_tensors.append(padded_tensor)
elif diff > 0:
# Trim the tensor
trimmed_tensor = tensor[:, :target_length]
processed_tensors.append(trimmed_tensor)
else:
# If it's already the correct size
processed_tensors.append(tensor)
return torch.cat(processed_tensors)
def normalize_latents(self, latents):
if self.data_mean is not None:
# Normalize latents to 0 mean and 0.5 std
latents = ((latents - self.data_mean) / self.data_std) * 0.5
return latents
def convert_indexes(self, indexes_25fps, fps_from=25, fps_to=60):
ratio = fps_to / fps_from
indexes_60fps = [int(index * ratio) for index in indexes_25fps]
return indexes_60fps
def _get_frames_and_audio(self, idx):
if self.load_all_possible_indexes:
indexes, video_file, audio_file, land_file = self._indexes[idx]
if self.audio_in_video:
vr = decord.AVReader(video_file, sample_rate=self.audio_rate)
else:
vr = decord.VideoReader(video_file)
len_video = len(vr)
if "AA_processed" in video_file or "1000actors_nsv" in video_file:
len_video *= 25 / 60
len_video = int(len_video)
else:
video_file, audio_file, land_file = self._indexes[idx]
if self.audio_in_video:
vr = decord.AVReader(video_file, sample_rate=self.audio_rate)
else:
vr = decord.VideoReader(video_file)
len_video = len(vr)
if "AA_processed" in video_file or "1000actors_nsv" in video_file:
len_video *= 25 / 60
len_video = int(len_video)
indexes = self.get_frame_indices(
len_video,
select_randomly=self.select_randomly,
start_idx=120 if "1000actors_nsv" in video_file else 0,
)
if self.get_separate_id:
id_idx = np.random.randint(0, len_video)
indexes.insert(0, id_idx)
if "AA_processed" in video_file or "1000actors_nsv" in video_file:
video_indexes = self.convert_indexes(indexes, fps_from=25, fps_to=60)
audio_file = audio_file.replace("_output_output", "")
if self.audio_emb_type == "wav2vec2" and "AA_processed" in video_file:
audio_path_extra = ".safetensors"
else:
audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors"
video_path_extra = f"_{self.latent_type}_512_latent.safetensors"
audio_path_extra_extra = (
".pt" if "AA_processed" in video_file else "_beats_emb.pt"
)
else:
video_indexes = indexes
audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors"
video_path_extra = f"_{self.latent_type}_512_latent.safetensors"
audio_path_extra_extra = "_beats_emb.pt"
emotions = None
if self.use_emotions:
emotions = self.get_emotions(video_file, video_indexes)
if self.get_separate_id:
emotions = (emotions[0][1:], emotions[1][1:], emotions[2][1:])
raw_audio = None
if self.audio_in_video:
raw_audio, frames_video = vr.get_batch(video_indexes)
raw_audio = rearrange(self.ensure_shape(raw_audio), "f s -> (f s)")
if self.use_latent and self.precomputed_latent:
latent_file = video_file.replace(self.video_ext, video_path_extra).replace(
self.video_folder, self.latent_folder
)
frames = load_file(latent_file)["latents"][video_indexes, :, :, :]
if frames.shape[-1] != 64:
print(f"Frames shape: {frames.shape}, video file: {video_file}")
frames = rearrange(frames, "t c h w -> c t h w") * self.latent_scale
frames = self.normalize_latents(frames)
else:
if self.audio_in_video:
frames = frames_video.permute(3, 0, 1, 2).float()
else:
frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float()
if raw_audio is None:
# Audio is not in video
raw_audio = self._load_audio(
audio_file,
max_len_sec=frames.shape[1] / self.fps,
start=indexes[0] / self.fps,
# indexes=indexes,
)
if not self.from_audio_embedding:
audio = raw_audio
audio_frames = rearrange(audio, "(f s) -> f s", s=self.samples_per_frame)
else:
audio = load_file(
audio_file.replace(self.audio_folder, self.audio_emb_folder).split(".")[
0
]
+ audio_path_extra
)["audio"]
audio_frames = audio[indexes, :]
if self.add_extra_audio_emb:
audio_extra = torch.load(
audio_file.replace(self.audio_folder, self.audio_emb_folder).split(
"."
)[0]
+ audio_path_extra_extra
)
audio_extra = audio_extra[indexes, :]
audio_frames = torch.cat([audio_frames, audio_extra], dim=-1)
audio_frames = (
audio_frames[1:] if self.need_cond else audio_frames
) # Remove audio of first frame
if self.get_original_frames:
original_frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float()
original_frames = self.scale_and_crop((original_frames / 255.0) * 2 - 1)
original_frames = (
original_frames[:, 1:] if self.need_cond else original_frames
)
else:
original_frames = None
if not self.use_latent or (self.use_latent and not self.precomputed_latent):
frames = self.scale_and_crop((frames / 255.0) * 2 - 1)
target = frames[:, 1:] if self.need_cond else frames
if self.mode == "prediction":
if self.use_latent:
if self.audio_in_video:
clean_cond = (
frames_video[0].unsqueeze(0).permute(3, 0, 1, 2).float()
)
else:
clean_cond = (
vr[video_indexes[0]].unsqueeze(0).permute(3, 0, 1, 2).float()
)
original_size = clean_cond.shape[-2:]
clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1).squeeze(
0
)
if self.latent_condition:
noisy_cond = frames[:, 0]
else:
noisy_cond = clean_cond
else:
clean_cond = frames[:, 0]
noisy_cond = clean_cond
elif self.mode == "interpolation":
if self.use_latent:
if self.audio_in_video:
clean_cond = frames_video[[0, -1]].permute(3, 0, 1, 2).float()
else:
clean_cond = (
vr.get_batch([video_indexes[0], video_indexes[-1]])
.permute(3, 0, 1, 2)
.float()
)
original_size = clean_cond.shape[-2:]
clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1)
if self.latent_condition:
noisy_cond = torch.stack([target[:, 0], target[:, -1]], dim=1)
else:
noisy_cond = clean_cond
else:
clean_cond = torch.stack([target[:, 0], target[:, -1]], dim=1)
noisy_cond = clean_cond
# Add noise to conditional frame
if self.cond_noise and isinstance(self.cond_noise, ListConfig):
cond_noise = (
self.cond_noise[0] + self.cond_noise[1] * torch.randn((1,))
).exp()
noisy_cond = noisy_cond + cond_noise * torch.randn_like(noisy_cond)
else:
noisy_cond = noisy_cond + self.cond_noise * torch.randn_like(noisy_cond)
cond_noise = self.cond_noise
if self.get_masks:
target_size = (
(self.resize_size, self.resize_size)
if not self.use_latent
else (self.resize_size // 8, self.resize_size // 8)
)
masks, landmarks = self._load_landmarks(
land_file, original_size, target_size, video_indexes
)
landmarks = None
masks = (
masks.permute(1, 0, 2, 3)[:, 1:]
if self.need_cond
else masks.permute(1, 0, 2, 3)
)
else:
masks = None
landmarks = None
return (
original_frames,
clean_cond,
noisy_cond,
target,
audio_frames,
raw_audio,
cond_noise,
emotions,
masks,
landmarks,
)
def filter_by_length(self, video_filelist, audio_filelist):
def with_opencv(filename):
video = cv2.VideoCapture(filename)
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
return int(frame_count)
filtered_video = []
filtered_audio = []
min_length = (self.num_frames - 1) * (self.skip_frames + 1) + 1
for vid_file, audio_file in tqdm(
zip(video_filelist, audio_filelist),
total=len(video_filelist),
desc="Filtering",
):
# vr = decord.VideoReader(vid_file)
len_video = with_opencv(vid_file)
# Short videos
if len_video < min_length:
continue
filtered_video.append(vid_file)
filtered_audio.append(audio_file)
print(f"New number of files: {len(filtered_video)}")
return filtered_video, filtered_audio
def _get_indexes(self, video_filelist, audio_filelist):
indexes = []
self.og_shape = None
for vid_file, audio_file in zip(video_filelist, audio_filelist):
vr = decord.VideoReader(vid_file)
if self.og_shape is None:
self.og_shape = vr[0].shape[-2]
len_video = len(vr)
# Short videos
if len_video < self.num_frames:
continue
else:
possible_indexes = list(
sliding_window(range(len_video), self.num_frames)
)[:: self.step]
possible_indexes = list(
map(lambda x: (x, vid_file, audio_file), possible_indexes)
)
indexes.extend(possible_indexes)
print("Indexes", len(indexes), "\n")
return indexes
def scale_and_crop(self, video):
h, w = video.shape[-2], video.shape[-1]
# scale shorter side to resolution
if self.resize_size is not None:
scale = self.resize_size / min(h, w)
if h < w:
target_size = (self.resize_size, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), self.resize_size)
video = F.interpolate(
video,
size=target_size,
mode="bilinear",
align_corners=False,
antialias=True,
)
# center crop
h, w = video.shape[-2], video.shape[-1]
w_start = (w - self.resize_size) // 2
h_start = (h - self.resize_size) // 2
video = video[
:,
:,
h_start : h_start + self.resize_size,
w_start : w_start + self.resize_size,
]
return self.maybe_augment(video)
def _calculate_weights(self):
aa_processed_count = sum(
1
for item in self._indexes
if "AA_processed" in (item[1] if len(item) == 3 else item[0])
)
nsv_processed_count = sum(
1
for item in self._indexes
if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0])
)
other_count = len(self._indexes) - aa_processed_count - nsv_processed_count
aa_processed_weight = 1 / aa_processed_count if aa_processed_count > 0 else 0
nsv_processed_weight = 1 / nsv_processed_count if nsv_processed_count > 0 else 0
other_weight = 1 / other_count if other_count > 0 else 0
print(
f"AA processed count: {aa_processed_count}, NSV processed count: {nsv_processed_count}, other count: {other_count}"
)
print(f"AA processed weight: {aa_processed_weight}")
print(f"NSV processed weight: {nsv_processed_weight}")
print(f"Other weight: {other_weight}")
weights = [
aa_processed_weight
if "AA_processed" in (item[1] if len(item) == 3 else item[0])
else nsv_processed_weight
if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0])
else other_weight
for item in self._indexes
]
return weights
def __getitem__(self, idx):
if self.balance_datasets:
idx = self.sampler.__iter__().__next__()
try:
(
original_frames,
clean_cond,
noisy_cond,
target,
audio,
raw_audio,
cond_noise,
emotions,
masks,
landmarks,
) = self._get_frames_and_audio(idx % len(self._indexes))
except Exception as e:
print(f"Error with index {idx}: {e}")
return self.__getitem__(np.random.randint(0, len(self)))
out_data = {}
if original_frames is not None:
out_data["original_frames"] = original_frames
if audio is not None:
out_data["audio_emb"] = audio
out_data["raw_audio"] = raw_audio
if self.use_emotions:
out_data["valence"] = emotions[0]
out_data["arousal"] = emotions[1]
out_data["emo_labels"] = emotions[2]
if self.use_latent:
input_key = "latents"
else:
input_key = "frames"
out_data[input_key] = target
if noisy_cond is not None:
out_data["cond_frames"] = noisy_cond
out_data["cond_frames_without_noise"] = clean_cond
if cond_noise is not None:
out_data["cond_aug"] = cond_noise
if masks is not None:
out_data["masks"] = masks
out_data["gt"] = target
if landmarks is not None:
out_data["landmarks"] = landmarks
out_data["motion_bucket_id"] = torch.tensor([self.motion_id])
out_data["fps_id"] = torch.tensor([self.fps - 1])
out_data["num_video_frames"] = self.num_frames
out_data["image_only_indicator"] = torch.zeros(self.num_frames)
return out_data
if __name__ == "__main__":
import torchvision.transforms as transforms
import cv2
transform = transforms.Compose(transforms=[transforms.Resize((256, 256))])
dataset = VideoDataset(
"/vol/paramonos2/projects/antoni/datasets/mahnob/filelist_videos_val.txt",
transform=transform,
num_frames=25,
)
print(len(dataset))
idx = np.random.randint(0, len(dataset))
for i in range(10):
print(dataset[i][0].shape, dataset[i][1].shape)
image_identity = (dataset[idx][0].permute(1, 2, 0).numpy() + 1) / 2 * 255
image_other = (dataset[idx][1][:, -1].permute(1, 2, 0).numpy() + 1) / 2 * 255
cv2.imwrite("image_identity.png", image_identity[:, :, ::-1])
for i in range(25):
image = (dataset[idx][1][:, i].permute(1, 2, 0).numpy() + 1) / 2 * 255
cv2.imwrite(f"tmp_vid_dataset/image_{i}.png", image[:, :, ::-1])
|