ML Classifier is a machine learning engine for quickly training image classification models in your browser. Models can be saved with a single command, and the resulting models reused to make image classification predictions.
This package is intended as a companion for ml-classifier-ui
, which provides a web frontend in React for uploading data and seeing results.
A walkthrough of the code can be found in the article Image Classification in the Browser with Javascript.
An interactive demo can be found here.
ml-classifier
can be installed via yarn
or npm
:
yarn add ml-classifier
or
npm install ml-classifier
Start by instantiating a new MLClassifier.
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
Then, train the model:
await mlClassifier.train(imageData, {
callbacks: {
onTrainBegin: () => {
console.log('training begins');
},
onBatchEnd: (batch: any,logs: any) => {
console.log('Loss is: ' + logs.loss.toFixed(5));
}
},
});
And get predictions:
const prediction = await mlClassifier.predict(data);
When you have a trained model you're happy with, save it with:
mlClassifier.save();
When you hit save, Tensorflow.js will download a weights file and a model topology file.
You'll need to combine both into a single json
file. Open up your model topology file and at the top level of the JSON file, make sure to add a weightsManifest
key pointing to your weights, like:
{
"weightsManifest": "ml-classifier-class1-class2.weights.bin",
"modelTopology": {
...
}
}
When using the model in your app, there's a few things to keep in mind:
- You need to make sure you transform images into the correct dimensions, depending on the pretrained model it was trained with. (For MOBILENET, this would be 1x224x224x3).
- You must create a pretrained model matching the dimensions used to train. An example is below for MOBILENET.
- You must first run your images through the pretrained model to activate them.
- After getting the final prediction, you must take the arg max.
- You'll get back a number indicating your class.
Full example for MOBILENET:
const loadImage = (src) => new Promise((resolve, reject) => {
const image = new Image();
image.src = src;
image.crossOrigin = 'Anonymous';
image.onload = () => resolve(image);
image.onerror = (err) => reject(err);
});
const pretrainedModelURL = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';
tf.loadModel(pretrainedModelURL).then(model => {
const layer = model.getLayer('conv_pw_13_relu');
return tf.model({
inputs: [model.inputs[0]],
outputs: layer.output,
});
}).then(pretrainedModel => {
return tf.loadModel('/model.json').then(model => {
return loadImage('/trees/tree1.png').then(loadedImage => {
const image = tf.reshape(tf.fromPixels(loadedImage), [1,224,224,3]);
const pretrainedModelPrediction = pretrainedModel.predict(image);
const modelPrediction = model.predict(pretrainedModelPrediction);
const prediction = modelPrediction.as1D().argMax().dataSync()[0];
console.log(prediction);
});
});
}).catch(err => {
console.error('Error', err);
});
Start by instantiating a new instance of MLClassifier
with:
const mlClassifier = new MLClassifier();
This will begin loading the pretrained model and provide you with an object onto which to add data and train.
MLClassifier
accepts a number of callbacks for beginning and end of various methods.
You can provide a custom pretrained model as a pretrainedModel
.
You can provide a custom training model as a trainingModel
.
- pretrainedModel (
string | tf.Model
) Optional - A string denoting which pretrained model to load from an internal config. Valid strings can be found on the exported objectPRETRAINED_MODELS
. You can also specify a preloaded pretrained model directly. - trainingModel (
tf.Model | Function
) Optional - A custom model to use during training. Can be provided as atf.Model
or as a function that accepts{xs: [...], ys: [...]
, number ofclasses
, andparams
provided to train. - onLoadStart (
Function
) Optional - A callback for whenload
(loading the pre-trained model) is first called. - onLoadComplete (
Function
) Optional - A callback for whenload
(loading the pre-trained model) is complete. - onAddDataStart (
Function
) Optional - A callback for whenaddData
is first called. - onAddDataComplete (
Function
) Optional - A callback for whenaddData
is complete. - onClearDataStart (
Function
) Optional - A callback for whenclearData
is first called. - onClearDataComplete (
Function
) Optional - A callback for whenclearData
is complete. - onTrainStart (
Function
) Optional - A callback for whentrain
is first called. - onTrainComplete (
Function
) Optional - A callback for whentrain
is complete. - onEvaluateStart (
Function
) Optional - A callback for whenevaluate
is first called. - onEvaluateComplete (
Function
) Optional - A callback for whenevaluate
is complete. - onPredictStart (
Function
) Optional - A callback for whenpredict
is first called. - onPredictComplete (
Function
) Optional - A callback for whenpredict
is complete. - onSaveStart (
Function
) Optional - A callback for whensave
is first called. - onSaveComplete (
Function
) Optional - A callback for whensave
is complete.
import MLClassifier, {
PRETRAINED_MODELS,
} from 'ml-classifier';
const mlClassifier = new MLClassifier({
pretrainedModel: PRETRAINED_MODELS.MOBILENET,
onLoadStart: () => console.log('onLoadStart'),
onLoadComplete: () => console.log('onLoadComplete'),
onAddDataStart: () => console.log('onAddDataStart'),
onAddDataComplete: () => console.log('onAddDataComplete'),
onClearDataStart: () => console.log('onClearDataStart'),
onClearDataComplete: () => console.log('onClearDataComplete'),
onTrainStart: () => console.log('onTrainStart'),
onTrainComplete: () => console.log('onTrainComplete'),
onEvaluateStart: () => console.log('onEvaluateStart'),
onEvaluateComplete: () => console.log('onEvaluateComplete'),
onPredictStart: () => console.log('onPredictStart'),
onPredictComplete: () => console.log('onPredictComplete'),
onSaveStart: () => console.log('onSaveStart'),
onSaveComplete: () => console.log('onSaveComplete'),
});
Example of specifying a preloaded pretrained model:
import MLClassifier from 'ml-classifier';
const mlClassifier = tf.loadModel('... some pretrained model ...').then(model => {
return new MLClassifier({
pretrainedModel: model,
});
});
This method takes an array of incoming images, an optional array of labels, and an optional dataType.
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, 'train');
- images (
Array<tf.Tensor3D | ImageData | HTMLImageElement | string>
) - an array of 3D tensors, ImageData (output from a canvastoPixels
, a native browserImage
, or a string representing the imagesrc
. Images can be any sizes, but will be cropped and sized down to match the pretrained model. - labels (
string[]
) - an array of strings, matching the images passed above. - dataType (
string
) Optional - an enum specifying which data type the images match. Data types can betrain
for data used inmodel.train()
, andeval
, for data used inmodel.evaluate()
. If no argument is supplied,dataType
will default totrain
.
Nothing.
train
begins training on the given dataset.
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.train({
callbacks: {
onTrainBegin: () => {
console.log('training begins');
},
},
});
- params (
Object
) Optional - a set of parameters that will be passed directly tomodel.fit
. View the Tensorflow.JS docs for an up-to-date list of arguments.
train
returns the resolved promise from fit
, an object containing loss and accuracy.
evaluate
is used to evaluate a model's performance.
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.train();
mlClassifier.addData(evaluationImages, labels, DataType.EVALUATE);
mlClassifier.evaluate();
- params (
Object
) Optional - a set of parameters that will be passed directly tomodel.evaluate
. View the Tensorflow.JS docs for an up-to-date list of arguments.
evaluate
returns a tf.Scalar representing the result of evaluate
.
predict
is used to make a specific prediction using a saved model.
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.train();
mlClassifier.predict(imageToPredict);
- image (
tf.Tensor3D
) - a single image encoded as atf.Tensor3D
. Image can be any size, but will be cropped and sized down to match the pretrained model.
predict
will return a string matching the prediction.
save
is a proxy to tf.model.save
, and will initiate a download from the browser, or save to local storage.
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.train();
mlClassifier.save(('path-to-save');
- handlerOrUrl (
io.IOHandler | string
) Optional - an argument to be passed tomodel.save
. If omitted, the model's unique labels will be concatenated together in the form ofclass1-class2-class3
. - params (
Object
) Optional - a set of parameters that will be passed directly tomodel.save
. View the Tensorflow.JS docs for an up-to-date list of arguments.
getModel
will return the trained Tensorflow.js model. Calling this method prior to calling mlClassifier.train
will return null
.
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.train();
mlClassifier.getModel();
None.
The saved Tensorflow.js model.
clearData
will clear out saved data.
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.clearData(DataType.TRAIN);
- dataType (
DataType
) Optional - specifies which data to clear. If no argument is provided, all data will be cleared.
Nothing.
Contributions are welcome!
You can start up a local copy of ml-classifier
with:
yarn watch
ml-classifier
is written in Typescript.
Tests are a work in progress. Currently, the test suite only consists of unit tests. Pull requests for additional tests are welcome!
Run tests with:
yarn test
This project is licensed under the MIT License - see the LICENSE file for details