polypseg / app.py
lofrienger's picture
Update app.py
0b0b458
# Import necessary libraries and load the model
import gradio as gr
from gradio.layouts import Column, Row
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from unet import UNet # Assuming UNet is the model class
MEAN = np.array([0.4732661 , 0.44874457, 0.3948762 ], dtype=np.float32)
STD = np.array([0.22674961, 0.22012031, 0.2238305 ], dtype=np.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_chns=3, class_num=2) # Initialize your model
model.load_state_dict(torch.load('unet_model.pth', map_location=device))
model = model.to(device)
model.eval()
# Define your examples
examples = [
["examples/image_1.jpg"],
["examples/image_2.jpg"],
["examples/image_3.jpg"]
]
# Define the segmentation function
def segment(img):
img = Image.fromarray(img.astype('uint8'), 'RGB')
original_size = img.size # Store the original size
img = img.resize((224, 224), Image.BILINEAR)
img = transforms.ToTensor()(img)
for i in range(3):
img[:, :, i] -= float(MEAN[i])
for i in range(3):
img[:, :, i] /= float(STD[i])
img = img.unsqueeze(0).to(device)
with torch.no_grad():
output = model(img)
output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
# Resize the mask back to the original image size
output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.NEAREST)
# Convert the PIL Image back to a numpy array
output = np.array(output)
binary_mask = np.zeros_like(output)
binary_mask[output > 0] = 255
return binary_mask
gr.Button('hhhhh')
# Create a Gradio interface
demo = gr.Interface(fn=segment,
inputs="image",
outputs="image",
title="<p>S<sup>2</sup>ME: Spatial-Spectral Mutual Teaching and Ensemble Learning</p><p>for Scribble-supervised Polyp Segmentation</p>",
description="<p>MICCAI 2023, the 26th International Conference on Medical Image Computing and Computer Assisted Intervention</p><p>An Wang, Mengya Xu, Yang Zhang, Mobarakol Islam, and Hongliang Ren</p>",
allow_flagging=False,
examples=examples) # Add your examples here
# Launch the interface
demo.launch()