From 0193ed756edd681e8e09139f2e2df99b581bbaab Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Tue, 28 May 2024 13:16:49 +0200 Subject: [PATCH] PatchTST: Add support for time features (#3167) --- .../torch/model/patch_tst/estimator.py | 72 +++++++++++++++---- src/gluonts/torch/model/patch_tst/module.py | 69 ++++++++++++++++-- test/torch/model/test_estimators.py | 19 +++++ 3 files changed, 140 insertions(+), 20 deletions(-) diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index 94ae5b5444..4eea0be9d7 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -30,6 +30,7 @@ TestSplitSampler, ExpectedNumInstanceSampler, SelectFields, + RenameFields, ) from gluonts.torch.model.estimator import PyTorchLightningEstimator from gluonts.torch.model.predictor import PyTorchPredictor @@ -74,6 +75,8 @@ class PatchTSTEstimator(PyTorchLightningEstimator): Number of attention heads in the Transformer encoder which must divide d_model. dim_feedforward Size of hidden layers in the Transformer encoder. + num_feat_dynamic_real + Number of dynamic real features in the data (default: 0). dropout Dropout probability in the Transformer encoder. activation @@ -115,6 +118,7 @@ def __init__( d_model: int = 32, nhead: int = 4, dim_feedforward: int = 128, + num_feat_dynamic_real: int = 0, dropout: float = 0.1, activation: str = "relu", norm_first: bool = False, @@ -151,6 +155,7 @@ def __init__( self.d_model = d_model self.nhead = nhead self.dim_feedforward = dim_feedforward + self.num_feat_dynamic_real = num_feat_dynamic_real self.dropout = dropout self.activation = activation self.norm_first = norm_first @@ -166,17 +171,26 @@ def __init__( ) def create_transformation(self) -> Transformation: - return SelectFields( - [ - FieldName.ITEM_ID, - FieldName.INFO, - FieldName.START, - FieldName.TARGET, - ], - allow_missing=True, - ) + AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, + return ( + SelectFields( + [ + FieldName.ITEM_ID, + FieldName.INFO, + FieldName.START, + FieldName.TARGET, + ] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.num_feat_dynamic_real > 0 + else [] + ), + allow_missing=True, + ) + + RenameFields({FieldName.FEAT_DYNAMIC_REAL: FieldName.FEAT_TIME}) + + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ) ) def create_lightning_module(self) -> pl.LightningModule: @@ -192,6 +206,7 @@ def create_lightning_module(self) -> pl.LightningModule: "d_model": self.d_model, "nhead": self.nhead, "dim_feedforward": self.dim_feedforward, + "num_feat_dynamic_real": self.num_feat_dynamic_real, "dropout": self.dropout, "activation": self.activation, "norm_first": self.norm_first, @@ -220,7 +235,10 @@ def _create_instance_splitter( instance_sampler=instance_sampler, past_length=self.context_length, future_length=self.prediction_length, - time_series_fields=[FieldName.OBSERVED_VALUES], + time_series_fields=[FieldName.OBSERVED_VALUES] + + ( + [FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 else [] + ), dummy_value=self.distr_output.value_in_support, ) @@ -239,7 +257,15 @@ def create_training_data_loader( instances, batch_size=self.batch_size, shuffle_buffer_length=shuffle_buffer_length, - field_names=TRAINING_INPUT_NAMES, + field_names=TRAINING_INPUT_NAMES + + ( + [ + f"past_{FieldName.FEAT_TIME}", + f"future_{FieldName.FEAT_TIME}", + ] + if self.num_feat_dynamic_real > 0 + else [] + ), output_type=torch.tensor, num_batches_per_epoch=self.num_batches_per_epoch, ) @@ -253,7 +279,15 @@ def create_validation_data_loader( return as_stacked_batches( instances, batch_size=self.batch_size, - field_names=TRAINING_INPUT_NAMES, + field_names=TRAINING_INPUT_NAMES + + ( + [ + f"past_{FieldName.FEAT_TIME}", + f"future_{FieldName.FEAT_TIME}", + ] + if self.num_feat_dynamic_real > 0 + else [] + ), output_type=torch.tensor, ) @@ -264,7 +298,15 @@ def create_predictor( return PyTorchPredictor( input_transform=transformation + prediction_splitter, - input_names=PREDICTION_INPUT_NAMES, + input_names=PREDICTION_INPUT_NAMES + + ( + [ + f"past_{FieldName.FEAT_TIME}", + f"future_{FieldName.FEAT_TIME}", + ] + if self.num_feat_dynamic_real > 0 + else [] + ), prediction_net=module, forecast_generator=self.distr_output.forecast_generator, batch_size=self.batch_size, diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index 3a59f80299..5f5c3e9fbc 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import Tuple +from typing import Optional, Tuple import numpy as np import torch @@ -21,7 +21,7 @@ from gluonts.model import Input, InputSpec from gluonts.torch.distributions import StudentTOutput from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler -from gluonts.torch.util import unsqueeze_expand, weighted_average +from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average from gluonts.torch.model.simple_feedforward import make_linear_layer @@ -85,6 +85,8 @@ class PatchTSTModel(nn.Module): Number of time points to predict. context_length Number of time steps prior to prediction time that the model. + num_feat_dynamic_real + Number of dynamic real features in the data (default: 0). distr_output Distribution to use to evaluate observations and sample predictions. Default: ``StudentTOutput()``. @@ -101,6 +103,7 @@ def __init__( d_model: int, nhead: int, dim_feedforward: int, + num_feat_dynamic_real: int, dropout: float, activation: str, norm_first: bool, @@ -120,6 +123,7 @@ def __init__( self.d_model = d_model self.padding_patch = padding_patch self.distr_output = distr_output + self.num_feat_dynamic_real = num_feat_dynamic_real if scaling == "mean": self.scaler = MeanScaler(keepdim=True) @@ -133,8 +137,11 @@ def __init__( self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride)) self.patch_num += 1 - # project from patch_len + 2 features (loc and scale) to d_model - self.patch_proj = make_linear_layer(patch_len + 2, d_model) + # project from `patch_len` + 2 features (`loc` and `scale`) + + # `num_feat_dynamic_real` x `patch_len` to d_model + self.patch_proj = make_linear_layer( + patch_len + 2 + self.num_feat_dynamic_real * patch_len, d_model + ) self.positional_encoding = SinusoidalPositionalEmbedding( self.patch_num, d_model @@ -163,6 +170,28 @@ def __init__( self.args_proj = self.distr_output.get_args_proj(d_model) def describe_inputs(self, batch_size=1) -> InputSpec: + if self.num_feat_dynamic_real > 0: + input_spec_feat = { + "past_time_feat": Input( + shape=( + batch_size, + self.context_length, + self.num_feat_dynamic_real, + ), + dtype=torch.float, + ), + "future_time_feat": Input( + shape=( + batch_size, + self.prediction_length, + self.num_feat_dynamic_real, + ), + dtype=torch.float, + ), + } + else: + input_spec_feat = {} + return InputSpec( { "past_target": Input( @@ -171,6 +200,7 @@ def describe_inputs(self, batch_size=1) -> InputSpec: "past_observed_values": Input( shape=(batch_size, self.context_length), dtype=torch.float ), + **input_spec_feat, }, torch.zeros, ) @@ -179,6 +209,8 @@ def forward( self, past_target: torch.Tensor, past_observed_values: torch.Tensor, + past_time_feat: Optional[torch.Tensor] = None, + future_time_feat: Optional[torch.Tensor] = None, ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: # scale the input past_target_scaled, loc, scale = self.scaler( @@ -192,6 +224,25 @@ def forward( dimension=1, size=self.patch_len, step=self.stride ) + # do patching for time features as well + if self.num_feat_dynamic_real > 0: + # shift time features by `prediction_length` so that they are + # aligned with the target input. + time_feat = take_last( + torch.cat((past_time_feat, future_time_feat), dim=1), + dim=1, + num=self.context_length, + ) + + # (bs x T x d) --> (bs x d x T) because the 1D padding is done on + # the last dimension. + time_feat = self.padding_patch_layer( + time_feat.transpose(-2, -1) + ).transpose(-2, -1) + time_feat_patches = time_feat.unfold( + dimension=1, size=self.patch_len, step=self.stride + ).flatten(-2, -1) + # add loc and scale to past_target_patches as additional features log_abs_loc = loc.abs().log1p() log_scale = scale.log() @@ -202,6 +253,9 @@ def forward( ) inputs = torch.cat((past_target_patches, expanded_static_feat), dim=-1) + if self.num_feat_dynamic_real > 0: + inputs = torch.cat((inputs, time_feat_patches), dim=-1) + # project patches enc_in = self.patch_proj(inputs) embed_pos = self.positional_encoding(enc_in.size()) @@ -224,9 +278,14 @@ def loss( past_observed_values: torch.Tensor, future_target: torch.Tensor, future_observed_values: torch.Tensor, + past_time_feat: Optional[torch.Tensor] = None, + future_time_feat: Optional[torch.Tensor] = None, ) -> torch.Tensor: distr_args, loc, scale = self( - past_target=past_target, past_observed_values=past_observed_values + past_target=past_target, + past_observed_values=past_observed_values, + past_time_feat=past_time_feat, + future_time_feat=future_time_feat, ) loss = self.distr_output.loss( target=future_target, distr_args=distr_args, loc=loc, scale=scale diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index 3c2ef3ea51..caa5ae10ec 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -296,6 +296,25 @@ def test_estimator_constant_dataset( num_batches_per_epoch=3, epochs=2, ), + lambda freq, prediction_length: PatchTSTEstimator( + prediction_length=prediction_length, + context_length=2 * prediction_length, + num_feat_dynamic_real=3, + patch_len=16, + batch_size=4, + num_batches_per_epoch=3, + trainer_kwargs=dict(max_epochs=2), + ), + lambda freq, prediction_length: PatchTSTEstimator( + prediction_length=prediction_length, + context_length=2 * prediction_length, + num_feat_dynamic_real=3, + distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]), + patch_len=16, + batch_size=4, + num_batches_per_epoch=3, + trainer_kwargs=dict(max_epochs=2), + ), lambda freq, prediction_length: WaveNetEstimator( freq=freq, prediction_length=prediction_length,