Spaces:
Sleeping
Sleeping
Commit
·
cec2d7c
1
Parent(s):
f13580c
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# Import necessary libraries and load the model
|
2 |
import gradio as gr
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
@@ -17,6 +18,14 @@ model.load_state_dict(torch.load('unet_model.pth', map_location=device))
|
|
17 |
model = model.to(device)
|
18 |
model.eval()
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
# Define the segmentation function
|
21 |
def segment(img):
|
22 |
img = Image.fromarray(img.astype('uint8'), 'RGB')
|
@@ -35,7 +44,7 @@ def segment(img):
|
|
35 |
output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
|
36 |
|
37 |
# Resize the mask back to the original image size
|
38 |
-
output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.
|
39 |
|
40 |
# Convert the PIL Image back to a numpy array
|
41 |
output = np.array(output)
|
@@ -44,10 +53,18 @@ def segment(img):
|
|
44 |
|
45 |
return binary_mask
|
46 |
|
|
|
|
|
47 |
# Create a Gradio interface
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Launch the interface
|
53 |
-
|
|
|
|
1 |
# Import necessary libraries and load the model
|
2 |
import gradio as gr
|
3 |
+
from gradio.layouts import Column, Row
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
|
|
18 |
model = model.to(device)
|
19 |
model.eval()
|
20 |
|
21 |
+
|
22 |
+
# Define your examples
|
23 |
+
examples = [
|
24 |
+
["examples/image_1.jpg", "examples/gt_1.png", "examples/scribble_1.png"],
|
25 |
+
# ["path_to_image2", "path_to_gt_mask2", "path_to_scribble_mask2"],
|
26 |
+
# ["path_to_image3", "path_to_gt_mask3", "path_to_scribble_mask3"]
|
27 |
+
]
|
28 |
+
|
29 |
# Define the segmentation function
|
30 |
def segment(img):
|
31 |
img = Image.fromarray(img.astype('uint8'), 'RGB')
|
|
|
44 |
output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
|
45 |
|
46 |
# Resize the mask back to the original image size
|
47 |
+
output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.NEAREST)
|
48 |
|
49 |
# Convert the PIL Image back to a numpy array
|
50 |
output = np.array(output)
|
|
|
53 |
|
54 |
return binary_mask
|
55 |
|
56 |
+
|
57 |
+
gr.Button('hhhhh')
|
58 |
# Create a Gradio interface
|
59 |
+
demo = gr.Interface(fn=segment,
|
60 |
+
inputs="image",
|
61 |
+
outputs="image",
|
62 |
+
title="<p>S<sup>2</sup>ME: Spatial-Spectral Mutual Teaching and Ensemble Learning</p><p>for Scribble-supervised Polyp Segmentation</p>",
|
63 |
+
description="<p>MICCAI 2023, the 26th International Conference on Medical Image Computing and Computer Assisted Intervention</p><p>An Wang, Mengya Xu, Yang Zhang, Mobarakol Islam, and Hongliang Ren</p>",
|
64 |
+
allow_flagging=False,
|
65 |
+
examples=examples) # Add your examples here
|
66 |
+
|
67 |
|
68 |
# Launch the interface
|
69 |
+
demo.launch()
|
70 |
+
|