File size: 5,232 Bytes
353e8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# gen2seg official inference pipeline code for Stable Diffusion model
# 
# Please see our project website at https://reachomk.github.io/gen2seg
#
# Additionally, if you use our code please cite our paper, along with the two works above. 

from dataclasses import dataclass
from typing import Union, List, Optional

import torch
import numpy as np
from PIL import Image
from einops import rearrange

from diffusers import DiffusionPipeline
from diffusers.utils import BaseOutput, logging
from transformers import AutoImageProcessor

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
class gen2segMAEInstanceOutput(BaseOutput):
    """
    Output class for the ViTMAE Instance Segmentation Pipeline.

    Args:
        prediction (`np.ndarray` or `torch.Tensor`):
            Predicted instance segmentation maps. The output has shape 
            `(batch_size, 3, height, width)` with pixel values scaled to [0, 255].
    """
    prediction: Union[np.ndarray, torch.Tensor]


class gen2segMAEInstancePipeline(DiffusionPipeline):
    r"""
    Pipeline for Instance Segmentation using a fine-tuned ViTMAEForPreTraining model.

    This pipeline takes one or more input images and returns an instance segmentation
    prediction for each image. The model is assumed to have been fine-tuned using an instance
    segmentation loss, and the reconstruction is performed by rearranging the model’s
    patch logits into an image.

    Args:
        model (`ViTMAEForPreTraining`):
            The fine-tuned ViTMAE model.
        image_processor (`AutoImageProcessor`):
            The image processor responsible for preprocessing input images.
    """
    def __init__(self, model, image_processor):
        super().__init__()
        self.register_modules(model=model, image_processor=image_processor)
        self.model = model
        self.image_processor = image_processor

    def check_inputs(
        self,
        image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]]
    ) -> List:
        if not isinstance(image, list):
            image = [image]
        # Additional input validations can be added here if desired.
        return image

    @torch.no_grad()
    def __call__(
        self,
        image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]],
        output_type: str = "np",
        **kwargs
    ) -> gen2segMAEInstanceOutput:
        r"""
        The call method of the pipeline.

        Args:
            image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, or a list of these):
                The input image(s) for instance segmentation. For arrays/tensors, expected values are in [0, 1].
            output_type (`str`, optional, defaults to `"np"`):
                The format of the output prediction. Choose `"np"` for a NumPy array or `"pt"` for a PyTorch tensor.
            **kwargs:
                Additional keyword arguments passed to the image processor.

        Returns:
            [`gen2segMAEInstanceOutput`]:
                An output object containing the predicted instance segmentation maps.
        """
        # 1. Check and prepare input images.
        images = self.check_inputs(image)
        inputs = self.image_processor(images=images, return_tensors="pt", **kwargs)
        pixel_values = inputs["pixel_values"].to(self.device)

        # 2. Forward pass through the model.
        outputs = self.model(pixel_values=pixel_values)
        logits = outputs.logits  # Expected shape: (B, num_patches, patch_dim)

        # 3. Retrieve patch size and image size from the model configuration.
        patch_size = self.model.config.patch_size  # e.g., 16
        image_size = self.model.config.image_size    # e.g., 224
        grid_size = image_size // patch_size

        # 4. Rearrange logits into the reconstructed image.
        #    The logits are reshaped from (B, num_patches, patch_dim) to (B, 3, H, W).
        reconstructed = rearrange(
            logits,
            "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
            h=grid_size,
            p1=patch_size,
            p2=patch_size,
            c=3,
        )

        # 5. Post-process the reconstructed output.
        #    For each sample, shift and scale the prediction to [0, 255].
        predictions = []
        for i in range(reconstructed.shape[0]):
            sample = reconstructed[i]
            min_val = torch.abs(sample.min())
            max_val = torch.abs(sample.max())
            sample = (sample + min_val) / (max_val + min_val + 1e-5)
            # sometimes the image is very dark so we perform gamma correction to "brighten" it
            # in practice we can set this value to whatever we want or disable it entirely. 
            sample = sample**0.6
            sample = sample * 255.0
            predictions.append(sample)
        prediction_tensor = torch.stack(predictions, dim=0).permute(0, 2, 3, 1)

        # 6. Format the output.
        if output_type == "np":
            prediction = prediction_tensor.cpu().numpy()
        else:
            prediction = prediction_tensor
        return gen2segMAEInstanceOutput(prediction=prediction)