|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
import numpy as np |
|
import requests |
|
from io import BytesIO |
|
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
model_path = hf_hub_download(repo_id="BritishWerewolf/U-2-Net", filename="onnx/model.onnx") |
|
|
|
|
|
from u2net import U2NET |
|
|
|
|
|
model = U2NET(3, 1) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((320, 320)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def segment_dress(image): |
|
original = image.convert("RGB") |
|
input_tensor = transform(original).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
d1, _, _, _, _, _, _ = model(input_tensor) |
|
pred = d1[0][0] |
|
pred = (pred - pred.min()) / (pred.max() - pred.min()) |
|
pred_np = pred.cpu().numpy() |
|
|
|
|
|
pred_resized = Image.fromarray((pred_np * 255).astype(np.uint8)).resize(original.size) |
|
|
|
|
|
mask = np.array(pred_resized) / 255.0 |
|
image_np = np.array(original).astype(np.uint8) |
|
segmented = (image_np * mask[..., None]).astype(np.uint8) |
|
|
|
return Image.fromarray(segmented) |
|
|
|
|
|
gr.Interface( |
|
fn=segment_dress, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=gr.Image(type="pil", label="Segmented Dress"), |
|
title="Dress Segmentation with U-2-Net", |
|
description="Segments the dress (or full foreground) using U-2-Net from Hugging Face" |
|
).launch() |