Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
import time | |
import os | |
import spaces | |
try: | |
from gen2seg_sd_pipeline import gen2segSDPipeline | |
from gen2seg_mae_pipeline import gen2segMAEInstancePipeline | |
except ImportError as e: | |
print(f"Error importing pipeline modules: {e}") | |
print("Please ensure gen2seg_sd_pipeline.py and gen2seg_mae_pipeline.py are in the same directory.") | |
# Optionally, raise an error or exit if pipelines are critical at startup | |
# raise ImportError("Could not import custom pipeline modules. Check file paths.") from e | |
from transformers import ViTMAEForPreTraining, AutoImageProcessor | |
# --- Configuration --- | |
MODEL_IDS = { | |
"SD": "reachomk/gen2seg-sd", | |
"MAE-H": "reachomk/gen2seg-mae-h" | |
} | |
# Check if a GPU is available and set the device accordingly | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {DEVICE}") | |
# --- Global Variables for Caching Pipelines --- | |
sd_pipe_global = None | |
mae_pipe_global = None | |
# --- Model Loading Functions --- | |
def get_sd_pipeline(): | |
"""Loads and caches the gen2seg Stable Diffusion pipeline.""" | |
global sd_pipe_global | |
if sd_pipe_global is None: | |
model_id_sd = MODEL_IDS["SD"] | |
print(f"Attempting to load SD pipeline from Hugging Face Hub: {model_id_sd}") | |
try: | |
sd_pipe_global = gen2segSDPipeline.from_pretrained( | |
model_id_sd, | |
use_safetensors=True, | |
# torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Optional: use float16 on GPU | |
).to(DEVICE) | |
print(f"SD Pipeline loaded successfully from {model_id_sd} on {DEVICE}.") | |
except Exception as e: | |
print(f"Error loading SD pipeline from Hugging Face Hub ({model_id_sd}): {e}") | |
sd_pipe_global = None # Ensure it remains None on failure | |
# Do not raise gr.Error here; let the main function handle it. | |
return sd_pipe_global | |
def get_mae_pipeline(): | |
"""Loads and caches the gen2seg MAE-H pipeline.""" | |
global mae_pipe_global | |
if mae_pipe_global is None: | |
model_id_mae = MODEL_IDS["MAE-H"] | |
print(f"Loading MAE-H pipeline with model {model_id_mae} on {DEVICE}...") | |
try: | |
model = ViTMAEForPreTraining.from_pretrained(model_id_mae) | |
model.to(DEVICE) | |
model.eval() # Set to evaluation mode | |
# Load the official MAE-H image processor | |
# Using "facebook/vit-mae-huge" as per the original app_mae.py | |
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge") | |
mae_pipe_global = gen2segMAEInstancePipeline(model=model, image_processor=image_processor) | |
# The custom MAE pipeline's model is already on the DEVICE. | |
print(f"MAE-H Pipeline with model {model_id_mae} loaded successfully on {DEVICE}.") | |
except Exception as e: | |
print(f"Error loading MAE-H model or pipeline from Hugging Face Hub ({model_id_mae}): {e}") | |
mae_pipe_global = None # Ensure it remains None on failure | |
# Do not raise gr.Error here; let the main function handle it. | |
return mae_pipe_global | |
# --- Unified Prediction Function --- | |
def segment_image(input_image: Image.Image, model_choice: str) -> Image.Image: | |
""" | |
Takes a PIL Image and model choice, performs segmentation, and returns the segmented image. | |
""" | |
if input_image is None: | |
raise gr.Error("No image provided. Please upload an image.") | |
print(f"Model selected: {model_choice}") | |
# Ensure image is in RGB format | |
image_rgb = input_image.convert("RGB") | |
original_resolution = image_rgb.size # (width, height) | |
seg_array = None | |
try: | |
if model_choice == "SD": | |
pipe_sd = get_sd_pipeline() | |
if pipe_sd is None: | |
raise gr.Error("The SD segmentation pipeline could not be loaded. " | |
"Please check the Space logs for more details, or try again later.") | |
print(f"Running SD inference with image size: {image_rgb.size}") | |
start_time = time.time() | |
with torch.no_grad(): | |
# The gen2segSDPipeline expects a single image or a list | |
# The pipeline's __call__ method handles preprocessing internally | |
seg_output = pipe_sd(image_rgb, match_input_resolution=False).prediction # Output is before resize | |
# seg_output is expected to be a numpy array (N,H,W,1) or (N,1,H,W) or tensor | |
# Based on gen2seg_sd_pipeline.py, if output_type="np" (default), it's [N,H,W,1] | |
# If output_type="pt", it's [N,1,H,W] | |
# The original app_sd.py converted tensor to numpy and squeezed. | |
if isinstance(seg_output, torch.Tensor): | |
seg_output = seg_output.cpu().numpy() | |
if seg_output.ndim == 4 and seg_output.shape[0] == 1: # Batch size 1 | |
if seg_output.shape[1] == 1: # Grayscale, (1, 1, H, W) | |
seg_array = seg_output.squeeze(0).squeeze(0).astype(np.uint8) | |
elif seg_output.shape[-1] == 1: # Grayscale, (1, H, W, 1) | |
seg_array = seg_output.squeeze(0).squeeze(-1).astype(np.uint8) | |
elif seg_output.shape[1] == 3: # RGB, (1, 3, H, W) -> (H, W, 3) | |
seg_array = np.transpose(seg_output.squeeze(0), (1, 2, 0)).astype(np.uint8) | |
elif seg_output.shape[-1] == 3: # RGB, (1, H, W, 3) | |
seg_array = seg_output.squeeze(0).astype(np.uint8) | |
else: # Fallback for unexpected shapes | |
seg_array = seg_output.squeeze().astype(np.uint8) | |
elif seg_output.ndim == 3: # (H, W, C) or (C, H, W) | |
seg_array = seg_output.astype(np.uint8) | |
elif seg_output.ndim == 2: # (H,W) | |
seg_array = seg_output.astype(np.uint8) | |
else: | |
raise TypeError(f"Unexpected SD segmentation output type/shape: {type(seg_output)}, {seg_output.shape}") | |
end_time = time.time() | |
print(f"SD Inference completed in {end_time - start_time:.2f} seconds.") | |
elif model_choice == "MAE-H": | |
pipe_mae = get_mae_pipeline() | |
if pipe_mae is None: | |
raise gr.Error("The MAE-H segmentation pipeline could not be loaded. " | |
"Please check the Space logs for more details, or try again later.") | |
print(f"Running MAE-H inference with image size: {image_rgb.size}") | |
start_time = time.time() | |
with torch.no_grad(): | |
# The gen2segMAEInstancePipeline expects a list of images | |
# output_type="np" returns a NumPy array | |
pipe_output = pipe_mae([image_rgb], output_type="np") | |
# Prediction is (batch_size, height, width, 3) for MAE | |
prediction_np = pipe_output.prediction[0] # Get the first (and only) image prediction | |
end_time = time.time() | |
print(f"MAE-H Inference completed in {end_time - start_time:.2f} seconds.") | |
if not isinstance(prediction_np, np.ndarray): | |
# This case should ideally not be reached if output_type="np" | |
prediction_np = prediction_np.cpu().numpy() | |
# Ensure it's in the expected (H, W, C) format and uint8 | |
if prediction_np.ndim == 3 and prediction_np.shape[-1] == 3: # Expected (H, W, 3) | |
seg_array = prediction_np.astype(np.uint8) | |
else: | |
# Attempt to handle other shapes if necessary, or raise error | |
raise gr.Error(f"Unexpected MAE-H prediction shape: {prediction_np.shape}. Expected (H, W, 3).") | |
# The MAE pipeline already does gamma correction and scaling to 0-255. | |
# It also ensures 3 channels. | |
else: | |
raise gr.Error(f"Invalid model choice: {model_choice}. Please select a valid model.") | |
if seg_array is None: | |
raise gr.Error("Segmentation array was not generated. An unknown error occurred.") | |
print(f"Segmentation array generated with shape: {seg_array.shape}, dtype: {seg_array.dtype}") | |
# Convert numpy array to PIL Image | |
# Handle grayscale or RGB based on seg_array channels | |
if seg_array.ndim == 2: # Grayscale | |
segmented_image_pil = Image.fromarray(seg_array, mode='L') | |
elif seg_array.ndim == 3 and seg_array.shape[-1] == 3: # RGB | |
segmented_image_pil = Image.fromarray(seg_array, mode='RGB') | |
elif seg_array.ndim == 3 and seg_array.shape[-1] == 1: # Grayscale with channel dim | |
segmented_image_pil = Image.fromarray(seg_array.squeeze(-1), mode='L') | |
else: | |
raise gr.Error(f"Cannot convert seg_array with shape {seg_array.shape} to PIL Image.") | |
# Resize back to original image resolution using LANCZOS for high quality | |
segmented_image_pil = segmented_image_pil.resize(original_resolution, Image.Resampling.LANCZOS) | |
print(f"Segmented image processed. Output size: {segmented_image_pil.size}, mode: {segmented_image_pil.mode}") | |
return segmented_image_pil | |
except Exception as e: | |
print(f"Error during segmentation with {model_choice}: {e}") | |
# Re-raise as gr.Error for Gradio to display, if not already one | |
if not isinstance(e, gr.Error): | |
# It's often helpful to include the type of the original exception | |
error_type = type(e).__name__ | |
raise gr.Error(f"An error occurred during segmentation: {error_type} - {str(e)}") | |
else: | |
raise e # Re-raise if it's already a gr.Error | |
# --- Gradio Interface --- | |
title = "gen2seg: Generative Models Enable Generalizable Instance Segmentation Demo (SD & MAE-H)" | |
description = f""" | |
<div style="text-align: center; font-family: 'Arial', sans-serif;"> | |
<p>Upload an image and choose a model architecture to see the instance segmentation result generated by the respective model. </p> | |
<p> | |
BIG THANKS to Huggingface for funding our demo with their Academic GPU Grant! | |
</p> | |
<ul> | |
<li><strong>SD</strong>: Based on Stable Diffusion 2. | |
<a href="https://huggingface.co/{MODEL_IDS['SD']}" target="_blank">Model Link</a>. | |
</li> | |
<li><strong>MAE-H</strong>: Based on Masked Autoencoder (Huge). | |
<a href="https://huggingface.co/{MODEL_IDS['MAE-H']}" target="_blank">Model Link</a>. | |
If you experience tokenizer artifacts or very dark images, you can use gamma correction to handle this. | |
</li> | |
</ul> | |
<p> | |
Paper: <a href="https://arxiv.org/abs/2505.15263">https://arxiv.org/abs/2505.15263</a> | |
</p> | |
<p> | |
For faster inference, please check out our GitHub to run the models locally on a GPU: | |
<a href="https://github.com/UCDvision/gen2seg" target="_blank">https://github.com/UCDvision/gen2seg</a> or check out our Colab demo <a href="https://colab.research.google.com/drive/10lPBP4figJf7MLp9T1b5cDQeU7MgODw3?usp=sharing" target="_blank">here</a>. | |
</p> | |
<p>If the demo experiences issues, please open an issue on our GitHub.</p> | |
<p> If you have not already, please see our webpage at <a href="https://reachomk.github.io/gen2seg" target="_blank">https://reachomk.github.io/gen2seg</a>. | |
</div> | |
""" | |
article = """ | |
""" | |
# Define Gradio inputs | |
input_image_component = gr.Image(type="pil", label="Input Image") | |
model_choice_component = gr.Dropdown( | |
choices=list(MODEL_IDS.keys()), | |
value="SD", # Default model | |
label="Choose Segmentation Model Architecture" | |
) | |
# Define Gradio output | |
output_image_component = gr.Image(type="pil", label="Segmented Image") | |
# Example images (ensure these paths are correct if you upload examples to your Space) | |
# For example, if you create an "examples" folder in your Space repo: | |
# example_paths = [ | |
# os.path.join("examples", "example1.jpg"), | |
# os.path.join("examples", "example2.png") | |
# ] | |
# Filter out non-existent example files to prevent errors | |
# example_paths = [ex for ex in example_paths if os.path.exists(ex)] | |
# Base list of example image paths/URLs | |
base_example_images = [ | |
"cats-on-rock-1948.jpg", | |
"dogs.png", | |
"000000484893.jpg", | |
"https://reachomk.github.io/gen2seg/images/comparison/vertical/7.png", | |
"https://reachomk.github.io/gen2seg/images/comparison/horizontal/11.png", | |
"https://reachomk.github.io/gen2seg/images/comparison/vertical/2.jpg" | |
] | |
# Generate examples for each image with both model choices | |
model_choices_for_examples = list(MODEL_IDS.keys()) # ["SD", "MAE-H"] | |
formatted_examples = [] | |
for img_path_or_url in base_example_images: | |
for model_choice in model_choices_for_examples: | |
formatted_examples.append([img_path_or_url, model_choice]) | |
iface = gr.Interface( | |
fn=segment_image, | |
inputs=[input_image_component, model_choice_component], | |
outputs=output_image_component, | |
title=title, | |
description=description, | |
article=article, | |
examples=None, #formatted_examples if formatted_examples else None, | |
allow_flagging="never", | |
theme="shivi/calm_seafoam" | |
) | |
if __name__ == "__main__": | |
# Optional: Pre-load a default model on startup if desired. | |
# This can make the first inference faster but increases startup time. | |
# print("Attempting to pre-load default SD model on startup...") | |
try: | |
get_sd_pipeline() # Pre-load the default SD model | |
print("Default SD model pre-loaded successfully or was already cached.") | |
except Exception as e: | |
print(f"Could not pre-load default SD model: {e}") | |
print("Launching Gradio interface...") | |
iface.launch() | |