File size: 13,913 Bytes
50238b6
 
 
 
 
 
4205c96
50238b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb080cb
50238b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7eecae
50238b6
 
 
 
 
 
 
 
 
 
dab358a
 
 
50238b6
 
429e555
50238b6
 
429e555
50238b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98bb16e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50238b6
 
 
 
 
 
 
2d56e6b
50238b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import gradio as gr
import torch
from PIL import Image
import numpy as np
import time
import os
import spaces

try:
    from gen2seg_sd_pipeline import gen2segSDPipeline
    from gen2seg_mae_pipeline import gen2segMAEInstancePipeline
except ImportError as e:
    print(f"Error importing pipeline modules: {e}")
    print("Please ensure gen2seg_sd_pipeline.py and gen2seg_mae_pipeline.py are in the same directory.")
    # Optionally, raise an error or exit if pipelines are critical at startup
    # raise ImportError("Could not import custom pipeline modules. Check file paths.") from e

from transformers import ViTMAEForPreTraining, AutoImageProcessor

# --- Configuration ---
MODEL_IDS = {
    "SD": "reachomk/gen2seg-sd",
    "MAE-H": "reachomk/gen2seg-mae-h"
}

# Check if a GPU is available and set the device accordingly
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- Global Variables for Caching Pipelines ---
sd_pipe_global = None
mae_pipe_global = None

# --- Model Loading Functions ---
def get_sd_pipeline():
    """Loads and caches the gen2seg Stable Diffusion pipeline."""
    global sd_pipe_global
    if sd_pipe_global is None:
        model_id_sd = MODEL_IDS["SD"]
        print(f"Attempting to load SD pipeline from Hugging Face Hub: {model_id_sd}")
        try:
            sd_pipe_global = gen2segSDPipeline.from_pretrained(
                model_id_sd,
                use_safetensors=True,
                # torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Optional: use float16 on GPU
            ).to(DEVICE)
            print(f"SD Pipeline loaded successfully from {model_id_sd} on {DEVICE}.")
        except Exception as e:
            print(f"Error loading SD pipeline from Hugging Face Hub ({model_id_sd}): {e}")
            sd_pipe_global = None # Ensure it remains None on failure
            # Do not raise gr.Error here; let the main function handle it.
    return sd_pipe_global

def get_mae_pipeline():
    """Loads and caches the gen2seg MAE-H pipeline."""
    global mae_pipe_global
    if mae_pipe_global is None:
        model_id_mae = MODEL_IDS["MAE-H"]
        print(f"Loading MAE-H pipeline with model {model_id_mae} on {DEVICE}...")
        try:
            model = ViTMAEForPreTraining.from_pretrained(model_id_mae)
            model.to(DEVICE)
            model.eval() # Set to evaluation mode

            # Load the official MAE-H image processor
            # Using "facebook/vit-mae-huge" as per the original app_mae.py
            image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge")
            
            mae_pipe_global = gen2segMAEInstancePipeline(model=model, image_processor=image_processor)
            # The custom MAE pipeline's model is already on the DEVICE.
            print(f"MAE-H Pipeline with model {model_id_mae} loaded successfully on {DEVICE}.")
        except Exception as e:
            print(f"Error loading MAE-H model or pipeline from Hugging Face Hub ({model_id_mae}): {e}")
            mae_pipe_global = None # Ensure it remains None on failure
            # Do not raise gr.Error here; let the main function handle it.
    return mae_pipe_global

