Skip to content

Commit

Permalink
Merge pull request #310 from kozistr/fix/lookahead-device
Browse files Browse the repository at this point in the history
[Fix] Save and load the Lookahead optimizer's state
  • Loading branch information
kozistr authored Dec 14, 2024
2 parents 913f7b7 + 1cf2690 commit 7bb85f9
Show file tree
Hide file tree
Showing 21 changed files with 76 additions and 25 deletions.
7 changes: 5 additions & 2 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ A clear and concise description of what the bug is.

* OS : (e.g. Linux, Windows, MacOS)
* PyTorch version : (e.g. 2.0.1, 1.13, >=1.8, <1.10)
* Python version : (e.g. 3.8, 3.11
* reproducible codes :
* Python version : (e.g. 3.8, 3.11)
* pytorch-optimizer version : (e.g. 3.3.0)
* reproducible codes : please share your reproducible codes, scripts, or links. If sharing the code is complicated, you can manually write minimal code to reproduce bugs!

Here's an [example](https://github.com/kozistr/pytorch_optimizer/issues/305#issue-2721453417).

## Log

Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

## The reasons why you use `pytorch-optimizer`.

1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
3. Easy to use, clean, and tested codes
4. Active maintenance
5. Somewhat a bit more optimized compared to the original implementation
* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
* Somewhat a bit more optimized compared to the original implementation

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down
8 changes: 8 additions & 0 deletions docs/changelogs/v3.3.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

### Feature

* Support `Cautious` variant to `AdaShift` optimizer. (#310)
* Save the state of the `Lookahead` optimizer too. (#310)

### Bug

* Fix `bias_correction` in `AdamG` optimizer. (#305, #308)
* Fix a potential bug when loading the state for `Lookahead` optimizer. (#306, #310)

### Docs

* Add more visualizations. (#310)

### Contributions

Expand Down
10 changes: 5 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

## The reasons why you use `pytorch-optimizer`.

1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
3. Easy to use, clean, and tested codes
4. Active maintenance
5. Somewhat a bit more optimized compared to the original implementation
* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
* Somewhat a bit more optimized compared to the original implementation

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down
32 changes: 32 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaPNM.png)

### AdaShift

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaShift.png)

### AdaSmooth

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaSmooth.png)
Expand Down Expand Up @@ -170,6 +174,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Lamb.png)

### LaProp

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_LaProp.png)

### LARS

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_LARS.png)
Expand All @@ -186,6 +194,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_MSVAG.png)

### Muon

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Muon.png)

### Nero

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Nero.png)
Expand Down Expand Up @@ -238,6 +250,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeAdamW.png)

### ScheduleFreeRAdam

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeRAdam.png)

### ScheduleFreeSGD

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeSGD.png)
Expand Down Expand Up @@ -368,6 +384,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaPNM.png)

### AdaShift

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaShift.png)

### AdaSmooth

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaSmooth.png)
Expand Down Expand Up @@ -464,6 +484,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Lamb.png)

### LaProp

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_LaProp.png)

### LARS

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_LARS.png)
Expand All @@ -480,6 +504,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_MSVAG.png)

### Muon

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Muon.png)

### Nero

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Nero.png)
Expand Down Expand Up @@ -532,6 +560,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeAdamW.png)

### ScheduleFreeRAdam

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeRAdam.png)

