Skip to content

Commit

Permalink
Correct normalization scheme; deprecate batch_size
Browse files Browse the repository at this point in the history
Existing code normalized as: `norm = sqrt(batch_size / total_iterations)`, where `total_iterations` = (number of fits per epoch) * (number of epochs in restart). However, `total_iterations = total_samples / batch_size` --> `norm = batch_size * sqrt(1 / (total_iterations_per_epoch * epochs))`, making `norm` scale _linearly_ with `batch_size`, which differs from authors' sqrt.

Users who never changed `batch_size` throughout training will be unaffected. (λ = λ_norm * sqrt(b / BT); λ_norm is what we pick, our "guess". The idea of normalization is to make it so that if our guess works well for `batch_size=32`, it'll work well for `batch_size=16` - but if `batch_size` is never changed, then performance is only affected by the guess.)

Main change [here](https://github.com/OverLordGoldDragon/keras-adamw/pull/53/files#diff-220519926b87c12115d2f727803fbe6bR19), closing #52.

**Updating existing code**: for a choice of λ_norm that previously worked well, apply `*=  sqrt(batch_size)`. Ex: `Dense(bias_regularizer=l2(1e-4))` --> `Dense(bias_regularizer=l2(1e-4 * sqrt(32)))`.
  • Loading branch information
OverLordGoldDragon authored Jul 13, 2020
1 parent 29aa8f2 commit a99d833
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 55 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
![](https://img.shields.io/badge/keras-tf.keras/eager-blue.svg)
![](https://img.shields.io/badge/keras-tf.keras/2.0-blue.svg)

Keras implementation of **AdamW**, **SGDW**, **NadamW**, and **Warm Restarts**, based on paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) - plus **Learning Rate Multipliers**
Keras/TF implementation of **AdamW**, **SGDW**, **NadamW**, and **Warm Restarts**, based on paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) - plus **Learning Rate Multipliers**

<img src="https://user-images.githubusercontent.com/16495490/65381086-233f7d00-dcb7-11e9-8c83-d0aec7b3663a.png" width="850">

Expand Down Expand Up @@ -107,7 +107,7 @@ for epoch in range(3):
## Use guidelines
### Weight decay
- **Set L2 penalty to ZERO** if regularizing a weight via `weight_decays` - else the purpose of the 'fix' is largely defeated, and weights will be over-decayed --_My recommendation_
- `lambda = lambda_norm * sqrt(batch_size/total_iterations)` --> _can be changed_; the intent is to scale λ to _decouple_ it from other hyperparams - including (but _not limited to_), train duration & batch size. --_Authors_ (Appendix, pg.1) (A-1)
- `lambda = lambda_norm * sqrt(1/total_iterations)` --> _can be changed_; the intent is to scale λ to _decouple_ it from other hyperparams - including (but _not limited to_), # of epochs & batch size. --_Authors_ (Appendix, pg.1) (A-1)
- `total_iterations_wd` --> set to normalize over _all epochs_ (or other interval `!= total_iterations`) instead of per-WR when using WR; may _sometimes_ yield better results --_My note_

### Warm restarts
Expand Down
2 changes: 1 addition & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from keras_adamw.utils import K_eval

#%%############################################################################
ipt = Input(shape=(120,4))
ipt = Input(shape=(120, 4))
x = LSTM(60, activation='relu', name='lstm_1',
kernel_regularizer=l1(1e-4), recurrent_regularizer=l2(2e-4))(ipt)
out = Dense(1, activation='sigmoid', kernel_regularizer=l1_l2(1e-4, 2e-4))(x)
Expand Down
2 changes: 1 addition & 1 deletion keras_adamw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
from .utils import get_weight_decays, fill_dict_in_order
from .utils import reset_seeds, K_eval

__version__ = '1.35'
__version__ = '1.36'
15 changes: 3 additions & 12 deletions keras_adamw/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class AdamW(Optimizer):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -74,7 +73,7 @@ class AdamW(Optimizer):
"""
def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
amsgrad=False, model=None, zero_penalties=True,
batch_size=32, total_iterations=0, total_iterations_wd=None,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
eta_min=0, eta_max=1, t_cur=0, **kwargs):
Expand All @@ -99,7 +98,6 @@ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')

self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.amsgrad = amsgrad
Expand Down Expand Up @@ -189,7 +187,6 @@ def get_config(self):
'beta_1': float(K_eval(self.beta_1)),
'beta_2': float(K_eval(self.beta_2)),
'decay': float(K_eval(self.decay)),
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down Expand Up @@ -226,7 +223,6 @@ class NadamW(Optimizer):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -271,7 +267,7 @@ class NadamW(Optimizer):
(https://arxiv.org/abs/1711.05101)
"""
def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999,
model=None, zero_penalties=True, batch_size=32,
model=None, zero_penalties=True,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
Expand All @@ -297,7 +293,6 @@ def __init__(self, learning_rate=0.002, beta_1=0.9, beta_2=0.999,
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')

self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.lr_multipliers = lr_multipliers
Expand Down Expand Up @@ -388,7 +383,6 @@ def get_config(self):
'beta_2': float(K_eval(self.beta_2)),
'epsilon': self.epsilon,
'schedule_decay': self.schedule_decay,
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down Expand Up @@ -421,7 +415,6 @@ class SGDW(Optimizer):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -465,7 +458,7 @@ class SGDW(Optimizer):
(https://arxiv.org/abs/1711.05101)
"""
def __init__(self, learning_rate=0.01, momentum=0., nesterov=False,
model=None, zero_penalties=True, batch_size=32,
model=None, zero_penalties=True,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
Expand All @@ -489,7 +482,6 @@ def __init__(self, learning_rate=0.01, momentum=0., nesterov=False,
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')

self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.nesterov = nesterov
Expand Down Expand Up @@ -556,7 +548,6 @@ def get_config(self):
'momentum': float(K_eval(self.momentum)),
'decay': float(K_eval(self.decay)),
'nesterov': self.nesterov,
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down
15 changes: 3 additions & 12 deletions keras_adamw/optimizers_225.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class AdamW(Optimizer):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -63,7 +62,7 @@ class AdamW(Optimizer):
"""
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False,
epsilon=None, decay=0.0, model=None, zero_penalties=True,
batch_size=32, total_iterations=0, total_iterations_wd=None,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
eta_min=0, eta_max=1, t_cur=0, **kwargs):
Expand All @@ -86,7 +85,6 @@ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False,

self.initial_decay = decay
self.epsilon = epsilon or K.epsilon()
self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.amsgrad = amsgrad
Expand Down Expand Up @@ -165,7 +163,6 @@ def get_config(self):
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down Expand Up @@ -202,7 +199,6 @@ class NadamW(Optimizer):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -248,7 +244,7 @@ class NadamW(Optimizer):
"""
def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999,
schedule_decay=0.004, epsilon=None,
model=None, zero_penalties=True, batch_size=32,
model=None, zero_penalties=True,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
Expand All @@ -272,7 +268,6 @@ def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999,

self.epsilon = epsilon or K.epsilon()
self.schedule_decay = schedule_decay
self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.lr_multipliers = lr_multipliers
Expand Down Expand Up @@ -352,7 +347,6 @@ def get_config(self):
'beta_2': float(K.get_value(self.beta_2)),
'epsilon': self.epsilon,
'schedule_decay': self.schedule_decay,
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down Expand Up @@ -385,7 +379,6 @@ class SGDW(Optimizer):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -429,7 +422,7 @@ class SGDW(Optimizer):
(https://arxiv.org/abs/1711.05101)
"""
def __init__(self, lr=0.01, momentum=0., nesterov=False, decay=0.0,
model=None, zero_penalties=True, batch_size=32,
model=None, zero_penalties=True,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
Expand All @@ -451,7 +444,6 @@ def __init__(self, lr=0.01, momentum=0., nesterov=False, decay=0.0,
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')

self.initial_decay = decay
self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.nesterov = nesterov
Expand Down Expand Up @@ -516,7 +508,6 @@ def get_config(self):
'momentum': float(K.get_value(self.momentum)),
'decay': float(K.get_value(self.decay)),
'nesterov': self.nesterov,
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'lr_multipliers': self.lr_multipliers,
Expand Down
15 changes: 3 additions & 12 deletions keras_adamw/optimizers_225tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class AdamW(OptimizerV2):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -83,7 +82,7 @@ class AdamW(OptimizerV2):
"""
def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False,
model=None, zero_penalties=True, batch_size=32,
model=None, zero_penalties=True,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
Expand All @@ -103,7 +102,6 @@ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
self.eta_max = K.constant(eta_max, name='eta_max')
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')
self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.lr_multipliers = lr_multipliers
Expand Down Expand Up @@ -267,7 +265,6 @@ def get_config(self):
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad,
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'use_cosine_annealing': self.use_cosine_annealing,
Expand Down Expand Up @@ -311,7 +308,6 @@ class NadamW(OptimizerV2):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -350,7 +346,7 @@ class NadamW(OptimizerV2):
"""
def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-7, model=None, zero_penalties=True,
batch_size=32, total_iterations=0, total_iterations_wd=None,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
eta_min=0, eta_max=1, t_cur=0, name="NadamW", **kwargs):
Expand Down Expand Up @@ -379,7 +375,6 @@ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
self.eta_max = K.constant(eta_max, name='eta_max')
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')
self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.lr_multipliers = lr_multipliers
Expand Down Expand Up @@ -556,7 +551,6 @@ def get_config(self):
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self.epsilon,
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'use_cosine_annealing': self.use_cosine_annealing,
Expand Down Expand Up @@ -594,7 +588,6 @@ class SGDW(OptimizerV2):
extracts weight penalties from layers, and overrides `weight_decays`.
zero_penalties: bool. If True and `model` is passed, will zero weight
penalties (loss-based). (RECOMMENDED; see README "Use guidelines").
batch_size: int >= 1. Train input batch size; used for normalization
total_iterations: int >= 0. Total expected iterations / weight updates
throughout training, used for normalization; <1>
lr_multipliers: dict / None. Name-value pairs specifying per-layer lr
Expand Down Expand Up @@ -631,7 +624,7 @@ class SGDW(OptimizerV2):
(https://arxiv.org/abs/1711.05101)
"""
def __init__(self, learning_rate=0.01, momentum=0.0, nesterov=False,
model=None, zero_penalties=True, batch_size=32,
model=None, zero_penalties=True,
total_iterations=0, total_iterations_wd=None,
use_cosine_annealing=False, lr_multipliers=None,
weight_decays=None, autorestart=None, init_verbose=True,
Expand All @@ -657,7 +650,6 @@ def __init__(self, learning_rate=0.01, momentum=0.0, nesterov=False,
self.eta_max = K.constant(eta_max, name='eta_max')
self.eta_t = K.variable(eta_t, dtype='float32', name='eta_t')
self.t_cur = K.variable(t_cur, dtype='int64', name='t_cur')
self.batch_size = batch_size
self.total_iterations = total_iterations
self.total_iterations_wd = total_iterations_wd or total_iterations
self.lr_multipliers = lr_multipliers
Expand Down Expand Up @@ -773,7 +765,6 @@ def get_config(self):
"decay": self._serialize_hyperparameter("decay"),
"momentum": self._serialize_hyperparameter("momentum"),
"nesterov": self.nesterov,
'batch_size': int(self.batch_size),
'total_iterations': int(self.total_iterations),
'weight_decays': self.weight_decays,
'use_cosine_annealing': self.use_cosine_annealing,
Expand Down
Loading

0 comments on commit a99d833

Please sign in to comment.