File size: 3,684 Bytes
9507932
04eca22
9507932
04eca22
 
9507932
2ed7990
9507932
 
04eca22
 
 
 
 
 
 
 
 
 
2eb5aa7
04eca22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9507932
 
 
 
 
 
 
 
 
 
 
 
 
 
04eca22
9507932
 
 
 
04eca22
 
 
 
9507932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
import io

# Assuming you have the U2NET model defined somewhere
from model.u2net import U2NET  # Replace with your actual import path

# Initialize the U2NET model
u2net = U2NET(in_ch=3, out_ch=1)

def load_model(model, model_path, device):
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    return model

# Load the model onto the specified device
u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu")

# Mean and std for normalization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

resize_shape = (320, 320)

transforms = T.Compose([
    T.Resize(resize_shape),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

def prepare_single_image(image, resize, transforms, device):
    """Prepare a single image for prediction."""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    image_resize = image.resize(resize, resample=Image.BILINEAR)
    image_trans = transforms(image_resize)
    image_batch = image_trans.unsqueeze(0).to(device)  # Add batch dimension
    return image_batch

def prepare_prediction(model, image_batch):
    model.eval()
    with torch.no_grad():
        results = model(image_batch)
    mask = torch.squeeze(results[0].cpu(), dim=0)
    return mask.numpy()

def normPRED(predicted_map):
    ma = np.max(predicted_map)
    mi = np.min(predicted_map)
    map_normalize = (predicted_map - mi) / (ma - mi)
    return map_normalize

def apply_mask(image, mask):
    """Apply the mask to the original image and return the result with transparent background."""
    mask = np.squeeze(mask)
    mask = normPRED(mask)
    mask = (mask * 255).astype(np.uint8)
    mask_image = Image.fromarray(mask, mode='L')  # 'L' mode for grayscale
    original_image = image.convert("RGB")
    original_image = original_image.resize(resize_shape, resample=Image.BILINEAR)
    original_image_rgba = original_image.convert("RGBA")
    transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0))
    masked_image = Image.composite(original_image_rgba, transparent_background, mask_image)
    return masked_image

# Streamlit app setup
st.title("Image Segmentation with U2NET")

# Sidebar for file upload and controls
st.sidebar.title("Controls :gear:")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])

# Function to handle image and segmentation display
def fix_image(upload=None):
    if upload is None:
        st.write("Please upload an image.")
        return
    
    image = Image.open(upload)
    
    # Display the original image
    st.image(image, caption="Uploaded Image", use_column_width=True)

    # Prepare image for segmentation
    image_batch = prepare_single_image(image, resize_shape, transforms, "cpu")
    prediction_u2net = prepare_prediction(u2net, image_batch)
    masked_image = apply_mask(image, prediction_u2net)

    # Display segmented image
    st.image(masked_image, caption='Segmented Image', use_column_width=True)

    # Provide download option for segmented image
    buf = io.BytesIO()
    masked_image.save(buf, format='PNG')
    byte_im = buf.getvalue()
    st.sidebar.markdown('### Download Segmented Image')
    st.sidebar.download_button(
        label="Download Segmented Image",
        data=byte_im,
        file_name="segmented_image.png",
        mime="image/png"
    )

# Handle image processing based on user input
if uploaded_file is not None:
    fix_image(upload=uploaded_file)