# --- Unified Prediction Function ---
@spaces.GPU(duration=90)
def segment_image(input_image: Image.Image, model_choice: str) -> Image.Image:
    """
    Takes a PIL Image and model choice, performs segmentation, and returns the segmented image.
    """
    if input_image is None:
        raise gr.Error("No image provided. Please upload an image.")

    print(f"Model selected: {model_choice}")
    # Ensure image is in RGB format
    image_rgb = input_image.convert("RGB")
    original_resolution = image_rgb.size # (width, height)
    seg_array = None

    try:
        if model_choice == "SD":
            pipe_sd = get_sd_pipeline()
            if pipe_sd is None:
                raise gr.Error("The SD segmentation pipeline could not be loaded. "
                               "Please check the Space logs for more details, or try again later.")

            print(f"Running SD inference with image size: {image_rgb.size}")
            start_time = time.time()
            with torch.no_grad():
                # The gen2segSDPipeline expects a single image or a list
                # The pipeline's __call__ method handles preprocessing internally
                seg_output = pipe_sd(image_rgb, match_input_resolution=False).prediction # Output is before resize
                
                # seg_output is expected to be a numpy array (N,H,W,1) or (N,1,H,W) or tensor
                # Based on gen2seg_sd_pipeline.py, if output_type="np" (default), it's [N,H,W,1]
                # If output_type="pt", it's [N,1,H,W]
                # The original app_sd.py converted tensor to numpy and squeezed.
                if isinstance(seg_output, torch.Tensor):
                    seg_output = seg_output.cpu().numpy()

                if seg_output.ndim == 4 and seg_output.shape[0] == 1: # Batch size 1
                    if seg_output.shape[1] == 1: # Grayscale, (1, 1, H, W)
                        seg_array = seg_output.squeeze(0).squeeze(0).astype(np.uint8)
                    elif seg_output.shape[-1] == 1: # Grayscale, (1, H, W, 1)
                         seg_array = seg_output.squeeze(0).squeeze(-1).astype(np.uint8)
                    elif seg_output.shape[1] == 3: # RGB, (1, 3, H, W) -> (H, W, 3)
                        seg_array = np.transpose(seg_output.squeeze(0), (1, 2, 0)).astype(np.uint8)
                    elif seg_output.shape[-1] == 3: # RGB, (1, H, W, 3)
                         seg_array = seg_output.squeeze(0).astype(np.uint8)
                    else: # Fallback for unexpected shapes
                        seg_array = seg_output.squeeze().astype(np.uint8)

                elif seg_output.ndim == 3: # (H, W, C) or (C, H, W)
                    seg_array = seg_output.astype(np.uint8)
                elif seg_output.ndim == 2: # (H,W)
                    seg_array = seg_output.astype(np.uint8)
                else:
                    raise TypeError(f"Unexpected SD segmentation output type/shape: {type(seg_output)}, {seg_output.shape}")
            end_time = time.time()
            print(f"SD Inference completed in {end_time - start_time:.2f} seconds.")


        elif model_choice == "MAE-H":
            pipe_mae = get_mae_pipeline()
            if pipe_mae is None:
                raise gr.Error("The MAE-H segmentation pipeline could not be loaded. "
                               "Please check the Space logs for more details, or try again later.")

            print(f"Running MAE-H inference with image size: {image_rgb.size}")
            start_time = time.time()
            with torch.no_grad():
                # The gen2segMAEInstancePipeline expects a list of images
                # output_type="np" returns a NumPy array
                pipe_output = pipe_mae([image_rgb], output_type="np")
                # Prediction is (batch_size, height, width, 3) for MAE
                prediction_np = pipe_output.prediction[0] # Get the first (and only) image prediction
            
            end_time = time.time()
            print(f"MAE-H Inference completed in {end_time - start_time:.2f} seconds.")
            
            if not isinstance(prediction_np, np.ndarray):
                # This case should ideally not be reached if output_type="np"
                prediction_np = prediction_np.cpu().numpy()

            # Ensure it's in the expected (H, W, C) format and uint8
            if prediction_np.ndim == 3 and prediction_np.shape[-1] == 3: # Expected (H, W, 3)
                seg_array = prediction_np.astype(np.uint8)
            else:
                # Attempt to handle other shapes if necessary, or raise error
                raise gr.Error(f"Unexpected MAE-H prediction shape: {prediction_np.shape}. Expected (H, W, 3).")
            
            # The MAE pipeline already does gamma correction and scaling to 0-255.
            # It also ensures 3 channels.

        else:
            raise gr.Error(f"Invalid model choice: {model_choice}. Please select a valid model.")

        if seg_array is None:
             raise gr.Error("Segmentation array was not generated. An unknown error occurred.")

        print(f"Segmentation array generated with shape: {seg_array.shape}, dtype: {seg_array.dtype}")

        # Convert numpy array to PIL Image
        # Handle grayscale or RGB based on seg_array channels
        if seg_array.ndim == 2: # Grayscale
            segmented_image_pil = Image.fromarray(seg_array, mode='L')
        elif seg_array.ndim == 3 and seg_array.shape[-1] == 3: # RGB
            segmented_image_pil = Image.fromarray(seg_array, mode='RGB')
        elif seg_array.ndim == 3 and seg_array.shape[-1] == 1: # Grayscale with channel dim
            segmented_image_pil = Image.fromarray(seg_array.squeeze(-1), mode='L')
        else:
            raise gr.Error(f"Cannot convert seg_array with shape {seg_array.shape} to PIL Image.")

        # Resize back to original image resolution using LANCZOS for high quality
        segmented_image_pil = segmented_image_pil.resize(original_resolution, Image.Resampling.LANCZOS)
        
        print(f"Segmented image processed. Output size: {segmented_image_pil.size}, mode: {segmented_image_pil.mode}")
        return segmented_image_pil

    except Exception as e:
        print(f"Error during segmentation with {model_choice}: {e}")
        # Re-raise as gr.Error for Gradio to display, if not already one
        if not isinstance(e, gr.Error):
            # It's often helpful to include the type of the original exception
            error_type = type(e).__name__
            raise gr.Error(f"An error occurred during segmentation: {error_type} - {str(e)}")
        else:
            raise e # Re-raise if it's already a gr.Error

