Dress_seg / app.py
gaur3009's picture
Update app.py
732979b verified
raw
history blame
2.09 kB
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import requests
from io import BytesIO
# U-2-Net architecture (simplified, or import from a .py file if you've saved it)
# You can get the U-2-Net code from https://github.com/xuebinqin/U-2-Net
# For demo, let's download the pre-trained model and use a wrapper instead
from huggingface_hub import hf_hub_download
# Download u2net.pth from HuggingFace Hub
model_path = hf_hub_download(repo_id="BritishWerewolf/U-2-Net", filename="onnx/model.onnx")
# Use a known U2NET implementation (e.g., from https://github.com/xuebinqin/U-2-Net/blob/master/u2net_test.py)
from u2net import U2NET # Assume you copied the model code as u2net.py
# Load model
model = U2NET(3, 1)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def segment_dress(image):
original = image.convert("RGB")
input_tensor = transform(original).unsqueeze(0)
with torch.no_grad():
d1, _, _, _, _, _, _ = model(input_tensor)
pred = d1[0][0]
pred = (pred - pred.min()) / (pred.max() - pred.min())
pred_np = pred.cpu().numpy()
# Resize to original size
pred_resized = Image.fromarray((pred_np * 255).astype(np.uint8)).resize(original.size)
# Apply mask
mask = np.array(pred_resized) / 255.0
image_np = np.array(original).astype(np.uint8)
segmented = (image_np * mask[..., None]).astype(np.uint8)
return Image.fromarray(segmented)
# Launch Gradio app
gr.Interface(
fn=segment_dress,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="pil", label="Segmented Dress"),
title="Dress Segmentation with U-2-Net",
description="Segments the dress (or full foreground) using U-2-Net from Hugging Face"
).launch()