Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynamicNBEATs model #1191

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
419 changes: 419 additions & 0 deletions experiments/nbeats_basis/nbeats_basis_experiment.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions nbs/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,8 @@
" default_config = {\n",
" \"input_size_multiplier\": [1, 2, 3, 4, 5],\n",
" \"h\": None,\n",
" \"basis\": tune.choice([\"polynomial\", \"changepoint\"]),\n",
" \"n_basis\": tune.choice([2, 5]),\n",
" \"learning_rate\": tune.loguniform(1e-4, 1e-1),\n",
" \"scaler_type\": tune.choice([None, 'robust', 'standard']),\n",
" \"max_steps\": tune.choice([500, 1000]),\n",
Expand Down
211 changes: 185 additions & 26 deletions nbs/models.nbeats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@
"from typing import Tuple, Optional\n",
"\n",
"import numpy as np\n",
"from numpy.polynomial.legendre import Legendre\n",
"from numpy.polynomial.chebyshev import Chebyshev\n",
"import torch\n",
"import torch.nn as nn\n",
"from scipy.interpolate import BSpline\n",
"\n",
"from neuralforecast.losses.pytorch import MAE\n",
"from neuralforecast.common._base_windows import BaseWindows"
Expand All @@ -84,6 +87,143 @@
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3b21a80",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def generate_legendre_basis(length, n_basis):\n",
" \"\"\"\n",
" Generates Legendre polynomial basis functions.\n",
"\n",
" Parameters:\n",
" - n_points (int): Number of data points.\n",
" - n_functions (int): Number of basis functions to generate.\n",
"\n",
" Returns:\n",
" - legendre_basis (ndarray): An array of Legendre basis functions.\n",
" \"\"\"\n",
" x = np.linspace(-1, 1, length) # Legendre polynomials are defined on [-1, 1]\n",
" legendre_basis = np.zeros((length, n_basis))\n",
" for i in range(n_basis):\n",
" # Legendre polynomial of degree i\n",
" P_i = Legendre.basis(i)\n",
" legendre_basis[:, i] = P_i(x)\n",
" return legendre_basis\n",
"\n",
"def generate_polynomial_basis(length, n_basis):\n",
" \"\"\"\n",
" Generates standard polynomial basis functions.\n",
"\n",
" Parameters:\n",
" - n_points (int): Number of data points.\n",
" - n_functions (int): Number of polynomial functions to generate.\n",
"\n",
" Returns:\n",
" - poly_basis (ndarray): An array of polynomial basis functions.\n",
" \"\"\"\n",
" return np.concatenate([np.power(np.arange(length, dtype=float) / length, i)[None, :]\n",
" for i in range(n_basis)]).T\n",
"\n",
"\n",
"def generate_changepoint_basis(length, n_basis):\n",
" \"\"\"\n",
" Generates changepoint basis functions with automatically spaced changepoints.\n",
"\n",
" Parameters:\n",
" - n_points (int): Number of data points.\n",
" - n_functions (int): Number of changepoint functions to generate.\n",
"\n",
" Returns:\n",
" - changepoint_basis (ndarray): An array of changepoint basis functions.\n",
" \"\"\"\n",
" x = np.linspace(0, 1, length)[:, None] # Shape: (length, 1)\n",
" changepoint_locations = np.linspace(0, 1, n_basis + 1)[1:][None, :] # Shape: (1, n_basis)\n",
" return np.maximum(0, x - changepoint_locations)\n",
"\n",
"def generate_piecewise_linear_basis(length, n_basis):\n",
" \"\"\"\n",
" Generates piecewise linear basis functions (linear splines).\n",
"\n",
" Parameters:\n",
" - n_points (int): Number of data points.\n",
" - n_functions (int): Number of piecewise linear basis functions to generate.\n",
"\n",
" Returns:\n",
" - pw_linear_basis (ndarray): An array of piecewise linear basis functions.\n",
" \"\"\"\n",
" x = np.linspace(0, 1, length)\n",
" knots = np.linspace(0, 1, n_basis+1)\n",
" pw_linear_basis = np.zeros((length, n_basis))\n",
" for i in range(1, n_basis):\n",
" pw_linear_basis[:, i] = np.maximum(0, np.minimum((x - knots[i-1]) / (knots[i] - knots[i-1]), (knots[i+1] - x) / (knots[i+1] - knots[i])))\n",
" return pw_linear_basis\n",
"\n",
"def generate_linear_hat_basis(length, n_basis):\n",
" x = np.linspace(0, 1, length)[:, None] # Shape: (length, 1)\n",
" centers = np.linspace(0, 1, n_basis)[None, :] # Shape: (1, n_basis)\n",
" width = 1.0 / (n_basis - 1)\n",
" \n",
" # Create triangular functions using piecewise linear equations\n",
" return np.maximum(0, 1 - np.abs(x - centers) / width)\n",
"\n",
"def generate_spline_basis(length, n_basis):\n",
" \"\"\"\n",
" Generates cubic spline basis functions.\n",
"\n",
" Parameters:\n",
" - n_points (int): Number of data points.\n",
" - n_functions (int): Number of basis functions.\n",
"\n",
" Returns:\n",
" - spline_basis (ndarray): An array of cubic spline basis functions.\n",
" \"\"\"\n",
" if n_basis < 4:\n",
" raise ValueError(f\"To use the spline basis, n_basis must be set to 4 or more. Current value is {n_basis}\")\n",
" x = np.linspace(0, 1, length)\n",
" knots = np.linspace(0, 1, n_basis - 2)\n",
" t = np.concatenate(([0, 0, 0], knots, [1, 1, 1]))\n",
" degree = 3\n",
" # Create basis coefficient matrix once\n",
" coefficients = np.eye(n_basis)\n",
" # Create single BSpline object with all coefficients\n",
" spline = BSpline(t, coefficients.T, degree)\n",
" return spline(x)\n",
"\n",
"def generate_chebyshev_basis(length, n_basis):\n",
" \"\"\"\n",
" Generates Chebyshev polynomial basis functions.\n",
"\n",
" Parameters:\n",
" - n_points (int): Number of data points.\n",
" - n_functions (int): Number of Chebyshev polynomials to generate.\n",
"\n",
" Returns:\n",
" - chebyshev_basis (ndarray): An array of Chebyshev polynomial basis functions.\n",
" \"\"\"\n",
" x = np.linspace(-1, 1, length)\n",
" chebyshev_basis = np.zeros((length, n_basis))\n",
" for i in range(n_basis):\n",
" T_i = Chebyshev.basis(i)\n",
" chebyshev_basis[:, i] = T_i(x)\n",
" return chebyshev_basis\n",
"\n",
"def get_basis(length, n_basis, basis):\n",
" basis_dict = {\n",
" 'legendre': generate_legendre_basis,\n",
" 'polynomial': generate_polynomial_basis,\n",
" 'changepoint': generate_changepoint_basis,\n",
" 'piecewise_linear': generate_piecewise_linear_basis,\n",
" 'linear_hat': generate_linear_hat_basis,\n",
" 'spline': generate_spline_basis,\n",
" 'chebyshev': generate_chebyshev_basis\n",
" }\n",
" return basis_dict[basis](length, n_basis+1)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -107,19 +247,19 @@
" return backcast, forecast\n",
"\n",
"class TrendBasis(nn.Module):\n",
" def __init__(self, degree_of_polynomial: int,\n",
" backcast_size: int, forecast_size: int,\n",
" out_features: int=1):\n",
" def __init__(self, \n",
" n_basis: int,\n",
" backcast_size: int,\n",
" forecast_size: int,\n",
" out_features: int=1,\n",
" basis='polynomial'):\n",
" super().__init__()\n",
" self.out_features = out_features\n",
" polynomial_size = degree_of_polynomial + 1\n",
" self.backcast_basis = nn.Parameter(\n",
" torch.tensor(np.concatenate([np.power(np.arange(backcast_size, dtype=float) / backcast_size, i)[None, :]\n",
" for i in range(polynomial_size)]), dtype=torch.float32), requires_grad=False)\n",
" torch.tensor(get_basis(backcast_size, n_basis, basis).T, dtype=torch.float32), requires_grad=False)\n",
" self.forecast_basis = nn.Parameter(\n",
" torch.tensor(np.concatenate([np.power(np.arange(forecast_size, dtype=float) / forecast_size, i)[None, :]\n",
" for i in range(polynomial_size)]), dtype=torch.float32), requires_grad=False)\n",
" \n",
" torch.tensor(get_basis(forecast_size, n_basis, basis).T, dtype=torch.float32), requires_grad=False)\n",
"\n",
" def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
" polynomial_size = self.forecast_basis.shape[0] # [polynomial_size, L+H]\n",
" backcast_theta = theta[:, :polynomial_size]\n",
Expand All @@ -130,8 +270,10 @@
" return backcast, forecast\n",
"\n",
"class SeasonalityBasis(nn.Module):\n",
" def __init__(self, harmonics: int, \n",
" backcast_size: int, forecast_size: int,\n",
" def __init__(self, \n",
" harmonics: int, \n",
" backcast_size: int, \n",
" forecast_size: int,\n",
" out_features: int=1):\n",
" super().__init__()\n",
" self.out_features = out_features\n",
Expand Down Expand Up @@ -191,8 +333,6 @@
" basis: nn.Module, \n",
" dropout_prob: float, \n",
" activation: str):\n",
" \"\"\"\n",
" \"\"\"\n",
" super().__init__()\n",
"\n",
" self.dropout_prob = dropout_prob\n",
Expand All @@ -209,7 +349,6 @@
"\n",
" if self.dropout_prob>0:\n",
" raise NotImplementedError('dropout')\n",
" #hidden_layers.append(nn.Dropout(p=self.dropout_prob))\n",
"\n",
" output_layer = [nn.Linear(in_features=mlp_units[-1][1], out_features=n_theta)]\n",
" layers = hidden_layers + output_layer\n",
Expand Down Expand Up @@ -245,7 +384,8 @@
" `h`: int, forecast horizon.<br>\n",
" `input_size`: int, considered autorregresive inputs (lags), y=[1,2,3,4] input_size=2 -> lags=[1,2].<br>\n",
" `n_harmonics`: int, Number of harmonic terms for seasonality stack type. Note that len(n_harmonics) = len(stack_types). Note that it will only be used if a seasonality stack is used.<br>\n",
" `n_polynomials`: int, polynomial degree for trend stack. Note that len(n_polynomials) = len(stack_types). Note that it will only be used if a trend stack is used.<br>\n",
" `basis`: str, Type of basis function to use in the trend stack. Choose one from ['legendre', 'polynomial', 'changepoint', 'piecewise_linear', 'linear_hat', 'spline', 'chebyshev']<br>\n",
" `n_basis`: int, the number of basis functions for the trend stack. Note that it will only be used if a trend stack is used.<br>\n",
" `stack_types`: List[str], List of stack types. Subset from ['seasonality', 'trend', 'identity'].<br>\n",
" `n_blocks`: List[int], Number of blocks for each stack. Note that len(n_blocks) = len(stack_types).<br>\n",
" `mlp_units`: List[List[int]], Structure of hidden layers for each stack type. Each internal list should contain the number of units of each hidden layer. Note that len(n_hidden) = len(stack_types).<br>\n",
Expand Down Expand Up @@ -291,7 +431,8 @@
" h,\n",
" input_size,\n",
" n_harmonics: int = 2,\n",
" n_polynomials: int = 2,\n",
" n_basis: int = 2,\n",
" basis: str = 'polynomial',\n",
" stack_types: list = ['identity', 'trend', 'seasonality'],\n",
" n_blocks: list = [1, 1, 1],\n",
" mlp_units: list = 3 * [[512, 512]],\n",
Expand Down Expand Up @@ -364,18 +505,23 @@
" dropout_prob_theta=dropout_prob_theta,\n",
" activation=activation,\n",
" shared_weights=shared_weights,\n",
" n_polynomials=n_polynomials, \n",
" n_harmonics=n_harmonics)\n",
" n_harmonics=n_harmonics,\n",
" n_basis=n_basis,\n",
" basis_type=basis)\n",
" self.blocks = torch.nn.ModuleList(blocks)\n",
"\n",
" def create_stack(self, stack_types, \n",
" def create_stack(self, \n",
" stack_types, \n",
" n_blocks, \n",
" input_size, \n",
" h, \n",
" mlp_units, \n",
" dropout_prob_theta, \n",
" activation, shared_weights,\n",
" n_polynomials, n_harmonics): \n",
" activation, \n",
" shared_weights,\n",
" n_harmonics, \n",
" n_basis, \n",
" basis_type): \n",
"\n",
" block_list = []\n",
" for i in range(len(stack_types)):\n",
Expand All @@ -389,14 +535,17 @@
" n_theta = 2 * (self.loss.outputsize_multiplier + 1) * \\\n",
" int(np.ceil(n_harmonics / 2 * h) - (n_harmonics - 1))\n",
" basis = SeasonalityBasis(harmonics=n_harmonics,\n",
" backcast_size=input_size,forecast_size=h,\n",
" backcast_size=input_size,\n",
" forecast_size=h,\n",
" out_features=self.loss.outputsize_multiplier)\n",
"\n",
" elif stack_types[i] == 'trend':\n",
" n_theta = (self.loss.outputsize_multiplier + 1) * (n_polynomials + 1)\n",
" basis = TrendBasis(degree_of_polynomial=n_polynomials,\n",
" backcast_size=input_size,forecast_size=h,\n",
" out_features=self.loss.outputsize_multiplier)\n",
" n_theta = (self.loss.outputsize_multiplier + 1) * (n_basis + 1)\n",
" basis = TrendBasis(n_basis=n_basis,\n",
" backcast_size=input_size,\n",
" forecast_size=h,\n",
" out_features=self.loss.outputsize_multiplier,\n",
" basis=basis_type)\n",
"\n",
" elif stack_types[i] == 'identity':\n",
" n_theta = input_size + self.loss.outputsize_multiplier * h\n",
Expand Down Expand Up @@ -658,6 +807,8 @@
"Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
"\n",
"model = NBEATS(h=12, input_size=24,\n",
" basis='polynomial',\n",
" n_basis=5,\n",
" loss=DistributionLoss(distribution='Poisson', level=[80, 90]),\n",
" stack_types = ['identity', 'trend', 'seasonality'],\n",
" max_steps=100,\n",
Expand Down Expand Up @@ -687,6 +838,14 @@
"plt.legend()\n",
"plt.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c87058ca",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
18 changes: 17 additions & 1 deletion neuralforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,23 @@
'neuralforecast.models.nbeats.TrendBasis.__init__': ( 'models.nbeats.html#trendbasis.__init__',
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.TrendBasis.forward': ( 'models.nbeats.html#trendbasis.forward',
'neuralforecast/models/nbeats.py')},
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.generate_changepoint_basis': ( 'models.nbeats.html#generate_changepoint_basis',
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.generate_chebyshev_basis': ( 'models.nbeats.html#generate_chebyshev_basis',
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.generate_legendre_basis': ( 'models.nbeats.html#generate_legendre_basis',
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.generate_linear_hat_basis': ( 'models.nbeats.html#generate_linear_hat_basis',
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.generate_piecewise_linear_basis': ( 'models.nbeats.html#generate_piecewise_linear_basis',
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.generate_polynomial_basis': ( 'models.nbeats.html#generate_polynomial_basis',
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.generate_spline_basis': ( 'models.nbeats.html#generate_spline_basis',
'neuralforecast/models/nbeats.py'),
'neuralforecast.models.nbeats.get_basis': ( 'models.nbeats.html#get_basis',
'neuralforecast/models/nbeats.py')},
'neuralforecast.models.nbeatsx': { 'neuralforecast.models.nbeatsx.ExogenousBasis': ( 'models.nbeatsx.html#exogenousbasis',
'neuralforecast/models/nbeatsx.py'),
'neuralforecast.models.nbeatsx.ExogenousBasis.__init__': ( 'models.nbeatsx.html#exogenousbasis.__init__',
Expand Down
2 changes: 2 additions & 0 deletions neuralforecast/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,8 @@ class AutoNBEATS(BaseAuto):
default_config = {
"input_size_multiplier": [1, 2, 3, 4, 5],
"h": None,
"basis": tune.choice(["polynomial", "changepoint"]),
"n_basis": tune.choice([2, 5]),
"learning_rate": tune.loguniform(1e-4, 1e-1),
"scaler_type": tune.choice([None, "robust", "standard"]),
"max_steps": tune.choice([500, 1000]),
Expand Down
Loading
Loading