gen2seg / app.py
reachomk's picture
Update app.py
4205c96 verified
import gradio as gr
import torch
from PIL import Image
import numpy as np
import time
import os
import spaces
try:
from gen2seg_sd_pipeline import gen2segSDPipeline
from gen2seg_mae_pipeline import gen2segMAEInstancePipeline
except ImportError as e:
print(f"Error importing pipeline modules: {e}")
print("Please ensure gen2seg_sd_pipeline.py and gen2seg_mae_pipeline.py are in the same directory.")
# Optionally, raise an error or exit if pipelines are critical at startup
# raise ImportError("Could not import custom pipeline modules. Check file paths.") from e
from transformers import ViTMAEForPreTraining, AutoImageProcessor
# --- Configuration ---
MODEL_IDS = {
"SD": "reachomk/gen2seg-sd",
"MAE-H": "reachomk/gen2seg-mae-h"
}
# Check if a GPU is available and set the device accordingly
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
# --- Global Variables for Caching Pipelines ---
sd_pipe_global = None
mae_pipe_global = None
# --- Model Loading Functions ---
def get_sd_pipeline():
"""Loads and caches the gen2seg Stable Diffusion pipeline."""
global sd_pipe_global
if sd_pipe_global is None:
model_id_sd = MODEL_IDS["SD"]
print(f"Attempting to load SD pipeline from Hugging Face Hub: {model_id_sd}")
try:
sd_pipe_global = gen2segSDPipeline.from_pretrained(
model_id_sd,
use_safetensors=True,
# torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Optional: use float16 on GPU
).to(DEVICE)
print(f"SD Pipeline loaded successfully from {model_id_sd} on {DEVICE}.")
except Exception as e:
print(f"Error loading SD pipeline from Hugging Face Hub ({model_id_sd}): {e}")
sd_pipe_global = None # Ensure it remains None on failure
# Do not raise gr.Error here; let the main function handle it.
return sd_pipe_global
def get_mae_pipeline():
"""Loads and caches the gen2seg MAE-H pipeline."""
global mae_pipe_global
if mae_pipe_global is None:
model_id_mae = MODEL_IDS["MAE-H"]
print(f"Loading MAE-H pipeline with model {model_id_mae} on {DEVICE}...")
try:
model = ViTMAEForPreTraining.from_pretrained(model_id_mae)
model.to(DEVICE)
model.eval() # Set to evaluation mode
# Load the official MAE-H image processor
# Using "facebook/vit-mae-huge" as per the original app_mae.py
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge")
mae_pipe_global = gen2segMAEInstancePipeline(model=model, image_processor=image_processor)
# The custom MAE pipeline's model is already on the DEVICE.
print(f"MAE-H Pipeline with model {model_id_mae} loaded successfully on {DEVICE}.")
except Exception as e:
print(f"Error loading MAE-H model or pipeline from Hugging Face Hub ({model_id_mae}): {e}")
mae_pipe_global = None # Ensure it remains None on failure
# Do not raise gr.Error here; let the main function handle it.
return mae_pipe_global
# --- Unified Prediction Function ---
@spaces.GPU(duration=90)
def segment_image(input_image: Image.Image, model_choice: str) -> Image.Image:
"""
Takes a PIL Image and model choice, performs segmentation, and returns the segmented image.
"""
if input_image is None:
raise gr.Error("No image provided. Please upload an image.")
print(f"Model selected: {model_choice}")
# Ensure image is in RGB format
image_rgb = input_image.convert("RGB")
original_resolution = image_rgb.size # (width, height)
seg_array = None
try:
if model_choice == "SD":
pipe_sd = get_sd_pipeline()
if pipe_sd is None:
raise gr.Error("The SD segmentation pipeline could not be loaded. "
"Please check the Space logs for more details, or try again later.")
print(f"Running SD inference with image size: {image_rgb.size}")
start_time = time.time()
with torch.no_grad():
# The gen2segSDPipeline expects a single image or a list
# The pipeline's __call__ method handles preprocessing internally
seg_output = pipe_sd(image_rgb, match_input_resolution=False).prediction # Output is before resize
# seg_output is expected to be a numpy array (N,H,W,1) or (N,1,H,W) or tensor
# Based on gen2seg_sd_pipeline.py, if output_type="np" (default), it's [N,H,W,1]
# If output_type="pt", it's [N,1,H,W]
# The original app_sd.py converted tensor to numpy and squeezed.
if isinstance(seg_output, torch.Tensor):
seg_output = seg_output.cpu().numpy()
if seg_output.ndim == 4 and seg_output.shape[0] == 1: # Batch size 1
if seg_output.shape[1] == 1: # Grayscale, (1, 1, H, W)
seg_array = seg_output.squeeze(0).squeeze(0).astype(np.uint8)
elif seg_output.shape[-1] == 1: # Grayscale, (1, H, W, 1)
seg_array = seg_output.squeeze(0).squeeze(-1).astype(np.uint8)
elif seg_output.shape[1] == 3: # RGB, (1, 3, H, W) -> (H, W, 3)
seg_array = np.transpose(seg_output.squeeze(0), (1, 2, 0)).astype(np.uint8)
elif seg_output.shape[-1] == 3: # RGB, (1, H, W, 3)
seg_array = seg_output.squeeze(0).astype(np.uint8)
else: # Fallback for unexpected shapes
seg_array = seg_output.squeeze().astype(np.uint8)
elif seg_output.ndim == 3: # (H, W, C) or (C, H, W)
seg_array = seg_output.astype(np.uint8)
elif seg_output.ndim == 2: # (H,W)
seg_array = seg_output.astype(np.uint8)
else:
raise TypeError(f"Unexpected SD segmentation output type/shape: {type(seg_output)}, {seg_output.shape}")
end_time = time.time()
print(f"SD Inference completed in {end_time - start_time:.2f} seconds.")
elif model_choice == "MAE-H":
pipe_mae = get_mae_pipeline()
if pipe_mae is None:
raise gr.Error("The MAE-H segmentation pipeline could not be loaded. "
"Please check the Space logs for more details, or try again later.")
print(f"Running MAE-H inference with image size: {image_rgb.size}")
start_time = time.time()
with torch.no_grad():
# The gen2segMAEInstancePipeline expects a list of images
# output_type="np" returns a NumPy array
pipe_output = pipe_mae([image_rgb], output_type="np")
# Prediction is (batch_size, height, width, 3) for MAE
prediction_np = pipe_output.prediction[0] # Get the first (and only) image prediction
end_time = time.time()
print(f"MAE-H Inference completed in {end_time - start_time:.2f} seconds.")
if not isinstance(prediction_np, np.ndarray):
# This case should ideally not be reached if output_type="np"
prediction_np = prediction_np.cpu().numpy()
# Ensure it's in the expected (H, W, C) format and uint8
if prediction_np.ndim == 3 and prediction_np.shape[-1] == 3: # Expected (H, W, 3)
seg_array = prediction_np.astype(np.uint8)
else:
# Attempt to handle other shapes if necessary, or raise error
raise gr.Error(f"Unexpected MAE-H prediction shape: {prediction_np.shape}. Expected (H, W, 3).")
# The MAE pipeline already does gamma correction and scaling to 0-255.
# It also ensures 3 channels.
else:
raise gr.Error(f"Invalid model choice: {model_choice}. Please select a valid model.")
if seg_array is None:
raise gr.Error("Segmentation array was not generated. An unknown error occurred.")
print(f"Segmentation array generated with shape: {seg_array.shape}, dtype: {seg_array.dtype}")
# Convert numpy array to PIL Image
# Handle grayscale or RGB based on seg_array channels
if seg_array.ndim == 2: # Grayscale
segmented_image_pil = Image.fromarray(seg_array, mode='L')
elif seg_array.ndim == 3 and seg_array.shape[-1] == 3: # RGB
segmented_image_pil = Image.fromarray(seg_array, mode='RGB')
elif seg_array.ndim == 3 and seg_array.shape[-1] == 1: # Grayscale with channel dim
segmented_image_pil = Image.fromarray(seg_array.squeeze(-1), mode='L')
else:
raise gr.Error(f"Cannot convert seg_array with shape {seg_array.shape} to PIL Image.")
# Resize back to original image resolution using LANCZOS for high quality
segmented_image_pil = segmented_image_pil.resize(original_resolution, Image.Resampling.LANCZOS)
print(f"Segmented image processed. Output size: {segmented_image_pil.size}, mode: {segmented_image_pil.mode}")
return segmented_image_pil
except Exception as e:
print(f"Error during segmentation with {model_choice}: {e}")
# Re-raise as gr.Error for Gradio to display, if not already one
if not isinstance(e, gr.Error):
# It's often helpful to include the type of the original exception
error_type = type(e).__name__
raise gr.Error(f"An error occurred during segmentation: {error_type} - {str(e)}")
else:
raise e # Re-raise if it's already a gr.Error
# --- Gradio Interface ---
title = "gen2seg: Generative Models Enable Generalizable Instance Segmentation Demo (SD & MAE-H)"
description = f"""
<div style="text-align: center; font-family: 'Arial', sans-serif;">
<p>Upload an image and choose a model architecture to see the instance segmentation result generated by the respective model. </p>
<p>
BIG THANKS to Huggingface for funding our demo with their Academic GPU Grant!
</p>
<ul>
<li><strong>SD</strong>: Based on Stable Diffusion 2.
<a href="https://huggingface.co/{MODEL_IDS['SD']}" target="_blank">Model Link</a>.
</li>
<li><strong>MAE-H</strong>: Based on Masked Autoencoder (Huge).
<a href="https://huggingface.co/{MODEL_IDS['MAE-H']}" target="_blank">Model Link</a>.
If you experience tokenizer artifacts or very dark images, you can use gamma correction to handle this.
</li>
</ul>
<p>
Paper: <a href="https://arxiv.org/abs/2505.15263">https://arxiv.org/abs/2505.15263</a>
</p>
<p>
For faster inference, please check out our GitHub to run the models locally on a GPU:
<a href="https://github.com/UCDvision/gen2seg" target="_blank">https://github.com/UCDvision/gen2seg</a> or check out our Colab demo <a href="https://colab.research.google.com/drive/10lPBP4figJf7MLp9T1b5cDQeU7MgODw3?usp=sharing" target="_blank">here</a>.
</p>
<p>If the demo experiences issues, please open an issue on our GitHub.</p>
<p> If you have not already, please see our webpage at <a href="https://reachomk.github.io/gen2seg" target="_blank">https://reachomk.github.io/gen2seg</a>.
</div>
"""
article = """
"""
# Define Gradio inputs
input_image_component = gr.Image(type="pil", label="Input Image")
model_choice_component = gr.Dropdown(
choices=list(MODEL_IDS.keys()),
value="SD", # Default model
label="Choose Segmentation Model Architecture"
)
# Define Gradio output
output_image_component = gr.Image(type="pil", label="Segmented Image")
# Example images (ensure these paths are correct if you upload examples to your Space)
# For example, if you create an "examples" folder in your Space repo:
# example_paths = [
# os.path.join("examples", "example1.jpg"),
# os.path.join("examples", "example2.png")
# ]
# Filter out non-existent example files to prevent errors
# example_paths = [ex for ex in example_paths if os.path.exists(ex)]
# Base list of example image paths/URLs
base_example_images = [
"cats-on-rock-1948.jpg",
"dogs.png",
"000000484893.jpg",
"https://reachomk.github.io/gen2seg/images/comparison/vertical/7.png",
"https://reachomk.github.io/gen2seg/images/comparison/horizontal/11.png",
"https://reachomk.github.io/gen2seg/images/comparison/vertical/2.jpg"
]
# Generate examples for each image with both model choices
model_choices_for_examples = list(MODEL_IDS.keys()) # ["SD", "MAE-H"]
formatted_examples = []
for img_path_or_url in base_example_images:
for model_choice in model_choices_for_examples:
formatted_examples.append([img_path_or_url, model_choice])
iface = gr.Interface(
fn=segment_image,
inputs=[input_image_component, model_choice_component],
outputs=output_image_component,
title=title,
description=description,
article=article,
examples=None, #formatted_examples if formatted_examples else None,
allow_flagging="never",
theme="shivi/calm_seafoam"
)
if __name__ == "__main__":
# Optional: Pre-load a default model on startup if desired.
# This can make the first inference faster but increases startup time.
# print("Attempting to pre-load default SD model on startup...")
try:
get_sd_pipeline() # Pre-load the default SD model
print("Default SD model pre-loaded successfully or was already cached.")
except Exception as e:
print(f"Could not pre-load default SD model: {e}")
print("Launching Gradio interface...")
iface.launch()