forked from huanzhang12/CLEVER
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_2layer.py
executable file
·81 lines (66 loc) · 2.95 KB
/
train_2layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
## train_2layer.py -- train MLP models for MNIST and CIFAR
##
## Copyright (C) 2017, Lily Weng <[email protected]>
## and Huan Zhang <[email protected]>
##
## This program is licenced under the BSD 2-Clause licence,
## contained in the LICENCE file in this directory.
import numpy as np
from tensorflow.contrib.keras.api.keras.models import Sequential
from tensorflow.contrib.keras.api.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.contrib.keras.api.keras.layers import Conv2D, MaxPooling2D
from tensorflow.contrib.keras.api.keras.layers import Lambda
from tensorflow.contrib.keras.api.keras.models import load_model
from tensorflow.contrib.keras.api.keras.optimizers import SGD
import tensorflow as tf
from setup_mnist import MNIST
from setup_cifar import CIFAR
import os
def train(data, file_name, params, num_epochs=50, batch_size=128, train_temp=1, init=None, lr=0.01, decay=1e-5, momentum=0.9):
"""
Train a 2-layer simple network for MNIST and CIFAR
"""
# create a Keras sequential model
model = Sequential()
# reshape the input (28*28*1) or (32*32*3) to 1-D
model.add(Flatten(input_shape=data.train_data.shape[1:]))
# first dense layer (the hidden layer)
model.add(Dense(params[0]))
# \alpha = 10 in softplus, multiply input by 10
model.add(Lambda(lambda x: x * 10))
# in Keras the softplus activation cannot set \alpha
model.add(Activation('softplus'))
# so manually add \alpha to the network
model.add(Lambda(lambda x: x * 0.1))
# the output layer, with 10 classes
model.add(Dense(10))
# load initial weights when given
if init != None:
model.load_weights(init)
# define the loss function which is the cross entropy between prediction and true label
def fn(correct, predicted):
return tf.nn.softmax_cross_entropy_with_logits(labels=correct,
logits=predicted/train_temp)
# initiate the SGD optimizer with given hyper parameters
sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
# compile the Keras model, given the specified loss and optimizer
model.compile(loss=fn,
optimizer=sgd,
metrics=['accuracy'])
# run training with given dataset, and print progress
model.fit(data.train_data, data.train_labels,
batch_size=batch_size,
validation_data=(data.validation_data, data.validation_labels),
epochs=num_epochs,
shuffle=True)
# save model to a file
if file_name != None:
model.save(file_name)
return model
if not os.path.isdir('models'):
os.makedirs('models')
if __name__ == "__main__":
train(MNIST(), file_name="models/mnist_2layer", params=[1024], num_epochs=50, lr=0.1, decay=1e-3)
train(CIFAR(), file_name="models/cifar_2layer", params=[1024], num_epochs=50, lr=0.2, decay=1e-3)