-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
index.js
286 lines (250 loc) · 11 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
// This is a helper class for loading and managing MNIST data specifically.
// It is a useful example of how you could create your own data manager class
// for arbitrary data though. It's worth a look :)
import {IMAGE_H, IMAGE_W, MnistData} from './data';
// This is a helper class for drawing loss graphs and MNIST images to the
// window. For the purposes of understanding the machine learning bits, you can
// largely ignore it
import * as ui from './ui';
/**
* Creates a convolutional neural network (Convnet) for the MNIST data.
*
* @returns {tf.Model} An instance of tf.Model.
*/
function createConvModel() {
// Create a sequential neural network model. tf.sequential provides an API
// for creating "stacked" models where the output from one layer is used as
// the input to the next layer.
const model = tf.sequential();
// The first layer of the convolutional neural network plays a dual role:
// it is both the input layer of the neural network and a layer that performs
// the first convolution operation on the input. It receives the 28x28 pixels
// black and white images. This input layer uses 16 filters with a kernel size
// of 5 pixels each. It uses a simple RELU activation function which pretty
// much just looks like this: __/
model.add(tf.layers.conv2d({
inputShape: [IMAGE_H, IMAGE_W, 1],
kernelSize: 3,
filters: 16,
activation: 'relu'
}));
// After the first layer we include a MaxPooling layer. This acts as a sort of
// downsampling using max values in a region instead of averaging.
// https://www.quora.com/What-is-max-pooling-in-convolutional-neural-networks
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
// Our third layer is another convolution, this time with 32 filters.
model.add(tf.layers.conv2d({kernelSize: 3, filters: 32, activation: 'relu'}));
// Max pooling again.
model.add(tf.layers.maxPooling2d({poolSize: 2, strides: 2}));
// Add another conv2d layer.
model.add(tf.layers.conv2d({kernelSize: 3, filters: 32, activation: 'relu'}));
// Now we flatten the output from the 2D filters into a 1D vector to prepare
// it for input into our last layer. This is common practice when feeding
// higher dimensional data to a final classification output layer.
model.add(tf.layers.flatten({}));
model.add(tf.layers.dense({units: 64, activation: 'relu'}));
// Our last layer is a dense layer which has 10 output units, one for each
// output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9). Here the classes actually
// represent numbers, but it's the same idea if you had classes that
// represented other entities like dogs and cats (two output classes: 0, 1).
// We use the softmax function as the activation for the output layer as it
// creates a probability distribution over our 10 classes so their output
// values sum to 1.
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
return model;
}
/**
* Creates a model consisting of only flatten, dense and dropout layers.
*
* The model create here has approximately the same number of parameters
* (~31k) as the convnet created by `createConvModel()`, but is
* expected to show a significantly worse accuracy after training, due to the
* fact that it doesn't utilize the spatial information as the convnet does.
*
* This is for comparison with the convolutional network above.
*
* @returns {tf.Model} An instance of tf.Model.
*/
function createDenseModel() {
const model = tf.sequential();
model.add(tf.layers.flatten({inputShape: [IMAGE_H, IMAGE_W, 1]}));
model.add(tf.layers.dense({units: 42, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
return model;
}
/**
* This callback type is used by the `train` function for insertion into
* the model.fit callback loop.
*
* @callback onIterationCallback
* @param {string} eventType Selector for which type of event to fire on.
* @param {number} batchOrEpochNumber The current epoch / batch number
* @param {tf.Logs} logs Logs to append to
*/
/**
* Compile and train the given model.
*
* @param {tf.Model} model The model to train.
* @param {onIterationCallback} onIteration A callback to execute every 10
* batches & epoch end.
*/
async function train(model, onIteration) {
ui.logStatus('Training model...');
// Now that we've defined our model, we will define our optimizer. The
// optimizer will be used to optimize our model's weight values during
// training so that we can decrease our training loss and increase our
// classification accuracy.
// We are using rmsprop as our optimizer.
// An optimizer is an iterative method for minimizing an loss function.
// It tries to find the minimum of our loss function with respect to the
// model's weight parameters.
const optimizer = 'rmsprop';
// We compile our model by specifying an optimizer, a loss function, and a
// list of metrics that we will use for model evaluation. Here we're using a
// categorical crossentropy loss, the standard choice for a multi-class
// classification problem like MNIST digits.
// The categorical crossentropy loss is differentiable and hence makes
// model training possible. But it is not amenable to easy interpretation
// by a human. This is why we include a "metric", namely accuracy, which is
// simply a measure of how many of the examples are classified correctly.
// This metric is not differentiable and hence cannot be used as the loss
// function of the model.
model.compile({
optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
// Batch size is another important hyperparameter. It defines the number of
// examples we group together, or batch, between updates to the model's
// weights during training. A value that is too low will update weights using
// too few examples and will not generalize well. Larger batch sizes require
// more memory resources and aren't guaranteed to perform better.
const batchSize = 320;
// Leave out the last 15% of the training data for validation, to monitor
// overfitting during training.
const validationSplit = 0.15;
// Get number of training epochs from the UI.
const trainEpochs = ui.getTrainEpochs();
// We'll keep a buffer of loss and accuracy values over time.
let trainBatchCount = 0;
const trainData = data.getTrainData();
const testData = data.getTestData();
const totalNumBatches =
Math.ceil(trainData.xs.shape[0] * (1 - validationSplit) / batchSize) *
trainEpochs;
// During the long-running fit() call for model training, we include
// callbacks, so that we can plot the loss and accuracy values in the page
// as the training progresses.
let valAcc;
await model.fit(trainData.xs, trainData.labels, {
batchSize,
validationSplit,
epochs: trainEpochs,
callbacks: {
onBatchEnd: async (batch, logs) => {
trainBatchCount++;
ui.logStatus(
`Training... (` +
`${(trainBatchCount / totalNumBatches * 100).toFixed(1)}%` +
` complete). To stop training, refresh or close page.`);
ui.plotLoss(trainBatchCount, logs.loss, 'train');
ui.plotAccuracy(trainBatchCount, logs.acc, 'train');
if (onIteration && batch % 10 === 0) {
onIteration('onBatchEnd', batch, logs);
}
await tf.nextFrame();
},
onEpochEnd: async (epoch, logs) => {
valAcc = logs.val_acc;
ui.plotLoss(trainBatchCount, logs.val_loss, 'validation');
ui.plotAccuracy(trainBatchCount, logs.val_acc, 'validation');
if (onIteration) {
onIteration('onEpochEnd', epoch, logs);
}
await tf.nextFrame();
}
}
});
const testResult = model.evaluate(testData.xs, testData.labels);
const testAccPercent = testResult[1].dataSync()[0] * 100;
const finalValAccPercent = valAcc * 100;
ui.logStatus(
`Final validation accuracy: ${finalValAccPercent.toFixed(1)}%; ` +
`Final test accuracy: ${testAccPercent.toFixed(1)}%`);
}
/**
* Show predictions on a number of test examples.
*
* @param {tf.Model} model The model to be used for making the predictions.
*/
async function showPredictions(model) {
const testExamples = 100;
const examples = data.getTestData(testExamples);
// Code wrapped in a tf.tidy() function callback will have their tensors freed
// from GPU memory after execution without having to call dispose().
// The tf.tidy callback runs synchronously.
tf.tidy(() => {
const output = model.predict(examples.xs);
// tf.argMax() returns the indices of the maximum values in the tensor along
// a specific axis. Categorical classification tasks like this one often
// represent classes as one-hot vectors. One-hot vectors are 1D vectors with
// one element for each output class. All values in the vector are 0
// except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The
// output from model.predict() will be a probability distribution, so we use
// argMax to get the index of the vector element that has the highest
// probability. This is our prediction.
// (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3)
// dataSync() synchronously downloads the tf.tensor values from the GPU so
// that we can use them in our normal CPU JavaScript code
// (for a non-blocking version of this function, use data()).
const axis = 1;
const labels = Array.from(examples.labels.argMax(axis).dataSync());
const predictions = Array.from(output.argMax(axis).dataSync());
ui.showTestResults(examples, predictions, labels);
});
}
function createModel() {
let model;
const modelType = ui.getModelTypeId();
if (modelType === 'ConvNet') {
model = createConvModel();
} else if (modelType === 'DenseNet') {
model = createDenseModel();
} else {
throw new Error(`Invalid model type: ${modelType}`);
}
return model;
}
let data;
async function load() {
data = new MnistData();
await data.load();
}
// This is our main function. It loads the MNIST data, trains the model, and
// then shows what the model predicted on unseen test data.
ui.setTrainButtonCallback(async () => {
ui.logStatus('Loading MNIST data...');
await load();
ui.logStatus('Creating model...');
const model = createModel();
model.summary();
ui.logStatus('Starting model training...');
await train(model, () => showPredictions(model));
});