Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) NXAI GmbH. | |
# This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
from abc import ABC, abstractmethod | |
from typing import Literal | |
import torch | |
from .standard_adapter import ContextType, get_batches | |
try: | |
from .gluon import format_gluonts_output, get_gluon_batches | |
_GLUONTS_AVAILABLE = True | |
except ImportError: | |
_GLUONTS_AVAILABLE = False | |
try: | |
from .hf_data import get_hfdata_batches | |
_HF_DATASETS_AVAILABLE = True | |
except ImportError: | |
_HF_DATASETS_AVAILABLE = False | |
DEF_TARGET_COLUMN = "target" | |
DEF_META_COLUMNS = ("start", "item_id") | |
def _format_output( | |
quantiles: torch.Tensor, | |
means: torch.Tensor, | |
sample_meta: list[dict], | |
quantile_levels: list[float], | |
output_type: Literal["torch", "numpy", "gluonts"], | |
): | |
if output_type == "torch": | |
return quantiles.cpu(), means.cpu() | |
elif output_type == "numpy": | |
return quantiles.cpu().numpy(), means.cpu().numpy() | |
elif output_type == "gluonts": | |
if not _GLUONTS_AVAILABLE: | |
raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!") | |
return format_gluonts_output(quantiles, means, sample_meta, quantile_levels) | |
else: | |
raise ValueError(f"Invalid output type: {output_type}") | |
def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs): | |
for batch_ctx, batch_meta in batches: | |
quantiles, mean = fc_func(batch_ctx, **predict_kwargs) | |
yield _format_output( | |
quantiles=quantiles, | |
means=mean, | |
sample_meta=batch_meta, | |
quantile_levels=quantile_levels, | |
output_type=output_type, | |
) | |
def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs): | |
if yield_per_batch: | |
return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs) | |
prediction_q = [] | |
prediction_m = [] | |
sample_meta = [] | |
for batch_ctx, batch_meta in batches: | |
quantiles, mean = fc_func(batch_ctx, **predict_kwargs) | |
prediction_q.append(quantiles) | |
prediction_m.append(mean) | |
sample_meta.extend(batch_meta) | |
prediction_q = torch.cat(prediction_q, dim=0) | |
prediction_m = torch.cat(prediction_m, dim=0) | |
return _format_output( | |
quantiles=prediction_q, | |
means=prediction_m, | |
sample_meta=sample_meta, | |
quantile_levels=quantile_levels, | |
output_type=output_type, | |
) | |
def _common_forecast_doc(): | |
common_doc = f""" | |
This method takes historical context data as input and outputs probabilistic forecasts. | |
Args: | |
output_type (Literal["torch", "numpy", "gluonts"], optional): | |
Specifies the desired format of the returned forecasts: | |
- "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, |quantile_levels|] | |
- "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, |quantile_levels|] | |
- "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects. | |
Defaults to "torch". | |
batch_size (int, optional): The number of time series instances to process concurrently by the model. | |
Defaults to 512. Must be $>= 1$. | |
quantile_levels (List[float], optional): Quantile levels for which predictions should be generated. | |
Defaults to (0.1, 0.2, ..., 0.9). | |
yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding | |
forecasts batch by batch as they are computed. | |
Defaults to `False`. | |
**predict_kwargs: Additional keyword arguments that are passed directly to the underlying | |
prediction mechanism of the pre-trained model. Refer to the model's | |
internal prediction method documentation for available options. | |
Returns: | |
The return type depends on `output_type` and `yield_per_batch`: | |
- If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item | |
will correspond to a batch of forecasts in the format specified by `output_type`. | |
- If `yield_per_batch` is `False`: A single object containing all forecasts. | |
- If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean). | |
- If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean). | |
- If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts. | |
""" | |
return common_doc | |
class ForecastModel(ABC): | |
def _forecast_quantiles(self, batch, **predict_kwargs): | |
pass | |
def forecast( | |
self, | |
context: ContextType, | |
output_type: Literal["torch", "numpy", "gluonts"] = "torch", | |
batch_size: int = 512, | |
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), | |
yield_per_batch: bool = False, | |
**predict_kwargs, | |
): | |
f""" | |
{_common_forecast_doc} | |
Args: | |
context (ContextType): The historical "context" data of the time series: | |
- `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor | |
- `np.ndarray`: 1D `[context_length]` or 2D `[batch_dim, context_length]` array | |
- `List[torch.Tensor]`: List of 1D tensors (samples with different lengths get padded per batch) | |
- `List[np.ndarray]`: List of 1D arrays (samples with different lengths get padded per batch) | |
""" | |
assert batch_size >= 1, "Batch size must be >= 1" | |
batches = get_batches(context, batch_size) | |
return _gen_forecast( | |
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs | |
) | |
def forecast_gluon( | |
self, | |
gluonDataset, | |
output_type: Literal["torch", "numpy", "gluonts"] = "torch", | |
batch_size: int = 512, | |
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), | |
yield_per_batch: bool = False, | |
data_kwargs: dict = {}, | |
**predict_kwargs, | |
): | |
f""" | |
{_common_forecast_doc()} | |
Args: | |
gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the | |
historical time series data. | |
data_kwargs (dict, optional): Additional keyword arguments passed to the | |
autogluon data processing function. | |
""" | |
assert batch_size >= 1, "Batch size must be >= 1" | |
if not _GLUONTS_AVAILABLE: | |
raise ValueError("forecast_gluon glutonts needs GluonTs but GluonTS is not available (not installed)!") | |
batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs) | |
return _gen_forecast( | |
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs | |
) | |
def forecast_hfdata( | |
self, | |
hf_dataset, | |
output_type: Literal["torch", "numpy", "gluonts"] = "torch", | |
batch_size: int = 512, | |
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), | |
yield_per_batch: bool = False, | |
data_kwargs: dict = {}, | |
**predict_kwargs, | |
): | |
f""" | |
{_common_forecast_doc()} | |
Args: | |
hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the | |
historical time series data. | |
data_kwargs (dict, optional): Additional keyword arguments passed to the | |
datasets data processing function. | |
""" | |
assert batch_size >= 1, "Batch size must be >= 1" | |
if not _HF_DATASETS_AVAILABLE: | |
raise ValueError( | |
"forecast_hfdata glutonts needs HuggingFace datasets but datasets is not available (not installed)!" | |
) | |
batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs) | |
return _gen_forecast( | |
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs | |
) | |