# --- Gradio Interface ---
title = "gen2seg: Generative Models Enable Generalizable Instance Segmentation Demo (SD & MAE-H)"
description = f"""
<div style="text-align: center; font-family: 'Arial', sans-serif;">
    <p>Upload an image and choose a model architecture to see the instance segmentation result generated by the respective model. </p>
    <p>
        BIG THANKS to Huggingface for funding our demo with their Academic GPU Grant!
    </p>
    <ul>
        <li><strong>SD</strong>: Based on Stable Diffusion 2.
            <a href="https://huggingface.co/{MODEL_IDS['SD']}" target="_blank">Model Link</a>.
        </li>
        <li><strong>MAE-H</strong>: Based on Masked Autoencoder (Huge).
            <a href="https://huggingface.co/{MODEL_IDS['MAE-H']}" target="_blank">Model Link</a>.
            If you experience tokenizer artifacts or very dark images, you can use gamma correction to handle this. 
        </li>
    </ul>
    <p>
    Paper: <a href="https://arxiv.org/abs/2505.15263">https://arxiv.org/abs/2505.15263</a>
    </p>
    <p>
        For faster inference, please check out our GitHub to run the models locally on a GPU:
        <a href="https://github.com/UCDvision/gen2seg" target="_blank">https://github.com/UCDvision/gen2seg</a> or check out our Colab demo <a href="https://colab.research.google.com/drive/10lPBP4figJf7MLp9T1b5cDQeU7MgODw3?usp=sharing" target="_blank">here</a>.
    </p>
    <p>If the demo experiences issues, please open an issue on our GitHub.</p>
    <p> If you have not already, please see our webpage at <a href="https://reachomk.github.io/gen2seg" target="_blank">https://reachomk.github.io/gen2seg</a>.

</div>
"""

article = """
"""

# Define Gradio inputs
input_image_component = gr.Image(type="pil", label="Input Image")
model_choice_component = gr.Dropdown(
    choices=list(MODEL_IDS.keys()),
    value="SD",  # Default model
    label="Choose Segmentation Model Architecture"
)

# Define Gradio output
output_image_component = gr.Image(type="pil", label="Segmented Image")

# Example images (ensure these paths are correct if you upload examples to your Space)
# For example, if you create an "examples" folder in your Space repo:
# example_paths = [
#     os.path.join("examples", "example1.jpg"),
#     os.path.join("examples", "example2.png")
# ]
# Filter out non-existent example files to prevent errors
# example_paths = [ex for ex in example_paths if os.path.exists(ex)]
# Base list of example image paths/URLs
base_example_images = [
    "cats-on-rock-1948.jpg",
    "dogs.png",
    "000000484893.jpg",
    "https://reachomk.github.io/gen2seg/images/comparison/vertical/7.png",
    "https://reachomk.github.io/gen2seg/images/comparison/horizontal/11.png",
    "https://reachomk.github.io/gen2seg/images/comparison/vertical/2.jpg"
]

# Generate examples for each image with both model choices
model_choices_for_examples = list(MODEL_IDS.keys()) # ["SD", "MAE-H"]
formatted_examples = []
for img_path_or_url in base_example_images:
    for model_choice in model_choices_for_examples:
        formatted_examples.append([img_path_or_url, model_choice])
iface = gr.Interface(
    fn=segment_image,
    inputs=[input_image_component, model_choice_component],
    outputs=output_image_component,
    title=title,
    description=description,
    article=article,
    examples=None, #formatted_examples if formatted_examples else None,
    allow_flagging="never",
    theme="shivi/calm_seafoam"
)

if __name__ == "__main__":
    # Optional: Pre-load a default model on startup if desired.
    # This can make the first inference faster but increases startup time.
    # print("Attempting to pre-load default SD model on startup...")
    try:
       get_sd_pipeline() # Pre-load the default SD model
       print("Default SD model pre-loaded successfully or was already cached.")
    except Exception as e:
       print(f"Could not pre-load default SD model: {e}")

    print("Launching Gradio interface...")
    iface.launch()