Spaces:
Sleeping
Sleeping
File size: 1,727 Bytes
a37c14e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# Import necessary libraries and load the model
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from unet import UNet # Assuming UNet is the model class
MEAN = np.array([0.4732661 , 0.44874457, 0.3948762 ], dtype=np.float32)
STD = np.array([0.22674961, 0.22012031, 0.2238305 ], dtype=np.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_chns=3, class_num=2) # Initialize your model
model.load_state_dict(torch.load('unet_model.pth'))
model = model.to(device)
model.eval()
# Define the segmentation function
def segment(img):
img = Image.fromarray(img.astype('uint8'), 'RGB')
original_size = img.size # Store the original size
img = img.resize((224, 224), Image.BILINEAR)
img = transforms.ToTensor()(img)
for i in range(3):
img[:, :, i] -= float(MEAN[i])
for i in range(3):
img[:, :, i] /= float(STD[i])
img = img.unsqueeze(0).to(device)
with torch.no_grad():
output = model(img)
output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
# Resize the mask back to the original image size
output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.BILINEAR)
# Convert the PIL Image back to a numpy array
output = np.array(output)
binary_mask = np.zeros_like(output)
binary_mask[output > 0] = 255
return binary_mask
# Create a Gradio interface
iface = gr.Interface(fn=segment, inputs="image", outputs="image", title="Segmentation Model",
description="Segment objects in an image.",
allow_flagging=False)
# Launch the interface
iface.launch() |