File size: 5,435 Bytes
04eca22
 
 
 
 
2ed7990
04eca22
2ed7990
 
 
 
 
04eca22
 
 
 
 
 
 
 
 
 
861a11c
04eca22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ed7990
04eca22
 
 
 
 
 
 
 
 
2ed7990
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04eca22
2ed7990
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
import streamlit as st
from model.u2net import U2NET
from io import BytesIO

# Constants
MAX_FILE_SIZE = 5 * 1024 * 1024  # 5 MB
DEFAULT_IMAGE_PATH = "default_image.png"  # Path to your default image

# Initialize the U2NET model
u2net = U2NET(in_ch=3, out_ch=1)

def load_model(model, model_path, device):
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    return model

# Load the model onto the specified device
u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu")

# Mean and std for normalization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

resize_shape = (320, 320)

transforms = T.Compose([
    T.Resize(resize_shape),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

def prepare_single_image(image, resize, transforms, device):
    """Prepare a single image for prediction."""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    image_resize = image.resize(resize, resample=Image.BILINEAR)
    image_trans = transforms(image_resize)
    image_batch = image_trans.unsqueeze(0).to(device)  # Add batch dimension
    return image_batch

def denorm_image(image_tensor):
    """Denormalize and convert tensor to numpy image."""
    image_tensor = image_tensor.cpu().clone()
    image_tensor = image_tensor * std[:, None, None] + mean[:, None, None]
    image_tensor = torch.clamp(image_tensor * 255., min=0., max=255.)
    image_tensor = image_tensor.permute(1, 2, 0).numpy().astype(np.uint8)
    return image_tensor

def prepare_prediction(model, image_batch):
    model.eval()
    with torch.no_grad():
        results = model(image_batch)
    mask = torch.squeeze(results[0].cpu(), dim=0)
    return mask.numpy()

def normPRED(predicted_map):
    ma = np.max(predicted_map)
    mi = np.min(predicted_map)
    map_normalize = (predicted_map - mi) / (ma - mi)
    return map_normalize

def apply_mask(image, mask):
    """Apply the mask to the original image and return the result with transparent background."""
    # Remove the extra dimension if present
    mask = np.squeeze(mask)
    
    # Normalize and convert the mask to uint8
    mask = normPRED(mask)
    mask = (mask * 255).astype(np.uint8)
    
    # Convert the mask to a PIL image
    mask_image = Image.fromarray(mask, mode='L')  # 'L' mode for grayscale
    
    # Open the original image and resize it
    original_image = image.convert("RGB")
    original_image = original_image.resize(resize_shape, resample=Image.BILINEAR)
    
    # Convert original image to RGBA
    original_image_rgba = original_image.convert("RGBA")
    
    # Create a new image with transparency
    transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0))
    
    # Apply the mask to create an image with alpha channel
    masked_image = Image.composite(original_image_rgba, transparent_background, mask_image)
    
    return masked_image

def segment_image(image):
    """Function to be used for segmentation."""
    # Ensure image is a PIL Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    image_batch = prepare_single_image(image, resize_shape, transforms, "cpu")
    prediction_u2net = prepare_prediction(u2net, image_batch)
    masked_image = apply_mask(image, prediction_u2net)
    return masked_image

def fix_image(upload=None):
    """Processes an uploaded image or a default image."""
    if upload is not None:
        image = Image.open(upload)
    else:
        image = Image.open(DEFAULT_IMAGE_PATH)  # Load default image

    st.image(image, caption='Selected Image', use_column_width=True)
    
    if st.button('Segment Image'):
        masked_image = segment_image(image)
        st.image(masked_image, caption='Segmented Image', use_column_width=True, format="PNG")
        # Save the image to a BytesIO object for downloading
        buffered = BytesIO()
        masked_image.save(buffered, format="PNG")
        st.download_button(
            label="Download Segmented Image",
            data=buffered.getvalue(),
            file_name="segmented_image.png",
            mime="image/png"
        )

# Define the pages
def page_one():
    """Page for image segmentation."""
    st.title("Image Segmentation with U2NET")
    st.write("Upload an image to segment it using the U2NET model. The background of the segmented output will be transparent.")

    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    # Determine image processing
    if uploaded_file is not None:
        if uploaded_file.size > MAX_FILE_SIZE:
            st.error("The uploaded file is too large. Please upload an image smaller than 5MB.")
        else:
            fix_image(upload=uploaded_file)
    else:
        fix_image()  # Use default image if none uploaded

def page_two():
    """Page for other code."""
    st.title("Other Feature")
    st.write("This page is for the second feature you want to implement.")
    # Add other code or features here

# Sidebar navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ("Image Segmentation", "Other Feature"))

# Page selection logic
if page == "Image Segmentation":
    page_one()
elif page == "Other Feature":
    page_two()