File size: 8,625 Bytes
14d91dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.

from abc import ABC, abstractmethod
from typing import Literal

import torch

from .standard_adapter import ContextType, get_batches

try:
    from .gluon import format_gluonts_output, get_gluon_batches

    _GLUONTS_AVAILABLE = True
except ImportError:
    _GLUONTS_AVAILABLE = False

try:
    from .hf_data import get_hfdata_batches

    _HF_DATASETS_AVAILABLE = True
except ImportError:
    _HF_DATASETS_AVAILABLE = False


DEF_TARGET_COLUMN = "target"
DEF_META_COLUMNS = ("start", "item_id")


def _format_output(
    quantiles: torch.Tensor,
    means: torch.Tensor,
    sample_meta: list[dict],
    quantile_levels: list[float],
    output_type: Literal["torch", "numpy", "gluonts"],
):
    if output_type == "torch":
        return quantiles.cpu(), means.cpu()
    elif output_type == "numpy":
        return quantiles.cpu().numpy(), means.cpu().numpy()
    elif output_type == "gluonts":
        if not _GLUONTS_AVAILABLE:
            raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!")
        return format_gluonts_output(quantiles, means, sample_meta, quantile_levels)
    else:
        raise ValueError(f"Invalid output type: {output_type}")


def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
    for batch_ctx, batch_meta in batches:
        quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
        yield _format_output(
            quantiles=quantiles,
            means=mean,
            sample_meta=batch_meta,
            quantile_levels=quantile_levels,
            output_type=output_type,
        )


def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs):
    if yield_per_batch:
        return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)

    prediction_q = []
    prediction_m = []
    sample_meta = []
    for batch_ctx, batch_meta in batches:
        quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
        prediction_q.append(quantiles)
        prediction_m.append(mean)
        sample_meta.extend(batch_meta)

    prediction_q = torch.cat(prediction_q, dim=0)
    prediction_m = torch.cat(prediction_m, dim=0)

    return _format_output(
        quantiles=prediction_q,
        means=prediction_m,
        sample_meta=sample_meta,
        quantile_levels=quantile_levels,
        output_type=output_type,
    )


def _common_forecast_doc():
    common_doc = f"""
        This method takes historical context data as input and outputs probabilistic forecasts.

        Args:
            output_type (Literal["torch", "numpy", "gluonts"], optional):
                Specifies the desired format of the returned forecasts:
                - "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, |quantile_levels|]
                - "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, |quantile_levels|]
                - "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
                Defaults to "torch".

            batch_size (int, optional): The number of time series instances to process concurrently by the model.
                                        Defaults to 512. Must be $>= 1$.

            quantile_levels (List[float], optional): Quantile levels for which predictions should be generated.
                                                     Defaults to (0.1, 0.2, ..., 0.9).

            yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
                                              forecasts batch by batch as they are computed.
                                              Defaults to `False`.

            **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
                              prediction mechanism of the pre-trained model. Refer to the model's
                              internal prediction method documentation for available options.

        Returns:
            The return type depends on `output_type` and `yield_per_batch`:
                - If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
                  will correspond to a batch of forecasts in the format specified by `output_type`.
                - If `yield_per_batch` is `False`: A single object containing all forecasts.
                  - If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
                  - If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
                  - If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
        """
    return common_doc


class ForecastModel(ABC):
    @abstractmethod
    def _forecast_quantiles(self, batch, **predict_kwargs):
        pass

    def forecast(
        self,
        context: ContextType,
        output_type: Literal["torch", "numpy", "gluonts"] = "torch",
        batch_size: int = 512,
        quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
        yield_per_batch: bool = False,
        **predict_kwargs,
    ):
        f"""
        {_common_forecast_doc}
        Args:
            context (ContextType): The historical "context" data of the time series:
                - `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor
                - `np.ndarray`: 1D `[context_length]` or 2D `[batch_dim, context_length]` array
                - `List[torch.Tensor]`: List of 1D tensors (samples with different lengths get padded per batch)
                - `List[np.ndarray]`: List of 1D arrays (samples with different lengths get padded per batch)
        """
        assert batch_size >= 1, "Batch size must be >= 1"
        batches = get_batches(context, batch_size)
        return _gen_forecast(
            self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
        )

    def forecast_gluon(
        self,
        gluonDataset,
        output_type: Literal["torch", "numpy", "gluonts"] = "torch",
        batch_size: int = 512,
        quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
        yield_per_batch: bool = False,
        data_kwargs: dict = {},
        **predict_kwargs,
    ):
        f"""
        {_common_forecast_doc()}

        Args:
            gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the
                                                            historical time series data.

            data_kwargs (dict, optional): Additional keyword arguments passed to the
                                          autogluon data processing function.
        """
        assert batch_size >= 1, "Batch size must be >= 1"
        if not _GLUONTS_AVAILABLE:
            raise ValueError("forecast_gluon glutonts needs GluonTs but GluonTS is not available (not installed)!")
        batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
        return _gen_forecast(
            self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
        )

    def forecast_hfdata(
        self,
        hf_dataset,
        output_type: Literal["torch", "numpy", "gluonts"] = "torch",
        batch_size: int = 512,
        quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
        yield_per_batch: bool = False,
        data_kwargs: dict = {},
        **predict_kwargs,
    ):
        f"""
        {_common_forecast_doc()}

        Args:
            hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the
                                           historical time series data.

            data_kwargs (dict, optional): Additional keyword arguments passed to the
                                          datasets data processing function.
        """
        assert batch_size >= 1, "Batch size must be >= 1"
        if not _HF_DATASETS_AVAILABLE:
            raise ValueError(
                "forecast_hfdata glutonts needs HuggingFace datasets but datasets is not available (not installed)!"
            )
        batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
        return _gen_forecast(
            self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
        )