File size: 2,091 Bytes
2a2e2dd
1d57c26
 
 
2a2e2dd
 
1d57c26
 
2a2e2dd
1d57c26
 
2a2e2dd
1d57c26
 
2a2e2dd
1d57c26
732979b
2a2e2dd
1d57c26
 
2a2e2dd
1d57c26
 
 
 
2a2e2dd
1d57c26
 
 
 
 
 
 
2a2e2dd
1d57c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a2e2dd
 
 
1d57c26
 
 
2a2e2dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import requests
from io import BytesIO

# U-2-Net architecture (simplified, or import from a .py file if you've saved it)
# You can get the U-2-Net code from https://github.com/xuebinqin/U-2-Net

# For demo, let's download the pre-trained model and use a wrapper instead
from huggingface_hub import hf_hub_download

# Download u2net.pth from HuggingFace Hub
model_path = hf_hub_download(repo_id="BritishWerewolf/U-2-Net", filename="onnx/model.onnx")

# Use a known U2NET implementation (e.g., from https://github.com/xuebinqin/U-2-Net/blob/master/u2net_test.py)
from u2net import U2NET  # Assume you copied the model code as u2net.py

# Load model
model = U2NET(3, 1)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((320, 320)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def segment_dress(image):
    original = image.convert("RGB")
    input_tensor = transform(original).unsqueeze(0)
    
    with torch.no_grad():
        d1, _, _, _, _, _, _ = model(input_tensor)
        pred = d1[0][0]
        pred = (pred - pred.min()) / (pred.max() - pred.min())
        pred_np = pred.cpu().numpy()
    
    # Resize to original size
    pred_resized = Image.fromarray((pred_np * 255).astype(np.uint8)).resize(original.size)
    
    # Apply mask
    mask = np.array(pred_resized) / 255.0
    image_np = np.array(original).astype(np.uint8)
    segmented = (image_np * mask[..., None]).astype(np.uint8)

    return Image.fromarray(segmented)

# Launch Gradio app
gr.Interface(
    fn=segment_dress,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Image(type="pil", label="Segmented Dress"),
    title="Dress Segmentation with U-2-Net",
    description="Segments the dress (or full foreground) using U-2-Net from Hugging Face"
).launch()