Skip to content

Commit

Permalink
Feat: Introducing horizon weighting to the distribution losses simila…
Browse files Browse the repository at this point in the history
…r to point losses like MAE. This has proven useful in our applications. Also cleaning notebook outputs.
  • Loading branch information
mwamsojo committed Dec 19, 2024
1 parent df8c431 commit f2c957e
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 27 deletions.
84 changes: 58 additions & 26 deletions nbs/losses.pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2475,7 +2475,7 @@
"\n",
" \"\"\"\n",
" def __init__(self, distribution, level=[80, 90], quantiles=None,\n",
" num_samples=1000, return_params=False, **distribution_kwargs):\n",
" num_samples=1000, return_params=False, horizon_weight = None, **distribution_kwargs):\n",
" super(DistributionLoss, self).__init__()\n",
"\n",
" qs, self.output_names = level_to_outputs(level)\n",
Expand All @@ -2489,6 +2489,12 @@
" self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n",
" num_qk = len(self.quantiles)\n",
"\n",
" # Generate a horizon weight tensor from the array\n",
" if horizon_weight is not None:\n",
" horizon_weight = torch.Tensor(horizon_weight.flatten())\n",
" self.horizon_weight = horizon_weight\n",
"\n",
"\n",
" if \"num_pieces\" not in distribution_kwargs:\n",
" num_pieces = 5\n",
" else:\n",
Expand Down Expand Up @@ -2610,36 +2616,62 @@
"\n",
" return samples, sample_mean, quants\n",
"\n",
" def __call__(self,\n",
" y: torch.Tensor,\n",
" distr_args: torch.Tensor,\n",
" mask: Union[torch.Tensor, None] = None):\n",
"\n",
"\n",
" def _compute_weights(self, y, mask):\n",
" \"\"\"\n",
" Compute final weights for each datapoint (based on all weights and all masks)\n",
" Set horizon_weight to a ones[H] tensor if not set.\n",
" If set, check that it has the same length as the horizon in x.\n",
" \"\"\"\n",
" Computes the negative log-likelihood objective function. \n",
" To estimate the following predictive distribution:\n",
" if mask is None:\n",
" mask = torch.ones_like(y, device=y.device)\n",
" else:\n",
" mask = mask.unsqueeze(1) # Add Q dimension.\n",
"\n",
" $$\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta) \\\\quad \\mathrm{and} \\\\quad -\\log(\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta))$$\n",
"\n",
" where $\\\\theta$ represents the distributions parameters. It aditionally \n",
" summarizes the objective signal using a weighted average using the `mask` tensor. \n",
" # get uniform weights if none\n",
" if self.horizon_weight is None:\n",
" self.horizon_weight = torch.ones(mask.shape[-1])\n",
" else:\n",
" assert mask.shape[-1] == len(self.horizon_weight), \\\n",
" 'horizon_weight must have same length as Y'\n",
" weights = self.horizon_weight.clone()\n",
" weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)\n",
" return weights * mask\n",
" \n",
"\n",
" **Parameters**<br>\n",
" `y`: tensor, Actual values.<br>\n",
" `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
" `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
" of the resulting distribution.<br>\n",
" `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
" of the resulting distribution.<br>\n",
" `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
"\n",
" **Returns**<br>\n",
" `loss`: scalar, weighted loss function against which backpropagation will be performed.<br>\n",
" \"\"\"\n",
" # Instantiate Scaled Decoupled Distribution\n",
" distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)\n",
" loss_values = -distr.log_prob(y)\n",
" loss_weights = mask\n",
" return weighted_average(loss_values, weights=loss_weights)"
" def __call__(self,\n",
" y: torch.Tensor,\n",
" distr_args: torch.Tensor,\n",
" mask: Union[torch.Tensor, None] = None):\n",
" \"\"\"\n",
" Computes the negative log-likelihood objective function. \n",
" To estimate the following predictive distribution:\n",
"\n",
" $$\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta) \\\\quad \\mathrm{and} \\\\quad -\\log(\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta))$$\n",
"\n",
" where $\\\\theta$ represents the distributions parameters. It aditionally \n",
" summarizes the objective signal using a weighted average using the `mask` tensor. \n",
" \n",
" **Parameters**<br>\n",
" `y`: tensor, Actual values.<br>\n",
" `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
" `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
" of the resulting distribution.<br>\n",
" `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
" of the resulting distribution.<br>\n",
" `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
"\n",
" **Returns**<br>\n",
" `loss`: scalar, weighted loss function against which backpropagation will be performed.<br>\n",
" \"\"\"\n",
" # Instantiate Scaled Decoupled Distribution\n",
" distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)\n",
" loss_values = -distr.log_prob(y)\n",
" loss_weights = self._compute_weights(y=y, mask=mask)\n",
" return weighted_average(loss_values, weights=loss_weights)"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions neuralforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@
'neuralforecast/losses/pytorch.py'),
'neuralforecast.losses.pytorch.DistributionLoss.__init__': ( 'losses.pytorch.html#distributionloss.__init__',
'neuralforecast/losses/pytorch.py'),
'neuralforecast.losses.pytorch.DistributionLoss._compute_weights': ( 'losses.pytorch.html#distributionloss._compute_weights',
'neuralforecast/losses/pytorch.py'),
'neuralforecast.losses.pytorch.DistributionLoss.get_distribution': ( 'losses.pytorch.html#distributionloss.get_distribution',
'neuralforecast/losses/pytorch.py'),
'neuralforecast.losses.pytorch.DistributionLoss.sample': ( 'losses.pytorch.html#distributionloss.sample',
Expand Down
30 changes: 29 additions & 1 deletion neuralforecast/losses/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,7 @@ def __init__(
quantiles=None,
num_samples=1000,
return_params=False,
horizon_weight=None,
**distribution_kwargs,
):
super(DistributionLoss, self).__init__()
Expand All @@ -1888,6 +1889,11 @@ def __init__(
self.quantiles = torch.nn.Parameter(qs, requires_grad=False)
num_qk = len(self.quantiles)

# Generate a horizon weight tensor from the array
if horizon_weight is not None:
horizon_weight = torch.Tensor(horizon_weight.flatten())
self.horizon_weight = horizon_weight

if "num_pieces" not in distribution_kwargs:
num_pieces = 5
else:
Expand Down Expand Up @@ -2011,6 +2017,28 @@ def sample(self, distr_args: torch.Tensor, num_samples: Optional[int] = None):

return samples, sample_mean, quants

def _compute_weights(self, y, mask):
"""
Compute final weights for each datapoint (based on all weights and all masks)
Set horizon_weight to a ones[H] tensor if not set.
If set, check that it has the same length as the horizon in x.
"""
if mask is None:
mask = torch.ones_like(y, device=y.device)
else:
mask = mask.unsqueeze(1) # Add Q dimension.

# get uniform weights if none
if self.horizon_weight is None:
self.horizon_weight = torch.ones(mask.shape[-1])
else:
assert mask.shape[-1] == len(
self.horizon_weight
), "horizon_weight must have same length as Y"
weights = self.horizon_weight.clone()
weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)
return weights * mask

def __call__(
self,
y: torch.Tensor,
Expand Down Expand Up @@ -2041,7 +2069,7 @@ def __call__(
# Instantiate Scaled Decoupled Distribution
distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)
loss_values = -distr.log_prob(y)
loss_weights = mask
loss_weights = self._compute_weights(y=y, mask=mask)
return weighted_average(loss_values, weights=loss_weights)

# %% ../../nbs/losses.pytorch.ipynb 74
Expand Down

0 comments on commit f2c957e

Please sign in to comment.