File size: 1,727 Bytes
a37c14e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Import necessary libraries and load the model
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from unet import UNet  # Assuming UNet is the model class

MEAN = np.array([0.4732661 , 0.44874457, 0.3948762 ], dtype=np.float32)
STD = np.array([0.22674961, 0.22012031, 0.2238305 ], dtype=np.float32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(in_chns=3, class_num=2)  # Initialize your model
model.load_state_dict(torch.load('unet_model.pth'))

model = model.to(device)
model.eval()

# Define the segmentation function
def segment(img):
    img = Image.fromarray(img.astype('uint8'), 'RGB')
    original_size = img.size  # Store the original size
    
    img = img.resize((224, 224), Image.BILINEAR)
    img = transforms.ToTensor()(img)
    for i in range(3):
        img[:, :, i] -= float(MEAN[i])
    for i in range(3):
        img[:, :, i] /= float(STD[i])
    
    img = img.unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(img)
    output = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze().cpu().numpy()
    
    # Resize the mask back to the original image size
    output = Image.fromarray(output.astype('uint8')).resize(original_size, resample=Image.BILINEAR)
    
    # Convert the PIL Image back to a numpy array
    output = np.array(output)
    binary_mask = np.zeros_like(output)
    binary_mask[output > 0] = 255
    
    return binary_mask

# Create a Gradio interface
iface = gr.Interface(fn=segment, inputs="image", outputs="image", title="Segmentation Model",
    description="Segment objects in an image.",
    allow_flagging=False)

# Launch the interface
iface.launch()