Skip to content

Commit

Permalink
Merge pull request #3 from warner-benjamin/opt_accum
Browse files Browse the repository at this point in the history
Add Optimizer Accumulation
  • Loading branch information
warner-benjamin authored Mar 11, 2024
2 parents 3618921 + c899679 commit 0bec5ca
Show file tree
Hide file tree
Showing 27 changed files with 656 additions and 234 deletions.
74 changes: 69 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
# optimī

### Fast, Modern, and Low Precision PyTorch Optimizers
### Fast, Modern, Memory Efficient, and Low Precision PyTorch Optimizers

optimi enables accurate low precision training via Kahan summation, supports fully decoupled weight decay, and features fast implementations of modern optimizers.
optimi enables accurate low precision training via Kahan summation, integrates gradient release and optimizer accumulation for additional memory efficiency, supports fully decoupled weight decay, and features fast implementations of modern optimizers.

## Low Precision Training with Kahan Summation

optimi optimizers can match the performance of mixed precision when [training in BFloat16 by using Kahan summation](https://optimi.benjaminwarner.dev/kahan_summation).
optimi optimizers can nearly reach or match the performance of mixed precision when [training in BFloat16 by using Kahan summation](https://optimi.benjaminwarner.dev/kahan_summation).

Training in BFloat16 with Kahan summation can reduce non-activation training memory usage by [37.5 to 45.5 percent](https://optimi.benjaminwarner.dev/kahan_summation/#memory-savings) when using an Adam optimizer. BFloat16 training increases single GPU [training speed by ~10 percent](https://optimi.benjaminwarner.dev/kahan_summation/#training-speedup) at the same batch size.

## Gradient Release: Fused Backward and Optimizer Step

optimi optimizers can perform the [optimization step layer-by-layer during the backward pass](https://optimi.benjaminwarner.dev/gradient_release), immediately freeing gradient memory.

Unlike the current PyTorch implementation, optimi’s gradient release optimizers are a drop-in replacement for standard optimizers and seamlessly work with exisiting hyperparmeter schedulers.

## Optimizer Accumulation: Gradient Release and Accumulation

optimi optimizers can approximate gradient accumulation with gradient release by [accumulating gradients into the optimizer states](https://optimi.benjaminwarner.dev/optimizer_accumulation).

## Fully Decoupled Weight Decay

In addition to supporting PyTorch-style decoupled weight decay, optimi optimizers also support [fully decoupled weight decay](https://optimi.benjaminwarner.dev/fully_decoupled_weight_decay).
Expand Down Expand Up @@ -44,7 +54,7 @@ from optimi import AdamW
# create or cast model in low precision (bfloat16)
model = nn.Linear(20, 1, dtype=torch.bfloat16)

# instantiate AdamW with parameters and fully decoupled weight decay
# initialize any optimi optimizer with parameters & fully decoupled weight decay
# Kahan summation is automatically enabled since model & inputs are bfloat16
opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5, decouple_lr=True)

Expand All @@ -63,10 +73,64 @@ To use with PyTorch-style weight decay with float32 or mixed precision:
# create model
model = nn.Linear(20, 1)

# instantiate AdamW with parameters
# initialize any optimi optimizer with parameters
opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
```

To use with gradient release:

```python
# initialize any optimi optimizer with `gradient_release=True`
# and call `prepare_for_gradient_release` on model and optimizer
opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True)
prepare_for_gradient_release(model, opt)

# calling backward on the model will peform the optimzier step
loss = model(torch.randn(20, dtype=torch.bfloat16))
loss.backward()

# optimizer step and zero_grad are no longer needed, and will
# harmlessly no-op if called by an existing training framework
# opt.step()
# opt.zero_grad()

# optionally remove gradient release hooks when done training
remove_gradient_release(model)
```

To use with optimizer accumulation:

```python
# initialize any optimi optimizer with `gradient_release=True`
# and call `prepare_for_gradient_release` on model and optimizer
opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True)
prepare_for_gradient_release(model, opt)

# update model parameters every four steps after accumulating
# gradients directly into the optimizer states
accumulation_steps = 4

# use existing PyTorch dataloader
for idx, batch in enumerate(dataloader):
# `optimizer_accumulation=True` accumulates gradients into
# optimizer states. set `optimizer_accumulation=False` to
# update parameters by performing a full gradient release step
opt.optimizer_accumulation = (idx+1) % accumulation_steps != 0

# calling backward on the model will peform the optimizer step
# either accumulating gradients or updating model parameters
loss = model(batch)
loss.backward()

# optimizer step and zero_grad are no longer needed, and will
# harmlessly no-op if called by an existing training framework
# opt.step()
# opt.zero_grad()

# optionally remove gradient release hooks when done training
remove_gradient_release(model)
```

## Differences from PyTorch

optimi optimizers do not support compilation, differentiation, complex numbers, or have capturable versions.
Expand Down
18 changes: 15 additions & 3 deletions docs/css/extra.css
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
--md-typeset-table-color: rgba(24, 24, 24, 0.05);

/* Code highlighting color shades */
/* --md-code-hl-color: #9a3fe4;
--md-code-hl-color--light: #9a3fe4; */
--md-code-hl-color: #9a3fe4;
--md-code-hl-color--light: #9a3fe43c;
--md-code-hl-number-color: #db5f00;
--md-code-hl-special-color: #d32300;
--md-code-hl-function-color: #cc9901;
Expand Down Expand Up @@ -161,7 +161,7 @@


/* Links */
.md-content a:not(.headerlink):not(.footnote-ref):not(.footnote-backref) {
.md-content a:not(.headerlink):not(.footnote-ref):not(.footnote-backref):not(:has(> code)) {
box-shadow: inset 0 -0.115rem 0 var(--light-purple);
text-decoration: none;
transition: all .15s cubic-bezier(.33,.66,.66,1);
Expand All @@ -171,6 +171,18 @@
color: var(--black) }
}

.md-content a code {
box-shadow: inset 0 -0.115rem 0 var(--light-purple);
text-decoration: none;
transition: all .15s cubic-bezier(.33,.66,.66,1);
z-index: 10;
border-bottom-left-radius: 0;
border-bottom-right-radius: 0;

&:hover { box-shadow: inset 0 -2rem 0 var(--dark-purple);
color: var(--black) }
}

/* Katex */
.katex-display {
margin-top: 0 !important;
Expand Down
2 changes: 1 addition & 1 deletion docs/foreach.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Like PyTorch, optimi supports foreach implementations of all optimizers. Foreach

Foreach implementations can increase optimizer peak memory usage. optimi attempts to reduce this extra overhead by reusing the gradient buffer for temporary variables. If the gradients are required between the optimization step and [gradient reset step](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html#torch.optim.Optimizer.zero_grad), set `foreach=False` to use the for-loop implementation.

??? warning "Important: Foreach Requires PyTorch 2.1+"
??? note "Note: Foreach Requires PyTorch 2.1+"

optimi’s foreach implementations require PyTorch 2.1 or newer.

Expand Down
33 changes: 23 additions & 10 deletions docs/gradient_release.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,40 @@
---
title: "Gradient Release: Fused Backward and Optimizer Step"
title: "Gradient Release"
description: "Fused Backward Pass and Optimizer Step"
---

# Gradient Release: Fused Backward and Optimizer Step
# Gradient Release

Gradient release reduces training memory by limiting gradients to one layer at any given time. Unlike [PyTorch’s implementation](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html), optimi’s gradient release is fully compatible with existing learning rate and optimizer schedulers and training frameworks.
**Fused Backward Pass and Optimizer Step**

Gradient release reduces training memory by limiting gradients to one layer at any given time. Unlike [PyTorch’s implementation](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html), optimi’s gradient release is fully compatible with both existing learning rate and optimizer schedulers and existing training frameworks.

During the backward pass, each model layer calculates its gradients, performs the optimizer step, and clears the gradients before proceeding to the backward pass for the next layer. This fused backward and optimizer step can reduce non-activation memory usage by ~25 percent for an Adam optimizer.

Gradient release can be combined with other techniques such as [Kahan summation](kahan_summation.md) or activation checkpointing for further memory savings.
Gradient release can also be combined with other techniques such as [Kahan summation](kahan_summation.md) or [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html) for further memory savings.

??? warning "Important: Gradient Release Requires PyTorch 2.1+"
??? note "Note: Gradient Release Requires PyTorch 2.1+"

Gradient release requires PyTorch 2.1 or newer.

Gradient release was proposed by Pudipeddi et al in [*Training Large Neural Networks with Constant Memory using a New Execution Algorithm*](https://arxiv.org/abs/2002.05645) and was enabled by PyTorch’s [`register_post_accumulate_grad_hook`](https://pytorch.org/docs/stable/generated/torch.Tensor.register_post_accumulate_grad_hook.html).

## Limitations and Workarounds

Since gradient release immediately frees the gradient during the backward pass, features which rely on persistent gradients like gradient clipping or gradient accumulation won’t work.
Since gradient release immediately frees the gradient during the backward pass, features which rely on persistent gradients like AMP's `GradScaler`, gradient clipping, or gradient accumulation won’t work.

!!! warning "Important: Gradient Release is Incompatible with FP16 Mixed Precision"

Gradient release is incompatible with Float16 Automatic Mixed Precision since PyTorch's `GradScaler` requires access to the entire model's gradients for the optimizer step.

Use BFloat16 Automatic Mixed Precision instead.

The recommended workaround for gradient clipping is to use [StableAdamW](optimizers/stableadamw.md) instead of Adam or AdamW, as StableAdamW removes the need for gradient clipping by porting Adafactor’s update clipping into AdamW.

??? tip "Tip: Use Optimizer Accumulation to Approximate Gradient Accumulation"

optimi's [optimizer accumulation](optimizer_accumulation.md) approximates [gradient accumlation](https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation) by defering parameter updates while accumulating gradients directly into the optimizer states.

One potential workaround for gradient accumulation is to increase the optimizer’s momentum or $\beta_1$ to approximate accumulating gradients across multiple batches.

## Example
Expand All @@ -45,10 +58,10 @@ prepare_for_gradient_release(model, opt)
loss = model(torch.randn(20, dtype=torch.bfloat16))
loss.backward()

# optimizer step and sero_grad is no longer needed, and
# will no-op if called by an existing training framework
opt.step()
opt.zero_grad()
# optimizer step and zero_grad are no longer needed, and will
# harmlessly no-op if called by an existing training framework
# opt.step()
# opt.zero_grad()

# optionally remove gradient release hooks when done training
remove_gradient_release(model)
Expand Down
64 changes: 49 additions & 15 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
---
title: "optimi"
description: "Fast, Modern, and Low Precision PyTorch Optimizers"
description: "Fast, Modern, Memory Efficient, and Low Precision PyTorch Optimizers"
---

# optimī

**Fast, Modern, and Low Precision PyTorch Optimizers**
**Fast, Modern, Memory Efficient, and Low Precision PyTorch Optimizers**

optimi enables accurate low precision training via Kahan summation, supports fully decoupled weight decay, and features fast implementations of modern optimizers.
optimi enables accurate low precision training via Kahan summation, integrates gradient release and optimizer accumulation for additional memory efficiency, supports fully decoupled weight decay, and features fast implementations of modern optimizers.

## Low Precision Training with Kahan Summation

optimi optimizers can match the performance of mixed precision when [training in BFloat16 by using Kahan summation](kahan_summation.md).
optimi optimizers can nearly reach or match the performance of mixed precision when [training in BFloat16 by using Kahan summation](kahan_summation.md).

Training in BFloat16 with Kahan summation can reduce non-activation training memory usage by [37.5 to 45.5 percent](kahan_summation.md/#memory-savings) when using an Adam optimizer. BFloat16 training increases single GPU [training speed by ~10 percent](kahan_summation.md/#training-speedup) at the same batch size.

Expand All @@ -21,6 +21,10 @@ optimi optimizers can perform the [optimization step layer-by-layer during the b

Unlike the current PyTorch implementation, optimi’s gradient release optimizers are a drop-in replacement for standard optimizers and seamlessly work with exisiting hyperparmeter schedulers.

## Optimizer Accumulation: Gradient Release and Accumulation

optimi optimizers can approximate gradient accumulation with gradient release by [accumulating gradients into the optimizer states](optimizer_accumulation.md).

## Fully Decoupled Weight Decay

In addition to supporting PyTorch-style decoupled weight decay, optimi optimizers also support [fully decoupled weight decay](fully_decoupled_weight_decay.md).
Expand Down Expand Up @@ -51,7 +55,7 @@ from optimi import AdamW
# create or cast model in low precision (bfloat16)
model = nn.Linear(20, 1, dtype=torch.bfloat16)

# initialize AdamW with parameters and fully decoupled weight decay
# initialize any optimi optimizer with parameters & fully decoupled weight decay
# Kahan summation is automatically enabled since model & inputs are bfloat16
opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5, decouple_lr=True)

Expand All @@ -70,29 +74,59 @@ To use with PyTorch-style weight decay with float32 or mixed precision:
# create model
model = nn.Linear(20, 1)

# initialize AdamW with parameters
# initialize any optimi optimizer with parameters
opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
```

To use with gradient release:

```python
# create model
model = nn.Linear(20, 1)

# initialize AdamW with `gradient_release=True` and call
# `prepare_for_gradient_release` on model and optimizer
# initialize any optimi optimizer with `gradient_release=True`
# and call `prepare_for_gradient_release` on model and optimizer
opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True)
prepare_for_gradient_release(model, opt)

# calling backward on the model will peform the optimzier step
loss = model(torch.randn(20, dtype=torch.bfloat16))
loss.backward()

# optimizer step and sero_grad is no longer needed, and
# will no-op if called by an existing training framework
opt.step()
opt.zero_grad()
# optimizer step and zero_grad are no longer needed, and will
# harmlessly no-op if called by an existing training framework
# opt.step()
# opt.zero_grad()

# optionally remove gradient release hooks when done training
remove_gradient_release(model)
```

To use with optimizer accumulation:

```python
# initialize any optimi optimizer with `gradient_release=True`
# and call `prepare_for_gradient_release` on model and optimizer
opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True)
prepare_for_gradient_release(model, opt)

# update model parameters every four steps after accumulating
# gradients directly into the optimizer states
accumulation_steps = 4

# use existing PyTorch dataloader
for idx, batch in enumerate(dataloader):
# `optimizer_accumulation=True` accumulates gradients into
# optimizer states. set `optimizer_accumulation=False` to
# update parameters by performing a full gradient release step
opt.optimizer_accumulation = (idx+1) % accumulation_steps != 0

# calling backward on the model will peform the optimizer step
# either accumulating gradients or updating model parameters
loss = model(batch)
loss.backward()

# optimizer step and zero_grad are no longer needed, and will
# harmlessly no-op if called by an existing training framework
# opt.step()
# opt.zero_grad()

# optionally remove gradient release hooks when done training
remove_gradient_release(model)
Expand Down
4 changes: 2 additions & 2 deletions docs/kahan_summation.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ title: Low Precision Training with Kahan Summation

# Low Precision Training with Kahan Summation

While training models in low precision (Float16 or BFloat16) usually does not match training in full precision (Float32) or [mixed precision](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch), optimi optimizers match the performance of mixed precision when training in BFloat16 by using Kahan summation[^1].
While training models in low precision (Float16 or BFloat16) usually differs from training in full precision (Float32) or [mixed precision](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch), optimi optimizers nearly reach or match the performance of mixed precision when training in BFloat16 by using Kahan summation[^1].

Training in low precision [reduces memory usage](#memory-savings) and increases [training speed](#training-speedup) relative to mixed precision training.

Expand Down Expand Up @@ -108,6 +108,6 @@ $$

This shows the optimi implementation of Kahan summation optimizers, which is equivalent to the *Revisiting BFloat16 Training* formulation.

[^1]: Current testing on small models shows no degradation in training performance.
[^1]: Current testing on small models shows little to no degradation in model performance.

[^2]: Also known as Kahan–Babuška summation or compensated summation.
Loading

0 comments on commit 0bec5ca

Please sign in to comment.