Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) NXAI GmbH. | |
# This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
import pandas as pd | |
import torch | |
from gluonts.dataset.common import Dataset | |
from gluonts.dataset.field_names import FieldName | |
from gluonts.model.forecast import QuantileForecast | |
from .standard_adapter import _batch_pad_iterable | |
DEF_TARGET_COLUMN = FieldName.TARGET # target | |
DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID) | |
def _get_gluon_ts_map(**gluon_kwargs): | |
target_col = gluon_kwargs.get("target_column", DEF_TARGET_COLUMN) | |
meta_columns = gluon_kwargs.get("meta_columns", DEF_META_COLUMNS) | |
def extract_gluon(series): | |
ctx = torch.Tensor(series[target_col]) | |
meta = {k: series[k] for k in meta_columns if k in series} | |
meta["length"] = len(ctx) | |
return ctx, meta | |
return extract_gluon | |
def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs): | |
return _batch_pad_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size) | |
def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels): | |
forecasts = [] | |
for i in range(quantile_forecasts.shape[0]): | |
start_date = meta[i].get(FieldName.START, pd.Period("01-01-2000", freq=meta[i].get("freq", "h"))) | |
start_date += meta[i].get("length", 0) | |
forecasts.append( | |
QuantileForecast( | |
forecast_arrays=torch.cat((quantile_forecasts[i], mean_forecasts[i].unsqueeze(1)), dim=1) | |
.T.cpu() | |
.numpy(), | |
start_date=start_date, | |
item_id=meta[i].get(FieldName.ITEM_ID, None), | |
forecast_keys=list(map(str, quantile_levels)) + ["mean"], | |
) | |
) | |
return forecasts | |