First of all, the images needs to be converted to tensors. The first approach would be to create a tensor containing all the features (respectively a tensor containing all the labels). This should the way to go only if the dataset contains few images.
const imageBuffer = await fs.readFile(feature_file);
tensorFeature = tfnode.node.decodeImage(imageBuffer) // create a tensor for the image
// create an array of all the features
// by iterating over all the images
tensorFeatures = tf.stack([tensorFeature, tensorFeature2, tensorFeature3])
The labels would be an array indicating the type of each image
labelArray = [0, 1, 2] // maybe 0 for dog, 1 for cat and 2 for birds
One needs now to create a hot encoding of the labels
tensorLabels = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 3);
Once there is the tensors, one would need to create the model for training. Here is a simple model.
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [height, width, numberOfChannels], // numberOfChannels = 3 for colorful images and one otherwise
filters: 32,
kernelSize: 3,
activation: 'relu',
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 3, activation: 'softmax'}));
Then the model can be trained
model.fit(tensorFeatures, tensorLabels)
If the dataset contains a lot of images, one would need to create a tfDataset instead. This answer discusses why.
const genFeatureTensor = image => {
const imageBuffer = await fs.readFile(feature_file);
return tfnode.node.decodeImage(imageBuffer)
}
const labelArray = indice => Array.from({length: numberOfClasses}, (_, k) => k === indice ? 1 : 0)
function* dataGenerator() {
const numElements = numberOfImages;
let index = 0;
while (index < numFeatures) {
const feature = genFeatureTensor(imagePath);
const label = tf.tensor1d(labelArray(classImageIndex))
index++;
yield {xs: feature, ys: label};
}
}
const ds = tf.data.generator(dataGenerator).batch(1) // specify an appropriate batchsize;
And use model.fitDataset(ds)
to train the model
The above is for training in nodejs. To do such a processing in the browser, genFeatureTensor
can be written as follow:
function loadImage(url){
return new Promise((resolve, reject) => {
const im = new Image()
im.crossOrigin = 'anonymous'
im.src="https://stackoverflow.com/questions/58953399/url"
im.onload = () => {
resolve(im)
}
})
}
genFeatureTensor = image => {
const img = await loadImage(image);
return tf.browser.fromPixels(image);
}
One word of caution is that doing heavy processing might block the main thread in the browser. This is where web workers come into play.