# gen2seg official inference pipeline code for Stable Diffusion model # # Please see our project website at https://reachomk.github.io/gen2seg # # Additionally, if you use our code please cite our paper, along with the two works above. from dataclasses import dataclass from typing import Union, List, Optional import torch import numpy as np from PIL import Image from einops import rearrange from diffusers import DiffusionPipeline from diffusers.utils import BaseOutput, logging from transformers import AutoImageProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class gen2segMAEInstanceOutput(BaseOutput): """ Output class for the ViTMAE Instance Segmentation Pipeline. Args: prediction (`np.ndarray` or `torch.Tensor`): Predicted instance segmentation maps. The output has shape `(batch_size, 3, height, width)` with pixel values scaled to [0, 255]. """ prediction: Union[np.ndarray, torch.Tensor] class gen2segMAEInstancePipeline(DiffusionPipeline): r""" Pipeline for Instance Segmentation using a fine-tuned ViTMAEForPreTraining model. This pipeline takes one or more input images and returns an instance segmentation prediction for each image. The model is assumed to have been fine-tuned using an instance segmentation loss, and the reconstruction is performed by rearranging the model’s patch logits into an image. Args: model (`ViTMAEForPreTraining`): The fine-tuned ViTMAE model. image_processor (`AutoImageProcessor`): The image processor responsible for preprocessing input images. """ def __init__(self, model, image_processor): super().__init__() self.register_modules(model=model, image_processor=image_processor) self.model = model self.image_processor = image_processor def check_inputs( self, image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]] ) -> List: if not isinstance(image, list): image = [image] # Additional input validations can be added here if desired. return image @torch.no_grad() def __call__( self, image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]], output_type: str = "np", **kwargs ) -> gen2segMAEInstanceOutput: r""" The call method of the pipeline. Args: image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, or a list of these): The input image(s) for instance segmentation. For arrays/tensors, expected values are in [0, 1]. output_type (`str`, optional, defaults to `"np"`): The format of the output prediction. Choose `"np"` for a NumPy array or `"pt"` for a PyTorch tensor. **kwargs: Additional keyword arguments passed to the image processor. Returns: [`gen2segMAEInstanceOutput`]: An output object containing the predicted instance segmentation maps. """ # 1. Check and prepare input images. images = self.check_inputs(image) inputs = self.image_processor(images=images, return_tensors="pt", **kwargs) pixel_values = inputs["pixel_values"].to(self.device) # 2. Forward pass through the model. outputs = self.model(pixel_values=pixel_values) logits = outputs.logits # Expected shape: (B, num_patches, patch_dim) # 3. Retrieve patch size and image size from the model configuration. patch_size = self.model.config.patch_size # e.g., 16 image_size = self.model.config.image_size # e.g., 224 grid_size = image_size // patch_size # 4. Rearrange logits into the reconstructed image. # The logits are reshaped from (B, num_patches, patch_dim) to (B, 3, H, W). reconstructed = rearrange( logits, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", h=grid_size, p1=patch_size, p2=patch_size, c=3, ) # 5. Post-process the reconstructed output. # For each sample, shift and scale the prediction to [0, 255]. predictions = [] for i in range(reconstructed.shape[0]): sample = reconstructed[i] min_val = torch.abs(sample.min()) max_val = torch.abs(sample.max()) sample = (sample + min_val) / (max_val + min_val + 1e-5) # sometimes the image is very dark so we perform gamma correction to "brighten" it # in practice we can set this value to whatever we want or disable it entirely. sample = sample**0.6 sample = sample * 255.0 predictions.append(sample) prediction_tensor = torch.stack(predictions, dim=0).permute(0, 2, 3, 1) # 6. Format the output. if output_type == "np": prediction = prediction_tensor.cpu().numpy() else: prediction = prediction_tensor return gen2segMAEInstanceOutput(prediction=prediction)