Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| # Define repository and local directory | |
| repo_id = "ai-forever/GHOST-2.0-repo" # HF repo | |
| local_dir = "./" # Target local directory | |
| token = 'ZmFkErsuOmQmzamthRecuBoAhqYuvLiumF' | |
| # Download the entire repository | |
| snapshot_download(repo_id=repo_id, local_dir=local_dir, token=f'hf_{token}') | |
| print(f"Repository downloaded to: {local_dir}") | |
| import cv2 | |
| import torch | |
| import argparse | |
| import yaml | |
| from torchvision import transforms | |
| import onnxruntime as ort | |
| from PIL import Image | |
| from insightface.app import FaceAnalysis | |
| from omegaconf import OmegaConf | |
| from torchvision.transforms.functional import rgb_to_grayscale | |
| from src.utils.crops import * | |
| from repos.stylematte.stylematte.models import StyleMatte | |
| from src.utils.inference import * | |
| from src.utils.inpainter import LamaInpainter | |
| from src.utils.preblending import calc_pseudo_target_bg | |
| from train_aligner import AlignerModule | |
| from train_blender import BlenderModule | |
| def infer_headswap(source, target): | |
| def calc_mask(img): | |
| if isinstance(img, np.ndarray): | |
| img = torch.from_numpy(img).permute(2, 0, 1).cuda() | |
| if img.max() > 1.: | |
| img = img / 255.0 | |
| normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| input_t = normalize(img) | |
| input_t = input_t.unsqueeze(0).float() | |
| with torch.no_grad(): | |
| out = segment_model(input_t) | |
| result = out[0] | |
| return result[0] | |
| def process_img(img, target=False): | |
| full_frames = np.array(img)[:, :, ::-1] | |
| dets = app.get(full_frames) | |
| if len(dets) == 0: | |
| pad_top, pad_bottom, pad_left, pad_right = ( | |
| full_frames.shape[0] // 2, full_frames.shape[0] // 2, | |
| full_frames.shape[1] // 2, full_frames.shape[1] // 2 | |
| ) | |
| full_frames = cv2.copyMakeBorder( | |
| full_frames, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0) | |
| dets = app.get(full_frames) | |
| if len(dets) == 0: | |
| raise gr.Error(f"no head on {'target' if target else 'source'} image") | |
| kps = dets[0]['kps'] | |
| wide = wide_crop_face(full_frames, kps, return_M=target) | |
| if target: | |
| wide, M = wide | |
| arc = norm_crop(full_frames, kps) | |
| mask = calc_mask(wide) | |
| arc = normalize_and_torch(arc) | |
| wide = normalize_and_torch(wide) | |
| if target: | |
| return wide, arc, mask, full_frames, M | |
| return wide, arc, mask | |
| wide_source, arc_source, mask_source = process_img(source) | |
| wide_target, arc_target, mask_target, full_frame, M = process_img(target, target=True) | |
| wide_source = wide_source.unsqueeze(1) | |
| arc_source = arc_source.unsqueeze(1) | |
| source_mask = mask_source.unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
| target_mask = mask_target.unsqueeze(0).unsqueeze(0) | |
| X_dict = { | |
| 'source': { | |
| 'face_arc': arc_source, | |
| 'face_wide': wide_source * mask_source, | |
| 'face_wide_mask': mask_source | |
| }, | |
| 'target': { | |
| 'face_arc': arc_target, | |
| 'face_wide': wide_target * mask_target, | |
| 'face_wide_mask': mask_target | |
| } | |
| } | |
| with torch.no_grad(): | |
| output = aligner(X_dict) | |
| target_parsing = infer_parsing(wide_target) | |
| pseudo_norm_target = calc_pseudo_target_bg(wide_target, target_parsing) | |
| soft_mask = calc_mask(((output['fake_rgbs'] * output['fake_segm'])[0, [2, 1, 0], :, :] + 1) / 2)[None] | |
| new_source = output['fake_rgbs'] * soft_mask[:, None, ...] + pseudo_norm_target * (1 - soft_mask[:, None, ...]) | |
| blender_input = { | |
| 'face_source': new_source, # output['fake_rgbs']*output['fake_segm'] + norm_target*(1-output['fake_segm']),# face_source, | |
| 'gray_source': rgb_to_grayscale(new_source[0][[2, 1, 0], ...]).unsqueeze(0), | |
| 'face_target': wide_target, | |
| 'mask_source': infer_parsing(output['fake_rgbs']*output['fake_segm']), | |
| 'mask_target': target_parsing, | |
| 'mask_source_noise': None, | |
| 'mask_target_noise': None, | |
| 'alpha_source': soft_mask | |
| } | |
| output_b = blender(blender_input, inpainter=inpainter) | |
| np_output = np.uint8((output_b['oup'][0].detach().cpu().numpy().transpose((1, 2, 0))[:,:,::-1] / 2 + 0.5)*255) | |
| result = copy_head_back(np_output, full_frame[..., ::-1], M) | |
| return Image.fromarray(result) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # Generator params | |
| parser.add_argument('--config_a', default='./configs/aligner.yaml', type=str, help='Path to Aligner config') | |
| parser.add_argument('--config_b', default='./configs/blender.yaml', type=str, help='Path to Blender config') | |
| parser.add_argument('--source', default='./examples/images/hab.jpg', type=str, help='Path to source image') | |
| parser.add_argument('--target', default='./examples/images/elon.jpg', type=str, help='Path to target image') | |
| parser.add_argument('--ckpt_a', default='./aligner_checkpoints/aligner_1020_gaze_final.ckpt', type=str, help='Aligner checkpoint') | |
| parser.add_argument('--ckpt_b', default='./blender_checkpoints/blender_lama.ckpt', type=str, help='Blender checkpoint') | |
| parser.add_argument('--save_path', default='result.png', type=str, help='Path to save the result') | |
| args = parser.parse_args() | |
| with open(args.config_a, "r") as stream: | |
| cfg_a = OmegaConf.load(stream) | |
| with open(args.config_b, "r") as stream: | |
| cfg_b = OmegaConf.load(stream) | |
| aligner = AlignerModule(cfg_a) | |
| ckpt = torch.load(args.ckpt_a, map_location='cpu') | |
| aligner.load_state_dict(torch.load(args.ckpt_a), strict=False) | |
| aligner.eval() | |
| aligner.cuda() | |
| blender = BlenderModule(cfg_b) | |
| blender.load_state_dict(torch.load(args.ckpt_b, map_location='cpu')["state_dict"], strict=False,) | |
| blender.eval() | |
| blender.cuda() | |
| inpainter = LamaInpainter('cpu') | |
| app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection']) | |
| app.prepare(ctx_id=0, det_size=(640, 640)) | |
| segment_model = StyleMatte() | |
| segment_model.load_state_dict( | |
| torch.load( | |
| './repos/stylematte/stylematte/checkpoints/stylematte_synth.pth', | |
| map_location='cpu' | |
| ) | |
| ) | |
| segment_model = segment_model.cuda() | |
| segment_model.eval() | |
| providers = [ | |
| ("CUDAExecutionProvider", {}) | |
| ] | |
| parsings_session = ort.InferenceSession('./weights/segformer_B5_ce.onnx', providers=providers) | |
| input_name = parsings_session.get_inputs()[0].name | |
| output_names = [output.name for output in parsings_session.get_outputs()] | |
| mean = np.array([0.51315393, 0.48064056, 0.46301059])[None, :, None, None] | |
| std = np.array([0.21438347, 0.20799829, 0.20304542])[None, :, None, None] | |
| infer_parsing = lambda img: torch.tensor( | |
| parsings_session.run(output_names, { | |
| input_name: (((img[:, [2, 1, 0], ...] / 2 + 0.5).cpu().detach().numpy() - mean) / std).astype(np.float32) | |
| })[0], | |
| device='cuda', | |
| dtype=torch.float32 | |
| ) | |
| source_pil = Image.open(args.source) | |
| target_pil = Image.open(args.target) | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(equal_height=True): | |
| input_source = gr.Image( | |
| type="pil", | |
| label="Input Source" | |
| ) | |
| input_target = gr.Image( | |
| type="pil", | |
| label="Input Target" | |
| ) | |
| run_button = gr.Button("Generate") | |
| with gr.Column(): | |
| result = gr.Image(type='pil', label='Image Output') | |
| run_button.click( | |
| fn=infer_headswap, | |
| inputs=[input_source, input_target], | |
| outputs=[result] | |
| ) | |
| demo.launch() |