lofrienger commited on
Commit
cec2d7c
·
1 Parent(s): f13580c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # Import necessary libraries and load the model
2
  import gradio as gr
 
3
  import torch
4
  import numpy as np
5
  from PIL import Image
@@ -17,6 +18,14 @@ model.load_state_dict(torch.load('unet_model.pth', map_location=device))
17
  model = model.to(device)
18
  model.eval()
19
 
 
 
 
 
 
 
 
 
20
  # Define the segmentation function
21
  def segment(img):
22
  img = Image.fromarray(img.astype('uint8'), 'RGB')
@@ -35,7 +44,7 @@ def segment(img):
35
  output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
36
 
37
  # Resize the mask back to the original image size
38
- output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.BILINEAR)
39
 
40
  # Convert the PIL Image back to a numpy array
41
  output = np.array(output)
@@ -44,10 +53,18 @@ def segment(img):
44
 
45
  return binary_mask
46
 
 
 
47
  # Create a Gradio interface
48
- iface = gr.Interface(fn=segment, inputs="image", outputs="image", title="Segmentation Model",
49
- description="Segment objects in an image.",
50
- allow_flagging=False)
 
 
 
 
 
51
 
52
  # Launch the interface
53
- iface.launch()
 
 
1
  # Import necessary libraries and load the model
2
  import gradio as gr
3
+ from gradio.layouts import Column, Row
4
  import torch
5
  import numpy as np
6
  from PIL import Image
 
18
  model = model.to(device)
19
  model.eval()
20
 
21
+
22
+ # Define your examples
23
+ examples = [
24
+ ["examples/image_1.jpg", "examples/gt_1.png", "examples/scribble_1.png"],
25
+ # ["path_to_image2", "path_to_gt_mask2", "path_to_scribble_mask2"],
26
+ # ["path_to_image3", "path_to_gt_mask3", "path_to_scribble_mask3"]
27
+ ]
28
+
29
  # Define the segmentation function
30
  def segment(img):
31
  img = Image.fromarray(img.astype('uint8'), 'RGB')
 
44
  output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
45
 
46
  # Resize the mask back to the original image size
47
+ output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.NEAREST)
48
 
49
  # Convert the PIL Image back to a numpy array
50
  output = np.array(output)
 
53
 
54
  return binary_mask
55
 
56
+
57
+ gr.Button('hhhhh')
58
  # Create a Gradio interface
59
+ demo = gr.Interface(fn=segment,
60
+ inputs="image",
61
+ outputs="image",
62
+ title="<p>S<sup>2</sup>ME: Spatial-Spectral Mutual Teaching and Ensemble Learning</p><p>for Scribble-supervised Polyp Segmentation</p>",
63
+ description="<p>MICCAI 2023, the 26th International Conference on Medical Image Computing and Computer Assisted Intervention</p><p>An Wang, Mengya Xu, Yang Zhang, Mobarakol Islam, and Hongliang Ren</p>",
64
+ allow_flagging=False,
65
+ examples=examples) # Add your examples here
66
+
67
 
68
  # Launch the interface
69
+ demo.launch()
70
+