Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import argparse | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
import numpy as np | |
import facer | |
import facer.transform | |
from copy import deepcopy | |
import PIL | |
def resize_image(image, max_size=1024): | |
height,width,_ = image.shape | |
if width > max_size or height > max_size: | |
if width > height: | |
new_width = max_size | |
new_height = int((height / width) * max_size) | |
else: | |
new_height = max_size | |
new_width = int((width / height) * max_size) | |
image = cv2.resize(image, (new_width, new_height)) | |
return image | |
def open_and_resize_image(image_file, max_size=1024, return_type='numpy'): | |
if isinstance(image_file, str) or isinstance(image_file, PIL.Image.Image): | |
if isinstance(image_file, str): | |
img = Image.open(image_file) | |
else: | |
img = image_file | |
width, height = img.size | |
if width > height: | |
new_width = max_size | |
new_height = int((height / width) * max_size) | |
else: | |
new_height = max_size | |
new_width = int((width / height) * max_size) | |
img = img.resize((new_width, new_height)) | |
if return_type == 'numpy': | |
return np.array(img.convert('RGB')) | |
else: | |
return img | |
elif isinstance(image_file, np.ndarray): | |
height,width,_ = image_file.shape | |
if width > height: | |
new_width = max_size | |
new_height = int((height / width) * max_size) | |
else: | |
new_height = max_size | |
new_width = int((width / height) * max_size) | |
img = cv2.resize(image_file, (new_width, new_height)) | |
assert return_type == 'numpy' | |
return img | |
else: | |
raise TypeError("Do not support this img type") | |
def loose_warp_face(input_image, face_detector, face_target_shape=(512, 512), scale=1.3, face_parser=None, device=None, croped_face_scale=3, bg_value = 0, croped_face_y_offset=0.0): | |
""" Get the tight/loose warp of the face in the image, in which only one face is of concern. | |
Args: | |
input_image: Image path, or PIL.Image.Image, or np.ndarray (dtype=np.uint8). | |
face_detector: a facer.face_detector, for face detection. | |
face_target_shape: Output resolution. | |
scale: Scale of the output image w.r.t. the face it contains. | |
Returns: | |
PIL.Image.Image, single warped face. | |
""" | |
_normalized_face_target_pts = torch.tensor([ | |
[38.2946, 51.6963], | |
[73.5318, 51.5014], | |
[56.0252, 71.7366], | |
[41.5493, 92.3655], | |
[70.729904, 92.2041]]) / 112.0 | |
target_pts = ((_normalized_face_target_pts - | |
torch.tensor([0.5, 0.5])) / scale | |
+ torch.tensor([0.5, 0.5])) | |
if face_detector is not None: | |
device = next(face_detector.parameters()).device | |
if isinstance(input_image, str): | |
# image_tensor_hwc = facer.read_hwc(input_image) | |
np_img = open_and_resize_image(input_image)[:,:,:3] # Downsample high-res images to avoid OOM. | |
img_height, img_width = np_img.shape[:2] | |
image_tensor_hwc = torch.from_numpy(np_img) | |
elif isinstance(input_image, Image.Image): | |
image_tensor_hwc = torch.from_numpy(np.array(input_image)[:,:,:3]) | |
img_height, img_width = image_tensor_hwc.shape[:2] | |
assert image_tensor_hwc.dtype == torch.uint8 | |
else: | |
assert isinstance(input_image, np.ndarray), 'Type %s of input_image is unsupported!' % type(input_image) | |
assert input_image.dtype == np.uint8, 'dtype %s of input np.ndarray is unsupported!' % input_image.dtype | |
input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)[:,:,:3] | |
input_image = resize_image(input_image) | |
image_tensor_hwc = torch.from_numpy(input_image) | |
img_height, img_width = image_tensor_hwc.shape[:2] | |
image_pt_bchw_255 = facer.hwc2bchw(image_tensor_hwc).to(device) | |
res = {'cropped_face_masked': None, 'cropped_face': None, 'cropped_img': None, 'cropped_face_mask': None, 'align_face': None} | |
if face_detector is not None: | |
try: | |
face_data = face_detector(image_pt_bchw_255) | |
except: | |
import pdb;pdb.set_trace() | |
if len(face_data) == 0: | |
return res | |
if face_parser is not None: | |
with torch.inference_mode(): | |
faces = face_parser(image_pt_bchw_255, face_data) | |
seg_logits = faces['seg']['logits'] | |
seg_probs = seg_logits.softmax(dim=1) | |
seg_probs = seg_probs.argmax(dim=1).unsqueeze(1)[:1] | |
face_rects = face_data['rects'][:1] | |
face_rects = face_data['rects'][:1] | |
x1,y1,x2,y2 = face_rects[0][:4] | |
x1 = (int(x1.item())) | |
y1 = (int(y1.item())) | |
x2 = (int(x2.item())) | |
y2 = (int(y2.item())) | |
face_width = x2-x1 | |
face_height = y2-y1 | |
center_x = int(0.5*(x1+x2)) | |
center_y = int(0.5*(y1+y2)) + croped_face_y_offset * face_height | |
croped_face_width = face_width*croped_face_scale | |
croped_face_height = face_height*croped_face_scale | |
x1 = max(int(center_x-0.5*croped_face_width),0) | |
x2 = min(int(center_x+0.5*croped_face_width), img_width-1) | |
y1 = max(int(center_y-0.5*croped_face_height),0) | |
y2 = min(int(center_y+0.5*croped_face_height), img_height-1) | |
croped_face_height = y2-y1 | |
croped_face_width = x2-x1 | |
center_x = int(0.5*(x1+x2)) | |
center_y = int(0.5*(y1+y2)) | |
croped_face_len = min(croped_face_height, croped_face_width) | |
x1 = int(center_x - 0.5*croped_face_len) | |
y1 = int(center_y - 0.5*croped_face_len) | |
x2 = x1+croped_face_len | |
y2 = y1+croped_face_len | |
croped_image_pt_bchw_255 = image_pt_bchw_255[:, :, y1:y2, x1:x2] | |
face_points = face_data['points'][:1] | |
batch_inds = face_data['image_ids'][:1] | |
matrix_align = facer.transform.get_face_align_matrix( | |
face_points, face_target_shape, | |
target_pts=(target_pts * torch.tensor(face_target_shape))) | |
grid = facer.transform.make_tanh_warp_grid( | |
matrix_align, 0.0, face_target_shape, image_pt_bchw_255.shape[2:],) | |
image = F.grid_sample( | |
image_pt_bchw_255.float()[batch_inds], | |
grid, 'bilinear', align_corners=False) | |
image_align_raw = deepcopy(image) | |
image_align_raw = facer.bchw2hwc(image_align_raw).to(torch.uint8).cpu().numpy() | |
image_align_raw = Image.fromarray(image_align_raw) | |
image_croped = facer.bchw2hwc(croped_image_pt_bchw_255).to(torch.uint8).cpu().numpy() | |
image_croped = Image.fromarray(image_croped) | |
if face_parser is not None: | |
image_no_mask = deepcopy(image) | |
new_size = list(seg_probs.shape) | |
new_size[1] = image.shape[1] | |
seg_probs = seg_probs.expand(new_size) | |
assert seg_probs.shape[0] == 1 and image.shape[0] == 1, 'mask shape {}, != image shape {}'.format(seg_probs.shape, image.shape) | |
mask_img = F.grid_sample(seg_probs.float(), grid, 'bilinear', align_corners=False) | |
image[mask_img == 0] = bg_value | |
mask_img[mask_img!=0] = 1 | |
assert mask_img.shape[0] == 1 | |
else: | |
image_no_mask = image | |
mask_img = None | |
else: | |
image = image_pt_bchw_255 | |
image_no_mask = image_pt_bchw_255 | |
image_align_raw = None | |
image_croped = None | |
image = facer.bchw2hwc(image).to(torch.uint8).cpu().numpy() | |
image_no_mask = facer.bchw2hwc(image_no_mask).to(torch.uint8).cpu().numpy() | |
res.update({'cropped_face_masked': Image.fromarray(image), 'cropped_face': Image.fromarray(image_no_mask), 'cropped_img':image_croped, 'cropped_face_mask': mask_img, 'align_face': image_align_raw}) | |
return res | |
def tight_warp_face(input_image, face_detector, face_parser=None, device=None): | |
return loose_warp_face(input_image, face_detector, | |
face_target_shape=(112, 112), scale=1, face_parser=face_parser, device=device) | |