File size: 2,667 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py
# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll))
# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py
# But we pass in the loss to avoid recomputation

from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torchmetrics import Metric

try:
    from flash_attn.losses.cross_entropy import CrossEntropyLoss
except ImportError:
    CrossEntropyLoss = torch.nn.CrossEntropyLoss

__all__ = ['Perplexity']


class Perplexity(Metric):
    r"""

    Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits

    per word a model needs to represent the sample.

    Args:

        kwargs:

            Additional keyword arguments, see :ref:`Metric kwargs` for more info.

    Examples:

        >>> import torch

        >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))

        >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))

        >>> target[0, 6:] = -100

        >>> metric = Perplexity(ignore_index=-100)

        >>> metric(preds, target)

        tensor(5.2545)

    """
    is_differentiable = True
    higher_is_better = False
    full_state_update = False
    total_log_probs: Tensor
    count: Tensor

    def __init__(self, **kwargs: Dict[str, Any]):
        super().__init__(**kwargs)
        self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64),
                       dist_reduce_fx="sum")
        self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum")

        self.loss_fn = CrossEntropyLoss()

    def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None:  # type: ignore
        """Compute and store intermediate statistics for Perplexity.

        Args:

            preds:

                Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].

            target:

                Ground truth values with a shape [batch_size, seq_len].

        """
        count = target.numel()
        if loss is None:
            loss = self.loss_fn(preds, target)
        self.total_log_probs += loss.double() * count
        self.count += count

    def compute(self) -> Tensor:
        """Compute the Perplexity.

        Returns:

           Perplexity

        """
        return torch.exp(self.total_log_probs / self.count)