File size: 10,986 Bytes
a560c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Clustering metrics."""

from typing import  Optional, Sequence, Union

from clu import metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np

Ndarray = Union[np.ndarray, jnp.ndarray]


def check_shape(x, expected_shape, name):
  """Check whether shape x is as expected.

  Args:
    x: Any data type with `shape` attribute. If `shape` attribute is not present
      it is assumed to be a scalar with shape ().
    expected_shape: The shape that is expected of x. For example,
      [None, None, 3] can be the `expected_shape` for a color image,
      [4, None, None, 3] if we know that batch size is 4.
    name: Name of `x` to provide informative error messages.

  Raises: ValueError if x's shape does not match expected_shape. Also raises
    ValueError if expected_shape is not a list or tuple.
  """
  if not isinstance(expected_shape, (list, tuple)):
    raise ValueError(
        "expected_shape should be a list or tuple of ints but got "
        f"{expected_shape}.")

  # Scalars have shape () by definition.
  shape = getattr(x, "shape", ())

  if (len(shape) != len(expected_shape) or
      any(j is not None and i != j for i, j in zip(shape, expected_shape))):
    raise ValueError(
        f"Input {name} had shape {shape} but {expected_shape} was expected.")


def _validate_inputs(predicted_segmentations,
                     ground_truth_segmentations,
                     padding_mask,
                     mask = None):
  """Checks that all inputs have the expected shapes.

  Args:
    predicted_segmentations: An array of integers of shape [bs, seq_len, H, W]
      containing model segmentation predictions.
    ground_truth_segmentations: An array of integers of shape [bs, seq_len, H,
      W] containing ground truth segmentations.
    padding_mask: An array of integers of shape [bs, seq_len, H, W] defining
      regions where the ground truth is meaningless, for example because this
      corresponds to regions which were padded during data augmentation. Value 0
      corresponds to padded regions, 1 corresponds to valid regions to be used
      for metric calculation.
    mask: An optional array of boolean mask values of shape [bs]. `True`
      corresponds to actual batch examples whereas `False` corresponds to
      padding.

  Raises:
    ValueError if the inputs are not valid.
  """

  check_shape(
      predicted_segmentations, [None, None, None, None],
      "predicted_segmentations [bs, seq_len, h, w]")
  check_shape(
      ground_truth_segmentations, [None, None, None, None],
      "ground_truth_segmentations [bs, seq_len, h, w]")
  check_shape(
      predicted_segmentations, ground_truth_segmentations.shape,
      "predicted_segmentations [should match ground_truth_segmentations]")
  check_shape(
      padding_mask, ground_truth_segmentations.shape,
      "padding_mask [should match ground_truth_segmentations]")

  if not jnp.issubdtype(predicted_segmentations.dtype, jnp.integer):
    raise ValueError("predicted_segmentations has to be integer-valued. "
                     "Got {}".format(predicted_segmentations.dtype))

  if not jnp.issubdtype(ground_truth_segmentations.dtype, jnp.integer):
    raise ValueError("ground_truth_segmentations has to be integer-valued. "
                     "Got {}".format(ground_truth_segmentations.dtype))

  if not jnp.issubdtype(padding_mask.dtype, jnp.integer):
    raise ValueError("padding_mask has to be integer-valued. "
                     "Got {}".format(padding_mask.dtype))

  if mask is not None:
    check_shape(mask, [None], "mask [bs]")
    if not jnp.issubdtype(mask.dtype, jnp.bool_):
      raise ValueError("mask has to be boolean. Got {}".format(mask.dtype))


def adjusted_rand_index(true_ids, pred_ids,
                        num_instances_true, num_instances_pred,
                        padding_mask = None,
                        ignore_background = False):
  """Computes the adjusted Rand index (ARI), a clustering similarity score.

  Args:
    true_ids: An integer-valued array of shape
      [batch_size, seq_len, H, W]. The true cluster assignment encoded
      as integer ids.
    pred_ids: An integer-valued array of shape
      [batch_size, seq_len, H, W]. The predicted cluster assignment
      encoded as integer ids.
    num_instances_true: An integer, the number of instances in true_ids
      (i.e. max(true_ids) + 1).
    num_instances_pred: An integer, the number of instances in true_ids
      (i.e. max(pred_ids) + 1).
    padding_mask: An array of integers of shape [batch_size, seq_len, H, W]
        defining regions where the ground truth is meaningless, for example
        because this corresponds to regions which were padded during data
        augmentation. Value 0 corresponds to padded regions, 1 corresponds to
        valid regions to be used for metric calculation.
    ignore_background: Boolean, if True, then ignore all pixels where
      true_ids == 0 (default: False).

  Returns:
    ARI scores as a float32 array of shape [batch_size].

  References:
    Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions"
      https://link.springer.com/article/10.1007/BF01908075
    Wikipedia
      https://en.wikipedia.org/wiki/Rand_index
    Scikit Learn
      http://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
  """
  # pylint: disable=invalid-name
  true_oh = jax.nn.one_hot(true_ids, num_instances_true)
  pred_oh = jax.nn.one_hot(pred_ids, num_instances_pred)
  if padding_mask is not None:
    true_oh = true_oh * padding_mask[Ellipsis, None]
    # pred_oh = pred_oh * padding_mask[..., None]  # <--  not needed

  if ignore_background:
    true_oh = true_oh[Ellipsis, 1:]  # Remove the background row.

  N = jnp.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
  A = jnp.sum(N, axis=-1)  # row-sum  (batch_size, c)
  B = jnp.sum(N, axis=-2)  # col-sum  (batch_size, k)
  num_points = jnp.sum(A, axis=1)

  rindex = jnp.sum(N * (N - 1), axis=[1, 2])
  aindex = jnp.sum(A * (A - 1), axis=1)
  bindex = jnp.sum(B * (B - 1), axis=1)
  expected_rindex = aindex * bindex / jnp.clip(num_points * (num_points-1), 1)
  max_rindex = (aindex + bindex) / 2
  denominator = max_rindex - expected_rindex
  ari = (rindex - expected_rindex) / denominator

  # There are two cases for which the denominator can be zero:
  # 1. If both label_pred and label_true assign all pixels to a single cluster.
  #    (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
  # 2. If both label_pred and label_true assign max 1 point to each cluster.
  #    (max_rindex == expected_rindex == rindex == 0)
  # In both cases, we want the ARI score to be 1.0:
  return jnp.where(denominator, ari, 1.0)


@flax.struct.dataclass
class Ari(metrics.Average):
  """Adjusted Rand Index (ARI) computed from predictions and labels.

  ARI is a similarity score to compare two clusterings. ARI returns values in
  the range [-1, 1], where 1 corresponds to two identical clusterings (up to
  permutation), i.e. a perfect match between the predicted clustering and the
  ground-truth clustering. A value of (close to) 0 corresponds to chance.
  Negative values corresponds to cases where the agreement between the
  clusterings is less than expected from a random assignment.

  In this implementation, we use ARI to compare predicted instance segmentation
  masks (including background prediction) with ground-truth segmentation
  annotations.
  """

  @classmethod
  def from_model_output(cls,
                        predicted_segmentations,
                        ground_truth_segmentations,
                        padding_mask,
                        ground_truth_max_num_instances,
                        predicted_max_num_instances,
                        ignore_background = False,
                        mask = None,
                        **_):
    """Computation of the ARI clustering metric.

    NOTE: This implementation does not currently support padding masks.

    Args:
      predicted_segmentations: An array of integers of shape
        [bs, seq_len, H, W] containing model segmentation predictions.
      ground_truth_segmentations: An array of integers of shape
        [bs, seq_len, H, W] containing ground truth segmentations.
      padding_mask: An array of integers of shape [bs, seq_len, H, W]
        defining regions where the ground truth is meaningless, for example
        because this corresponds to regions which were padded during data
        augmentation. Value 0 corresponds to padded regions, 1 corresponds to
        valid regions to be used for metric calculation.
      ground_truth_max_num_instances: Maximum number of instances (incl.
        background, which counts as the 0-th instance) possible in the dataset.
      predicted_max_num_instances: Maximum number of predicted instances (incl.
        background).
      ignore_background: If True, then ignore all pixels where
        ground_truth_segmentations == 0 (default: False).
      mask: An optional array of boolean mask values of shape [bs]. `True`
        corresponds to actual batch examples whereas `False` corresponds to
        padding.

    Returns:
      Object of Ari with computed intermediate values.
    """
    _validate_inputs(
        predicted_segmentations=predicted_segmentations,
        ground_truth_segmentations=ground_truth_segmentations,
        padding_mask=padding_mask,
        mask=mask)

    batch_size = predicted_segmentations.shape[0]

    if mask is None:
      mask = jnp.ones(batch_size, dtype=padding_mask.dtype)
    else:
      mask = jnp.asarray(mask, dtype=padding_mask.dtype)

    ari_batch = adjusted_rand_index(
        pred_ids=predicted_segmentations,
        true_ids=ground_truth_segmentations,
        num_instances_true=ground_truth_max_num_instances,
        num_instances_pred=predicted_max_num_instances,
        padding_mask=padding_mask,
        ignore_background=ignore_background)
    return cls(total=jnp.sum(ari_batch * mask), count=jnp.sum(mask))  # pylint: disable=unexpected-keyword-arg


@flax.struct.dataclass
class AriNoBg(Ari):
  """Adjusted Rand Index (ARI), ignoring the ground-truth background label."""

  @classmethod
  def from_model_output(cls, **kwargs):
    """See `Ari` docstring for allowed keyword arguments."""
    return super().from_model_output(**kwargs, ignore_background=True)