Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) NXAI GmbH. | |
# This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
import logging | |
from abc import abstractmethod | |
import torch | |
from ..api_adapter.forecast import ForecastModel | |
LOGGER = logging.getLogger() | |
class TensorQuantileUniPredictMixin(ForecastModel): | |
def _forecast_tensor( | |
self, | |
context: torch.Tensor, | |
prediction_length: int | None = None, | |
**predict_kwargs, | |
) -> torch.Tensor: | |
pass | |
def quantiles(self): | |
pass | |
def _forecast_quantiles( | |
self, | |
context: torch.Tensor, | |
prediction_length: int | None = None, | |
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], | |
output_device: str = "cpu", | |
auto_cast: bool = False, | |
**predict_kwargs, | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
with torch.autocast(device_type=self.device.type, enabled=auto_cast): | |
predictions = self._forecast_tensor( | |
context=context, prediction_length=prediction_length, **predict_kwargs | |
).detach() | |
predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2) | |
training_quantile_levels = list(self.quantiles) | |
if set(quantile_levels).issubset(set(training_quantile_levels)): | |
quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]] | |
else: | |
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max( | |
training_quantile_levels | |
): | |
logging.warning( | |
f"Requested quantile levels ({quantile_levels}) fall outside the range of " | |
f"quantiles the model was trained on ({training_quantile_levels}). " | |
"Predictions for out-of-range quantiles will be clamped to the nearest " | |
"boundary of the trained quantiles (i.e., minimum or maximum trained level). " | |
"This can significantly impact prediction accuracy, especially for extreme quantiles. " | |
) | |
# Interpolate quantiles | |
augmented_predictions = torch.cat( | |
[predictions[..., [0]], predictions, predictions[..., [-1]]], | |
dim=-1, | |
) | |
quantiles = torch.quantile( | |
augmented_predictions, | |
q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype), | |
dim=-1, | |
).permute(1, 2, 0) | |
# median as mean | |
mean = predictions[:, :, training_quantile_levels.index(0.5)] | |
return quantiles, mean | |