Skip to content

Commit

Permalink
fix: g_snr
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Dec 21, 2024
1 parent bc9ab50 commit 3b2e11d
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions pytorch_optimizer/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class SGDSaI(BaseOptimizer):
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param momentum: float. momentum factor (0.0 = SignSGD, >0 = Signum).
:param momentum: float. coefficients used for computing running averages of gradient.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param eps: float. term added to the denominator to improve numerical stability.
Expand All @@ -415,7 +415,7 @@ class SGDSaI(BaseOptimizer):
def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 1e-2,
weight_decouple: bool = True,
Expand Down Expand Up @@ -468,10 +468,11 @@ def warmup_step(self, closure: CLOSURE = None) -> LOSS:
raise NoSparseGradientError(str(self))

sigma = grad.std().nan_to_num_()
grad_norm_snr = grad.norm()
grad_norm_snr.div_(sigma.add_(group['eps']))
grad_norm = grad.norm()

self.state[p]['gsnr'] = grad_norm_snr
g_snr = grad_norm.div_(sigma.add_(group['eps'])) if sigma != 0.0 else grad_norm

self.state[p]['gsnr'] = g_snr

self.has_warmup = True

Expand All @@ -488,7 +489,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
loss = closure()

for group in self.param_groups:
momentum = group['momentum']
momentum: float = group['momentum']
for p in group['params']:
if p.grad is None:
continue
Expand All @@ -506,8 +507,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
else:
buf = grad

Check warning on line 508 in pytorch_optimizer/optimizer/sgd.py

View check run for this annotation

Codecov / codecov/patch

pytorch_optimizer/optimizer/sgd.py#L508

Added line #L508 was not covered by tests

step_size = group['lr'] * state['gsnr']

self.apply_weight_decay(
p,
grad,
Expand All @@ -517,6 +516,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
False,
)

p.add_(buf, alpha=-step_size)
p.add_(buf, alpha=-group['lr'] * state['gsnr'])

return loss

0 comments on commit 3b2e11d

Please sign in to comment.