removebg / app.py
Hamam
Update app.py
2ed7990 verified
raw
history blame
5.44 kB
import os
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
import streamlit as st
from model.u2net import U2NET
from io import BytesIO
# Constants
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5 MB
DEFAULT_IMAGE_PATH = "default_image.png" # Path to your default image
# Initialize the U2NET model
u2net = U2NET(in_ch=3, out_ch=1)
def load_model(model, model_path, device):
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
return model
# Load the model onto the specified device
u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu")
# Mean and std for normalization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
resize_shape = (320, 320)
transforms = T.Compose([
T.Resize(resize_shape),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
def prepare_single_image(image, resize, transforms, device):
"""Prepare a single image for prediction."""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
image_resize = image.resize(resize, resample=Image.BILINEAR)
image_trans = transforms(image_resize)
image_batch = image_trans.unsqueeze(0).to(device) # Add batch dimension
return image_batch
def denorm_image(image_tensor):
"""Denormalize and convert tensor to numpy image."""
image_tensor = image_tensor.cpu().clone()
image_tensor = image_tensor * std[:, None, None] + mean[:, None, None]
image_tensor = torch.clamp(image_tensor * 255., min=0., max=255.)
image_tensor = image_tensor.permute(1, 2, 0).numpy().astype(np.uint8)
return image_tensor
def prepare_prediction(model, image_batch):
model.eval()
with torch.no_grad():
results = model(image_batch)
mask = torch.squeeze(results[0].cpu(), dim=0)
return mask.numpy()
def normPRED(predicted_map):
ma = np.max(predicted_map)
mi = np.min(predicted_map)
map_normalize = (predicted_map - mi) / (ma - mi)
return map_normalize
def apply_mask(image, mask):
"""Apply the mask to the original image and return the result with transparent background."""
# Remove the extra dimension if present
mask = np.squeeze(mask)
# Normalize and convert the mask to uint8
mask = normPRED(mask)
mask = (mask * 255).astype(np.uint8)
# Convert the mask to a PIL image
mask_image = Image.fromarray(mask, mode='L') # 'L' mode for grayscale
# Open the original image and resize it
original_image = image.convert("RGB")
original_image = original_image.resize(resize_shape, resample=Image.BILINEAR)
# Convert original image to RGBA
original_image_rgba = original_image.convert("RGBA")
# Create a new image with transparency
transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0))
# Apply the mask to create an image with alpha channel
masked_image = Image.composite(original_image_rgba, transparent_background, mask_image)
return masked_image
def segment_image(image):
"""Function to be used for segmentation."""
# Ensure image is a PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_batch = prepare_single_image(image, resize_shape, transforms, "cpu")
prediction_u2net = prepare_prediction(u2net, image_batch)
masked_image = apply_mask(image, prediction_u2net)
return masked_image
def fix_image(upload=None):
"""Processes an uploaded image or a default image."""
if upload is not None:
image = Image.open(upload)
else:
image = Image.open(DEFAULT_IMAGE_PATH) # Load default image
st.image(image, caption='Selected Image', use_column_width=True)
if st.button('Segment Image'):
masked_image = segment_image(image)
st.image(masked_image, caption='Segmented Image', use_column_width=True, format="PNG")
# Save the image to a BytesIO object for downloading
buffered = BytesIO()
masked_image.save(buffered, format="PNG")
st.download_button(
label="Download Segmented Image",
data=buffered.getvalue(),
file_name="segmented_image.png",
mime="image/png"
)
# Define the pages
def page_one():
"""Page for image segmentation."""
st.title("Image Segmentation with U2NET")
st.write("Upload an image to segment it using the U2NET model. The background of the segmented output will be transparent.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# Determine image processing
if uploaded_file is not None:
if uploaded_file.size > MAX_FILE_SIZE:
st.error("The uploaded file is too large. Please upload an image smaller than 5MB.")
else:
fix_image(upload=uploaded_file)
else:
fix_image() # Use default image if none uploaded
def page_two():
"""Page for other code."""
st.title("Other Feature")
st.write("This page is for the second feature you want to implement.")
# Add other code or features here
# Sidebar navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ("Image Segmentation", "Other Feature"))
# Page selection logic
if page == "Image Segmentation":
page_one()
elif page == "Other Feature":
page_two()