Hamam commited on
Commit
9507932
·
verified ·
1 Parent(s): 9aae90a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -91
app.py CHANGED
@@ -1,15 +1,13 @@
1
- import os
2
- from PIL import Image
3
  import numpy as np
 
4
  import torch
5
  import torchvision.transforms as T
6
- import streamlit as st
7
- from model.u2net import U2NET
8
- from io import BytesIO
9
 
10
- # Constants
11
- MAX_FILE_SIZE = 5 * 1024 * 1024 # 5 MB
12
- DEFAULT_IMAGE_PATH = "image (4).png" # Path to your default image
13
 
14
  # Initialize the U2NET model
15
  u2net = U2NET(in_ch=3, out_ch=1)
@@ -20,7 +18,7 @@ def load_model(model, model_path, device):
20
  return model
21
 
22
  # Load the model onto the specified device
23
- u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu")
24
 
25
  # Mean and std for normalization
26
  mean = torch.tensor([0.485, 0.456, 0.406])
@@ -44,14 +42,6 @@ def prepare_single_image(image, resize, transforms, device):
44
  image_batch = image_trans.unsqueeze(0).to(device) # Add batch dimension
45
  return image_batch
46
 
47
- def denorm_image(image_tensor):
48
- """Denormalize and convert tensor to numpy image."""
49
- image_tensor = image_tensor.cpu().clone()
50
- image_tensor = image_tensor * std[:, None, None] + mean[:, None, None]
51
- image_tensor = torch.clamp(image_tensor * 255., min=0., max=255.)
52
- image_tensor = image_tensor.permute(1, 2, 0).numpy().astype(np.uint8)
53
- return image_tensor
54
-
55
  def prepare_prediction(model, image_batch):
56
  model.eval()
57
  with torch.no_grad():
@@ -67,93 +57,55 @@ def normPRED(predicted_map):
67
 
68
  def apply_mask(image, mask):
69
  """Apply the mask to the original image and return the result with transparent background."""
70
- # Remove the extra dimension if present
71
  mask = np.squeeze(mask)
72
-
73
- # Normalize and convert the mask to uint8
74
  mask = normPRED(mask)
75
  mask = (mask * 255).astype(np.uint8)
76
-
77
- # Convert the mask to a PIL image
78
  mask_image = Image.fromarray(mask, mode='L') # 'L' mode for grayscale
79
-
80
- # Open the original image and resize it
81
  original_image = image.convert("RGB")
82
  original_image = original_image.resize(resize_shape, resample=Image.BILINEAR)
83
-
84
- # Convert original image to RGBA
85
  original_image_rgba = original_image.convert("RGBA")
86
-
87
- # Create a new image with transparency
88
  transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0))
89
-
90
- # Apply the mask to create an image with alpha channel
91
  masked_image = Image.composite(original_image_rgba, transparent_background, mask_image)
92
-
93
  return masked_image
94
 
95
- def segment_image(image):
96
- """Function to be used for segmentation."""
97
- # Ensure image is a PIL Image
98
- if isinstance(image, np.ndarray):
99
- image = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
101
  image_batch = prepare_single_image(image, resize_shape, transforms, "cpu")
102
  prediction_u2net = prepare_prediction(u2net, image_batch)
103
  masked_image = apply_mask(image, prediction_u2net)
104
- return masked_image
105
 
