File size: 3,757 Bytes
9507932
04eca22
9507932
04eca22
 
9507932
2ed7990
9507932
 
04eca22
 
 
 
 
 
 
 
 
 
2eb5aa7
04eca22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9507932
 
 
 
 
 
 
 
 
e16ff8f
 
 
 
04eca22
9507932
04eca22
 
 
759e430
 
 
 
 
 
 
 
 
9507932
 
 
 
 
 
 
 
 
 
 
 
 
e16ff8f
9507932
 
e16ff8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
import io

# Assuming you have the U2NET model defined somewhere
from model.u2net import U2NET  # Replace with your actual import path

# Initialize the U2NET model
u2net = U2NET(in_ch=3, out_ch=1)

def load_model(model, model_path, device):
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    return model

# Load the model onto the specified device
u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu")

# Mean and std for normalization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

resize_shape = (320, 320)

transforms = T.Compose([
    T.Resize(resize_shape),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

def prepare_single_image(image, resize, transforms, device):
    """Prepare a single image for prediction."""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    image_resize = image.resize(resize, resample=Image.BILINEAR)
    image_trans = transforms(image_resize)
    image_batch = image_trans.unsqueeze(0).to(device)  # Add batch dimension
    return image_batch

def prepare_prediction(model, image_batch):
    model.eval()
    with torch.no_grad():
        results = model(image_batch)
    mask = torch.squeeze(results[0].cpu(), dim=0)
    return mask.numpy()

def normPRED(predicted_map):
    ma = np.max(predicted_map)
    mi = np.min(predicted_map)
    map_normalize = (predicted_map - mi) / (ma - mi)
    return map_normalize

def apply_mask(image, mask):
    """Apply the mask to the original image and return the result with transparent background."""
    mask = np.squeeze(mask)
    mask = normPRED(mask)
    mask = (mask * 255).astype(np.uint8)
    mask_image = Image.fromarray(mask, mode='L')  # 'L' mode for grayscale
    original_image = image.convert("RGB")
    original_image = original_image.resize(resize_shape, resample=Image.BILINEAR)
    original_image_rgba = original_image.convert("RGBA")
    transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0))
    masked_image = Image.composite(original_image_rgba, transparent_background, mask_image)
    return masked_image

# Streamlit app setup
st.title("Image Segmentation with U2NET")

# Sidebar for file upload and controls
st.sidebar.title("Controls :gear:")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])

# Function to handle image and segmentation display
def fix_image(upload=None):
    if upload:
        image = Image.open(upload)
    else:
        image = Image.open("8.jpg")
    
    # Prepare image for segmentation
    image_batch = prepare_single_image(image, resize_shape, transforms, "cpu")
    prediction_u2net = prepare_prediction(u2net, image_batch)
    masked_image = apply_mask(image, prediction_u2net)
    
    # Display the original and segmented images side by side
    col1, col2 = st.columns(2)
    
    with col1:
        st.image(image, caption="Uploaded Image", use_column_width=True)
    
    with col2:
        st.image(masked_image, caption='Segmented Image', use_column_width=True)

    # Provide download option for segmented image
    buf = io.BytesIO()
    masked_image.save(buf, format='PNG')
    byte_im = buf.getvalue()
    st.sidebar.markdown('### Download Segmented Image')
    st.sidebar.download_button(
        label="Download Segmented Image",
        data=byte_im,
        file_name="segmented_image.png",
        mime="image/png"
    )


if uploaded_file is not None:
    fix_image(upload=uploaded_file)
else:
    fix_image()  # Use default image if none uploaded