gaur3009 commited on
Commit
1d57c26
·
verified ·
1 Parent(s): bca0fa8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -22
app.py CHANGED
@@ -1,36 +1,62 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
  from PIL import Image
4
  import numpy as np
 
 
5
 
6
- # Load U-2-Net segmentation pipeline
7
- pipe = pipeline("image-segmentation", model="BritishWerewolf/U-2-Net")
8
 
9
- # Segmentation function for Gradio
10
- def segment_dress(image: Image.Image):
11
- # Run U-2-Net pipeline
12
- segments = pipe(image)
13
 
14
- # We'll assume the first segment is the foreground (person+clothes)
15
- if not segments:
16
- return image
17
 
18
- # Load and apply mask
19
- mask = Image.open(segments[0]["mask"]).convert("L").resize(image.size)
20
- mask_np = np.array(mask) / 255.0
21
- image_np = np.array(image).astype(np.uint8)
22
 
23
- # Apply mask to image (keep only masked region)
24
- segmented = (image_np * mask_np[..., None]).astype(np.uint8)
25
- segmented_img = Image.fromarray(segmented)
 
26
 
27
- return segmented_img
 
 
 
 
 
 
28
 
29
- # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  gr.Interface(
31
  fn=segment_dress,
32
  inputs=gr.Image(type="pil", label="Upload Image"),
33
- outputs=gr.Image(type="pil", label="Segmented Dress Region"),
34
- title="Dress Segmentation using U-2-Net",
35
- description="Upload an image. The U-2-Net model will segment the main foreground (usually a person wearing a dress)."
36
  ).launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
  from PIL import Image
6
  import numpy as np
7
+ import requests
8
+ from io import BytesIO
9
 
10
+ # U-2-Net architecture (simplified, or import from a .py file if you've saved it)
11
+ # You can get the U-2-Net code from https://github.com/xuebinqin/U-2-Net
12
 
13
+ # For demo, let's download the pre-trained model and use a wrapper instead
14
+ from huggingface_hub import hf_hub_download
 
 
15
 
16
+ # Download u2net.pth from HuggingFace Hub
17
+ model_path = hf_hub_download(repo_id="BritishWerewolf/U-2-Net", filename="u2net.pth")
 
18
 
19
+ # Use a known U2NET implementation (e.g., from https://github.com/xuebinqin/U-2-Net/blob/master/u2net_test.py)
20
+ from u2net import U2NET # Assume you copied the model code as u2net.py
 
 
21
 
22
+ # Load model
23
+ model = U2NET(3, 1)
24
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
25
+ model.eval()
26
 
27
+ # Preprocessing
28
+ transform = transforms.Compose([
29
+ transforms.Resize((320, 320)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
32
+ std=[0.229, 0.224, 0.225])
33
+ ])
34
 
35
+ def segment_dress(image):
36
+ original = image.convert("RGB")
37
+ input_tensor = transform(original).unsqueeze(0)
38
+
39
+ with torch.no_grad():
40
+ d1, _, _, _, _, _, _ = model(input_tensor)
41
+ pred = d1[0][0]
42
+ pred = (pred - pred.min()) / (pred.max() - pred.min())
43
+ pred_np = pred.cpu().numpy()
44
+
45
+ # Resize to original size
46
+ pred_resized = Image.fromarray((pred_np * 255).astype(np.uint8)).resize(original.size)
47
+
48
+ # Apply mask
49
+ mask = np.array(pred_resized) / 255.0
50
+ image_np = np.array(original).astype(np.uint8)
51
+ segmented = (image_np * mask[..., None]).astype(np.uint8)
52
+
53
+ return Image.fromarray(segmented)
54
+
55
+ # Launch Gradio app
56
  gr.Interface(
57
  fn=segment_dress,
58
  inputs=gr.Image(type="pil", label="Upload Image"),
59
+ outputs=gr.Image(type="pil", label="Segmented Dress"),
60
+ title="Dress Segmentation with U-2-Net",
61
+ description="Segments the dress (or full foreground) using U-2-Net from Hugging Face"
62
  ).launch()