faceparser / app.py
leonelhs's picture
init app
84ce4ca
raw
history blame
2.59 kB
import os
import gradio as gr
import numpy as np
import torch
from PIL import Image
from bisnet import BiSeNet
from huggingface_hub import snapshot_download
from utils import vis_parsing_maps, decode_segmentation_masks, image_to_tensor
os.system("pip freeze")
REPO_ID = "leonelhs/faceparser"
MODEL_NAME = "79999_iter.pth"
model = BiSeNet(n_classes=19)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
snapshot_folder = snapshot_download(repo_id=REPO_ID)
model_path = os.path.join(snapshot_folder, MODEL_NAME)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
def makeOverlay(image, mask):
prediction_mask = np.asarray(mask)
image = image.resize((512, 512), Image.BILINEAR)
dark_map, overlay = vis_parsing_maps(image, prediction_mask)
colormap = decode_segmentation_masks(dark_map)
return overlay, colormap
def predict(image):
with torch.no_grad():
image = image.resize((512, 512), Image.BILINEAR)
input_tensor = image_to_tensor(image)
input_tensor = torch.unsqueeze(input_tensor, 0)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
output = model(input_tensor)[0]
return output.squeeze(0).cpu().numpy().argmax(0)
def inference(image):
mask = predict(image)
overlay, colormap = makeOverlay(image, mask)
return overlay
title = "Face Parser"
description = r"""
## Image face parser for research
This is an implementation of <a href='https://github.com/zllrunning/face-parsing.PyTorch' target='_blank'>face-parsing.PyTorch</a>.
It has no any particular purpose than start research on AI models.
"""
article = r"""
Questions, doubts, comments, please email πŸ“§ `[email protected]`
This demo is running on a CPU, if you like this project please make us a donation to run on a GPU or just give us a <a href='https://github.com/leonelhs/zeroscratches/' target='_blank'>Github ⭐</a>
<a href="https://www.buymeacoffee.com/leonelhs"><img src="https://img.buymeacoffee.com/button-api/?text=Buy me a coffee&emoji=&slug=leonelhs&button_colour=FFDD00&font_colour=000000&font_family=Cookie&outline_colour=000000&coffee_colour=ffffff" /></a>
<center><img src='https://visitor-badge.glitch.me/badge?page_id=zeroscratches.visitor-badge' alt='visitor badge'></center>
"""
demo = gr.Interface(
inference, [
gr.Image(type="pil", label="Input"),
], [
gr.Image(type="numpy", label="Image face parsed")
],
title=title,
description=description,
article=article)
demo.queue().launch()