From c5f4b76de482bd09915a2f9209f56d876d3b3acd Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Tue, 30 Apr 2024 18:29:51 +0200 Subject: [PATCH 1/6] PatchTST: Add support for time features --- .../torch/model/patch_tst/estimator.py | 28 ++++++++++++++---- src/gluonts/torch/model/patch_tst/module.py | 29 +++++++++++++++++-- test/torch/model/test_estimators.py | 19 ++++++++++++ 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index 94ae5b5444..1f8280e083 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -115,6 +115,7 @@ def __init__( d_model: int = 32, nhead: int = 4, dim_feedforward: int = 128, + use_feat_dynamic_real: bool = False, dropout: float = 0.1, activation: str = "relu", norm_first: bool = False, @@ -151,6 +152,7 @@ def __init__( self.d_model = d_model self.nhead = nhead self.dim_feedforward = dim_feedforward + self.use_feat_dynamic_real = use_feat_dynamic_real self.dropout = dropout self.activation = activation self.norm_first = norm_first @@ -172,7 +174,10 @@ def create_transformation(self) -> Transformation: FieldName.INFO, FieldName.START, FieldName.TARGET, - ], + ] + ( + [FieldName.FEAT_DYNAMIC_REAL] if self.use_feat_dynamic_real + else [] + ), allow_missing=True, ) + AddObservedValuesIndicator( target_field=FieldName.TARGET, @@ -192,6 +197,7 @@ def create_lightning_module(self) -> pl.LightningModule: "d_model": self.d_model, "nhead": self.nhead, "dim_feedforward": self.dim_feedforward, + "use_feat_dynamic_real": self.use_feat_dynamic_real, "dropout": self.dropout, "activation": self.activation, "norm_first": self.norm_first, @@ -220,7 +226,10 @@ def _create_instance_splitter( instance_sampler=instance_sampler, past_length=self.context_length, future_length=self.prediction_length, - time_series_fields=[FieldName.OBSERVED_VALUES], + time_series_fields=[FieldName.OBSERVED_VALUES] + ( + [FieldName.FEAT_DYNAMIC_REAL] if self.use_feat_dynamic_real + else [] + ), dummy_value=self.distr_output.value_in_support, ) @@ -239,7 +248,10 @@ def create_training_data_loader( instances, batch_size=self.batch_size, shuffle_buffer_length=shuffle_buffer_length, - field_names=TRAINING_INPUT_NAMES, + field_names=TRAINING_INPUT_NAMES + ( + ["past_feat_dynamic_real", "future_feat_dynamic_real"] if + self.use_feat_dynamic_real else [] + ), output_type=torch.tensor, num_batches_per_epoch=self.num_batches_per_epoch, ) @@ -253,7 +265,10 @@ def create_validation_data_loader( return as_stacked_batches( instances, batch_size=self.batch_size, - field_names=TRAINING_INPUT_NAMES, + field_names=TRAINING_INPUT_NAMES + ( + ["past_feat_dynamic_real", "future_feat_dynamic_real"] if + self.use_feat_dynamic_real else [] + ), output_type=torch.tensor, ) @@ -264,7 +279,10 @@ def create_predictor( return PyTorchPredictor( input_transform=transformation + prediction_splitter, - input_names=PREDICTION_INPUT_NAMES, + input_names=PREDICTION_INPUT_NAMES + ( + ["past_feat_dynamic_real", "future_feat_dynamic_real"] if + self.use_feat_dynamic_real else [] + ), prediction_net=module, forecast_generator=self.distr_output.forecast_generator, batch_size=self.batch_size, diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index 4e829e2ea1..aeb28dc3a9 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from typing import Tuple +from typing import Optional, Tuple import numpy as np import torch @@ -95,6 +95,7 @@ def __init__( d_model: int, nhead: int, dim_feedforward: int, + use_feat_dynamic_real: bool, dropout: float, activation: str, norm_first: bool, @@ -114,6 +115,7 @@ def __init__( self.d_model = d_model self.padding_patch = padding_patch self.distr_output = distr_output + self.use_feat_dynamic_real = use_feat_dynamic_real if scaling == "mean": self.scaler = MeanScaler(keepdim=True) @@ -157,6 +159,19 @@ def __init__( self.args_proj = self.distr_output.get_args_proj(d_model) def describe_inputs(self, batch_size=1) -> InputSpec: + if self.use_feat_dynamic_real: + input_spec_feat = { + "past_feat_dynamic_real": Input( + shape=(batch_size, self.context_length), dtype=torch.float + ), + "future_feat_dynamic_real": Input( + shape=(batch_size, self.prediction_length), + dtype=torch.float + ), + } + else: + input_spec_feat = {} + return InputSpec( { "past_target": Input( @@ -165,6 +180,7 @@ def describe_inputs(self, batch_size=1) -> InputSpec: "past_observed_values": Input( shape=(batch_size, self.context_length), dtype=torch.float ), + **input_spec_feat, }, torch.zeros, ) @@ -173,7 +189,11 @@ def forward( self, past_target: torch.Tensor, past_observed_values: torch.Tensor, + past_feat_dynamic_real: Optional[torch.Tensor] = None, + future_feat_dynamic_real: Optional[torch.Tensor] = None, ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: + if self.use_feat_dynamic_real: + assert future_feat_dynamic_real is not None # scale the input past_target_scaled, loc, scale = self.scaler( past_target, past_observed_values @@ -218,9 +238,14 @@ def loss( past_observed_values: torch.Tensor, future_target: torch.Tensor, future_observed_values: torch.Tensor, + past_feat_dynamic_real: Optional[torch.Tensor] = None, + future_feat_dynamic_real: Optional[torch.Tensor] = None, ) -> torch.Tensor: distr_args, loc, scale = self( - past_target=past_target, past_observed_values=past_observed_values + past_target=past_target, + past_observed_values=past_observed_values, + past_feat_dynamic_real=past_feat_dynamic_real, + future_feat_dynamic_real=future_feat_dynamic_real, ) loss = self.distr_output.loss( target=future_target, distr_args=distr_args, loc=loc, scale=scale diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index 3c2ef3ea51..7c1652e439 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -296,6 +296,25 @@ def test_estimator_constant_dataset( num_batches_per_epoch=3, epochs=2, ), + lambda freq, prediction_length: PatchTSTEstimator( + prediction_length=prediction_length, + context_length=2 * prediction_length, + use_feat_dynamic_real=True, + patch_len=16, + batch_size=4, + num_batches_per_epoch=3, + trainer_kwargs=dict(max_epochs=2), + ), + lambda freq, prediction_length: PatchTSTEstimator( + prediction_length=prediction_length, + context_length=2 * prediction_length, + use_feat_dynamic_real=True, + distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]), + patch_len=16, + batch_size=4, + num_batches_per_epoch=3, + trainer_kwargs=dict(max_epochs=2), + ), lambda freq, prediction_length: WaveNetEstimator( freq=freq, prediction_length=prediction_length, From b2039a4442356ff0c4190a52566956f2bab962af Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Thu, 2 May 2024 16:17:13 +0200 Subject: [PATCH 2/6] rename feature field --- .../torch/model/patch_tst/estimator.py | 25 +++++++++++++------ src/gluonts/torch/model/patch_tst/module.py | 18 ++++++------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index 1f8280e083..166183a1c3 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -30,6 +30,7 @@ TestSplitSampler, ExpectedNumInstanceSampler, SelectFields, + RenameFields, ) from gluonts.torch.model.estimator import PyTorchLightningEstimator from gluonts.torch.model.predictor import PyTorchPredictor @@ -179,6 +180,10 @@ def create_transformation(self) -> Transformation: else [] ), allow_missing=True, + ) + RenameFields( + { + FieldName.FEAT_DYNAMIC_REAL: FieldName.FEAT_TIME + } ) + AddObservedValuesIndicator( target_field=FieldName.TARGET, output_field=FieldName.OBSERVED_VALUES, @@ -227,7 +232,7 @@ def _create_instance_splitter( past_length=self.context_length, future_length=self.prediction_length, time_series_fields=[FieldName.OBSERVED_VALUES] + ( - [FieldName.FEAT_DYNAMIC_REAL] if self.use_feat_dynamic_real + [FieldName.FEAT_TIME] if self.use_feat_dynamic_real else [] ), dummy_value=self.distr_output.value_in_support, @@ -249,8 +254,10 @@ def create_training_data_loader( batch_size=self.batch_size, shuffle_buffer_length=shuffle_buffer_length, field_names=TRAINING_INPUT_NAMES + ( - ["past_feat_dynamic_real", "future_feat_dynamic_real"] if - self.use_feat_dynamic_real else [] + [ + f"past_{FieldName.FEAT_TIME}", + f"future_{FieldName.FEAT_TIME}" + ] if self.use_feat_dynamic_real else [] ), output_type=torch.tensor, num_batches_per_epoch=self.num_batches_per_epoch, @@ -266,8 +273,10 @@ def create_validation_data_loader( instances, batch_size=self.batch_size, field_names=TRAINING_INPUT_NAMES + ( - ["past_feat_dynamic_real", "future_feat_dynamic_real"] if - self.use_feat_dynamic_real else [] + [ + f"past_{FieldName.FEAT_TIME}", + f"future_{FieldName.FEAT_TIME}" + ] if self.use_feat_dynamic_real else [] ), output_type=torch.tensor, ) @@ -280,8 +289,10 @@ def create_predictor( return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES + ( - ["past_feat_dynamic_real", "future_feat_dynamic_real"] if - self.use_feat_dynamic_real else [] + [ + f"past_{FieldName.FEAT_TIME}", + f"future_{FieldName.FEAT_TIME}" + ] if self.use_feat_dynamic_real else [] ), prediction_net=module, forecast_generator=self.distr_output.forecast_generator, diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index aeb28dc3a9..a0f869c84f 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -161,10 +161,10 @@ def __init__( def describe_inputs(self, batch_size=1) -> InputSpec: if self.use_feat_dynamic_real: input_spec_feat = { - "past_feat_dynamic_real": Input( + "past_time_feat": Input( shape=(batch_size, self.context_length), dtype=torch.float ), - "future_feat_dynamic_real": Input( + "future_time_feat": Input( shape=(batch_size, self.prediction_length), dtype=torch.float ), @@ -189,11 +189,11 @@ def forward( self, past_target: torch.Tensor, past_observed_values: torch.Tensor, - past_feat_dynamic_real: Optional[torch.Tensor] = None, - future_feat_dynamic_real: Optional[torch.Tensor] = None, + past_time_feat: Optional[torch.Tensor] = None, + future_time_feat: Optional[torch.Tensor] = None, ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: if self.use_feat_dynamic_real: - assert future_feat_dynamic_real is not None + assert future_time_feat is not None # scale the input past_target_scaled, loc, scale = self.scaler( past_target, past_observed_values @@ -238,14 +238,14 @@ def loss( past_observed_values: torch.Tensor, future_target: torch.Tensor, future_observed_values: torch.Tensor, - past_feat_dynamic_real: Optional[torch.Tensor] = None, - future_feat_dynamic_real: Optional[torch.Tensor] = None, + past_time_feat: Optional[torch.Tensor] = None, + future_time_feat: Optional[torch.Tensor] = None, ) -> torch.Tensor: distr_args, loc, scale = self( past_target=past_target, past_observed_values=past_observed_values, - past_feat_dynamic_real=past_feat_dynamic_real, - future_feat_dynamic_real=future_feat_dynamic_real, + past_time_feat=past_time_feat, + future_time_feat=future_time_feat, ) loss = self.distr_output.loss( target=future_target, distr_args=distr_args, loc=loc, scale=scale From e9821c1f510c5d3f2d93a6fd6e77ed704e06ed6c Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Sun, 5 May 2024 13:00:56 +0200 Subject: [PATCH 3/6] Update inputs with time features --- .../torch/model/patch_tst/estimator.py | 16 +++---- src/gluonts/torch/model/patch_tst/module.py | 47 +++++++++++++++---- test/torch/model/test_estimators.py | 4 +- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index 166183a1c3..0018ee47e0 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -116,7 +116,7 @@ def __init__( d_model: int = 32, nhead: int = 4, dim_feedforward: int = 128, - use_feat_dynamic_real: bool = False, + num_feat_dynamic_real: int = 0, dropout: float = 0.1, activation: str = "relu", norm_first: bool = False, @@ -153,7 +153,7 @@ def __init__( self.d_model = d_model self.nhead = nhead self.dim_feedforward = dim_feedforward - self.use_feat_dynamic_real = use_feat_dynamic_real + self.num_feat_dynamic_real = num_feat_dynamic_real self.dropout = dropout self.activation = activation self.norm_first = norm_first @@ -176,7 +176,7 @@ def create_transformation(self) -> Transformation: FieldName.START, FieldName.TARGET, ] + ( - [FieldName.FEAT_DYNAMIC_REAL] if self.use_feat_dynamic_real + [FieldName.FEAT_DYNAMIC_REAL] if self.num_feat_dynamic_real > 0 else [] ), allow_missing=True, @@ -202,7 +202,7 @@ def create_lightning_module(self) -> pl.LightningModule: "d_model": self.d_model, "nhead": self.nhead, "dim_feedforward": self.dim_feedforward, - "use_feat_dynamic_real": self.use_feat_dynamic_real, + "num_feat_dynamic_real": self.num_feat_dynamic_real, "dropout": self.dropout, "activation": self.activation, "norm_first": self.norm_first, @@ -232,7 +232,7 @@ def _create_instance_splitter( past_length=self.context_length, future_length=self.prediction_length, time_series_fields=[FieldName.OBSERVED_VALUES] + ( - [FieldName.FEAT_TIME] if self.use_feat_dynamic_real + [FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 else [] ), dummy_value=self.distr_output.value_in_support, @@ -257,7 +257,7 @@ def create_training_data_loader( [ f"past_{FieldName.FEAT_TIME}", f"future_{FieldName.FEAT_TIME}" - ] if self.use_feat_dynamic_real else [] + ] if self.num_feat_dynamic_real > 0 else [] ), output_type=torch.tensor, num_batches_per_epoch=self.num_batches_per_epoch, @@ -276,7 +276,7 @@ def create_validation_data_loader( [ f"past_{FieldName.FEAT_TIME}", f"future_{FieldName.FEAT_TIME}" - ] if self.use_feat_dynamic_real else [] + ] if self.num_feat_dynamic_real > 0 else [] ), output_type=torch.tensor, ) @@ -292,7 +292,7 @@ def create_predictor( [ f"past_{FieldName.FEAT_TIME}", f"future_{FieldName.FEAT_TIME}" - ] if self.use_feat_dynamic_real else [] + ] if self.num_feat_dynamic_real > 0 else [] ), prediction_net=module, forecast_generator=self.distr_output.forecast_generator, diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index a0f869c84f..8623d3aa97 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -21,7 +21,7 @@ from gluonts.model import Input, InputSpec from gluonts.torch.distributions import StudentTOutput from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler -from gluonts.torch.util import unsqueeze_expand, weighted_average +from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average from gluonts.torch.model.simple_feedforward import make_linear_layer @@ -95,7 +95,7 @@ def __init__( d_model: int, nhead: int, dim_feedforward: int, - use_feat_dynamic_real: bool, + num_feat_dynamic_real: int, dropout: float, activation: str, norm_first: bool, @@ -115,7 +115,7 @@ def __init__( self.d_model = d_model self.padding_patch = padding_patch self.distr_output = distr_output - self.use_feat_dynamic_real = use_feat_dynamic_real + self.num_feat_dynamic_real = num_feat_dynamic_real if scaling == "mean": self.scaler = MeanScaler(keepdim=True) @@ -130,7 +130,10 @@ def __init__( self.patch_num += 1 # project from patch_len + 2 features (loc and scale) to d_model - self.patch_proj = make_linear_layer(patch_len + 2, d_model) + self.patch_proj = make_linear_layer( + patch_len + 2 + self.num_feat_dynamic_real * patch_len, + d_model + ) self.positional_encoding = SinusoidalPositionalEmbedding( self.patch_num, d_model @@ -159,13 +162,21 @@ def __init__( self.args_proj = self.distr_output.get_args_proj(d_model) def describe_inputs(self, batch_size=1) -> InputSpec: - if self.use_feat_dynamic_real: + if self.num_feat_dynamic_real > 0: input_spec_feat = { "past_time_feat": Input( - shape=(batch_size, self.context_length), dtype=torch.float + shape=( + batch_size, + self.context_length, + self.num_feat_dynamic_real + ), dtype=torch.float ), "future_time_feat": Input( - shape=(batch_size, self.prediction_length), + shape=( + batch_size, + self.prediction_length, + self.num_feat_dynamic_real + ), dtype=torch.float ), } @@ -192,8 +203,6 @@ def forward( past_time_feat: Optional[torch.Tensor] = None, future_time_feat: Optional[torch.Tensor] = None, ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: - if self.use_feat_dynamic_real: - assert future_time_feat is not None # scale the input past_target_scaled, loc, scale = self.scaler( past_target, past_observed_values @@ -206,6 +215,23 @@ def forward( dimension=1, size=self.patch_len, step=self.stride ) + # do patching for time features as well + if self.num_feat_dynamic_real > 0: + time_feat = take_last( + torch.cat((past_time_feat, future_time_feat), dim=1), + dim=1, + num=self.context_length + ) + + # (bs x T x d) --> (bs x d x T) because the 1D padding is done on + # the last dimension. + time_feat = self.padding_patch_layer( + time_feat.transpose(-2, -1) + ).transpose(-2, -1) + time_feat_patches = time_feat.unfold( + dimension=1, size=self.patch_len, step=self.stride + ).flatten(-2, -1) + # add loc and scale to past_target_patches as additional features log_abs_loc = loc.abs().log1p() log_scale = scale.log() @@ -216,6 +242,9 @@ def forward( ) inputs = torch.cat((past_target_patches, expanded_static_feat), dim=-1) + if self.num_feat_dynamic_real > 0: + inputs = torch.cat((inputs, time_feat_patches), dim=-1) + # project patches enc_in = self.patch_proj(inputs) embed_pos = self.positional_encoding(enc_in.size()) diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index 7c1652e439..caa5ae10ec 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -299,7 +299,7 @@ def test_estimator_constant_dataset( lambda freq, prediction_length: PatchTSTEstimator( prediction_length=prediction_length, context_length=2 * prediction_length, - use_feat_dynamic_real=True, + num_feat_dynamic_real=3, patch_len=16, batch_size=4, num_batches_per_epoch=3, @@ -308,7 +308,7 @@ def test_estimator_constant_dataset( lambda freq, prediction_length: PatchTSTEstimator( prediction_length=prediction_length, context_length=2 * prediction_length, - use_feat_dynamic_real=True, + num_feat_dynamic_real=3, distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]), patch_len=16, batch_size=4, From 6f60881f975d198e08f88ea6b82d09074a4b9a12 Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Sun, 5 May 2024 13:03:36 +0200 Subject: [PATCH 4/6] formatting checks --- .../torch/model/patch_tst/estimator.py | 71 +++++++++++-------- src/gluonts/torch/model/patch_tst/module.py | 14 ++-- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index 0018ee47e0..b11560e893 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -169,24 +169,26 @@ def __init__( ) def create_transformation(self) -> Transformation: - return SelectFields( - [ - FieldName.ITEM_ID, - FieldName.INFO, - FieldName.START, - FieldName.TARGET, - ] + ( - [FieldName.FEAT_DYNAMIC_REAL] if self.num_feat_dynamic_real > 0 - else [] - ), - allow_missing=True, - ) + RenameFields( - { - FieldName.FEAT_DYNAMIC_REAL: FieldName.FEAT_TIME - } - ) + AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, + return ( + SelectFields( + [ + FieldName.ITEM_ID, + FieldName.INFO, + FieldName.START, + FieldName.TARGET, + ] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.num_feat_dynamic_real > 0 + else [] + ), + allow_missing=True, + ) + + RenameFields({FieldName.FEAT_DYNAMIC_REAL: FieldName.FEAT_TIME}) + + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ) ) def create_lightning_module(self) -> pl.LightningModule: @@ -231,9 +233,9 @@ def _create_instance_splitter( instance_sampler=instance_sampler, past_length=self.context_length, future_length=self.prediction_length, - time_series_fields=[FieldName.OBSERVED_VALUES] + ( - [FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 - else [] + time_series_fields=[FieldName.OBSERVED_VALUES] + + ( + [FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 else [] ), dummy_value=self.distr_output.value_in_support, ) @@ -253,11 +255,14 @@ def create_training_data_loader( instances, batch_size=self.batch_size, shuffle_buffer_length=shuffle_buffer_length, - field_names=TRAINING_INPUT_NAMES + ( + field_names=TRAINING_INPUT_NAMES + + ( [ f"past_{FieldName.FEAT_TIME}", - f"future_{FieldName.FEAT_TIME}" - ] if self.num_feat_dynamic_real > 0 else [] + f"future_{FieldName.FEAT_TIME}", + ] + if self.num_feat_dynamic_real > 0 + else [] ), output_type=torch.tensor, num_batches_per_epoch=self.num_batches_per_epoch, @@ -272,11 +277,14 @@ def create_validation_data_loader( return as_stacked_batches( instances, batch_size=self.batch_size, - field_names=TRAINING_INPUT_NAMES + ( + field_names=TRAINING_INPUT_NAMES + + ( [ f"past_{FieldName.FEAT_TIME}", - f"future_{FieldName.FEAT_TIME}" - ] if self.num_feat_dynamic_real > 0 else [] + f"future_{FieldName.FEAT_TIME}", + ] + if self.num_feat_dynamic_real > 0 + else [] ), output_type=torch.tensor, ) @@ -288,11 +296,14 @@ def create_predictor( return PyTorchPredictor( input_transform=transformation + prediction_splitter, - input_names=PREDICTION_INPUT_NAMES + ( + input_names=PREDICTION_INPUT_NAMES + + ( [ f"past_{FieldName.FEAT_TIME}", - f"future_{FieldName.FEAT_TIME}" - ] if self.num_feat_dynamic_real > 0 else [] + f"future_{FieldName.FEAT_TIME}", + ] + if self.num_feat_dynamic_real > 0 + else [] ), prediction_net=module, forecast_generator=self.distr_output.forecast_generator, diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index 8623d3aa97..334cb1f739 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -131,8 +131,7 @@ def __init__( # project from patch_len + 2 features (loc and scale) to d_model self.patch_proj = make_linear_layer( - patch_len + 2 + self.num_feat_dynamic_real * patch_len, - d_model + patch_len + 2 + self.num_feat_dynamic_real * patch_len, d_model ) self.positional_encoding = SinusoidalPositionalEmbedding( @@ -168,16 +167,17 @@ def describe_inputs(self, batch_size=1) -> InputSpec: shape=( batch_size, self.context_length, - self.num_feat_dynamic_real - ), dtype=torch.float + self.num_feat_dynamic_real, + ), + dtype=torch.float, ), "future_time_feat": Input( shape=( batch_size, self.prediction_length, - self.num_feat_dynamic_real + self.num_feat_dynamic_real, ), - dtype=torch.float + dtype=torch.float, ), } else: @@ -220,7 +220,7 @@ def forward( time_feat = take_last( torch.cat((past_time_feat, future_time_feat), dim=1), dim=1, - num=self.context_length + num=self.context_length, ) # (bs x T x d) --> (bs x d x T) because the 1D padding is done on From 19e5a1a8ab7bd3e5e2e64fb67683d8149fb3e4e9 Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Sun, 5 May 2024 13:19:03 +0200 Subject: [PATCH 5/6] update comments --- src/gluonts/torch/model/patch_tst/estimator.py | 2 ++ src/gluonts/torch/model/patch_tst/module.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index b11560e893..4eea0be9d7 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -75,6 +75,8 @@ class PatchTSTEstimator(PyTorchLightningEstimator): Number of attention heads in the Transformer encoder which must divide d_model. dim_feedforward Size of hidden layers in the Transformer encoder. + num_feat_dynamic_real + Number of dynamic real features in the data (default: 0). dropout Dropout probability in the Transformer encoder. activation diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index 334cb1f739..88080ea183 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -79,6 +79,8 @@ class PatchTSTModel(nn.Module): Number of time points to predict. context_length Number of time steps prior to prediction time that the model. + num_feat_dynamic_real + Number of dynamic real features in the data (default: 0). distr_output Distribution to use to evaluate observations and sample predictions. Default: ``StudentTOutput()``. @@ -129,7 +131,8 @@ def __init__( self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride)) self.patch_num += 1 - # project from patch_len + 2 features (loc and scale) to d_model + # project from `patch_len` + 2 features (`loc` and `scale`) + + # `num_feat_dynamic_real` x `patch_len` to d_model self.patch_proj = make_linear_layer( patch_len + 2 + self.num_feat_dynamic_real * patch_len, d_model ) From 80071ed0cf8cd5e7cc27f48e063ac355da56dca0 Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Sun, 5 May 2024 13:37:09 +0200 Subject: [PATCH 6/6] Add a comment --- src/gluonts/torch/model/patch_tst/module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index 88080ea183..68370ba843 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -220,6 +220,8 @@ def forward( # do patching for time features as well if self.num_feat_dynamic_real > 0: + # shift time features by `prediction_length` so that they are + # aligned with the target input. time_feat = take_last( torch.cat((past_time_feat, future_time_feat), dim=1), dim=1,