gaur3009 commited on
Commit
3e98665
·
verified ·
1 Parent(s): ec2d9ce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from cloth_segmentation.networks.u2net import U2NET # Import U²-Net model
8
+
9
+ # Load U²-Net model
10
+ model_path = "u2net_model/u2net.pth"
11
+ model = U2NET(3, 1)
12
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
13
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
14
+ model.load_state_dict(state_dict)
15
+ model.eval()
16
+
17
+ def segment_dress(image_np):
18
+ """Detects dress using U²-Net and creates a binary mask."""
19
+
20
+ # Convert image to tensor
21
+ transform_pipeline = transforms.Compose([
22
+ transforms.ToTensor(),
23
+ transforms.Resize((320, 320))
24
+ ])
25
+
26
+ image = Image.fromarray(image_np).convert("RGB")
27
+ input_tensor = transform_pipeline(image).unsqueeze(0)
28
+
29
+ # U²-Net inference
30
+ with torch.no_grad():
31
+ output = model(input_tensor)[0][0].squeeze().cpu().numpy()
32
+
33
+ # Generate binary mask
34
+ dress_mask = (output > 0.5).astype(np.uint8) * 255
35
+ dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
36
+
37
+ return dress_mask
38
+
39
+ def remove_background(image_np):
40
+ """Removes background and replaces it with white while keeping the dress."""
41
+
42
+ # Generate dress mask
43
+ mask = segment_dress(image_np)
44
+
45
+ # Make background white
46
+ white_bg = np.ones_like(image_np) * 255 # White background
47
+ segmented_dress = np.where(mask[..., None] > 128, image_np, white_bg)
48
+
49
+ return Image.fromarray(segmented_dress)
50
+
51
+ # Gradio Interface
52
+ demo = gr.Interface(
53
+ fn=remove_background,
54
+ inputs=gr.Image(type="numpy", label="Upload Dress Image"),
55
+ outputs=gr.Image(type="pil", label="Dress with White Background"),
56
+ title="Dress Segmentation & Background Removal",
57
+ description="Upload a dress image, and this AI model will detect the dress and replace the background with white."
58
+ )
59
+
60
+ if __name__ == "__main__":
61
+ demo.launch()