Spaces:
Running
on
Zero
Running
on
Zero
# gen2seg official inference pipeline code for Stable Diffusion model | |
# | |
# Please see our project website at https://reachomk.github.io/gen2seg | |
# | |
# Additionally, if you use our code please cite our paper, along with the two works above. | |
from dataclasses import dataclass | |
from typing import Union, List, Optional | |
import torch | |
import numpy as np | |
from PIL import Image | |
from einops import rearrange | |
from diffusers import DiffusionPipeline | |
from diffusers.utils import BaseOutput, logging | |
from transformers import AutoImageProcessor | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class gen2segMAEInstanceOutput(BaseOutput): | |
""" | |
Output class for the ViTMAE Instance Segmentation Pipeline. | |
Args: | |
prediction (`np.ndarray` or `torch.Tensor`): | |
Predicted instance segmentation maps. The output has shape | |
`(batch_size, 3, height, width)` with pixel values scaled to [0, 255]. | |
""" | |
prediction: Union[np.ndarray, torch.Tensor] | |
class gen2segMAEInstancePipeline(DiffusionPipeline): | |
r""" | |
Pipeline for Instance Segmentation using a fine-tuned ViTMAEForPreTraining model. | |
This pipeline takes one or more input images and returns an instance segmentation | |
prediction for each image. The model is assumed to have been fine-tuned using an instance | |
segmentation loss, and the reconstruction is performed by rearranging the model’s | |
patch logits into an image. | |
Args: | |
model (`ViTMAEForPreTraining`): | |
The fine-tuned ViTMAE model. | |
image_processor (`AutoImageProcessor`): | |
The image processor responsible for preprocessing input images. | |
""" | |
def __init__(self, model, image_processor): | |
super().__init__() | |
self.register_modules(model=model, image_processor=image_processor) | |
self.model = model | |
self.image_processor = image_processor | |
def check_inputs( | |
self, | |
image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]] | |
) -> List: | |
if not isinstance(image, list): | |
image = [image] | |
# Additional input validations can be added here if desired. | |
return image | |
def __call__( | |
self, | |
image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]], | |
output_type: str = "np", | |
**kwargs | |
) -> gen2segMAEInstanceOutput: | |
r""" | |
The call method of the pipeline. | |
Args: | |
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, or a list of these): | |
The input image(s) for instance segmentation. For arrays/tensors, expected values are in [0, 1]. | |
output_type (`str`, optional, defaults to `"np"`): | |
The format of the output prediction. Choose `"np"` for a NumPy array or `"pt"` for a PyTorch tensor. | |
**kwargs: | |
Additional keyword arguments passed to the image processor. | |
Returns: | |
[`gen2segMAEInstanceOutput`]: | |
An output object containing the predicted instance segmentation maps. | |
""" | |
# 1. Check and prepare input images. | |
images = self.check_inputs(image) | |
inputs = self.image_processor(images=images, return_tensors="pt", **kwargs) | |
pixel_values = inputs["pixel_values"].to(self.device) | |
# 2. Forward pass through the model. | |
outputs = self.model(pixel_values=pixel_values) | |
logits = outputs.logits # Expected shape: (B, num_patches, patch_dim) | |
# 3. Retrieve patch size and image size from the model configuration. | |
patch_size = self.model.config.patch_size # e.g., 16 | |
image_size = self.model.config.image_size # e.g., 224 | |
grid_size = image_size // patch_size | |
# 4. Rearrange logits into the reconstructed image. | |
# The logits are reshaped from (B, num_patches, patch_dim) to (B, 3, H, W). | |
reconstructed = rearrange( | |
logits, | |
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)", | |
h=grid_size, | |
p1=patch_size, | |
p2=patch_size, | |
c=3, | |
) | |
# 5. Post-process the reconstructed output. | |
# For each sample, shift and scale the prediction to [0, 255]. | |
predictions = [] | |
for i in range(reconstructed.shape[0]): | |
sample = reconstructed[i] | |
min_val = torch.abs(sample.min()) | |
max_val = torch.abs(sample.max()) | |
sample = (sample + min_val) / (max_val + min_val + 1e-5) | |
# sometimes the image is very dark so we perform gamma correction to "brighten" it | |
# in practice we can set this value to whatever we want or disable it entirely. | |
sample = sample**0.6 | |
sample = sample * 255.0 | |
predictions.append(sample) | |
prediction_tensor = torch.stack(predictions, dim=0).permute(0, 2, 3, 1) | |
# 6. Format the output. | |
if output_type == "np": | |
prediction = prediction_tensor.cpu().numpy() | |
else: | |
prediction = prediction_tensor | |
return gen2segMAEInstanceOutput(prediction=prediction) | |