Spaces:
Build error
Build error
File size: 10,986 Bytes
a560c26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Clustering metrics."""
from typing import Optional, Sequence, Union
from clu import metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np
Ndarray = Union[np.ndarray, jnp.ndarray]
def check_shape(x, expected_shape, name):
"""Check whether shape x is as expected.
Args:
x: Any data type with `shape` attribute. If `shape` attribute is not present
it is assumed to be a scalar with shape ().
expected_shape: The shape that is expected of x. For example,
[None, None, 3] can be the `expected_shape` for a color image,
[4, None, None, 3] if we know that batch size is 4.
name: Name of `x` to provide informative error messages.
Raises: ValueError if x's shape does not match expected_shape. Also raises
ValueError if expected_shape is not a list or tuple.
"""
if not isinstance(expected_shape, (list, tuple)):
raise ValueError(
"expected_shape should be a list or tuple of ints but got "
f"{expected_shape}.")
# Scalars have shape () by definition.
shape = getattr(x, "shape", ())
if (len(shape) != len(expected_shape) or
any(j is not None and i != j for i, j in zip(shape, expected_shape))):
raise ValueError(
f"Input {name} had shape {shape} but {expected_shape} was expected.")
def _validate_inputs(predicted_segmentations,
ground_truth_segmentations,
padding_mask,
mask = None):
"""Checks that all inputs have the expected shapes.
Args:
predicted_segmentations: An array of integers of shape [bs, seq_len, H, W]
containing model segmentation predictions.
ground_truth_segmentations: An array of integers of shape [bs, seq_len, H,
W] containing ground truth segmentations.
padding_mask: An array of integers of shape [bs, seq_len, H, W] defining
regions where the ground truth is meaningless, for example because this
corresponds to regions which were padded during data augmentation. Value 0
corresponds to padded regions, 1 corresponds to valid regions to be used
for metric calculation.
mask: An optional array of boolean mask values of shape [bs]. `True`
corresponds to actual batch examples whereas `False` corresponds to
padding.
Raises:
ValueError if the inputs are not valid.
"""
check_shape(
predicted_segmentations, [None, None, None, None],
"predicted_segmentations [bs, seq_len, h, w]")
check_shape(
ground_truth_segmentations, [None, None, None, None],
"ground_truth_segmentations [bs, seq_len, h, w]")
check_shape(
predicted_segmentations, ground_truth_segmentations.shape,
"predicted_segmentations [should match ground_truth_segmentations]")
check_shape(
padding_mask, ground_truth_segmentations.shape,
"padding_mask [should match ground_truth_segmentations]")
if not jnp.issubdtype(predicted_segmentations.dtype, jnp.integer):
raise ValueError("predicted_segmentations has to be integer-valued. "
"Got {}".format(predicted_segmentations.dtype))
if not jnp.issubdtype(ground_truth_segmentations.dtype, jnp.integer):
raise ValueError("ground_truth_segmentations has to be integer-valued. "
"Got {}".format(ground_truth_segmentations.dtype))
if not jnp.issubdtype(padding_mask.dtype, jnp.integer):
raise ValueError("padding_mask has to be integer-valued. "
"Got {}".format(padding_mask.dtype))
if mask is not None:
check_shape(mask, [None], "mask [bs]")
if not jnp.issubdtype(mask.dtype, jnp.bool_):
raise ValueError("mask has to be boolean. Got {}".format(mask.dtype))
def adjusted_rand_index(true_ids, pred_ids,
num_instances_true, num_instances_pred,
padding_mask = None,
ignore_background = False):
"""Computes the adjusted Rand index (ARI), a clustering similarity score.
Args:
true_ids: An integer-valued array of shape
[batch_size, seq_len, H, W]. The true cluster assignment encoded
as integer ids.
pred_ids: An integer-valued array of shape
[batch_size, seq_len, H, W]. The predicted cluster assignment
encoded as integer ids.
num_instances_true: An integer, the number of instances in true_ids
(i.e. max(true_ids) + 1).
num_instances_pred: An integer, the number of instances in true_ids
(i.e. max(pred_ids) + 1).
padding_mask: An array of integers of shape [batch_size, seq_len, H, W]
defining regions where the ground truth is meaningless, for example
because this corresponds to regions which were padded during data
augmentation. Value 0 corresponds to padded regions, 1 corresponds to
valid regions to be used for metric calculation.
ignore_background: Boolean, if True, then ignore all pixels where
true_ids == 0 (default: False).
Returns:
ARI scores as a float32 array of shape [batch_size].
References:
Lawrence Hubert, Phipps Arabie. 1985. "Comparing partitions"
https://link.springer.com/article/10.1007/BF01908075
Wikipedia
https://en.wikipedia.org/wiki/Rand_index
Scikit Learn
http://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html
"""
# pylint: disable=invalid-name
true_oh = jax.nn.one_hot(true_ids, num_instances_true)
pred_oh = jax.nn.one_hot(pred_ids, num_instances_pred)
if padding_mask is not None:
true_oh = true_oh * padding_mask[Ellipsis, None]
# pred_oh = pred_oh * padding_mask[..., None] # <-- not needed
if ignore_background:
true_oh = true_oh[Ellipsis, 1:] # Remove the background row.
N = jnp.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
A = jnp.sum(N, axis=-1) # row-sum (batch_size, c)
B = jnp.sum(N, axis=-2) # col-sum (batch_size, k)
num_points = jnp.sum(A, axis=1)
rindex = jnp.sum(N * (N - 1), axis=[1, 2])
aindex = jnp.sum(A * (A - 1), axis=1)
bindex = jnp.sum(B * (B - 1), axis=1)
expected_rindex = aindex * bindex / jnp.clip(num_points * (num_points-1), 1)
max_rindex = (aindex + bindex) / 2
denominator = max_rindex - expected_rindex
ari = (rindex - expected_rindex) / denominator
# There are two cases for which the denominator can be zero:
# 1. If both label_pred and label_true assign all pixels to a single cluster.
# (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
# 2. If both label_pred and label_true assign max 1 point to each cluster.
# (max_rindex == expected_rindex == rindex == 0)
# In both cases, we want the ARI score to be 1.0:
return jnp.where(denominator, ari, 1.0)
@flax.struct.dataclass
class Ari(metrics.Average):
"""Adjusted Rand Index (ARI) computed from predictions and labels.
ARI is a similarity score to compare two clusterings. ARI returns values in
the range [-1, 1], where 1 corresponds to two identical clusterings (up to
permutation), i.e. a perfect match between the predicted clustering and the
ground-truth clustering. A value of (close to) 0 corresponds to chance.
Negative values corresponds to cases where the agreement between the
clusterings is less than expected from a random assignment.
In this implementation, we use ARI to compare predicted instance segmentation
masks (including background prediction) with ground-truth segmentation
annotations.
"""
@classmethod
def from_model_output(cls,
predicted_segmentations,
ground_truth_segmentations,
padding_mask,
ground_truth_max_num_instances,
predicted_max_num_instances,
ignore_background = False,
mask = None,
**_):
"""Computation of the ARI clustering metric.
NOTE: This implementation does not currently support padding masks.
Args:
predicted_segmentations: An array of integers of shape
[bs, seq_len, H, W] containing model segmentation predictions.
ground_truth_segmentations: An array of integers of shape
[bs, seq_len, H, W] containing ground truth segmentations.
padding_mask: An array of integers of shape [bs, seq_len, H, W]
defining regions where the ground truth is meaningless, for example
because this corresponds to regions which were padded during data
augmentation. Value 0 corresponds to padded regions, 1 corresponds to
valid regions to be used for metric calculation.
ground_truth_max_num_instances: Maximum number of instances (incl.
background, which counts as the 0-th instance) possible in the dataset.
predicted_max_num_instances: Maximum number of predicted instances (incl.
background).
ignore_background: If True, then ignore all pixels where
ground_truth_segmentations == 0 (default: False).
mask: An optional array of boolean mask values of shape [bs]. `True`
corresponds to actual batch examples whereas `False` corresponds to
padding.
Returns:
Object of Ari with computed intermediate values.
"""
_validate_inputs(
predicted_segmentations=predicted_segmentations,
ground_truth_segmentations=ground_truth_segmentations,
padding_mask=padding_mask,
mask=mask)
batch_size = predicted_segmentations.shape[0]
if mask is None:
mask = jnp.ones(batch_size, dtype=padding_mask.dtype)
else:
mask = jnp.asarray(mask, dtype=padding_mask.dtype)
ari_batch = adjusted_rand_index(
pred_ids=predicted_segmentations,
true_ids=ground_truth_segmentations,
num_instances_true=ground_truth_max_num_instances,
num_instances_pred=predicted_max_num_instances,
padding_mask=padding_mask,
ignore_background=ignore_background)
return cls(total=jnp.sum(ari_batch * mask), count=jnp.sum(mask)) # pylint: disable=unexpected-keyword-arg
@flax.struct.dataclass
class AriNoBg(Ari):
"""Adjusted Rand Index (ARI), ignoring the ground-truth background label."""
@classmethod
def from_model_output(cls, **kwargs):
"""See `Ari` docstring for allowed keyword arguments."""
return super().from_model_output(**kwargs, ignore_background=True)
|