reachomk commited on
Commit
50238b6
·
verified ·
1 Parent(s): 37798ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +280 -280
app.py CHANGED
@@ -1,280 +1,280 @@
1
- import gradio as gr
2
- import torch
3
- from PIL import Image
4
- import numpy as np
5
- import time
6
- import os
7
-
8
- # --- Import Custom Pipelines ---
9
- # Ensure these files are in the same directory or accessible in PYTHONPATH
10
- try:
11
- from gen2seg_sd_pipeline import gen2segSDPipeline
12
- from gen2seg_mae_pipeline import gen2segMAEInstancePipeline
13
- except ImportError as e:
14
- print(f"Error importing pipeline modules: {e}")
15
- print("Please ensure gen2seg_sd_pipeline.py and gen2seg_mae_pipeline.py are in the same directory.")
16
- # Optionally, raise an error or exit if pipelines are critical at startup
17
- # raise ImportError("Could not import custom pipeline modules. Check file paths.") from e
18
-
19
- from transformers import ViTMAEForPreTraining, AutoImageProcessor
20
-
21
- # --- Configuration ---
22
- MODEL_IDS = {
23
- "SD": "reachomk/gen2seg-sd",
24
- "MAE-H": "reachomk/gen2seg-mae-h"
25
- }
26
-
27
- # Check if a GPU is available and set the device accordingly
28
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
- print(f"Using device: {DEVICE}")
30
-
31
- # --- Global Variables for Caching Pipelines ---
32
- sd_pipe_global = None
33
- mae_pipe_global = None
34
-
35
- # --- Model Loading Functions ---
36
- def get_sd_pipeline():
37
- """Loads and caches the gen2seg Stable Diffusion pipeline."""
38
- global sd_pipe_global
39
- if sd_pipe_global is None:
40
- model_id_sd = MODEL_IDS["SD"]
41
- print(f"Attempting to load SD pipeline from Hugging Face Hub: {model_id_sd}")
42
- try:
43
- sd_pipe_global = gen2segSDPipeline.from_pretrained(
44
- model_id_sd,
45
- use_safetensors=True,
46
- # torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Optional: use float16 on GPU
47
- ).to(DEVICE)
48
- print(f"SD Pipeline loaded successfully from {model_id_sd} on {DEVICE}.")
49
- except Exception as e:
50
- print(f"Error loading SD pipeline from Hugging Face Hub ({model_id_sd}): {e}")
51
- sd_pipe_global = None # Ensure it remains None on failure
52
- # Do not raise gr.Error here; let the main function handle it.
53
- return sd_pipe_global
54
-
55
- def get_mae_pipeline():
56
- """Loads and caches the gen2seg MAE-H pipeline."""
57
- global mae_pipe_global
58
- if mae_pipe_global is None:
59
- model_id_mae = MODEL_IDS["MAE-H"]
60
- print(f"Loading MAE-H pipeline with model {model_id_mae} on {DEVICE}...")
61
- try:
62
- model = ViTMAEForPreTraining.from_pretrained(model_id_mae)
63
- model.to(DEVICE)
64
- model.eval() # Set to evaluation mode
65
-
66
- # Load the official MAE-H image processor
67
- # Using "facebook/vit-mae-huge" as per the original app_mae.py
68
- image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge")
69
-
70
- mae_pipe_global = gen2segMAEInstancePipeline(model=model, image_processor=image_processor)
71
- # The custom MAE pipeline's model is already on the DEVICE.
72
- print(f"MAE-H Pipeline with model {model_id_mae} loaded successfully on {DEVICE}.")
73
- except Exception as e:
74
- print(f"Error loading MAE-H model or pipeline from Hugging Face Hub ({model_id_mae}): {e}")
75
- mae_pipe_global = None # Ensure it remains None on failure
76
- # Do not raise gr.Error here; let the main function handle it.
77
- return mae_pipe_global
78
-
79
- # --- Unified Prediction Function ---
80
- def segment_image(input_image: Image.Image, model_choice: str) -> Image.Image:
81
- """
82
- Takes a PIL Image and model choice, performs segmentation, and returns the segmented image.
83
- """
84
- if input_image is None:
85
- raise gr.Error("No image provided. Please upload an image.")
86
-
87
- print(f"Model selected: {model_choice}")
88
- # Ensure image is in RGB format
89
- image_rgb = input_image.convert("RGB")
90
- original_resolution = image_rgb.size # (width, height)
91
- seg_array = None
92
-
93
- try:
94
- if model_choice == "SD":
95
- pipe_sd = get_sd_pipeline()
96
- if pipe_sd is None:
97
- raise gr.Error("The SD segmentation pipeline could not be loaded. "
98
- "Please check the Space logs for more details, or try again later.")
99
-
100
- print(f"Running SD inference with image size: {image_rgb.size}")
101
- start_time = time.time()
102
- with torch.no_grad():
103
- # The gen2segSDPipeline expects a single image or a list
104
- # The pipeline's __call__ method handles preprocessing internally
105
- seg_output = pipe_sd(image_rgb, match_input_resolution=False).prediction # Output is before resize
106
-
107
- # seg_output is expected to be a numpy array (N,H,W,1) or (N,1,H,W) or tensor
108
- # Based on gen2seg_sd_pipeline.py, if output_type="np" (default), it's [N,H,W,1]
109
- # If output_type="pt", it's [N,1,H,W]
110
- # The original app_sd.py converted tensor to numpy and squeezed.
111
- if isinstance(seg_output, torch.Tensor):
112
- seg_output = seg_output.cpu().numpy()
113
-
114
- if seg_output.ndim == 4 and seg_output.shape[0] == 1: # Batch size 1
115
- if seg_output.shape[1] == 1: # Grayscale, (1, 1, H, W)
116
- seg_array = seg_output.squeeze(0).squeeze(0).astype(np.uint8)
117
- elif seg_output.shape[-1] == 1: # Grayscale, (1, H, W, 1)
118
- seg_array = seg_output.squeeze(0).squeeze(-1).astype(np.uint8)
119
- elif seg_output.shape[1] == 3: # RGB, (1, 3, H, W) -> (H, W, 3)
120
- seg_array = np.transpose(seg_output.squeeze(0), (1, 2, 0)).astype(np.uint8)
121
- elif seg_output.shape[-1] == 3: # RGB, (1, H, W, 3)
122
- seg_array = seg_output.squeeze(0).astype(np.uint8)
123
- else: # Fallback for unexpected shapes
124
- seg_array = seg_output.squeeze().astype(np.uint8)
125
-
126
- elif seg_output.ndim == 3: # (H, W, C) or (C, H, W)
127
- seg_array = seg_output.astype(np.uint8)
128
- elif seg_output.ndim == 2: # (H,W)
129
- seg_array = seg_output.astype(np.uint8)
130
- else:
131
- raise TypeError(f"Unexpected SD segmentation output type/shape: {type(seg_output)}, {seg_output.shape}")
132
- end_time = time.time()
133
- print(f"SD Inference completed in {end_time - start_time:.2f} seconds.")
134
-
135
-
136
- elif model_choice == "MAE-H":
137
- pipe_mae = get_mae_pipeline()
138
- if pipe_mae is None:
139
- raise gr.Error("The MAE-H segmentation pipeline could not be loaded. "
140
- "Please check the Space logs for more details, or try again later.")
141
-
142
- print(f"Running MAE-H inference with image size: {image_rgb.size}")
143
- start_time = time.time()
144
- with torch.no_grad():
145
- # The gen2segMAEInstancePipeline expects a list of images
146
- # output_type="np" returns a NumPy array
147
- pipe_output = pipe_mae([image_rgb], output_type="np")
148
- # Prediction is (batch_size, height, width, 3) for MAE
149
- prediction_np = pipe_output.prediction[0] # Get the first (and only) image prediction
150
-
151
- end_time = time.time()
152
- print(f"MAE-H Inference completed in {end_time - start_time:.2f} seconds.")
153
-
154
- if not isinstance(prediction_np, np.ndarray):
155
- # This case should ideally not be reached if output_type="np"
156
- prediction_np = prediction_np.cpu().numpy()
157
-
158
- # Ensure it's in the expected (H, W, C) format and uint8
159
- if prediction_np.ndim == 3 and prediction_np.shape[-1] == 3: # Expected (H, W, 3)
160
- seg_array = prediction_np.astype(np.uint8)
161
- else:
162
- # Attempt to handle other shapes if necessary, or raise error
163
- raise gr.Error(f"Unexpected MAE-H prediction shape: {prediction_np.shape}. Expected (H, W, 3).")
164
-
165
- # The MAE pipeline already does gamma correction and scaling to 0-255.
166
- # It also ensures 3 channels.
167
-
168
- else:
169
- raise gr.Error(f"Invalid model choice: {model_choice}. Please select a valid model.")
170
-
171
- if seg_array is None:
172
- raise gr.Error("Segmentation array was not generated. An unknown error occurred.")
173
-
174
- print(f"Segmentation array generated with shape: {seg_array.shape}, dtype: {seg_array.dtype}")
175
-
176
- # Convert numpy array to PIL Image
177
- # Handle grayscale or RGB based on seg_array channels
178
- if seg_array.ndim == 2: # Grayscale
179
- segmented_image_pil = Image.fromarray(seg_array, mode='L')
180
- elif seg_array.ndim == 3 and seg_array.shape[-1] == 3: # RGB
181
- segmented_image_pil = Image.fromarray(seg_array, mode='RGB')
182
- elif seg_array.ndim == 3 and seg_array.shape[-1] == 1: # Grayscale with channel dim
183
- segmented_image_pil = Image.fromarray(seg_array.squeeze(-1), mode='L')
184
- else:
185
- raise gr.Error(f"Cannot convert seg_array with shape {seg_array.shape} to PIL Image.")
186
-
187
- # Resize back to original image resolution using LANCZOS for high quality
188
- segmented_image_pil = segmented_image_pil.resize(original_resolution, Image.Resampling.LANCZOS)
189
-
190
- print(f"Segmented image processed. Output size: {segmented_image_pil.size}, mode: {segmented_image_pil.mode}")
191
- return segmented_image_pil
192
-
193
- except Exception as e:
194
- print(f"Error during segmentation with {model_choice}: {e}")
195
- # Re-raise as gr.Error for Gradio to display, if not already one
196
- if not isinstance(e, gr.Error):
197
- # It's often helpful to include the type of the original exception
198
- error_type = type(e).__name__
199
- raise gr.Error(f"An error occurred during segmentation: {error_type} - {str(e)}")
200
- else:
201
- raise e # Re-raise if it's already a gr.Error
202
-
203
- # --- Gradio Interface ---
204
- title = "gen2seg: Generative Models Enable Generalizable Instance Segmentation Demo (SD & MAE-H)"
205
- description = f"""
206
- <div style="text-align: center; font-family: 'Arial', sans-serif;">
207
- <p>Upload an image and choose a model architecture to see the instance segmentation result generated by the respective model. </p>
208
- <p>
209
- Currently, inference is running on CPU.
210
- Performance will be significantly better on GPU.
211
- </p>
212
- <ul>
213
- <li><strong>SD</strong>: Based on Stable Diffusion 2.
214
- <a href="https://huggingface.co/{MODEL_IDS['SD']}" target="_blank">Model Link</a>.
215
- <em>Approx. CPU inference time: ~1-2 minutes per image.</em>
216
- </li>
217
- <li><strong>MAE-H</strong>: Based on Masked Autoencoder (Huge).
218
- <a href="https://huggingface.co/{MODEL_IDS['MAE-H']}" target="_blank">Model Link</a>.
219
- <em>Approx. CPU inference time: ~15-45 seconds per image.</em>
220
- If you experience tokenizer artifacts or very dark images, you can use gamma correction to handle this.
221
- </li>
222
- </ul>
223
- <p>
224
- For faster inference, please check out our GitHub to run the models locally on a GPU:
225
- <a href="https://github.com/UCDvision/gen2seg" target="_blank">https://github.com/UCDvision/gen2seg</a>
226
- </p>
227
- <p>If the demo experiences issues, please open an issue on our GitHub.</p>
228
- <p> If you have not already, please see our webpage at <a href="https://reachomk.github.io/gen2seg" target="_blank">https://reachomk.github.io/gen2seg</a>
229
-
230
- </div>
231
- """
232
-
233
- article = """
234
- """
235
-
236
- # Define Gradio inputs
237
- input_image_component = gr.Image(type="pil", label="Input Image")
238
- model_choice_component = gr.Dropdown(
239
- choices=list(MODEL_IDS.keys()),
240
- value="SD", # Default model
241
- label="Choose Segmentation Model Architecture"
242
- )
243
-
244
- # Define Gradio output
245
- output_image_component = gr.Image(type="pil", label="Segmented Image")
246
-
247
- # Example images (ensure these paths are correct if you upload examples to your Space)
248
- # For example, if you create an "examples" folder in your Space repo:
249
- # example_paths = [
250
- # os.path.join("examples", "example1.jpg"),
251
- # os.path.join("examples", "example2.png")
252
- # ]
253
- # Filter out non-existent example files to prevent errors
254
- # example_paths = [ex for ex in example_paths if os.path.exists(ex)]
255
- example_paths = [] # Add paths to example images here if you have them
256
-
257
- iface = gr.Interface(
258
- fn=segment_image,
259
- inputs=[input_image_component, model_choice_component],
260
- outputs=output_image_component,
261
- title=title,
262
- description=description,
263
- article=article,
264
- examples=example_paths if example_paths else None, # Pass None if no examples
265
- allow_flagging="never",
266
- theme=gr.themes.Soft() # Using a soft theme for a slightly modern look
267
- )
268
-
269
- if __name__ == "__main__":
270
- # Optional: Pre-load a default model on startup if desired.
271
- # This can make the first inference faster but increases startup time.
272
- # print("Attempting to pre-load default SD model on startup...")
273
- # try:
274
- # get_sd_pipeline() # Pre-load the default SD model
275
- # print("Default SD model pre-loaded successfully or was already cached.")
276
- # except Exception as e:
277
- # print(f"Could not pre-load default SD model: {e}")
278
-
279
- print("Launching Gradio interface...")
280
- iface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import time
6
+ import os
7
+
8
+ # --- Import Custom Pipelines ---
9
+ # Ensure these files are in the same directory or accessible in PYTHONPATH
10
+ try:
11
+ from gen2seg_sd_pipeline import gen2segSDPipeline
12
+ from gen2seg_mae_pipeline import gen2segMAEInstancePipeline
13
+ except ImportError as e:
14
+ print(f"Error importing pipeline modules: {e}")
15
+ print("Please ensure gen2seg_sd_pipeline.py and gen2seg_mae_pipeline.py are in the same directory.")
16
+ # Optionally, raise an error or exit if pipelines are critical at startup
17
+ # raise ImportError("Could not import custom pipeline modules. Check file paths.") from e
18
+
19
+ from transformers import ViTMAEForPreTraining, AutoImageProcessor
20
+
21
+ # --- Configuration ---
22
+ MODEL_IDS = {
23
+ "SD": "reachomk/gen2seg-sd",
24
+ "MAE-H": "reachomk/gen2seg-mae-h"
25
+ }
26
+
27
+ # Check if a GPU is available and set the device accordingly
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ print(f"Using device: {DEVICE}")
30
+
31
+ # --- Global Variables for Caching Pipelines ---
32
+ sd_pipe_global = None
33
+ mae_pipe_global = None
34
+
35
+ # --- Model Loading Functions ---
36
+ def get_sd_pipeline():
37
+ """Loads and caches the gen2seg Stable Diffusion pipeline."""
38
+ global sd_pipe_global
39
+ if sd_pipe_global is None:
40
+ model_id_sd = MODEL_IDS["SD"]
41
+ print(f"Attempting to load SD pipeline from Hugging Face Hub: {model_id_sd}")
42
+ try:
43
+ sd_pipe_global = gen2segSDPipeline.from_pretrained(
44
+ model_id_sd,
45
+ use_safetensors=True,
46
+ # torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Optional: use float16 on GPU
47
+ ).to(DEVICE)
48
+ print(f"SD Pipeline loaded successfully from {model_id_sd} on {DEVICE}.")
49
+ except Exception as e:
50
+ print(f"Error loading SD pipeline from Hugging Face Hub ({model_id_sd}): {e}")
51
+ sd_pipe_global = None # Ensure it remains None on failure
52
+ # Do not raise gr.Error here; let the main function handle it.
53
+ return sd_pipe_global
54
+
55
+ def get_mae_pipeline():
56
+ """Loads and caches the gen2seg MAE-H pipeline."""
57
+ global mae_pipe_global
58
+ if mae_pipe_global is None:
59
+ model_id_mae = MODEL_IDS["MAE-H"]
60
+ print(f"Loading MAE-H pipeline with model {model_id_mae} on {DEVICE}...")
61
+ try:
62
+ model = ViTMAEForPreTraining.from_pretrained(model_id_mae)
63
+ model.to(DEVICE)
64
+ model.eval() # Set to evaluation mode
65
+
66
+ # Load the official MAE-H image processor
67
+ # Using "facebook/vit-mae-huge" as per the original app_mae.py
68
+ image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge")
69
+
70
+ mae_pipe_global = gen2segMAEInstancePipeline(model=model, image_processor=image_processor)
71
+ # The custom MAE pipeline's model is already on the DEVICE.
72
+ print(f"MAE-H Pipeline with model {model_id_mae} loaded successfully on {DEVICE}.")
73
+ except Exception as e:
74
+ print(f"Error loading MAE-H model or pipeline from Hugging Face Hub ({model_id_mae}): {e}")
75
+ mae_pipe_global = None # Ensure it remains None on failure
76
+ # Do not raise gr.Error here; let the main function handle it.
77
+ return mae_pipe_global
78
+
79
+ # --- Unified Prediction Function ---
80
+ def segment_image(input_image: Image.Image, model_choice: str) -> Image.Image:
81
+ """
82
+ Takes a PIL Image and model choice, performs segmentation, and returns the segmented image.
83
+ """
84
+ if input_image is None:
85
+ raise gr.Error("No image provided. Please upload an image.")
86
+
87
+ print(f"Model selected: {model_choice}")
88
+ # Ensure image is in RGB format
89
+ image_rgb = input_image.convert("RGB")
90
+ original_resolution = image_rgb.size # (width, height)
91
+ seg_array = None
92
+
93
+ try:
94
+ if model_choice == "SD":
95
+ pipe_sd = get_sd_pipeline()
96
+ if pipe_sd is None:
97
+ raise gr.Error("The SD segmentation pipeline could not be loaded. "
98
+ "Please check the Space logs for more details, or try again later.")
99
+
100
+ print(f"Running SD inference with image size: {image_rgb.size}")
101
+ start_time = time.time()
102
+ with torch.no_grad():
103
+ # The gen2segSDPipeline expects a single image or a list
104
+ # The pipeline's __call__ method handles preprocessing internally
105
+ seg_output = pipe_sd(image_rgb, match_input_resolution=False).prediction # Output is before resize
106
+
107
+ # seg_output is expected to be a numpy array (N,H,W,1) or (N,1,H,W) or tensor
108
+ # Based on gen2seg_sd_pipeline.py, if output_type="np" (default), it's [N,H,W,1]
109
+ # If output_type="pt", it's [N,1,H,W]
110
+ # The original app_sd.py converted tensor to numpy and squeezed.
111
+ if isinstance(seg_output, torch.Tensor):
112
+ seg_output = seg_output.cpu().numpy()
113
+
114
+ if seg_output.ndim == 4 and seg_output.shape[0] == 1: # Batch size 1
115
+ if seg_output.shape[1] == 1: # Grayscale, (1, 1, H, W)
116
+ seg_array = seg_output.squeeze(0).squeeze(0).astype(np.uint8)
117
+ elif seg_output.shape[-1] == 1: # Grayscale, (1, H, W, 1)
118
+ seg_array = seg_output.squeeze(0).squeeze(-1).astype(np.uint8)
119
+ elif seg_output.shape[1] == 3: # RGB, (1, 3, H, W) -> (H, W, 3)
120
+ seg_array = np.transpose(seg_output.squeeze(0), (1, 2, 0)).astype(np.uint8)
121
+ elif seg_output.shape[-1] == 3: # RGB, (1, H, W, 3)
122
+ seg_array = seg_output.squeeze(0).astype(np.uint8)
123
+ else: # Fallback for unexpected shapes
124
+ seg_array = seg_output.squeeze().astype(np.uint8)
125
+
126
+ elif seg_output.ndim == 3: # (H, W, C) or (C, H, W)
127
+ seg_array = seg_output.astype(np.uint8)
128
+ elif seg_output.ndim == 2: # (H,W)
129
+ seg_array = seg_output.astype(np.uint8)
130
+ else:
131
+ raise TypeError(f"Unexpected SD segmentation output type/shape: {type(seg_output)}, {seg_output.shape}")
132
+ end_time = time.time()
133
+ print(f"SD Inference completed in {end_time - start_time:.2f} seconds.")
134
+
135
+
136
+ elif model_choice == "MAE-H":
137
+ pipe_mae = get_mae_pipeline()
138
+ if pipe_mae is None:
139
+ raise gr.Error("The MAE-H segmentation pipeline could not be loaded. "
140
+ "Please check the Space logs for more details, or try again later.")
141
+
142
+ print(f"Running MAE-H inference with image size: {image_rgb.size}")
143
+ start_time = time.time()
144
+ with torch.no_grad():
145
+ # The gen2segMAEInstancePipeline expects a list of images
146
+ # output_type="np" returns a NumPy array
147
+ pipe_output = pipe_mae([image_rgb], output_type="np")
148
+ # Prediction is (batch_size, height, width, 3) for MAE
149
+ prediction_np = pipe_output.prediction[0] # Get the first (and only) image prediction
150
+
151
+ end_time = time.time()
152
+ print(f"MAE-H Inference completed in {end_time - start_time:.2f} seconds.")
153
+
154
+ if not isinstance(prediction_np, np.ndarray):
155
+ # This case should ideally not be reached if output_type="np"
156
+ prediction_np = prediction_np.cpu().numpy()
157
+
158
+ # Ensure it's in the expected (H, W, C) format and uint8
159
+ if prediction_np.ndim == 3 and prediction_np.shape[-1] == 3: # Expected (H, W, 3)
160
+ seg_array = prediction_np.astype(np.uint8)
161
+ else:
162
+ # Attempt to handle other shapes if necessary, or raise error
163
+ raise gr.Error(f"Unexpected MAE-H prediction shape: {prediction_np.shape}. Expected (H, W, 3).")
164
+
165
+ # The MAE pipeline already does gamma correction and scaling to 0-255.
166
+ # It also ensures 3 channels.
167
+
168
+ else:
169
+ raise gr.Error(f"Invalid model choice: {model_choice}. Please select a valid model.")
170
+
171
+ if seg_array is None:
172
+ raise gr.Error("Segmentation array was not generated. An unknown error occurred.")
173
+
174
+ print(f"Segmentation array generated with shape: {seg_array.shape}, dtype: {seg_array.dtype}")
175
+
176
+ # Convert numpy array to PIL Image
177
+ # Handle grayscale or RGB based on seg_array channels
178
+ if seg_array.ndim == 2: # Grayscale
179
+ segmented_image_pil = Image.fromarray(seg_array, mode='L')
180
+ elif seg_array.ndim == 3 and seg_array.shape[-1] == 3: # RGB
181
+ segmented_image_pil = Image.fromarray(seg_array, mode='RGB')
182
+ elif seg_array.ndim == 3 and seg_array.shape[-1] == 1: # Grayscale with channel dim
183
+ segmented_image_pil = Image.fromarray(seg_array.squeeze(-1), mode='L')
184
+ else:
185
+ raise gr.Error(f"Cannot convert seg_array with shape {seg_array.shape} to PIL Image.")
186
+
187
+ # Resize back to original image resolution using LANCZOS for high quality
188
+ segmented_image_pil = segmented_image_pil.resize(original_resolution, Image.Resampling.LANCZOS)
189
+
190
+ print(f"Segmented image processed. Output size: {segmented_image_pil.size}, mode: {segmented_image_pil.mode}")
191
+ return segmented_image_pil
192
+
193
+ except Exception as e:
194
+ print(f"Error during segmentation with {model_choice}: {e}")
195
+ # Re-raise as gr.Error for Gradio to display, if not already one
196
+ if not isinstance(e, gr.Error):
197
+ # It's often helpful to include the type of the original exception
198
+ error_type = type(e).__name__
199
+ raise gr.Error(f"An error occurred during segmentation: {error_type} - {str(e)}")
200
+ else:
201
+ raise e # Re-raise if it's already a gr.Error
202
+
203
+ # --- Gradio Interface ---
204
+ title = "gen2seg: Generative Models Enable Generalizable Instance Segmentation Demo (SD & MAE-H)"
205
+ description = f"""
206
+ <div style="text-align: center; font-family: 'Arial', sans-serif;">
207
+ <p>Upload an image and choose a model architecture to see the instance segmentation result generated by the respective model. </p>
208
+ <p>
209
+ Currently, inference is running on CPU.
210
+ Performance will be significantly better on GPU.
211
+ </p>
212
+ <ul>
213
+ <li><strong>SD</strong>: Based on Stable Diffusion 2.
214
+ <a href="https://huggingface.co/{MODEL_IDS['SD']}" target="_blank">Model Link</a>.
215
+ <em>Approx. CPU inference time: ~1-2 minutes per image.</em>
216
+ </li>
217
+ <li><strong>MAE-H</strong>: Based on Masked Autoencoder (Huge).
218
+ <a href="https://huggingface.co/{MODEL_IDS['MAE-H']}" target="_blank">Model Link</a>.
219
+ <em>Approx. CPU inference time: ~15-45 seconds per image.</em>
220
+ If you experience tokenizer artifacts or very dark images, you can use gamma correction to handle this.
221
+ </li>
222
+ </ul>
223
+ <p>
224
+ For faster inference, please check out our GitHub to run the models locally on a GPU:
225
+ <a href="https://github.com/UCDvision/gen2seg" target="_blank">https://github.com/UCDvision/gen2seg</a>
226
+ </p>
227
+ <p>If the demo experiences issues, please open an issue on our GitHub.</p>
228
+ <p> If you have not already, please see our webpage at <a href="https://reachomk.github.io/gen2seg" target="_blank">https://reachomk.github.io/gen2seg</a>
229
+
230
+ </div>
231
+ """
232
+
233
+ article = """
234
+ """
235
+
236
+ # Define Gradio inputs
237
+ input_image_component = gr.Image(type="pil", label="Input Image")
238
+ model_choice_component = gr.Dropdown(
239
+ choices=list(MODEL_IDS.keys()),
240
+ value="SD", # Default model
241
+ label="Choose Segmentation Model Architecture"
242
+ )
243
+
244
+ # Define Gradio output
245
+ output_image_component = gr.Image(type="pil", label="Segmented Image")
246
+
247
+ # Example images (ensure these paths are correct if you upload examples to your Space)
248
+ # For example, if you create an "examples" folder in your Space repo:
249
+ # example_paths = [
250
+ # os.path.join("examples", "example1.jpg"),
251
+ # os.path.join("examples", "example2.png")
252
+ # ]
253
+ # Filter out non-existent example files to prevent errors
254
+ # example_paths = [ex for ex in example_paths if os.path.exists(ex)]
255
+ example_paths = ["cats-on-rock-1948.jpg", "dogs.png", "000000484893.jpg", "https://reachomk.github.io/gen2seg/images/comparison/vertical/7.png", "https://reachomk.github.io/gen2seg/images/comparison/horizontal/11.png", "https://reachomk.github.io/gen2seg/images/comparison/vertical/2.jpg"] # Add paths to example images here if you have them
256
+
257
+ iface = gr.Interface(
258
+ fn=segment_image,
259
+ inputs=[input_image_component, model_choice_component],
260
+ outputs=output_image_component,
261
+ title=title,
262
+ description=description,
263
+ article=article,
264
+ examples=example_paths if example_paths else None, # Pass None if no examples
265
+ allow_flagging="never",
266
+ theme="shivi/calm_seafoam"
267
+ )
268
+
269
+ if __name__ == "__main__":
270
+ # Optional: Pre-load a default model on startup if desired.
271
+ # This can make the first inference faster but increases startup time.
272
+ # print("Attempting to pre-load default SD model on startup...")
273
+ try:
274
+ get_sd_pipeline() # Pre-load the default SD model
275
+ print("Default SD model pre-loaded successfully or was already cached.")
276
+ except Exception as e:
277
+ print(f"Could not pre-load default SD model: {e}")
278
+
279
+ print("Launching Gradio interface...")
280
+ iface.launch()