Spaces:
Sleeping
Sleeping
File size: 5,435 Bytes
04eca22 2ed7990 04eca22 2ed7990 04eca22 861a11c 04eca22 2ed7990 04eca22 2ed7990 04eca22 2ed7990 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
import streamlit as st
from model.u2net import U2NET
from io import BytesIO
# Constants
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5 MB
DEFAULT_IMAGE_PATH = "default_image.png" # Path to your default image
# Initialize the U2NET model
u2net = U2NET(in_ch=3, out_ch=1)
def load_model(model, model_path, device):
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
return model
# Load the model onto the specified device
u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu")
# Mean and std for normalization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
resize_shape = (320, 320)
transforms = T.Compose([
T.Resize(resize_shape),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
def prepare_single_image(image, resize, transforms, device):
"""Prepare a single image for prediction."""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
image_resize = image.resize(resize, resample=Image.BILINEAR)
image_trans = transforms(image_resize)
image_batch = image_trans.unsqueeze(0).to(device) # Add batch dimension
return image_batch
def denorm_image(image_tensor):
"""Denormalize and convert tensor to numpy image."""
image_tensor = image_tensor.cpu().clone()
image_tensor = image_tensor * std[:, None, None] + mean[:, None, None]
image_tensor = torch.clamp(image_tensor * 255., min=0., max=255.)
image_tensor = image_tensor.permute(1, 2, 0).numpy().astype(np.uint8)
return image_tensor
def prepare_prediction(model, image_batch):
model.eval()
with torch.no_grad():
results = model(image_batch)
mask = torch.squeeze(results[0].cpu(), dim=0)
return mask.numpy()
def normPRED(predicted_map):
ma = np.max(predicted_map)
mi = np.min(predicted_map)
map_normalize = (predicted_map - mi) / (ma - mi)
return map_normalize
def apply_mask(image, mask):
"""Apply the mask to the original image and return the result with transparent background."""
# Remove the extra dimension if present
mask = np.squeeze(mask)
# Normalize and convert the mask to uint8
mask = normPRED(mask)
mask = (mask * 255).astype(np.uint8)
# Convert the mask to a PIL image
mask_image = Image.fromarray(mask, mode='L') # 'L' mode for grayscale
# Open the original image and resize it
original_image = image.convert("RGB")
original_image = original_image.resize(resize_shape, resample=Image.BILINEAR)
# Convert original image to RGBA
original_image_rgba = original_image.convert("RGBA")
# Create a new image with transparency
transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0))
# Apply the mask to create an image with alpha channel
masked_image = Image.composite(original_image_rgba, transparent_background, mask_image)
return masked_image
def segment_image(image):
"""Function to be used for segmentation."""
# Ensure image is a PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image_batch = prepare_single_image(image, resize_shape, transforms, "cpu")
prediction_u2net = prepare_prediction(u2net, image_batch)
masked_image = apply_mask(image, prediction_u2net)
return masked_image
def fix_image(upload=None):
"""Processes an uploaded image or a default image."""
if upload is not None:
image = Image.open(upload)
else:
image = Image.open(DEFAULT_IMAGE_PATH) # Load default image
st.image(image, caption='Selected Image', use_column_width=True)
if st.button('Segment Image'):
masked_image = segment_image(image)
st.image(masked_image, caption='Segmented Image', use_column_width=True, format="PNG")
# Save the image to a BytesIO object for downloading
buffered = BytesIO()
masked_image.save(buffered, format="PNG")
st.download_button(
label="Download Segmented Image",
data=buffered.getvalue(),
file_name="segmented_image.png",
mime="image/png"
)
# Define the pages
def page_one():
"""Page for image segmentation."""
st.title("Image Segmentation with U2NET")
st.write("Upload an image to segment it using the U2NET model. The background of the segmented output will be transparent.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# Determine image processing
if uploaded_file is not None:
if uploaded_file.size > MAX_FILE_SIZE:
st.error("The uploaded file is too large. Please upload an image smaller than 5MB.")
else:
fix_image(upload=uploaded_file)
else:
fix_image() # Use default image if none uploaded
def page_two():
"""Page for other code."""
st.title("Other Feature")
st.write("This page is for the second feature you want to implement.")
# Add other code or features here
# Sidebar navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ("Image Segmentation", "Other Feature"))
# Page selection logic
if page == "Image Segmentation":
page_one()
elif page == "Other Feature":
page_two()
|