File size: 3,753 Bytes
04eca22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import os
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
from model.u2net import U2NET

# Initialize the U2NET model
u2net = U2NET(in_ch=3, out_ch=1)

def load_model(model, model_path, device):
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    return model

# Load the model onto the specified device
u2net = load_model(model=u2net, model_path="/content/u2net.pth", device="cpu")

# Mean and std for normalization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

resize_shape = (320, 320)

transforms = T.Compose([
    T.Resize(resize_shape),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

def prepare_single_image(image, resize, transforms, device):
    """Prepare a single image for prediction."""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    image_resize = image.resize(resize, resample=Image.BILINEAR)
    image_trans = transforms(image_resize)
    image_batch = image_trans.unsqueeze(0).to(device)  # Add batch dimension
    return image_batch

def denorm_image(image_tensor):
    """Denormalize and convert tensor to numpy image."""
    image_tensor = image_tensor.cpu().clone()
    image_tensor = image_tensor * std[:, None, None] + mean[:, None, None]
    image_tensor = torch.clamp(image_tensor * 255., min=0., max=255.)
    image_tensor = image_tensor.permute(1, 2, 0).numpy().astype(np.uint8)
    return image_tensor

def prepare_prediction(model, image_batch):
    model.eval()
    with torch.no_grad():
        results = model(image_batch)
    mask = torch.squeeze(results[0].cpu(), dim=0)
    return mask.numpy()

def normPRED(predicted_map):
    ma = np.max(predicted_map)
    mi = np.min(predicted_map)
    map_normalize = (predicted_map - mi) / (ma - mi)
    return map_normalize

def apply_mask(image, mask):
    """Apply the mask to the original image and return the result with transparent background."""
    # Remove the extra dimension if present
    mask = np.squeeze(mask)
    
    # Normalize and convert the mask to uint8
    mask = normPRED(mask)
    mask = (mask * 255).astype(np.uint8)
    
    # Convert the mask to a PIL image
    mask_image = Image.fromarray(mask, mode='L')  # 'L' mode for grayscale
    
    # Open the original image and resize it
    original_image = image.convert("RGB")
    original_image = original_image.resize(resize_shape, resample=Image.BILINEAR)
    
    # Convert original image to RGBA
    original_image_rgba = original_image.convert("RGBA")
    
    # Create a new image with transparency
    transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0))
    
    # Apply the mask to create an image with alpha channel
    masked_image = Image.composite(original_image_rgba, transparent_background, mask_image)
    
    return masked_image

def segment_image(image):
    """Function to be used with Gradio for segmentation."""
    # Ensure image is a PIL Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    image_batch = prepare_single_image(image, resize_shape, transforms, "cpu")
    prediction_u2net = prepare_prediction(u2net, image_batch)
    masked_image = apply_mask(image, prediction_u2net)
    return masked_image

# Define the Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Image(type="pil",format="png"),
    title="Image Segmentation with U2NET",
    description="Upload an image to segment it using the U2NET model. The background of the segmented output will be transparent."
)

# Launch the interface
iface.launch()