Spaces:
Running
Running
import sys | |
import os | |
import torch | |
from . import mtcnn | |
from .face_yolo import face_yolo_detection | |
import argparse | |
from PIL import Image | |
from tqdm import tqdm | |
import random | |
from datetime import datetime | |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
mtcnn_model = mtcnn.MTCNN(device=DEVICE, crop_size=(112, 112)) | |
def add_padding(pil_img, top, right, bottom, left, color=(0,0,0)): | |
width, height = pil_img.size | |
new_width = width + right + left | |
new_height = height + top + bottom | |
result = Image.new(pil_img.mode, (new_width, new_height), color) | |
result.paste(pil_img, (left, top)) | |
return result | |
def handle_image_mtcnn(img_path, pil_img): | |
img = Image.open(img_path).convert('RGB') if pil_img is None else pil_img | |
assert isinstance(img, Image.Image), 'Face alignment requires PIL image or path' | |
try: | |
bboxes, faces = mtcnn_model.align_multi(img, limit=1) | |
return bboxes[0], faces[0] | |
except Exception as e: | |
print(f'Face detection failed: {e}') | |
return None, None | |
def get_aligned_face(image_path_or_image_paths, rgb_pil_image=None, algorithm='mtcnn'): | |
if algorithm=='mtcnn': | |
if isinstance(image_path_or_image_paths, list): | |
results = [handle_image_mtcnn(path, rgb_pil_image) for path in image_path_or_image_paths] | |
return results | |
elif isinstance(image_path_or_image_paths, str): | |
return [handle_image_mtcnn(image_path_or_image_paths, rgb_pil_image)] | |
else: | |
raise TypeError("image_path_or_image_paths must be a list or string") | |
elif algorithm=='yolo': | |
if isinstance(image_path_or_image_paths, list): | |
image_paths = image_path_or_image_paths | |
results = face_yolo_detection(image_paths, | |
# yolo_model_path="ckpts/yolo_face_detection/model.pt", | |
use_batch=True, device=DEVICE) | |
elif isinstance(image_path_or_image_paths, str): | |
image_paths = [image_path_or_image_paths] | |
results = face_yolo_detection(image_paths, | |
# yolo_model_path="ckpts/yolo_face_detection/model.pt", | |
use_batch=True, device=DEVICE) | |
else: | |
raise TypeError("image_path_or_image_paths must be a list or string") | |
results = list(results) | |
return results | |