Skip to content

Commit

Permalink
PatchTST: Add support for time features (awslabs#3167)
Browse files Browse the repository at this point in the history
  • Loading branch information
rshyamsundar authored and kashif committed Jun 15, 2024
1 parent ccb48cc commit f64c830
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 20 deletions.
72 changes: 57 additions & 15 deletions src/gluonts/torch/model/patch_tst/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
RenameFields,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
Expand Down Expand Up @@ -74,6 +75,8 @@ class PatchTSTEstimator(PyTorchLightningEstimator):
Number of attention heads in the Transformer encoder which must divide d_model.
dim_feedforward
Size of hidden layers in the Transformer encoder.
num_feat_dynamic_real
Number of dynamic real features in the data (default: 0).
dropout
Dropout probability in the Transformer encoder.
activation
Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(
d_model: int = 32,
nhead: int = 4,
dim_feedforward: int = 128,
num_feat_dynamic_real: int = 0,
dropout: float = 0.1,
activation: str = "relu",
norm_first: bool = False,
Expand Down Expand Up @@ -151,6 +155,7 @@ def __init__(
self.d_model = d_model
self.nhead = nhead
self.dim_feedforward = dim_feedforward
self.num_feat_dynamic_real = num_feat_dynamic_real
self.dropout = dropout
self.activation = activation
self.norm_first = norm_first
Expand All @@ -166,17 +171,26 @@ def __init__(
)

def create_transformation(self) -> Transformation:
return SelectFields(
[
FieldName.ITEM_ID,
FieldName.INFO,
FieldName.START,
FieldName.TARGET,
],
allow_missing=True,
) + AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
return (
SelectFields(
[
FieldName.ITEM_ID,
FieldName.INFO,
FieldName.START,
FieldName.TARGET,
]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.num_feat_dynamic_real > 0
else []
),
allow_missing=True,
)
+ RenameFields({FieldName.FEAT_DYNAMIC_REAL: FieldName.FEAT_TIME})
+ AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
)
)

def create_lightning_module(self) -> pl.LightningModule:
Expand All @@ -192,6 +206,7 @@ def create_lightning_module(self) -> pl.LightningModule:
"d_model": self.d_model,
"nhead": self.nhead,
"dim_feedforward": self.dim_feedforward,
"num_feat_dynamic_real": self.num_feat_dynamic_real,
"dropout": self.dropout,
"activation": self.activation,
"norm_first": self.norm_first,
Expand Down Expand Up @@ -220,7 +235,10 @@ def _create_instance_splitter(
instance_sampler=instance_sampler,
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=[FieldName.OBSERVED_VALUES],
time_series_fields=[FieldName.OBSERVED_VALUES]
+ (
[FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 else []
),
dummy_value=self.distr_output.value_in_support,
)

Expand All @@ -239,7 +257,15 @@ def create_training_data_loader(
instances,
batch_size=self.batch_size,
shuffle_buffer_length=shuffle_buffer_length,
field_names=TRAINING_INPUT_NAMES,
field_names=TRAINING_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
output_type=torch.tensor,
num_batches_per_epoch=self.num_batches_per_epoch,
)
Expand All @@ -253,7 +279,15 @@ def create_validation_data_loader(
return as_stacked_batches(
instances,
batch_size=self.batch_size,
field_names=TRAINING_INPUT_NAMES,
field_names=TRAINING_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
output_type=torch.tensor,
)

Expand All @@ -264,7 +298,15 @@ def create_predictor(

return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
input_names=PREDICTION_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
prediction_net=module,
forecast_generator=self.distr_output.forecast_generator,
batch_size=self.batch_size,
Expand Down
69 changes: 64 additions & 5 deletions src/gluonts/torch/model/patch_tst/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Tuple
from typing import Optional, Tuple

import numpy as np
import torch
Expand All @@ -21,7 +21,7 @@
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import StudentTOutput
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.util import unsqueeze_expand, weighted_average
from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average
from gluonts.torch.model.simple_feedforward import make_linear_layer


Expand Down Expand Up @@ -85,6 +85,8 @@ class PatchTSTModel(nn.Module):
Number of time points to predict.
context_length
Number of time steps prior to prediction time that the model.
num_feat_dynamic_real
Number of dynamic real features in the data (default: 0).
distr_output
Distribution to use to evaluate observations and sample predictions.
Default: ``StudentTOutput()``.
Expand All @@ -101,6 +103,7 @@ def __init__(
d_model: int,
nhead: int,
dim_feedforward: int,
num_feat_dynamic_real: int,
dropout: float,
activation: str,
norm_first: bool,
Expand All @@ -120,6 +123,7 @@ def __init__(
self.d_model = d_model
self.padding_patch = padding_patch
self.distr_output = distr_output
self.num_feat_dynamic_real = num_feat_dynamic_real

if scaling == "mean":
self.scaler = MeanScaler(keepdim=True)
Expand All @@ -133,8 +137,11 @@ def __init__(
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
self.patch_num += 1

# project from patch_len + 2 features (loc and scale) to d_model
self.patch_proj = make_linear_layer(patch_len + 2, d_model)
# project from `patch_len` + 2 features (`loc` and `scale`) +
# `num_feat_dynamic_real` x `patch_len` to d_model
self.patch_proj = make_linear_layer(
patch_len + 2 + self.num_feat_dynamic_real * patch_len, d_model
)

self.positional_encoding = SinusoidalPositionalEmbedding(
self.patch_num, d_model
Expand Down Expand Up @@ -163,6 +170,28 @@ def __init__(
self.args_proj = self.distr_output.get_args_proj(d_model)

def describe_inputs(self, batch_size=1) -> InputSpec:
if self.num_feat_dynamic_real > 0:
input_spec_feat = {
"past_time_feat": Input(
shape=(
batch_size,
self.context_length,
self.num_feat_dynamic_real,
),
dtype=torch.float,
),
"future_time_feat": Input(
shape=(
batch_size,
self.prediction_length,
self.num_feat_dynamic_real,
),
dtype=torch.float,
),
}
else:
input_spec_feat = {}

return InputSpec(
{
"past_target": Input(
Expand All @@ -171,6 +200,7 @@ def describe_inputs(self, batch_size=1) -> InputSpec:
"past_observed_values": Input(
shape=(batch_size, self.context_length), dtype=torch.float
),
**input_spec_feat,
},
torch.zeros,
)
Expand All @@ -179,6 +209,8 @@ def forward(
self,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: Optional[torch.Tensor] = None,
future_time_feat: Optional[torch.Tensor] = None,
) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]:
# scale the input
past_target_scaled, loc, scale = self.scaler(
Expand All @@ -192,6 +224,25 @@ def forward(
dimension=1, size=self.patch_len, step=self.stride
)

# do patching for time features as well
if self.num_feat_dynamic_real > 0:
# shift time features by `prediction_length` so that they are
# aligned with the target input.
time_feat = take_last(
torch.cat((past_time_feat, future_time_feat), dim=1),
dim=1,
num=self.context_length,
)

# (bs x T x d) --> (bs x d x T) because the 1D padding is done on
# the last dimension.
time_feat = self.padding_patch_layer(
time_feat.transpose(-2, -1)
).transpose(-2, -1)
time_feat_patches = time_feat.unfold(
dimension=1, size=self.patch_len, step=self.stride
).flatten(-2, -1)

# add loc and scale to past_target_patches as additional features
log_abs_loc = loc.abs().log1p()
log_scale = scale.log()
Expand All @@ -202,6 +253,9 @@ def forward(
)
inputs = torch.cat((past_target_patches, expanded_static_feat), dim=-1)

if self.num_feat_dynamic_real > 0:
inputs = torch.cat((inputs, time_feat_patches), dim=-1)

# project patches
enc_in = self.patch_proj(inputs)
embed_pos = self.positional_encoding(enc_in.size())
Expand All @@ -224,9 +278,14 @@ def loss(
past_observed_values: torch.Tensor,
future_target: torch.Tensor,
future_observed_values: torch.Tensor,
past_time_feat: Optional[torch.Tensor] = None,
future_time_feat: Optional[torch.Tensor] = None,
) -> torch.Tensor:
distr_args, loc, scale = self(
past_target=past_target, past_observed_values=past_observed_values
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat,
)
loss = self.distr_output.loss(
target=future_target, distr_args=distr_args, loc=loc, scale=scale
Expand Down
19 changes: 19 additions & 0 deletions test/torch/model/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,25 @@ def test_estimator_constant_dataset(
num_batches_per_epoch=3,
epochs=2,
),
lambda freq, prediction_length: PatchTSTEstimator(
prediction_length=prediction_length,
context_length=2 * prediction_length,
num_feat_dynamic_real=3,
patch_len=16,
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda freq, prediction_length: PatchTSTEstimator(
prediction_length=prediction_length,
context_length=2 * prediction_length,
num_feat_dynamic_real=3,
distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]),
patch_len=16,
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda freq, prediction_length: WaveNetEstimator(
freq=freq,
prediction_length=prediction_length,
Expand Down

0 comments on commit f64c830

Please sign in to comment.