Skip to content

Commit

Permalink
Add optimizer accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Mar 4, 2024
1 parent 3618921 commit efb7e5e
Show file tree
Hide file tree
Showing 18 changed files with 395 additions and 181 deletions.
34 changes: 22 additions & 12 deletions optimi/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=group["foreach"],
gradient_release=False,
optimizer_accumulation=False,
)
else:
state = self.state[param]
Expand All @@ -193,6 +194,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=False,
gradient_release=True,
optimizer_accumulation=self._optimizer_accumulation,
)

return loss
Expand All @@ -217,6 +219,7 @@ def adam(
kahan_sum: bool = False,
foreach: bool = False,
gradient_release: bool = False,
optimizer_accumulation: bool = False,
):
"""Functional API to apply an Adam or AdamW optimization step.
Expand All @@ -240,6 +243,7 @@ def adam(
kahan_sum: Enables Kahan summation for low precision parameters
foreach: Enables the faster foreach implementation
gradient_release: Fuses optimizer step as part of the parameter's backward pass
optimizer_accumulation: Accumulate gradients into state during gradient release step
"""
# calculate debiased beta hat & complement terms
step.add_(1)
Expand Down Expand Up @@ -276,6 +280,7 @@ def adam(
eps=eps,
decouple_wd=(decouple_wd or decouple_lr),
kahan_sum=kahan_sum,
update_parameters=(not optimizer_accumulation),
)


Expand All @@ -293,6 +298,7 @@ def _single_adam(
eps: float,
decouple_wd: bool,
kahan_sum: bool = False,
update_parameters: bool = True,
):
for i, param in enumerate(params):
grad = grads[i]
Expand All @@ -313,6 +319,7 @@ def _single_adam(
eps=eps,
decouple_wd=decouple_wd,
kahan_sum=kahan_sum,
update_parameters=update_parameters,
)


Expand All @@ -330,9 +337,10 @@ def _single_param_adam(
eps: float,
decouple_wd: bool,
kahan_sum: bool = False,
update_parameters: bool = True,
):
# decoupled weight decay, fully decoupled weight decay, or L2 weight decay
if weight_decay != 0:
if weight_decay != 0 and update_parameters:
if decouple_wd:
param.mul_(weight_decay)
else:
Expand All @@ -342,19 +350,20 @@ def _single_param_adam(
exp_avg.lerp_(grad, weight=beta1_comp)
exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1 - beta2_hat)

if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# Adam step
kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr)
if update_parameters:
if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# Adam step
kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr)

# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)
# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)

# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
else:
# Adam step
param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr)
# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
else:
# Adam step
param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr)


def _foreach_adam(
Expand All @@ -371,6 +380,7 @@ def _foreach_adam(
eps: float,
decouple_wd: bool,
kahan_sum: bool = False,
**kwargs,
):
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, kahan_comps])
for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_kahan_comps), _) in grouped_tensors.items():
Expand Down
3 changes: 3 additions & 0 deletions optimi/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def adamw(
kahan_sum: bool = False,
foreach: bool = False,
gradient_release: bool = False,
optimizer_accumulation: bool = False,
):
"""Functional API to apply an AdamW optimization step.
Expand All @@ -113,6 +114,7 @@ def adamw(
kahan_sum: Enables Kahan summation for low precision `params`
foreach: Enables the faster foreach implementation
gradient_release: Fuses optimizer step as part of the parameter's backward pass
optimizer_accumulation: Accumulate gradients into state during gradient release step
"""
adam(
params=params,
Expand All @@ -132,4 +134,5 @@ def adamw(
kahan_sum=kahan_sum,
foreach=foreach,
gradient_release=gradient_release,
optimizer_accumulation=optimizer_accumulation,
)
52 changes: 31 additions & 21 deletions optimi/adan.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=group["foreach"],
gradient_release=False,
optimizer_accumulation=False,
)
else:
state = self.state[param]
Expand All @@ -216,6 +217,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=False,
gradient_release=True,
optimizer_accumulation=self._optimizer_accumulation,
)

