File size: 2,749 Bytes
14d91dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.


import logging
from abc import abstractmethod

import torch

from ..api_adapter.forecast import ForecastModel

LOGGER = logging.getLogger()


class TensorQuantileUniPredictMixin(ForecastModel):
    @abstractmethod
    def _forecast_tensor(
        self,
        context: torch.Tensor,
        prediction_length: int | None = None,
        **predict_kwargs,
    ) -> torch.Tensor:
        pass

    @property
    @abstractmethod
    def quantiles(self):
        pass

    def _forecast_quantiles(
        self,
        context: torch.Tensor,
        prediction_length: int | None = None,
        quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
        output_device: str = "cpu",
        auto_cast: bool = False,
        **predict_kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        with torch.autocast(device_type=self.device.type, enabled=auto_cast):
            predictions = self._forecast_tensor(
                context=context, prediction_length=prediction_length, **predict_kwargs
            ).detach()
        predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)

        training_quantile_levels = list(self.quantiles)

        if set(quantile_levels).issubset(set(training_quantile_levels)):
            quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
        else:
            if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
                training_quantile_levels
            ):
                logging.warning(
                    f"Requested quantile levels ({quantile_levels}) fall outside the range of "
                    f"quantiles the model was trained on ({training_quantile_levels}). "
                    "Predictions for out-of-range quantiles will be clamped to the nearest "
                    "boundary of the trained quantiles (i.e., minimum or maximum trained level). "
                    "This can significantly impact prediction accuracy, especially for extreme quantiles. "
                )
            # Interpolate quantiles
            augmented_predictions = torch.cat(
                [predictions[..., [0]], predictions, predictions[..., [-1]]],
                dim=-1,
            )
            quantiles = torch.quantile(
                augmented_predictions,
                q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
                dim=-1,
            ).permute(1, 2, 0)
        # median as mean
        mean = predictions[:, :, training_quantile_levels.index(0.5)]
        return quantiles, mean