File size: 2,363 Bytes
a37c14e
 
cec2d7c
a37c14e
 
 
 
 
 
 
 
 
 
 
 
f13580c
a37c14e
 
 
 
cec2d7c
 
 
0b0b458
 
 
cec2d7c
 
a37c14e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec2d7c
a37c14e
 
 
 
 
 
 
 
cec2d7c
 
a37c14e
cec2d7c
 
 
 
 
 
 
 
a37c14e
 
cec2d7c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Import necessary libraries and load the model
import gradio as gr
from gradio.layouts import Column, Row
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from unet import UNet  # Assuming UNet is the model class

MEAN = np.array([0.4732661 , 0.44874457, 0.3948762 ], dtype=np.float32)
STD = np.array([0.22674961, 0.22012031, 0.2238305 ], dtype=np.float32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(in_chns=3, class_num=2)  # Initialize your model
model.load_state_dict(torch.load('unet_model.pth', map_location=device))

model = model.to(device)
model.eval()


# Define your examples
examples = [
    ["examples/image_1.jpg"],
    ["examples/image_2.jpg"],
    ["examples/image_3.jpg"]
]

# Define the segmentation function
def segment(img):
    img = Image.fromarray(img.astype('uint8'), 'RGB')
    original_size = img.size  # Store the original size
    
    img = img.resize((224, 224), Image.BILINEAR)
    img = transforms.ToTensor()(img)
    for i in range(3):
        img[:, :, i] -= float(MEAN[i])
    for i in range(3):
        img[:, :, i] /= float(STD[i])
    
    img = img.unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(img)
    output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
    
    # Resize the mask back to the original image size
    output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.NEAREST)
    
    # Convert the PIL Image back to a numpy array
    output = np.array(output)
    binary_mask = np.zeros_like(output)
    binary_mask[output > 0] = 255
    
    return binary_mask


gr.Button('hhhhh')
# Create a Gradio interface
demo = gr.Interface(fn=segment, 
                    inputs="image",
                    outputs="image",
                    title="<p>S<sup>2</sup>ME: Spatial-Spectral Mutual Teaching and Ensemble Learning</p><p>for Scribble-supervised Polyp Segmentation</p>",
                    description="<p>MICCAI 2023, the 26th International Conference on Medical Image Computing and Computer Assisted Intervention</p><p>An Wang, Mengya Xu, Yang Zhang, Mobarakol Islam, and Hongliang Ren</p>",
                    allow_flagging=False,
                    examples=examples)  # Add your examples here


# Launch the interface
demo.launch()