-
Notifications
You must be signed in to change notification settings - Fork 5
/
lookhead.py
91 lines (83 loc) · 3.72 KB
/
lookhead.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
""" Lookahead Optimizer Wrapper.
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch.optim.optimizer import Optimizer
from collections import defaultdict
class Lookahead(Optimizer):
def __init__(self, base_optimizer, alpha=0.5, k=6):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict)
# manually add our defaults to the param groups
for name, default in defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def update_slow(self, group):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self.state[fast_p]
if 'slow_buffer' not in param_state:
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
param_state['slow_buffer'].copy_(fast_p.data)
slow = param_state['slow_buffer']
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
fast_p.data.copy_(slow)
def sync_lookahead(self):
for group in self.param_groups:
self.update_slow(group)
def step(self, closure=None):
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
loss = self.base_optimizer.step(closure)
for group in self.param_groups:
group['lookahead_step'] += 1
if group['lookahead_step'] % group['lookahead_k'] == 0:
self.update_slow(group)
return loss
def state_dict(self):
fast_state_dict = self.base_optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict['state']
param_groups = fast_state_dict['param_groups']
return {
'state': fast_state,
'slow_state': slow_state,
'param_groups': param_groups,
}
def load_state_dict(self, state_dict):
fast_state_dict = {
'state': state_dict['state'],
'param_groups': state_dict['param_groups'],
}
self.base_optimizer.load_state_dict(fast_state_dict)
# We want to restore the slow state, but share param_groups reference
# with base_optimizer. This is a bit redundant but least code
slow_state_new = False
if 'slow_state' not in state_dict:
print('Loading state_dict from optimizer without Lookahead applied.')
state_dict['slow_state'] = defaultdict(dict)
slow_state_new = True
slow_state_dict = {
'state': state_dict['slow_state'],
'param_groups': state_dict['param_groups'], # this is pointless but saves code
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.param_groups = self.base_optimizer.param_groups # make both ref same container
if slow_state_new:
# reapply defaults to catch missing lookahead specific ones
for name, default in self.defaults.items():
for group in self.param_groups:
group.setdefault(name, default)