Hamam commited on
Commit
2ed7990
·
verified ·
1 Parent(s): 861a11c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -11
app.py CHANGED
@@ -1,10 +1,15 @@
1
- import gradio as gr
2
  import os
3
  from PIL import Image
4
  import numpy as np
5
  import torch
6
  import torchvision.transforms as T
 
7
  from model.u2net import U2NET
 
 
 
 
 
8
 
9
  # Initialize the U2NET model
10
  u2net = U2NET(in_ch=3, out_ch=1)
@@ -88,7 +93,7 @@ def apply_mask(image, mask):
88
  return masked_image
89
 
90
  def segment_image(image):
91
- """Function to be used with Gradio for segmentation."""
92
  # Ensure image is a PIL Image
93
  if isinstance(image, np.ndarray):
94
  image = Image.fromarray(image)
@@ -98,13 +103,57 @@ def segment_image(image):
98
  masked_image = apply_mask(image, prediction_u2net)
99
  return masked_image
100
 
101
- # Define the Gradio interface
102
- iface = gr.Interface(
103
- fn=segment_image,
104
- inputs=gr.Image(type="numpy"),
105
- outputs=gr.Image(type="pil",format="png"),
106
- title="Remove Background From Image"
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # Launch the interface
110
- iface.launch()
 
 
 
 
 
1
  import os
2
  from PIL import Image
3
  import numpy as np
4
  import torch
5
  import torchvision.transforms as T
6
+ import streamlit as st
7
  from model.u2net import U2NET
8
+ from io import BytesIO
9
+
10
+ # Constants
11
+ MAX_FILE_SIZE = 5 * 1024 * 1024 # 5 MB
12
+ DEFAULT_IMAGE_PATH = "default_image.png" # Path to your default image
13
 
14
  # Initialize the U2NET model
15
  u2net = U2NET(in_ch=3, out_ch=1)
 
93
  return masked_image
94
 
95
  def segment_image(image):
96
+ """Function to be used for segmentation."""
97
  # Ensure image is a PIL Image
98
  if isinstance(image, np.ndarray):
99
  image = Image.fromarray(image)
 
103
  masked_image = apply_mask(image, prediction_u2net)
104
  return masked_image
105
 
106
+ def fix_image(upload=None):
107
+ """Processes an uploaded image or a default image."""
108
+ if upload is not None:
109
+ image = Image.open(upload)
110
+ else:
111
+ image = Image.open(DEFAULT_IMAGE_PATH) # Load default image
112
+
113
+ st.image(image, caption='Selected Image', use_column_width=True)
114
+
115
+ if st.button('Segment Image'):
116
+ masked_image = segment_image(image)
117
+ st.image(masked_image, caption='Segmented Image', use_column_width=True, format="PNG")
118
+ # Save the image to a BytesIO object for downloading
119
+ buffered = BytesIO()
120
+ masked_image.save(buffered, format="PNG")
121
+ st.download_button(
122
+ label="Download Segmented Image",
123
+ data=buffered.getvalue(),
124
+ file_name="segmented_image.png",
125
+ mime="image/png"
126
+ )
127
+
128
+ # Define the pages
129
+ def page_one():
130
+ """Page for image segmentation."""
131
+ st.title("Image Segmentation with U2NET")
132
+ st.write("Upload an image to segment it using the U2NET model. The background of the segmented output will be transparent.")
133
+
134
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
135
+
136
+ # Determine image processing
137
+ if uploaded_file is not None:
138
+ if uploaded_file.size > MAX_FILE_SIZE:
139
+ st.error("The uploaded file is too large. Please upload an image smaller than 5MB.")
140
+ else:
141
+ fix_image(upload=uploaded_file)
142
+ else:
143
+ fix_image() # Use default image if none uploaded
144
+
145
+ def page_two():
146
+ """Page for other code."""
147
+ st.title("Other Feature")
148
+ st.write("This page is for the second feature you want to implement.")
149
+ # Add other code or features here
150
+
151
+ # Sidebar navigation
152
+ st.sidebar.title("Navigation")
153
+ page = st.sidebar.radio("Go to", ("Image Segmentation", "Other Feature"))
154
 
155
+ # Page selection logic
156
+ if page == "Image Segmentation":
157
+ page_one()
158
+ elif page == "Other Feature":
159
+ page_two()