import streamlit as st import numpy as np from PIL import Image import torch import torchvision.transforms as T import io # Assuming you have the U2NET model defined somewhere from model.u2net import U2NET # Replace with your actual import path # Initialize the U2NET model u2net = U2NET(in_ch=3, out_ch=1) def load_model(model, model_path, device): model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) return model # Load the model onto the specified device u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu") # Mean and std for normalization mean = torch.tensor([0.485, 0.456, 0.406]) std = torch.tensor([0.229, 0.224, 0.225]) resize_shape = (320, 320) transforms = T.Compose([ T.Resize(resize_shape), T.ToTensor(), T.Normalize(mean=mean, std=std) ]) def prepare_single_image(image, resize, transforms, device): """Prepare a single image for prediction.""" if isinstance(image, np.ndarray): image = Image.fromarray(image) image = image.convert("RGB") image_resize = image.resize(resize, resample=Image.BILINEAR) image_trans = transforms(image_resize) image_batch = image_trans.unsqueeze(0).to(device) # Add batch dimension return image_batch def prepare_prediction(model, image_batch): model.eval() with torch.no_grad(): results = model(image_batch) mask = torch.squeeze(results[0].cpu(), dim=0) return mask.numpy() def normPRED(predicted_map): ma = np.max(predicted_map) mi = np.min(predicted_map) map_normalize = (predicted_map - mi) / (ma - mi) return map_normalize def apply_mask(image, mask): """Apply the mask to the original image and return the result with transparent background.""" mask = np.squeeze(mask) mask = normPRED(mask) mask = (mask * 255).astype(np.uint8) mask_image = Image.fromarray(mask, mode='L') # 'L' mode for grayscale original_image = image.convert("RGB") original_image = original_image.resize(resize_shape, resample=Image.BILINEAR) original_image_rgba = original_image.convert("RGBA") transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0)) masked_image = Image.composite(original_image_rgba, transparent_background, mask_image) return masked_image # Streamlit app setup st.title("Image Segmentation with U2NET") # Sidebar for file upload and controls st.sidebar.title("Controls :gear:") uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) # Function to handle image and segmentation display def fix_image(upload=None): if upload: image = Image.open(upload) else: image = Image.open("8.jpg") # Prepare image for segmentation image_batch = prepare_single_image(image, resize_shape, transforms, "cpu") prediction_u2net = prepare_prediction(u2net, image_batch) masked_image = apply_mask(image, prediction_u2net) # Display the original and segmented images side by side col1, col2 = st.columns(2) with col1: st.image(image, caption="Uploaded Image", use_column_width=True) with col2: st.image(masked_image, caption='Segmented Image', use_column_width=True) # Provide download option for segmented image buf = io.BytesIO() masked_image.save(buf, format='PNG') byte_im = buf.getvalue() st.sidebar.markdown('### Download Segmented Image') st.sidebar.download_button( label="Download Segmented Image", data=byte_im, file_name="segmented_image.png", mime="image/png" ) if uploaded_file is not None: fix_image(upload=uploaded_file) else: fix_image() # Use default image if none uploaded