# Import necessary libraries and load the model import gradio as gr from gradio.layouts import Column, Row import torch import numpy as np from PIL import Image from torchvision import transforms from unet import UNet # Assuming UNet is the model class MEAN = np.array([0.4732661 , 0.44874457, 0.3948762 ], dtype=np.float32) STD = np.array([0.22674961, 0.22012031, 0.2238305 ], dtype=np.float32) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = UNet(in_chns=3, class_num=2) # Initialize your model model.load_state_dict(torch.load('unet_model.pth', map_location=device)) model = model.to(device) model.eval() # Define your examples examples = [ ["examples/image_1.jpg"], ["examples/image_2.jpg"], ["examples/image_3.jpg"] ] # Define the segmentation function def segment(img): img = Image.fromarray(img.astype('uint8'), 'RGB') original_size = img.size # Store the original size img = img.resize((224, 224), Image.BILINEAR) img = transforms.ToTensor()(img) for i in range(3): img[:, :, i] -= float(MEAN[i]) for i in range(3): img[:, :, i] /= float(STD[i]) img = img.unsqueeze(0).to(device) with torch.no_grad(): output = model(img) output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy() # Resize the mask back to the original image size output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.NEAREST) # Convert the PIL Image back to a numpy array output = np.array(output) binary_mask = np.zeros_like(output) binary_mask[output > 0] = 255 return binary_mask gr.Button('hhhhh') # Create a Gradio interface demo = gr.Interface(fn=segment, inputs="image", outputs="image", title="
S2ME: Spatial-Spectral Mutual Teaching and Ensemble Learning
for Scribble-supervised Polyp Segmentation
", description="MICCAI 2023, the 26th International Conference on Medical Image Computing and Computer Assisted Intervention
An Wang, Mengya Xu, Yang Zhang, Mobarakol Islam, and Hongliang Ren
", allow_flagging=False, examples=examples) # Add your examples here # Launch the interface demo.launch()