Spaces:
Build error
Build error
from torch.utils.data import TensorDataset | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import glob | |
import pickle | |
import random | |
import os | |
import cv2 | |
import tqdm | |
import sys | |
sys.path.append('..') | |
# from utils.cap_aug import CAP_AUG | |
class FaceEmbed(TensorDataset): | |
def __init__(self, data_path_list, same_prob=0.8): | |
datasets = [] | |
# embeds = [] | |
self.N = [] | |
self.same_prob = same_prob | |
for data_path in data_path_list: | |
image_list = glob.glob(f'{data_path}/*.*g') | |
datasets.append(image_list) | |
self.N.append(len(image_list)) | |
# with open(f'{data_path}/embed.pkl', 'rb') as f: | |
# embed = pickle.load(f) | |
# embeds.append(embed) | |
self.datasets = datasets | |
# self.embeds = embeds | |
self.transforms_arcface = transforms.Compose([ | |
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
self.transforms_base = transforms.Compose([ | |
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01), | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
def __getitem__(self, item): | |
idx = 0 | |
while item >= self.N[idx]: | |
item -= self.N[idx] | |
idx += 1 | |
image_path = self.datasets[idx][item] | |
# name = os.path.split(image_path)[1] | |
# embed = self.embeds[idx][name] | |
Xs = cv2.imread(image_path)[:, :, ::-1] | |
Xs = Image.fromarray(Xs) | |
if random.random() > self.same_prob: | |
image_path = random.choice(self.datasets[random.randint(0, len(self.datasets)-1)]) | |
Xt = cv2.imread(image_path)[:, :, ::-1] | |
Xt = Image.fromarray(Xt) | |
same_person = 0 | |
else: | |
Xt = Xs.copy() | |
same_person = 1 | |
return self.transforms_arcface(Xs), self.transforms_base(Xs), self.transforms_base(Xt), same_person | |
def __len__(self): | |
return sum(self.N) | |
class FaceEmbedVGG2(TensorDataset): | |
def __init__(self, data_path, same_prob=0.8, same_identity=False): | |
self.same_prob = same_prob | |
self.same_identity = same_identity | |
self.images_list = glob.glob(f'{data_path}/*/*.*g') | |
self.folders_list = glob.glob(f'{data_path}/*') | |
self.folder2imgs = {} | |
for folder in tqdm.tqdm(self.folders_list): | |
folder_imgs = glob.glob(f'{folder}/*') | |
self.folder2imgs[folder] = folder_imgs | |
self.N = len(self.images_list) | |
self.transforms_arcface = transforms.Compose([ | |
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
self.transforms_base = transforms.Compose([ | |
transforms.ColorJitter(0.2, 0.2, 0.2, 0.01), | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
def __getitem__(self, item): | |
image_path = self.images_list[item] | |
Xs = cv2.imread(image_path)[:, :, ::-1] | |
Xs = Image.fromarray(Xs) | |
if self.same_identity: | |
folder_name = '/'.join(image_path.split('/')[:-1]) | |
if random.random() > self.same_prob: | |
image_path = random.choice(self.images_list) | |
Xt = cv2.imread(image_path)[:, :, ::-1] | |
Xt = Image.fromarray(Xt) | |
same_person = 0 | |
else: | |
if self.same_identity: | |
image_path = random.choice(self.folder2imgs[folder_name]) | |
Xt = cv2.imread(image_path)[:, :, ::-1] | |
Xt = Image.fromarray(Xt) | |
else: | |
Xt = Xs.copy() | |
same_person = 1 | |
return self.transforms_arcface(Xs), self.transforms_base(Xs), self.transforms_base(Xt), same_person | |
def __len__(self): | |
return self.N |