Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add option to modify the default configure_optimizers() behavior #1015

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ nbdev_export
If you're working on the local interface you can just use `nbdev_test --n_workers 1 --do_print --timing`.

### Cleaning notebooks
Since the notebooks output cells can vary from run to run (even if they produce the same outputs) the notebooks are cleaned before committing them. Please make sure to run `nbdev_clean --clear_all` before committing your changes. If you clean the library's notebooks with this command please backtrack the changes you make to the example notebooks `git checkout nbs/examples`, unless you intend to change the examples.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revise this outdated path

Since the notebooks output cells can vary from run to run (even if they produce the same outputs) the notebooks are cleaned before committing them. Please make sure to run `nbdev_clean --clear_all` before committing your changes. If you clean the library's notebooks with this command please backtrack the changes you make to the example notebooks `git checkout nbs/docs`, unless you intend to change the examples.

## Do you want to contribute to the documentation?

Expand All @@ -78,6 +78,6 @@ Since the notebooks output cells can vary from run to run (even if they produce
1. Find the relevant notebook.
2. Make your changes.
3. Run all cells.
4. If you are modifying library notebooks (not in `nbs/examples`), clean all outputs using `Edit > Clear All Outputs`.
4. If you are modifying library notebooks (not in `nbs/docs`), clean all outputs using `Edit > Clear All Outputs`.
5. Run `nbdev_preview`.
6. Clean the notebook metadata using `nbdev_clean`.
63 changes: 15 additions & 48 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"import random\n",
"import warnings\n",
"from contextlib import contextmanager\n",
"from copy import deepcopy\n",
"from dataclasses import dataclass\n",
"\n",
"import fsspec\n",
Expand Down Expand Up @@ -121,15 +120,12 @@
" random_seed,\n",
" loss,\n",
" valid_loss,\n",
" optimizer,\n",
" optimizer_kwargs,\n",
" lr_scheduler,\n",
" lr_scheduler_kwargs,\n",
" futr_exog_list,\n",
" hist_exog_list,\n",
" stat_exog_list,\n",
" max_steps,\n",
" early_stop_patience_steps,\n",
" config_optimizers=None,\n",
" **trainer_kwargs,\n",
" ):\n",
" super().__init__()\n",
Expand All @@ -150,18 +146,8 @@
" self.train_trajectories = []\n",
" self.valid_trajectories = []\n",
"\n",
" # Optimization\n",
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n",
"\n",
" # lr scheduler\n",
" if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
" raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
" self.lr_scheduler = lr_scheduler\n",
" self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
"\n",
" # function has the same signature as LightningModule's configure_optimizer\n",
" self.config_optimizers = config_optimizers\n",
"\n",
" # Variables\n",
" self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
Expand Down Expand Up @@ -399,39 +385,20 @@
" random.seed(self.random_seed)\n",
"\n",
" def configure_optimizers(self):\n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
" if 'lr' in optimizer_signature.parameters:\n",
" if 'lr' in optimizer_kwargs:\n",
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
" optimizer_kwargs['lr'] = self.learning_rate\n",
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
" else:\n",
" if self.optimizer_kwargs:\n",
" warnings.warn(\n",
" \"ignoring optimizer_kwargs as the optimizer is not specified\"\n",
" ) \n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" if self.config_optimizers is not None:\n",
" # return the customized optimizer settings if specified\n",
" return self.config_optimizers(self)\n",
" \n",
" lr_scheduler = {'frequency': 1, 'interval': 'step'}\n",
" if self.lr_scheduler:\n",
" lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n",
" lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n",
" if 'optimizer' in lr_scheduler_signature.parameters:\n",
" if 'optimizer' in lr_scheduler_kwargs:\n",
" warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n",
" del lr_scheduler_kwargs['optimizer']\n",
" lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n",
" else:\n",
" if self.lr_scheduler_kwargs:\n",
" warnings.warn(\n",
" \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n",
" ) \n",
" lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n",
" # default choice\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {\n",
" \"scheduler\": torch.optim.lr_scheduler.StepLR(\n",
" optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
" )\n",
" return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
" ),\n",
" \"frequency\": 1,\n",
" \"interval\": \"step\",\n",
" }\n",
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
"\n",
" def get_test_size(self):\n",
" return self.test_size\n",
Expand Down
10 changes: 2 additions & 8 deletions nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,25 +104,19 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" config_optimizers=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs, \n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" max_steps=max_steps,\n",
" early_stop_patience_steps=early_stop_patience_steps,\n",
" config_optimizers=config_optimizers,\n",
" **trainer_kwargs,\n",
" )\n",
"\n",
Expand Down
12 changes: 3 additions & 9 deletions nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,19 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" config_optimizers=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" max_steps=max_steps,\n",
" early_stop_patience_steps=early_stop_patience_steps, \n",
" early_stop_patience_steps=early_stop_patience_steps,\n",
" config_optimizers=config_optimizers,\n",
" **trainer_kwargs,\n",
" )\n",
"\n",
Expand Down
12 changes: 3 additions & 9 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,19 @@
" drop_last_loader=False,\n",
" random_seed=1,\n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" config_optimizers=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" max_steps=max_steps,\n",
" early_stop_patience_steps=early_stop_patience_steps, \n",
" early_stop_patience_steps=early_stop_patience_steps,\n",
" config_optimizers=config_optimizers,\n",
" **trainer_kwargs,\n",
" )\n",
"\n",
Expand Down
Loading
Loading