Spaces:
Sleeping
Sleeping
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
import torchvision.transforms as T | |
import streamlit as st | |
from model.u2net import U2NET | |
from io import BytesIO | |
# Constants | |
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5 MB | |
DEFAULT_IMAGE_PATH = "default_image.png" # Path to your default image | |
# Initialize the U2NET model | |
u2net = U2NET(in_ch=3, out_ch=1) | |
def load_model(model, model_path, device): | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model = model.to(device) | |
return model | |
# Load the model onto the specified device | |
u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu") | |
# Mean and std for normalization | |
mean = torch.tensor([0.485, 0.456, 0.406]) | |
std = torch.tensor([0.229, 0.224, 0.225]) | |
resize_shape = (320, 320) | |
transforms = T.Compose([ | |
T.Resize(resize_shape), | |
T.ToTensor(), | |
T.Normalize(mean=mean, std=std) | |
]) | |
def prepare_single_image(image, resize, transforms, device): | |
"""Prepare a single image for prediction.""" | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image = image.convert("RGB") | |
image_resize = image.resize(resize, resample=Image.BILINEAR) | |
image_trans = transforms(image_resize) | |
image_batch = image_trans.unsqueeze(0).to(device) # Add batch dimension | |
return image_batch | |
def denorm_image(image_tensor): | |
"""Denormalize and convert tensor to numpy image.""" | |
image_tensor = image_tensor.cpu().clone() | |
image_tensor = image_tensor * std[:, None, None] + mean[:, None, None] | |
image_tensor = torch.clamp(image_tensor * 255., min=0., max=255.) | |
image_tensor = image_tensor.permute(1, 2, 0).numpy().astype(np.uint8) | |
return image_tensor | |
def prepare_prediction(model, image_batch): | |
model.eval() | |
with torch.no_grad(): | |
results = model(image_batch) | |
mask = torch.squeeze(results[0].cpu(), dim=0) | |
return mask.numpy() | |
def normPRED(predicted_map): | |
ma = np.max(predicted_map) | |
mi = np.min(predicted_map) | |
map_normalize = (predicted_map - mi) / (ma - mi) | |
return map_normalize | |
def apply_mask(image, mask): | |
"""Apply the mask to the original image and return the result with transparent background.""" | |
# Remove the extra dimension if present | |
mask = np.squeeze(mask) | |
# Normalize and convert the mask to uint8 | |
mask = normPRED(mask) | |
mask = (mask * 255).astype(np.uint8) | |
# Convert the mask to a PIL image | |
mask_image = Image.fromarray(mask, mode='L') # 'L' mode for grayscale | |
# Open the original image and resize it | |
original_image = image.convert("RGB") | |
original_image = original_image.resize(resize_shape, resample=Image.BILINEAR) | |
# Convert original image to RGBA | |
original_image_rgba = original_image.convert("RGBA") | |
# Create a new image with transparency | |
transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0)) | |
# Apply the mask to create an image with alpha channel | |
masked_image = Image.composite(original_image_rgba, transparent_background, mask_image) | |
return masked_image | |
def segment_image(image): | |
"""Function to be used for segmentation.""" | |
# Ensure image is a PIL Image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image_batch = prepare_single_image(image, resize_shape, transforms, "cpu") | |
prediction_u2net = prepare_prediction(u2net, image_batch) | |
masked_image = apply_mask(image, prediction_u2net) | |
return masked_image | |
def fix_image(upload=None): | |
"""Processes an uploaded image or a default image.""" | |
if upload is not None: | |
image = Image.open(upload) | |
else: | |
image = Image.open(DEFAULT_IMAGE_PATH) # Load default image | |
st.image(image, caption='Selected Image', use_column_width=True) | |
if st.button('Segment Image'): | |
masked_image = segment_image(image) | |
st.image(masked_image, caption='Segmented Image', use_column_width=True, format="PNG") | |
# Save the image to a BytesIO object for downloading | |
buffered = BytesIO() | |
masked_image.save(buffered, format="PNG") | |
st.download_button( | |
label="Download Segmented Image", | |
data=buffered.getvalue(), | |
file_name="segmented_image.png", | |
mime="image/png" | |
) | |
# Define the pages | |
def page_one(): | |
"""Page for image segmentation.""" | |
st.title("Image Segmentation with U2NET") | |
st.write("Upload an image to segment it using the U2NET model. The background of the segmented output will be transparent.") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
# Determine image processing | |
if uploaded_file is not None: | |
if uploaded_file.size > MAX_FILE_SIZE: | |
st.error("The uploaded file is too large. Please upload an image smaller than 5MB.") | |
else: | |
fix_image(upload=uploaded_file) | |
else: | |
fix_image() # Use default image if none uploaded | |
def page_two(): | |
"""Page for other code.""" | |
st.title("Other Feature") | |
st.write("This page is for the second feature you want to implement.") | |
# Add other code or features here | |
# Sidebar navigation | |
st.sidebar.title("Navigation") | |
page = st.sidebar.radio("Go to", ("Image Segmentation", "Other Feature")) | |
# Page selection logic | |
if page == "Image Segmentation": | |
page_one() | |
elif page == "Other Feature": | |
page_two() | |