Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] support providing DataLoader arguments to optimize GPU usage #1186

Merged
merged 13 commits into from
Nov 8, 2024
Merged
14 changes: 13 additions & 1 deletion nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,25 @@
" datamodule_constructor = TimeSeriesDataModule\n",
" else:\n",
" datamodule_constructor = _DistributedTimeSeriesDataModule\n",
" \n",
" dataloader_kwargs = self.dataloader_kwargs if self.dataloader_kwargs is not None else {}\n",
" \n",
" if self.num_workers_loader != 0: # value is not at its default\n",
" warnings.warn(\n",
" \"The `num_workers_loader` argument is deprecated and will be removed in a future version. \"\n",
" \"Please provide num_workers through `dataloader_kwargs`, e.g. \"\n",
" f\"`dataloader_kwargs={{'num_workers': {self.num_workers_loader}}}`\",\n",
" category=FutureWarning,\n",
" )\n",
" dataloader_kwargs['num_workers'] = self.num_workers_loader\n",
"\n",
" datamodule = datamodule_constructor(\n",
" dataset=dataset, \n",
" batch_size=batch_size,\n",
" valid_batch_size=valid_batch_size,\n",
" num_workers=self.num_workers_loader,\n",
" drop_last=self.drop_last_loader,\n",
" shuffle_train=shuffle_train,\n",
" **dataloader_kwargs\n",
" )\n",
"\n",
" if self.val_check_steps > self.max_steps:\n",
Expand Down
2 changes: 2 additions & 0 deletions nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
Expand Down Expand Up @@ -173,6 +174,7 @@
"\n",
" # DataModule arguments\n",
" self.num_workers_loader = num_workers_loader\n",
" self.dataloader_kwargs = dataloader_kwargs\n",
" self.drop_last_loader = drop_last_loader\n",
" # used by on_validation_epoch_end hook\n",
" self.validation_step_outputs = []\n",
Expand Down
4 changes: 3 additions & 1 deletion nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
Expand Down Expand Up @@ -172,6 +173,7 @@
"\n",
" # DataModule arguments\n",
" self.num_workers_loader = num_workers_loader\n",
" self.dataloader_kwargs = dataloader_kwargs\n",
" self.drop_last_loader = drop_last_loader\n",
" # used by on_validation_epoch_end hook\n",
" self.validation_step_outputs = []\n",
Expand Down Expand Up @@ -536,7 +538,7 @@
" self._check_exog(dataset)\n",
" self._restart_seed(random_seed)\n",
" data_module_kwargs = self._set_quantile_for_iqloss(**data_module_kwargs)\n",
"\n",
" \n",
" if step_size > 1:\n",
" raise Exception('Recurrent models do not support step_size > 1')\n",
"\n",
Expand Down
2 changes: 2 additions & 0 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
Expand Down Expand Up @@ -188,6 +189,7 @@
"\n",
" # DataModule arguments\n",
" self.num_workers_loader = num_workers_loader\n",
" self.dataloader_kwargs = dataloader_kwargs\n",
" self.drop_last_loader = drop_last_loader\n",
" # used by on_validation_epoch_end hook\n",
" self.validation_step_outputs = []\n",
Expand Down
9 changes: 7 additions & 2 deletions nbs/docs/tutorials/18_adding_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@
" step_size: int = 1,\n",
" scaler_type: str = 'identity',\n",
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" **trainer_kwargs):\n",
" # Inherit BaseWindows class\n",
Expand Down Expand Up @@ -415,7 +414,13 @@
]
}
],
"metadata": {},
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 3 additions & 0 deletions nbs/models.autoformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -511,6 +512,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(Autoformer, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand Down Expand Up @@ -539,6 +541,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.bitcn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
"\n",
" **References**<br> \n",
Expand Down Expand Up @@ -224,6 +225,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(BiTCN, self).__init__(\n",
" h=h,\n",
Expand Down Expand Up @@ -253,6 +255,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.deepar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
"\n",
" **References**<br>\n",
Expand Down Expand Up @@ -234,6 +235,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" if exclude_insample_y:\n",
Expand Down Expand Up @@ -276,6 +278,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" self.horizon_backup = self.h # Used because h=0 during training\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.deepnpts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
"\n",
" **References**<br>\n",
Expand Down Expand Up @@ -169,6 +170,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" if exclude_insample_y:\n",
Expand Down Expand Up @@ -208,6 +210,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" self.h = h\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.dilated_rnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br> \n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
" \"\"\"\n",
" # Class attributes\n",
Expand Down Expand Up @@ -433,6 +434,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(DilatedRNN, self).__init__(\n",
" h=h,\n",
Expand All @@ -458,6 +460,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.dlinear.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -206,6 +207,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(DLinear, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand Down Expand Up @@ -234,6 +236,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
" \n",
" # Architecture\n",
Expand Down
5 changes: 4 additions & 1 deletion nbs/models.fedformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
" \"\"\"\n",
Expand Down Expand Up @@ -503,6 +504,7 @@
" optimizer_kwargs=None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(FEDformer, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand All @@ -529,7 +531,8 @@
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs, \n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs, \n",
" **trainer_kwargs)\n",
" # Architecture\n",
" self.label_len = int(np.ceil(input_size * decoder_input_size_multiplier))\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.gru.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
" \"\"\"\n",
" # Class attributes\n",
Expand Down Expand Up @@ -168,6 +169,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(GRU, self).__init__(\n",
" h=h,\n",
Expand All @@ -193,6 +195,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
Expand Down
3 changes: 3 additions & 0 deletions nbs/models.informer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
" `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
" `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. <br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -359,6 +360,7 @@
" optimizer_kwargs = None,\n",
" lr_scheduler = None,\n",
" lr_scheduler_kwargs = None,\n",
" dataloader_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(Informer, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand Down Expand Up @@ -387,6 +389,7 @@
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" dataloader_kwargs=dataloader_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
Expand Down
Loading