-
Notifications
You must be signed in to change notification settings - Fork 0
/
pix2pix.js
107 lines (91 loc) · 4.53 KB
/
pix2pix.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
class Pix2pix {
constructor(model, callback) {
this.ready = false;
this.loadCheckpoints(model).then(() => {
this.ready = true;
if (callback) {
callback();
}
});
}
async loadCheckpoints(path) {
this.weights = await fetchWeights(path);
}
async transfer(inputElement, callback = () => {}) {
tf.setBackend('cpu');
const input = tf.browser.fromPixels(inputElement);
const inputData = input.dataSync();
const floatInput = tf.tensor3d(inputData, input.shape, 'float32');
const normalizedInput = tf.div(floatInput, tf.scalar(255.0));
function preprocess(inputPreproc) {
return tf.sub(tf.mul(inputPreproc, tf.scalar(2)), tf.scalar(1));
}
function deprocess(inputDeproc) {
return tf.div(tf.add(inputDeproc, tf.scalar(1)), tf.scalar(2));
}
function batchnorm(inputBat, scale, offset) {
const moments = tf.moments(inputBat, [0, 1]);
const varianceEpsilon = 1e-5;
return tf.batchNormalization(inputBat, moments.mean, moments.variance, varianceEpsilon, scale, offset);
}
function conv2d(inputCon, filterCon) {
return tf.conv2d(inputCon, filterCon, [2, 2], 'same');
}
function deconv2d(inputDeconv, filterDeconv, biasDecon) {
const convolved = tf.conv2dTranspose(inputDeconv, filterDeconv, [inputDeconv.shape[0] * 2, inputDeconv.shape[1] * 2, filterDeconv.shape[2]], [2, 2], 'same');
const biased = tf.add(convolved, biasDecon);
return biased;
}
const result = tf.tidy(() => {
const preprocessedInput = preprocess(normalizedInput);
const layers = [];
let filter = this.weights['generator/encoder_1/conv2d/kernel'];
let bias = this.weights['generator/encoder_1/conv2d/bias'];
let convolved = conv2d(preprocessedInput, filter, bias);
layers.push(convolved);
for (let i = 2; i <= 8; i += 1) {
const scope = `generator/encoder_${i.toString()}`;
filter = this.weights[`${scope}/conv2d/kernel`];
const bias2 = this.weights[`${scope}/conv2d/bias`];
const layerInput = layers[layers.length - 1];
const rectified = tf.leakyRelu(layerInput, 0.2);
convolved = conv2d(rectified, filter, bias2);
const scale = this.weights[`${scope}/batch_normalization/gamma`];
const offset = this.weights[`${scope}/batch_normalization/beta`];
const normalized = batchnorm(convolved, scale, offset);
layers.push(normalized);
}
for (let i = 8; i >= 2; i -= 1) {
let layerInput;
if (i === 8) {
layerInput = layers[layers.length - 1];
} else {
const skipLayer = i - 1;
layerInput = tf.concat([layers[layers.length - 1], layers[skipLayer]], 2);
}
const rectified = tf.relu(layerInput);
const scope = `generator/decoder_${i.toString()}`;
filter = this.weights[`${scope}/conv2d_transpose/kernel`];
bias = this.weights[`${scope}/conv2d_transpose/bias`];
convolved = deconv2d(rectified, filter, bias);
const scale = this.weights[`${scope}/batch_normalization/gamma`];
const offset = this.weights[`${scope}/batch_normalization/beta`];
const normalized = batchnorm(convolved, scale, offset);
layers.push(normalized);
}
const layerInput = tf.concat([layers[layers.length - 1], layers[0]], 2);
let rectified2 = tf.relu(layerInput);
filter = this.weights['generator/decoder_1/conv2d_transpose/kernel'];
const bias3 = this.weights['generator/decoder_1/conv2d_transpose/bias'];
convolved = deconv2d(rectified2, filter, bias3);
rectified2 = tf.tanh(convolved);
layers.push(rectified2);
const output = layers[layers.length - 1];
const deprocessedOutput = deprocess(output);
return deprocessedOutput;
});
await tf.nextFrame();
callback(array3DToImage(result));
}
}
const pix2pix = (model, callback = () => {}) => new Pix2pix(model, callback);