### ScheduleFreeSGD

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeSGD.png)
Expand Down
Binary file added docs/visualizations/rastrigin_AdaShift.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/visualizations/rastrigin_AdamG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rastrigin_LaProp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rastrigin_Muon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_AdaShift.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/visualizations/rosenbrock_AdamG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_LaProp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_Muon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion examples/visualize_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def execute_steps(func, initial_state, optimizer_class, optimizer_config, num_it

if optimizer_class.__name__ == 'Ranger21':
optimizer_config.update({'num_iterations': num_iters})
if optimizer_class.__name__ == 'AdaShift':
optimizer_config.update({'keep_num': 1})

optimizer = optimizer_class([x], **optimizer_config)

Expand Down Expand Up @@ -155,7 +157,7 @@ def main():
optimizers = [
(optimizer, -6, 0.5)
for optimizer_name, optimizer in OPTIMIZERS.items()
if optimizer_name.lower() not in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini'}
if optimizer_name.lower() not in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini', 'demo'}
]
optimizers.extend([(torch.optim.AdamW, -6, 0.5), (torch.optim.Adam, -6, 0.5), (torch.optim.SGD, -6, -1.0)])

Expand Down
6 changes: 4 additions & 2 deletions pytorch_optimizer/optimizer/a2grad.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import math
from typing import Optional
from typing import Literal, Optional

import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS

VARIANTS = Literal['uni', 'inc', 'exp']


class A2Grad(BaseOptimizer):
r"""Optimal Adaptive and Accelerated Stochastic Gradient Descent.
Expand All @@ -26,7 +28,7 @@ def __init__(
beta: float = 10.0,
lips: float = 10.0,
rho: float = 0.5,
variant: str = 'uni',
variant: VARIANTS = 'uni',
**kwargs,
):
self.validate_learning_rate(lr)
Expand Down
12 changes: 9 additions & 3 deletions pytorch_optimizer/optimizer/adashift.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AdaShift(BaseOptimizer):
:param keep_num: int. number of gradients used to compute first moment estimation.
:param reduce_func: Optional[Callable]. function applied to squared gradients to further reduce the correlation.
If None, no function is applied.
:param cautious: bool. whether to use cautious feature.
:param eps: float. term added to the denominator to improve numerical stability.
"""

Expand All @@ -27,6 +28,7 @@ def __init__(
betas: BETAS = (0.9, 0.999),
keep_num: int = 10,
reduce_func: Optional[Callable] = torch.max,
cautious: bool = False,
eps: float = 1e-10,
**kwargs,
):
Expand All @@ -36,6 +38,7 @@ def __init__(
self.validate_non_negative(eps, 'eps')

self.reduce_func: Callable = reduce_func if reduce_func is not None else lambda x: x
self.cautious = cautious

defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'keep_num': keep_num, 'eps': eps}
super().__init__(params, defaults)
Expand Down Expand Up @@ -101,13 +104,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg = state['exp_avg']
exp_avg.sub_(offset_grad, alpha=first_grad_weight).mul_(beta1).add_(grad, alpha=last_grad_weight)

reduced_grad_sq = self.reduce_func(offset_grad.mul_(offset_grad))
reduced_grad_sq = self.reduce_func(offset_grad.pow_(2))

exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2)

de_nom = exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps'])
update = exp_avg.clone()
update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))
if self.cautious:
self.apply_cautious(update, grad)

p.addcdiv_(exp_avg, de_nom, value=-group['lr'])
p.add_(update, alpha=-group['lr'])

return loss
8 changes: 4 additions & 4 deletions pytorch_optimizer/optimizer/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ def __init__(
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
self._optimizer_step_post_hooks: Dict[int, Callable] = {}

self.optimizer = optimizer
self.alpha = alpha
self.k = k
self.pullback_momentum = pullback_momentum

self.optimizer = optimizer

self.state: STATE = defaultdict(dict)

for group in self.param_groups:
Expand Down Expand Up @@ -93,11 +92,12 @@ def clear_and_load_backup(self):
del state['backup_params']

def state_dict(self) -> STATE:
return self.optimizer.state_dict()
return {'lookahead_state': self.state, 'base_optimizer': self.optimizer.state_dict()}

def load_state_dict(self, state: STATE):
r"""Load state."""
self.optimizer.load_state_dict(state)
self.state = state['lookahead_state']
self.optimizer.load_state_dict(state['base_optimizer'])

@torch.no_grad()
def zero_grad(self):
Expand Down
1 change: 1 addition & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,4 +552,5 @@
(LaProp, {'lr': 1e0, 'cautious': True}, 2),
(AdamP, {'lr': 1e0, 'cautious': True}, 2),
(ADOPT, {'lr': 1e1, 'cautious': True}, 3),
(AdaShift, {'lr': 1e1, 'keep_num': 1, 'cautious': True}, 3),
]
3 changes: 0 additions & 3 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,12 @@ def test_lookahead_parameters():

_ = opt.__getstate__()

# test lookahead step `k`
with pytest.raises(ValueError):
Lookahead(optimizer, k=0)

# test ema ratio `alpha`
with pytest.raises(ValueError):
Lookahead(optimizer, alpha=-0.1)

# test invalid pullback momentum type
with pytest.raises(ValueError):
Lookahead(optimizer, pullback_momentum='invalid')

Expand Down

0 comments on commit 7bb85f9

Please sign in to comment.