return loss
Expand Down Expand Up @@ -243,6 +245,7 @@ def adan(
kahan_sum: bool = False,
foreach: bool = False,
gradient_release: bool = False,
optimizer_accumulation: bool = False,
):
"""Functional API to apply a Adan optimization step.
Expand All @@ -269,6 +272,7 @@ def adan(
kahan_sum: Enables Kahan summation for low precision parameters
foreach: Enables the faster foreach implementation
gradient_release: Fuses optimizer step as part of the parameter's backward pass
optimizer_accumulation: Accumulate gradients into state during gradient release step
"""
# calculate debiased beta hat & complement terms
step.add_(1)
Expand Down Expand Up @@ -315,6 +319,7 @@ def adan(
weight_decay=weight_decay,
adam_wd=adam_wd,
kahan_sum=kahan_sum,
update_parameters=(not optimizer_accumulation),
)


Expand All @@ -336,6 +341,7 @@ def _single_adan(
weight_decay: float,
adam_wd: bool,
kahan_sum: bool = False,
update_parameters: bool = True,
):
for i, param in enumerate(params):
grad = grads[i]
Expand All @@ -362,6 +368,7 @@ def _single_adan(
weight_decay=weight_decay,
adam_wd=adam_wd,
kahan_sum=kahan_sum,
update_parameters=update_parameters,
)


Expand All @@ -383,6 +390,7 @@ def _single_param_adan(
weight_decay: float,
adam_wd: bool,
kahan_sum: bool = False,
update_parameters: bool = True,
):
# difference between current & previous gradients, prev_grad is negated in last step
prev_grad.add_(grad)
Expand All @@ -400,32 +408,33 @@ def _single_param_adan(
# set next step's prior_grad as negated current grad
prev_grad.copy_(grad).mul_(-1)

# calculate 1/η_k using prev_grad as buffer. LR is multiplied in Adan step
denom = exp_avg_sq.sqrt().add_(eps)
if update_parameters:
# calculate 1/η_k using prev_grad as buffer. LR is multiplied in Adan step
denom = exp_avg_sq.sqrt().add_(eps)

# Adam-style weight decay
if adam_wd and weight_decay != 0:
param.mul_(weight_decay)
# Adam-style weight decay
if adam_wd and weight_decay != 0:
param.mul_(weight_decay)

if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# Adan step
kahan_comp.addcdiv_(exp_avg, denom, value=-lr)
kahan_comp.addcdiv_(exp_avg_diff, denom, value=-lr * beta2)
if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# Adan step
kahan_comp.addcdiv_(exp_avg, denom, value=-lr)
kahan_comp.addcdiv_(exp_avg_diff, denom, value=-lr * beta2)

# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)
# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)

# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
else:
# Adan step
param.addcdiv_(exp_avg, denom, value=-lr)
param.addcdiv_(exp_avg_diff, denom, value=-lr * beta2)
# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
else:
# Adan step
param.addcdiv_(exp_avg, denom, value=-lr)
param.addcdiv_(exp_avg_diff, denom, value=-lr * beta2)

# Adan-style weight decay
if not adam_wd and weight_decay != 0:
param.div_(weight_decay)
# Adan-style weight decay
if not adam_wd and weight_decay != 0:
param.div_(weight_decay)


def _foreach_adan(
Expand All @@ -446,6 +455,7 @@ def _foreach_adan(
weight_decay: float,
adam_wd: bool,
kahan_sum: bool = False,
**kwargs,
):
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, exp_avg_diffs, prev_grads, kahan_comps])
for (_, dtype), (
Expand Down
37 changes: 24 additions & 13 deletions optimi/lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=group["foreach"],
gradient_release=False,
optimizer_accumulation=False,
)
else:
state = self.state[param]
Expand All @@ -166,6 +167,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=False,
gradient_release=True,
optimizer_accumulation=self._optimizer_accumulation,
)