106
- def fix_image(upload=None):
107
- """Processes an uploaded image or a default image."""
108
- if upload is not None:
109
- image = Image.open(upload)
110
- else:
111
- image = Image.open(DEFAULT_IMAGE_PATH) # Load default image
112
-
113
- st.image(image, caption='Selected Image', use_column_width=True)
114
-
115
- if st.button('Segment Image'):
116
- masked_image = segment_image(image)
117
- st.image(masked_image, caption='Segmented Image', use_column_width=True, format="PNG")
118
- # Save the image to a BytesIO object for downloading
119
- buffered = BytesIO()
120
- masked_image.save(buffered, format="PNG")
121
- st.download_button(
122
- label="Download Segmented Image",
123
- data=buffered.getvalue(),
124
- file_name="segmented_image.png",
125
- mime="image/png"
126
- )
127
-
128
- # Define the pages
129
- def page_one():
130
- """Page for image segmentation."""
131
- st.title("Image Segmentation with U2NET")
132
- st.write("Upload an image to segment it using the U2NET model. The background of the segmented output will be transparent.")
133
-
134
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
135
-
136
- # Determine image processing
137
- if uploaded_file is not None:
138
- if uploaded_file.size > MAX_FILE_SIZE:
139
- st.error("The uploaded file is too large. Please upload an image smaller than 5MB.")
140
- else:
141
- fix_image(upload=uploaded_file)
142
- else:
143
- fix_image() # Use default image if none uploaded
144
-
145
- def page_two():
146
- """Page for other code."""
147
- st.title("Other Feature")
148
- st.write("This page is for the second feature you want to implement.")
149
- # Add other code or features here
150
-
151
- # Sidebar navigation
152
- st.sidebar.title("Navigation")
153
- page = st.sidebar.radio("Go to", ("Image Segmentation", "Other Feature"))
154
-
155
- # Page selection logic
156
- if page == "Image Segmentation":
157
- page_one()
158
- elif page == "Other Feature":
159
- page_two()
 
1
+ %%file app.py
2
+ import streamlit as st
3
  import numpy as np
4
+ from PIL import Image
5
  import torch
6
  import torchvision.transforms as T
7
+ import io
 
 
8
 
9
+ # Assuming you have the U2NET model defined somewhere
10
+ from model.u2net import U2NET # Replace with your actual import path
 
11
 
12
  # Initialize the U2NET model
13
  u2net = U2NET(in_ch=3, out_ch=1)
 
18
  return model
19
 
20
  # Load the model onto the specified device
21
+ u2net = load_model(model=u2net, model_path="/content/u2net.pth", device="cpu")
22
 
23
  # Mean and std for normalization
24
  mean = torch.tensor([0.485, 0.456, 0.406])
 
42
  image_batch = image_trans.unsqueeze(0).to(device) # Add batch dimension
43
  return image_batch
44
 
 
 
 
 
 
 
 
 
45
  def prepare_prediction(model, image_batch):
46
  model.eval()
47
  with torch.no_grad():
 
57
 
58
  def apply_mask(image, mask):
59
  """Apply the mask to the original image and return the result with transparent background."""
 
60
  mask = np.squeeze(mask)
 
 
61
  mask = normPRED(mask)
62
  mask = (mask * 255).astype(np.uint8)
 
 
63
  mask_image = Image.fromarray(mask, mode='L') # 'L' mode for grayscale
 
 
64
  original_image = image.convert("RGB")
65
  original_image = original_image.resize(resize_shape, resample=Image.BILINEAR)
 
 
66
  original_image_rgba = original_image.convert("RGBA")
 
 
67
  transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0))
 
 
68
  masked_image = Image.composite(original_image_rgba, transparent_background, mask_image)
 
69
  return masked_image
70
 
71
+ # Streamlit app setup
72
+ st.title("Image Segmentation with U2NET")
73
+
74
+ # Sidebar for file upload and controls
75
+ st.sidebar.title("Controls :gear:")
76
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
77
+
78
+ # Function to handle image and segmentation display
79
+ def fix_image(upload=None):
80
+ if upload is None:
81
+ st.write("Please upload an image.")
82
+ return
83
+
84
+ image = Image.open(upload)
85
 
86
+ # Display the original image
87
+ st.image(image, caption="Uploaded Image", use_column_width=True)
88
+
89
+ # Prepare image for segmentation
90
  image_batch = prepare_single_image(image, resize_shape, transforms, "cpu")
91
  prediction_u2net = prepare_prediction(u2net, image_batch)
92
  masked_image = apply_mask(image, prediction_u2net)
 
93
 
94
+ # Display segmented image
95
+ st.image(masked_image, caption='Segmented Image', use_column_width=True)
96
+
97
+ # Provide download option for segmented image
98
+ buf = io.BytesIO()
99
+ masked_image.save(buf, format='PNG')
100
+ byte_im = buf.getvalue()
101
+ st.sidebar.markdown('### Download Segmented Image')
102
+ st.sidebar.download_button(
103
+ label="Download Segmented Image",
104
+ data=byte_im,
105
+ file_name="segmented_image.png",
106
+ mime="image/png"
107
+ )
108
+
109
+ # Handle image processing based on user input
110
+ if uploaded_file is not None:
111
+ fix_image(upload=uploaded_file)