gaur3009 commited on
Commit
159ad53
·
verified ·
1 Parent(s): 0210521

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -47
app.py CHANGED
@@ -1,62 +1,43 @@
1
  import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import transforms
5
- from PIL import Image
6
  import numpy as np
 
7
  import requests
8
- from io import BytesIO
9
-
10
- # U-2-Net architecture (simplified, or import from a .py file if you've saved it)
11
- # You can get the U-2-Net code from https://github.com/xuebinqin/U-2-Net
12
-
13
- # For demo, let's download the pre-trained model and use a wrapper instead
14
- from huggingface_hub import hf_hub_download
15
-
16
- # Download u2net.pth from HuggingFace Hub
17
- model_path = hf_hub_download(repo_id="BritishWerewolf/U-2-Net", filename="onnx/model.onnx")
18
-
19
- # Use a known U2NET implementation (e.g., from https://github.com/xuebinqin/U-2-Net/blob/master/u2net_test.py)
20
- from u2net import U2NET # Assume you copied the model code as u2net.py
21
 
22
- # Load model
23
- model = U2NET(3, 1)
24
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=False))
25
- model.eval()
26
 
27
- # Preprocessing
28
- transform = transforms.Compose([
29
- transforms.Resize((320, 320)),
30
- transforms.ToTensor(),
31
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
32
- std=[0.229, 0.224, 0.225])
33
- ])
34
 
 
35
  def segment_dress(image):
36
- original = image.convert("RGB")
37
- input_tensor = transform(original).unsqueeze(0)
38
-
39
- with torch.no_grad():
40
- d1, _, _, _, _, _, _ = model(input_tensor)
41
- pred = d1[0][0]
42
- pred = (pred - pred.min()) / (pred.max() - pred.min())
43
- pred_np = pred.cpu().numpy()
44
 
45
- # Resize to original size
46
- pred_resized = Image.fromarray((pred_np * 255).astype(np.uint8)).resize(original.size)
47
-
48
- # Apply mask
49
- mask = np.array(pred_resized) / 255.0
50
- image_np = np.array(original).astype(np.uint8)
51
- segmented = (image_np * mask[..., None]).astype(np.uint8)
 
52
 
53
- return Image.fromarray(segmented)
54
 
55
- # Launch Gradio app
56
  gr.Interface(
57
  fn=segment_dress,
58
  inputs=gr.Image(type="pil", label="Upload Image"),
59
  outputs=gr.Image(type="pil", label="Segmented Dress"),
60
- title="Dress Segmentation with U-2-Net",
61
- description="Segments the dress (or full foreground) using U-2-Net from Hugging Face"
62
  ).launch()
 
1
  import gradio as gr
2
+ import onnxruntime as ort
 
 
 
3
  import numpy as np
4
+ from PIL import Image
5
  import requests
6
+ from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Load ONNX model
9
+ ort_session = ort.InferenceSession("model.onnx") # Ensure model.onnx is in your app folder
 
 
10
 
11
+ # Preprocessing function
12
+ def preprocess(image):
13
+ image = image.resize((320, 320)).convert("RGB")
14
+ image_np = np.array(image).astype(np.float32) / 255.0
15
+ image_np = image_np.transpose(2, 0, 1) # HWC -> CHW
16
+ image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
17
+ return image_np
18
 
19
+ # Inference + Postprocessing
20
  def segment_dress(image):
21
+ input_tensor = preprocess(image)
22
+ inputs = {ort_session.get_inputs()[0].name: input_tensor}
23
+ outputs = ort_session.run(None, inputs)
 
 
 
 
 
24
 
25
+ pred = outputs[0][0][0]
26
+ pred = (pred - pred.min()) / (pred.max() - pred.min())
27
+ pred_img = Image.fromarray((pred * 255).astype(np.uint8)).resize(image.size)
28
+
29
+ # Apply mask to image
30
+ image_np = np.array(image.convert("RGB"))
31
+ mask = np.array(pred_img).astype(np.float32) / 255.0
32
+ masked = (image_np * mask[..., None]).astype(np.uint8)
33
 
34
+ return Image.fromarray(masked)
35
 
36
+ # Gradio app
37
  gr.Interface(
38
  fn=segment_dress,
39
  inputs=gr.Image(type="pil", label="Upload Image"),
40
  outputs=gr.Image(type="pil", label="Segmented Dress"),
41
+ title="U-2-Net Dress Segmentation (ONNX)",
42
+ description="Upload an image to segment foreground using U-2-Net ONNX model"
43
  ).launch()