return loss
Expand All @@ -186,6 +188,7 @@ def lion(
kahan_sum: bool = False,
foreach: bool = False,
gradient_release: bool = False,
optimizer_accumulation: bool = False,
):
"""Functional API to apply a Lion optimization step.
Expand All @@ -205,6 +208,7 @@ def lion(
kahan_sum: Enables Kahan summation for low precision `params`
foreach: Enables the faster foreach implementation
gradient_release: Fuses optimizer step as part of the parameter's backward pass
optimizer_accumulation: Accumulate gradients into state during gradient release step
"""
# calculate decoupled weight decay or fully decoupled weight decay
if weight_decay != 0:
Expand Down Expand Up @@ -237,6 +241,7 @@ def lion(
beta2_comp=beta2_comp,
weight_decay=weight_decay,
kahan_sum=kahan_sum,
update_parameters=(not optimizer_accumulation),
)


Expand All @@ -251,6 +256,7 @@ def _single_lion(
beta2_comp: float,
weight_decay: float,
kahan_sum: bool = False,
update_parameters: bool = True,
):
for i, param in enumerate(params):
grad = grads[i]
Expand All @@ -267,6 +273,7 @@ def _single_lion(
beta2_comp=beta2_comp,
weight_decay=weight_decay,
kahan_sum=kahan_sum,
update_parameters=update_parameters,
)


Expand All @@ -281,30 +288,33 @@ def _single_param_lion(
beta2_comp: float,
weight_decay: float,
kahan_sum: bool = False,
update_parameters: bool = True,
):
# decoupled weight decay or fully decoupled weight decay
if weight_decay != 0:
if weight_decay != 0 and update_parameters:
param.mul_(weight_decay)

# parameter update value
update = exp_avg.lerp(grad, weight=beta1_comp).sign_()
if update_parameters:
update = exp_avg.lerp(grad, weight=beta1_comp).sign_()

# update gradient moving average
exp_avg.lerp_(grad, weight=beta2_comp)

if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# Lion step
kahan_comp.add_(update, alpha=-lr)
if update_parameters:
if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# Lion step
kahan_comp.add_(update, alpha=-lr)

# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)
# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)

# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
else:
# Lion step
param.add_(update, alpha=-lr)
# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
else:
# Lion step
param.add_(update, alpha=-lr)


def _foreach_lion(
Expand All @@ -318,6 +328,7 @@ def _foreach_lion(
beta2_comp: float,
weight_decay: float,
kahan_sum: bool = False,
**kwargs,
):
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, kahan_comps])
for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_kahan_comps), _) in grouped_tensors.items():
Expand Down
13 changes: 13 additions & 0 deletions optimi/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(self, params: Iterable[Tensor] | Iterable[dict], defaults: dict[str

super().__init__(params, defaults)

# by default perform the normal parameter update step
self._optimizer_accumulation = False

# if gradient_release is enabled, disable foreach step so normal optimizer step won't error
if self.defaults["gradient_release"]:
self.defaults["foreach"] = False
Expand All @@ -48,6 +51,16 @@ def __init__(self, params: Iterable[Tensor] | Iterable[dict], defaults: dict[str
for p in group["params"]:
self.state[p]["group"] = group

@property
def optimizer_accumulation(self) -> bool:
"Accumulate gradients in optimizer states during gradient release instead of a full step."
return self._optimizer_accumulation

@optimizer_accumulation.setter
def optimizer_accumulation(self, optimizer_accumulation: bool):
"Accumulate gradients in optimizer states during gradient release instead of a full step."
self._optimizer_accumulation = optimizer_accumulation

def step(self, closure: Callable | None = None, param: Tensor | None = None):
"""Performs a single optimization step on the whole model or individual parameter.
Expand Down
Loading

0 comments on commit efb7e5e

Please sign in to comment.