Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net model
|
8 |
+
|
9 |
+
# Load U²-Net model
|
10 |
+
model_path = "u2net_model/u2net.pth"
|
11 |
+
model = U2NET(3, 1)
|
12 |
+
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
13 |
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
|
14 |
+
model.load_state_dict(state_dict)
|
15 |
+
model.eval()
|
16 |
+
|
17 |
+
def segment_dress(image_np):
|
18 |
+
"""Detects dress using U²-Net and creates a binary mask."""
|
19 |
+
|
20 |
+
# Convert image to tensor
|
21 |
+
transform_pipeline = transforms.Compose([
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Resize((320, 320))
|
24 |
+
])
|
25 |
+
|
26 |
+
image = Image.fromarray(image_np).convert("RGB")
|
27 |
+
input_tensor = transform_pipeline(image).unsqueeze(0)
|
28 |
+
|
29 |
+
# U²-Net inference
|
30 |
+
with torch.no_grad():
|
31 |
+
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
|
32 |
+
|
33 |
+
# Generate binary mask
|
34 |
+
dress_mask = (output > 0.5).astype(np.uint8) * 255
|
35 |
+
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
|
36 |
+
|
37 |
+
return dress_mask
|
38 |
+
|
39 |
+
def remove_background(image_np):
|
40 |
+
"""Removes background and replaces it with white while keeping the dress."""
|
41 |
+
|
42 |
+
# Generate dress mask
|
43 |
+
mask = segment_dress(image_np)
|
44 |
+
|
45 |
+
# Make background white
|
46 |
+
white_bg = np.ones_like(image_np) * 255 # White background
|
47 |
+
segmented_dress = np.where(mask[..., None] > 128, image_np, white_bg)
|
48 |
+
|
49 |
+
return Image.fromarray(segmented_dress)
|
50 |
+
|
51 |
+
# Gradio Interface
|
52 |
+
demo = gr.Interface(
|
53 |
+
fn=remove_background,
|
54 |
+
inputs=gr.Image(type="numpy", label="Upload Dress Image"),
|
55 |
+
outputs=gr.Image(type="pil", label="Dress with White Background"),
|
56 |
+
title="Dress Segmentation & Background Removal",
|
57 |
+
description="Upload a dress image, and this AI model will detect the dress and replace the background with white."
|
58 |
+
)
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
demo.launch()
|