diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 2df85f3b8..0fc7e0cae 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -15,8 +15,11 @@ A clear and concise description of what the bug is. * OS : (e.g. Linux, Windows, MacOS) * PyTorch version : (e.g. 2.0.1, 1.13, >=1.8, <1.10) -* Python version : (e.g. 3.8, 3.11 -* reproducible codes : +* Python version : (e.g. 3.8, 3.11) +* pytorch-optimizer version : (e.g. 3.3.0) +* reproducible codes : please share your reproducible codes, scripts, or links. If sharing the code is complicated, you can manually write minimal code to reproduce bugs! + +Here's an [example](https://github.com/kozistr/pytorch_optimizer/issues/305#issue-2721453417). ## Log diff --git a/README.md b/README.md index d565c09f5..6f550c2f8 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,11 @@ ## The reasons why you use `pytorch-optimizer`. -1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! -2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` -3. Easy to use, clean, and tested codes -4. Active maintenance -5. Somewhat a bit more optimized compared to the original implementation +* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` +* Easy to use, clean, and tested codes +* Active maintenance +* Somewhat a bit more optimized compared to the original implementation Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). diff --git a/docs/changelogs/v3.3.1.md b/docs/changelogs/v3.3.1.md index 291e8341a..9b6eed730 100644 --- a/docs/changelogs/v3.3.1.md +++ b/docs/changelogs/v3.3.1.md @@ -2,9 +2,17 @@ ### Feature +* Support `Cautious` variant to `AdaShift` optimizer. (#310) +* Save the state of the `Lookahead` optimizer too. (#310) + ### Bug * Fix `bias_correction` in `AdamG` optimizer. (#305, #308) +* Fix a potential bug when loading the state for `Lookahead` optimizer. (#306, #310) + +### Docs + +* Add more visualizations. (#310) ### Contributions diff --git a/docs/index.md b/docs/index.md index d565c09f5..6f550c2f8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,11 +10,11 @@ ## The reasons why you use `pytorch-optimizer`. -1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! -2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` -3. Easy to use, clean, and tested codes -4. Active maintenance -5. Somewhat a bit more optimized compared to the original implementation +* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` +* Easy to use, clean, and tested codes +* Active maintenance +* Somewhat a bit more optimized compared to the original implementation Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). diff --git a/docs/visualization.md b/docs/visualization.md index 961035763..0d797b81d 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -74,6 +74,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaPNM.png) +### AdaShift + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaShift.png) + ### AdaSmooth ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaSmooth.png) @@ -170,6 +174,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Lamb.png) +### LaProp + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_LaProp.png) + ### LARS ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_LARS.png) @@ -186,6 +194,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_MSVAG.png) +### Muon + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Muon.png) + ### Nero ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Nero.png) @@ -238,6 +250,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeAdamW.png) +### ScheduleFreeRAdam + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeRAdam.png) + ### ScheduleFreeSGD ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeSGD.png) @@ -368,6 +384,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaPNM.png) +### AdaShift + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaShift.png) + ### AdaSmooth ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaSmooth.png) @@ -464,6 +484,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Lamb.png) +### LaProp + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_LaProp.png) + ### LARS ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_LARS.png) @@ -480,6 +504,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_MSVAG.png) +### Muon + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Muon.png) + ### Nero ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Nero.png) @@ -532,6 +560,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeAdamW.png) +### ScheduleFreeRAdam + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeRAdam.png) + ### ScheduleFreeSGD ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeSGD.png) diff --git a/docs/visualizations/rastrigin_AdaShift.png b/docs/visualizations/rastrigin_AdaShift.png new file mode 100644 index 000000000..d99aea838 Binary files /dev/null and b/docs/visualizations/rastrigin_AdaShift.png differ diff --git a/docs/visualizations/rastrigin_AdamG.png b/docs/visualizations/rastrigin_AdamG.png index 62724d85c..7bb51a1f3 100644 Binary files a/docs/visualizations/rastrigin_AdamG.png and b/docs/visualizations/rastrigin_AdamG.png differ diff --git a/docs/visualizations/rastrigin_LaProp.png b/docs/visualizations/rastrigin_LaProp.png new file mode 100644 index 000000000..ff077af57 Binary files /dev/null and b/docs/visualizations/rastrigin_LaProp.png differ diff --git a/docs/visualizations/rastrigin_Muon.png b/docs/visualizations/rastrigin_Muon.png new file mode 100644 index 000000000..a7cea61da Binary files /dev/null and b/docs/visualizations/rastrigin_Muon.png differ diff --git a/docs/visualizations/rastrigin_ScheduleFreeRAdam.png b/docs/visualizations/rastrigin_ScheduleFreeRAdam.png new file mode 100644 index 000000000..01f943801 Binary files /dev/null and b/docs/visualizations/rastrigin_ScheduleFreeRAdam.png differ diff --git a/docs/visualizations/rosenbrock_AdaShift.png b/docs/visualizations/rosenbrock_AdaShift.png new file mode 100644 index 000000000..462ecc73c Binary files /dev/null and b/docs/visualizations/rosenbrock_AdaShift.png differ diff --git a/docs/visualizations/rosenbrock_AdamG.png b/docs/visualizations/rosenbrock_AdamG.png index e8f1e5b05..f7eab6eb9 100644 Binary files a/docs/visualizations/rosenbrock_AdamG.png and b/docs/visualizations/rosenbrock_AdamG.png differ diff --git a/docs/visualizations/rosenbrock_LaProp.png b/docs/visualizations/rosenbrock_LaProp.png new file mode 100644 index 000000000..6a29ee9db Binary files /dev/null and b/docs/visualizations/rosenbrock_LaProp.png differ diff --git a/docs/visualizations/rosenbrock_Muon.png b/docs/visualizations/rosenbrock_Muon.png new file mode 100644 index 000000000..a86b0b496 Binary files /dev/null and b/docs/visualizations/rosenbrock_Muon.png differ diff --git a/docs/visualizations/rosenbrock_ScheduleFreeRAdam.png b/docs/visualizations/rosenbrock_ScheduleFreeRAdam.png new file mode 100644 index 000000000..feaeba150 Binary files /dev/null and b/docs/visualizations/rosenbrock_ScheduleFreeRAdam.png differ diff --git a/examples/visualize_optimizers.py b/examples/visualize_optimizers.py index 44c41925e..80f645877 100644 --- a/examples/visualize_optimizers.py +++ b/examples/visualize_optimizers.py @@ -31,6 +31,8 @@ def execute_steps(func, initial_state, optimizer_class, optimizer_config, num_it if optimizer_class.__name__ == 'Ranger21': optimizer_config.update({'num_iterations': num_iters}) + if optimizer_class.__name__ == 'AdaShift': + optimizer_config.update({'keep_num': 1}) optimizer = optimizer_class([x], **optimizer_config) @@ -155,7 +157,7 @@ def main(): optimizers = [ (optimizer, -6, 0.5) for optimizer_name, optimizer in OPTIMIZERS.items() - if optimizer_name.lower() not in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini'} + if optimizer_name.lower() not in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini', 'demo'} ] optimizers.extend([(torch.optim.AdamW, -6, 0.5), (torch.optim.Adam, -6, 0.5), (torch.optim.SGD, -6, -1.0)]) diff --git a/pytorch_optimizer/optimizer/a2grad.py b/pytorch_optimizer/optimizer/a2grad.py index 6da2604f5..46ac668fe 100644 --- a/pytorch_optimizer/optimizer/a2grad.py +++ b/pytorch_optimizer/optimizer/a2grad.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Literal, Optional import torch @@ -7,6 +7,8 @@ from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +VARIANTS = Literal['uni', 'inc', 'exp'] + class A2Grad(BaseOptimizer): r"""Optimal Adaptive and Accelerated Stochastic Gradient Descent. @@ -26,7 +28,7 @@ def __init__( beta: float = 10.0, lips: float = 10.0, rho: float = 0.5, - variant: str = 'uni', + variant: VARIANTS = 'uni', **kwargs, ): self.validate_learning_rate(lr) diff --git a/pytorch_optimizer/optimizer/adashift.py b/pytorch_optimizer/optimizer/adashift.py index bd8179d05..64f78b7ea 100644 --- a/pytorch_optimizer/optimizer/adashift.py +++ b/pytorch_optimizer/optimizer/adashift.py @@ -17,6 +17,7 @@ class AdaShift(BaseOptimizer): :param keep_num: int. number of gradients used to compute first moment estimation. :param reduce_func: Optional[Callable]. function applied to squared gradients to further reduce the correlation. If None, no function is applied. + :param cautious: bool. whether to use cautious feature. :param eps: float. term added to the denominator to improve numerical stability. """ @@ -27,6 +28,7 @@ def __init__( betas: BETAS = (0.9, 0.999), keep_num: int = 10, reduce_func: Optional[Callable] = torch.max, + cautious: bool = False, eps: float = 1e-10, **kwargs, ): @@ -36,6 +38,7 @@ def __init__( self.validate_non_negative(eps, 'eps') self.reduce_func: Callable = reduce_func if reduce_func is not None else lambda x: x + self.cautious = cautious defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'keep_num': keep_num, 'eps': eps} super().__init__(params, defaults) @@ -101,13 +104,16 @@ def step(self, closure: CLOSURE = None) -> LOSS: exp_avg = state['exp_avg'] exp_avg.sub_(offset_grad, alpha=first_grad_weight).mul_(beta1).add_(grad, alpha=last_grad_weight) - reduced_grad_sq = self.reduce_func(offset_grad.mul_(offset_grad)) + reduced_grad_sq = self.reduce_func(offset_grad.pow_(2)) exp_avg_sq = state['exp_avg_sq'] exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2) - de_nom = exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']) + update = exp_avg.clone() + update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps'])) + if self.cautious: + self.apply_cautious(update, grad) - p.addcdiv_(exp_avg, de_nom, value=-group['lr']) + p.add_(update, alpha=-group['lr']) return loss diff --git a/pytorch_optimizer/optimizer/lookahead.py b/pytorch_optimizer/optimizer/lookahead.py index 63a15ee15..84e845207 100644 --- a/pytorch_optimizer/optimizer/lookahead.py +++ b/pytorch_optimizer/optimizer/lookahead.py @@ -31,12 +31,11 @@ def __init__( self._optimizer_step_pre_hooks: Dict[int, Callable] = {} self._optimizer_step_post_hooks: Dict[int, Callable] = {} + self.optimizer = optimizer self.alpha = alpha self.k = k self.pullback_momentum = pullback_momentum - self.optimizer = optimizer - self.state: STATE = defaultdict(dict) for group in self.param_groups: @@ -93,11 +92,12 @@ def clear_and_load_backup(self): del state['backup_params'] def state_dict(self) -> STATE: - return self.optimizer.state_dict() + return {'lookahead_state': self.state, 'base_optimizer': self.optimizer.state_dict()} def load_state_dict(self, state: STATE): r"""Load state.""" - self.optimizer.load_state_dict(state) + self.state = state['lookahead_state'] + self.optimizer.load_state_dict(state['base_optimizer']) @torch.no_grad() def zero_grad(self): diff --git a/tests/constants.py b/tests/constants.py index de3c1d446..21b4e0bf2 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -552,4 +552,5 @@ (LaProp, {'lr': 1e0, 'cautious': True}, 2), (AdamP, {'lr': 1e0, 'cautious': True}, 2), (ADOPT, {'lr': 1e1, 'cautious': True}, 3), + (AdaShift, {'lr': 1e1, 'keep_num': 1, 'cautious': True}, 3), ] diff --git a/tests/test_optimizer_parameters.py b/tests/test_optimizer_parameters.py index e7610e9eb..901490136 100644 --- a/tests/test_optimizer_parameters.py +++ b/tests/test_optimizer_parameters.py @@ -77,15 +77,12 @@ def test_lookahead_parameters(): _ = opt.__getstate__() - # test lookahead step `k` with pytest.raises(ValueError): Lookahead(optimizer, k=0) - # test ema ratio `alpha` with pytest.raises(ValueError): Lookahead(optimizer, alpha=-0.1) - # test invalid pullback momentum type with pytest.raises(ValueError): Lookahead(optimizer, pullback_momentum='invalid')