Nikita
added tirex as model
14d91dc
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
from abc import ABC, abstractmethod
from typing import Literal
import torch
from .standard_adapter import ContextType, get_batches
try:
from .gluon import format_gluonts_output, get_gluon_batches
_GLUONTS_AVAILABLE = True
except ImportError:
_GLUONTS_AVAILABLE = False
try:
from .hf_data import get_hfdata_batches
_HF_DATASETS_AVAILABLE = True
except ImportError:
_HF_DATASETS_AVAILABLE = False
DEF_TARGET_COLUMN = "target"
DEF_META_COLUMNS = ("start", "item_id")
def _format_output(
quantiles: torch.Tensor,
means: torch.Tensor,
sample_meta: list[dict],
quantile_levels: list[float],
output_type: Literal["torch", "numpy", "gluonts"],
):
if output_type == "torch":
return quantiles.cpu(), means.cpu()
elif output_type == "numpy":
return quantiles.cpu().numpy(), means.cpu().numpy()
elif output_type == "gluonts":
if not _GLUONTS_AVAILABLE:
raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!")
return format_gluonts_output(quantiles, means, sample_meta, quantile_levels)
else:
raise ValueError(f"Invalid output type: {output_type}")
def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
for batch_ctx, batch_meta in batches:
quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
yield _format_output(
quantiles=quantiles,
means=mean,
sample_meta=batch_meta,
quantile_levels=quantile_levels,
output_type=output_type,
)
def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs):
if yield_per_batch:
return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
prediction_q = []
prediction_m = []
sample_meta = []
for batch_ctx, batch_meta in batches:
quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
prediction_q.append(quantiles)
prediction_m.append(mean)
sample_meta.extend(batch_meta)
prediction_q = torch.cat(prediction_q, dim=0)
prediction_m = torch.cat(prediction_m, dim=0)
return _format_output(
quantiles=prediction_q,
means=prediction_m,
sample_meta=sample_meta,
quantile_levels=quantile_levels,
output_type=output_type,
)
def _common_forecast_doc():
common_doc = f"""
This method takes historical context data as input and outputs probabilistic forecasts.
Args:
output_type (Literal["torch", "numpy", "gluonts"], optional):
Specifies the desired format of the returned forecasts:
- "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, |quantile_levels|]
- "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, |quantile_levels|]
- "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
Defaults to "torch".
batch_size (int, optional): The number of time series instances to process concurrently by the model.
Defaults to 512. Must be $>= 1$.
quantile_levels (List[float], optional): Quantile levels for which predictions should be generated.
Defaults to (0.1, 0.2, ..., 0.9).
yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
forecasts batch by batch as they are computed.
Defaults to `False`.
**predict_kwargs: Additional keyword arguments that are passed directly to the underlying
prediction mechanism of the pre-trained model. Refer to the model's
internal prediction method documentation for available options.
Returns:
The return type depends on `output_type` and `yield_per_batch`:
- If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
will correspond to a batch of forecasts in the format specified by `output_type`.
- If `yield_per_batch` is `False`: A single object containing all forecasts.
- If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
- If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
- If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
"""
return common_doc
class ForecastModel(ABC):
@abstractmethod
def _forecast_quantiles(self, batch, **predict_kwargs):
pass
def forecast(
self,
context: ContextType,
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
batch_size: int = 512,
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
yield_per_batch: bool = False,
**predict_kwargs,
):
f"""
{_common_forecast_doc}
Args:
context (ContextType): The historical "context" data of the time series:
- `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor
- `np.ndarray`: 1D `[context_length]` or 2D `[batch_dim, context_length]` array
- `List[torch.Tensor]`: List of 1D tensors (samples with different lengths get padded per batch)
- `List[np.ndarray]`: List of 1D arrays (samples with different lengths get padded per batch)
"""
assert batch_size >= 1, "Batch size must be >= 1"
batches = get_batches(context, batch_size)
return _gen_forecast(
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
)
def forecast_gluon(
self,
gluonDataset,
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
batch_size: int = 512,
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
yield_per_batch: bool = False,
data_kwargs: dict = {},
**predict_kwargs,
):
f"""
{_common_forecast_doc()}
Args:
gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the
historical time series data.
data_kwargs (dict, optional): Additional keyword arguments passed to the
autogluon data processing function.
"""
assert batch_size >= 1, "Batch size must be >= 1"
if not _GLUONTS_AVAILABLE:
raise ValueError("forecast_gluon glutonts needs GluonTs but GluonTS is not available (not installed)!")
batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
return _gen_forecast(
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
)
def forecast_hfdata(
self,
hf_dataset,
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
batch_size: int = 512,
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
yield_per_batch: bool = False,
data_kwargs: dict = {},
**predict_kwargs,
):
f"""
{_common_forecast_doc()}
Args:
hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the
historical time series data.
data_kwargs (dict, optional): Additional keyword arguments passed to the
datasets data processing function.
"""
assert batch_size >= 1, "Batch size must be >= 1"
if not _HF_DATASETS_AVAILABLE:
raise ValueError(
"forecast_hfdata glutonts needs HuggingFace datasets but datasets is not available (not installed)!"
)
batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
return _gen_forecast(
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
)