Spaces:
Running
Running
File size: 4,697 Bytes
051b2c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from transformers import CLIPImageProcessor
import sys
sys.path.append("/path/to/FollowYourEmoji")
from media_pipe import FaceMeshDetector, FaceMeshAlign
from media_pipe.draw_util import FaceMeshVisualizer
def val_collate_fn(samples):
return {
'ref_frame': [sample['ref_frame'] for sample in samples],
'clip_image': [sample['clip_image'] for sample in samples],
'motions': [sample['motions'] for sample in samples],
'file_name': [sample['file_name'] for sample in samples],
'lmk_name': [sample['lmk_name'] for sample in samples],
}
class ValDataset(Dataset):
def __init__(self, input_path, lmk_path, resolution_w=512, resolution_h=512):
print(f'Loading dataset from {input_path} and {lmk_path}')
all_img_paths = self._get_path_files(Path(input_path), file_suffix=['.jpg', '.jpeg', '.png', '.webp'])
all_lmk_paths = self._get_path_files(Path(lmk_path), file_suffix=['.npy'])
print(f'Found {len(all_img_paths)} image files and {len(all_lmk_paths)} lmk files')
print(f"ALL IMG PATH: {all_img_paths}")
print(f"ALL LKM PATH: {all_lmk_paths}")
self.all_paths = []
for lmk_path in all_lmk_paths:
for img_path in all_img_paths:
self.all_paths.append((img_path, lmk_path))
self.W = resolution_w
self.H = resolution_h
self.to_tensor = T.ToTensor()
self.detector = FaceMeshDetector()
self.aligner = FaceMeshAlign()
self.clip_image_processor = CLIPImageProcessor()
self.vis = FaceMeshVisualizer(forehead_edge=False, iris_edge=False, iris_point=True)
def __len__(self):
return len(self.all_paths)
def _get_path_files(self, path, file_suffix):
all_paths = []
if path.is_file():
if path.suffix.lower() in file_suffix:
all_paths = [path]
else:
raise ValueError('Path is not valid image file.')
elif path.is_dir():
all_paths = sorted(
[
f
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in file_suffix
]
)
if len(all_paths) == 0:
raise ValueError('Folder does not contain any images.')
else:
raise ValueError
return all_paths
def get_align_motion(self, ref_lmk, temp_lmks):
motions = self.aligner(ref_lmk, temp_lmks)
motions = [self.to_tensor(motion) for motion in motions]
motions = torch.stack(motions).permute((1,0,2,3))
return motions
def __getitem__(self, index):
img_path, lmk_path = self.all_paths[index]
W, H = self.W, self.H
image = Image.open(img_path).convert('RGB')
# resize and center crop
scale = min(W / image.size[0], H / image.size[1])
ref_image = image.resize(
(int(image.size[0] * scale), int(image.size[1] * scale)))
w, h = ref_image.size[0], ref_image.size[1]
ref_image = ref_image.crop((w//2-W//2, h//2-H//2, w//2+W//2, h//2+H//2))
ref_image = np.array(ref_image)
# reference image lmk
ref_lmk_image, ref_lmk = self.detector(ref_image)
# clip image
clip_image = Image.fromarray(np.array(ref_image))
clip_image = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values[0]
# reference image
ref_image = self.to_tensor(ref_image).unsqueeze(1)
ref_image = ref_image * 2.0 - 1.0
# motion sequence
temp_lmks = np.load(lmk_path, allow_pickle=True)
# landmark align and draw motions
if ref_lmk is not None:
motions = self.get_align_motion(ref_lmk, temp_lmks)
else:
motions = [
self.vis.draw_landmarks((H, W), lmk['lmks'].astype(np.float32), normed=True)
for lmk in temp_lmks
]
motions = [self.to_tensor(motion) for motion in motions]
motions = torch.stack(motions).permute((1,0,2,3))
example = dict()
example["file_name"] = str(img_path.stem).split('/')[-1]
example["lmk_name"] = str(lmk_path.stem).split('/')[-1]
example["motions"] = motions # value in [0, 1]
example["ref_frame"] = ref_image # value in [-1, 1]
example["ref_lmk_image"] = ref_lmk_image
example["clip_image"] = clip_image
return example
|