Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement SGDSaI optimizer #316

Merged
merged 14 commits into from
Dec 21, 2024
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **85 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -194,6 +194,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
10 changes: 10 additions & 0 deletions docs/changelogs/v3.3.2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
### Change Log

### Feature

* Implement `SGDSaI` optimizer. (#315, #316)
* [No More Adam: Learning Rate Scaling at Initialization is All You Need](https://arxiv.org/abs/2412.11768)

### Bug

* Clone `exp_avg` before calling `apply_cautious` not to mask `exp_avg`. (#316)
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **85 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -194,6 +194,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@
:docstring:
:members:

::: pytorch_optimizer.SGDSaI
:docstring:
:members:

::: pytorch_optimizer.SGDP
:docstring:
:members:
Expand Down
8 changes: 8 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDP.png)

### SGDSaI

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDSaI.png)

### SGDW

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDW.png)
Expand Down Expand Up @@ -592,6 +596,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDP.png)

### SGDSaI

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDSaI.png)

### SGDW

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDW.png)
Expand Down
Binary file added docs/visualizations/rastrigin_SGDSaI.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_SGDSaI.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "3.3.1"
version = "3.3.2"
description = "optimizer & lr scheduler & objective function collections in PyTorch"
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
ScheduleFreeAdamW,
ScheduleFreeRAdam,
ScheduleFreeSGD,
SGDSaI,
Shampoo,
SignSGD,
SophiaH,
Expand Down
3 changes: 2 additions & 1 deletion pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from pytorch_optimizer.optimizer.rotograd import RotoGrad
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
from pytorch_optimizer.optimizer.sgdp import SGDP
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
from pytorch_optimizer.optimizer.sm3 import SM3
Expand Down Expand Up @@ -281,6 +281,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
ScheduleFreeRAdam,
LaProp,
MARS,
SGDSaI,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
3 changes: 2 additions & 1 deletion pytorch_optimizer/optimizer/adashift.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2)

update = exp_avg.clone()
update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))
if self.cautious:
self.apply_cautious(update, grad)

update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))

p.add_(update, alpha=-group['lr'])

return loss
5 changes: 3 additions & 2 deletions pytorch_optimizer/optimizer/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:

de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

update = exp_avg.clone()
if self.cautious:
self.apply_cautious(exp_avg, grad)
self.apply_cautious(update, grad)

update = (exp_avg + alpha_t * exp_avg_slow).div_(de_nom)
update.add_(exp_avg_slow, alpha=alpha_t).div_(de_nom)

p.add_(update, alpha=-step_size)

Expand Down
25 changes: 14 additions & 11 deletions pytorch_optimizer/optimizer/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,27 @@ def optimize_mixed(

exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1)

update = exp_avg.clone()
if cautious:
self.apply_cautious(exp_avg, grad)
self.apply_cautious(update, grad)

if mars_type == 'adamw' or (mars_type == 'shampoo' and not is_grad_2d):
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2)

bias_correction1: float = self.debias(beta1, step)
bias_correction2_sq: float = math.sqrt(self.debias(beta2, step))

update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
update.div_(bias_correction2_sq).mul_(bias_correction1)
de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
de_nom.div_(bias_correction2_sq).mul_(bias_correction1)

return exp_avg.div(update)
return update.div_(de_nom)

if mars_type == 'lion':
return exp_avg.sign()
return update.sign_()

factor: float = max(1.0, grad.size(0) / grad.size(1)) ** 0.5
factor: float = math.sqrt(max(1.0, grad.size(0) / grad.size(1)))

return zero_power_via_newton_schulz_5(exp_avg.mul(1.0 / (1.0 - beta1)), eps=eps).mul_(factor)
return zero_power_via_newton_schulz_5(update.mul_(1.0 / (1.0 - beta1)), eps=eps).mul_(factor)

def optimize_1d(
self,
Expand All @@ -162,13 +163,15 @@ def optimize_1d(
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
update.div_(bias_correction2_sq).mul_(bias_correction1)
update = exp_avg.clone()

if cautious:
self.apply_cautious(exp_avg, grad)
self.apply_cautious(update, grad)

return exp_avg.div(update)
de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
de_nom.div_(bias_correction2_sq).mul_(bias_correction1)

return update.div_(de_nom)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
Expand Down
123 changes: 123 additions & 0 deletions pytorch_optimizer/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ def __init__(
}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'SignSGD'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
Expand Down Expand Up @@ -396,3 +399,123 @@ def step(self, closure: CLOSURE = None) -> LOSS:
p.add_(torch.sign(buf), alpha=-group['lr'])

return loss


class SGDSaI(BaseOptimizer):
r"""No More Adam: Learning Rate Scaling at Initialization is All You Need.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param momentum: float. coefficients used for computing running averages of gradient.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 1e-2,
weight_decouple: bool = True,
eps: float = 1e-8,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_range(momentum, 'beta', 0.0, 1.0)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

self.has_warmup: bool = False

defaults: DEFAULTS = {
'lr': lr,
'momentum': momentum,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'eps': eps,
}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'SGDSaI'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

if group['momentum'] > 0.0:
state['momentum_buffer'] = torch.zeros_like(p)

@torch.no_grad()
def warmup_step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

sigma = grad.std().nan_to_num_()
grad_norm = grad.norm()

g_snr = grad_norm.div_(sigma.add_(group['eps'])) if sigma != 0.0 else grad_norm

self.state[p]['gsnr'] = g_snr

self.has_warmup = True

return loss

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
if not self.has_warmup:
self.warmup_step(closure)

loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
momentum: float = group['momentum']
for p in group['params']:
if p.grad is None:
continue

grad = p.grad

state = self.state[p]

if momentum > 0.0:
if 'momentum_buffer' not in state:
state['momentum_buffer'] = grad.clone()

buf = state['momentum_buffer']
buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)
else:
buf = grad

self.apply_weight_decay(
p,
grad,
group['lr'],
group['weight_decay'],
group['weight_decouple'],
False,
)

p.add_(buf, alpha=-group['lr'] * state['gsnr'])

return loss
3 changes: 3 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
ScheduleFreeAdamW,
ScheduleFreeRAdam,
ScheduleFreeSGD,
SGDSaI,
Shampoo,
SignSGD,
SophiaH,
Expand Down Expand Up @@ -538,6 +539,8 @@
(MARS, {'lr': 1e-1, 'weight_decay': 1e-3, 'mars_type': 'lion', 'optimize_1d': True}, 5),
(MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'mars_type': 'shampoo'}, 5),
(MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'mars_type': 'adamw', 'ams_bound': True}, 5),
(SGDSaI, {'lr': 1e0}, 15),
(SGDSaI, {'lr': 1e0, 'momentum': 0.0}, 15),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 84
assert len(get_supported_optimizers()) == 85
assert len(get_supported_optimizers('adam*')) == 7
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 9

Expand Down
Loading