-
Notifications
You must be signed in to change notification settings - Fork 5
/
federated_avg.py
202 lines (175 loc) · 8.22 KB
/
federated_avg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
Implementation of simulator for FederatedAveraging as described in the paper
Communication-Efficient Learning of Deep Networks from Decentralized Data
(McMahan, Brendan, et al. 2017)
http://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf
"""
import numpy as np
import utils
class fedavg_worker(object):
"""
Class to represent a FederatedAveraging worker
(called a "Client" in the paper)
"""
def __init__(self, model):
"""
Initialize the worker with a Keras model
Parameters
----------
model : tf.keras.Model
The model to be used by the worker to do local training.
"""
self.model = model
self.X_train = None
self.y_train = None
class fedavg_server(object):
"""
Class to represent a FederatedAveraging parameter server
"""
def __init__(self, model, K):
"""
Initialize the parameter server with a Keras model
and a list of size K to store K worker weights
Parameters
----------
model : tf.keras.Model
The model to be used by the parameter server as the global model.
K : int
The number of workers that can participate in federated learning.
"""
self.global_model = model
self.worker_weights = [None] * K
class fedavg_hparams(object):
"""
Hyperparameters for the FederatedAveraging algorithm.
"""
def __init__(self, K=100, C=0.1, E=1, B=10, eta=0.1, MAX_T=10000, iid=True,
evaluation_interval=5, target_val_accuracy=0.9):
"""
Parameters
----------
-----------------------------------------------------------
The following variables are named to match the paper:
K : int, optional
The number of workers. The default is 100.
C : float, optional
The fraction of workers randomly selected per iteration. The default is 0.1.
E : int, optional
The number of epochs a worker will run over its data per iteration. The default is 1.
B : int, optional
The worker minibatch size. The default is 10.
eta : float, optional
Step size to use for worker SGD optimizers. The default is 0.1.
-----------------------------------------------------------
-----------------------------------------------------------
The following variables are not assigned explicit names in the paper:
MAX_T : int, optional
Maximum number of rounds of FederatedAveraging to run. The default is 10000.
iid : boolean, optional
If iid, we simulate an evenly distributed random split of the data across
workers. Otherwise each worker gets data in only one (or few) classes.
The default is True.
evaluation_interval : int, optional
Evaluate global model whenever this many iterations of FederatedAveraging have been run. The default is 5.
target_val_accuracy : float, optional
Stop training if this target validation accuracy has been achieved. The default is 0.9.
-----------------------------------------------------------
"""
self.K = K
self.C = C
self.E = E
self.B = B
self.eta = eta
self.MAX_T = MAX_T
self.iid = iid
self.evaluation_interval = evaluation_interval
self.target_val_accuracy = target_val_accuracy
def federated_averaging(X_train, y_train, X_val, y_val, model_constructor, hparams, rng=None):
"""
Simulate training a model using FederatedAveraging across K distributed devices.
Return the final global model and metrics gathered over the course of the run.
Parameters
----------
X_train : numpy ndarray
Training features.
y_train : numpy ndarray
Training targets.
X_val : numpy ndarray
Validation features.
y_val : numpy ndarray
Validation targets.
model_constructor : function
function that constructs a compiled tf.keras.Model using hparams.
hparams : fedavg_hparams
Hyperparameters for FederatedAveraging.
rng : numpy.random.Generator, optional
instance to use for random number generation.
Returns
-------
global_model : tf.keras.Model
The final global model
log : dict
Dictionary containing training and validation metrics:
loss :
training loss at each iteration
accuracy :
training accuracy at each iteration
val_loss :
validation loss at each iteration
val_accuracy :
validation accuracy at each iteration
iteration :
the iteration number at which the measurements were made
communication_rounds :
the cumulative number of worker uploads by each iteration
worker_upload_fraction :
the average fraction of workers who upload each iteration
"""
if rng is None:
rng = np.random.default_rng()
#Initialize the server
server = fedavg_server(model_constructor(hparams), hparams.K)
#Initialize the workers
workers = [fedavg_worker(model_constructor(hparams)) for i in range(hparams.K)]
#Partition the dataset into K splits and assign to workers
#Note: In the real world we would not have access to the dataset as it would be distributed
# across all the worker devices. Here in simulation, we have access to the complete dataset
# and define the splits that go to each worker.
X_train_splits, y_train_splits, split_weights = utils.split_training_data(X_train, y_train, hparams.K, hparams.iid, rng)
for i, worker in enumerate(workers):
worker.X_train = X_train_splits[i]
worker.y_train = y_train_splits[i]
#Execute the iterations of FederatedAveraging and keep track of the number of communication rounds
log = {"loss": [], "accuracy": [], "val_loss": [], "val_accuracy": [],
"iteration": [], "communication_rounds": [], "worker_upload_fraction": []}
communication_rounds = 0
#Do initial evaluation of the randomly initialized global model as a baseline
utils.evaluate_and_log(log, server.global_model, X_train, y_train, X_val, y_val, 0, communication_rounds, hparams.K)
m = int(np.ceil(hparams.C * hparams.K)) # Number of workers to use per iteration
#Note: In the real world each worker would perform its update in parallel on a separate device.
# Here in simulation, we can perform worker updates sequentially on the same device.
for t in range(hparams.MAX_T):
global_weights = server.global_model.get_weights()
#Randomly pick the workers to be used for this iteration
worker_indices = set(rng.integers(hparams.K, size=m))
#Perform the local update on each randomly selected worker starting from the global weights
for wk_i in worker_indices:
worker = workers[wk_i]
worker.model.set_weights(global_weights)
worker.model.fit(worker.X_train, worker.y_train, batch_size=hparams.B, epochs=hparams.E)
#Upload the worker weights to the server
server.worker_weights[wk_i] = worker.model.get_weights()
communication_rounds += m
#Average all the worker weights to get the updated global weights
for i in range(len(global_weights)):
global_weights[i] = np.sum(
[split_weights[wk_i]*(server.worker_weights[wk_i][i] if wk_i in worker_indices else global_weights[i]) for wk_i in range(hparams.K)],
axis=0)
server.global_model.set_weights(global_weights)
#Evaluate the global model on the train and validation sets on the evaluation interval
if (t+1) % hparams.evaluation_interval == 0:
utils.evaluate_and_log(log, server.global_model, X_train, y_train, X_val, y_val, t+1, communication_rounds, hparams.K)
#Stop training when we have reached the target validation accuracy
if log["val_accuracy"][-1] >= hparams.target_val_accuracy:
break
return server.global_model, log