Predicting Handwritten Numbers with the MNIST dataset and TensorFlow.js…

Using an already trained network from a previous project, it is possible to make a tool that predicts what the value of number actually is from a primitive sketch of it.

The overall tool is a small web app consisting of few different elements, namely:

  • A pretrained convolutional neural network (CNN) based on the MNIST dataset. Note that the model and it weights are saved after training and then loaded when a prediction is to be made
  • TensorFlow.js which will load used to model on a webpage
  • A 2D HTML Canvas element where one can draw a number that will be used for making a prediction
  • D3.js is used to make a graph of the outputted prediction values
  • And other miscellaneous Javascript code to make a simple user interface and have the above remaining elements working with each other

Using a Pretrained Network…

In the environment where the original model was trained, it can be exported to a format that is readable by TensorFlow.js. Note the supported model should be Keras based and needs to only use just the built in layer functions (fancier/custom layers are not supported).

In one’s environment, the the tensorflowjs package needs to be installed.

pip install tensorflowjs

From the command line, the following needs to be executed:

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Canvas Drawing…

The big issue with drawing on the canvas is to make sure that the correct offset is used between the canvas and its parent object. In our code, everything is placed in a ‘contents’ div (which groups everything together so it can be centered).

const rect = document.getElementById('contents').getBoundingClientRect();
        
pos.x = event.clientX - rect.left;
pos.y = event.clientY - rect.top;

The drawing is done by: 1) setting parameters; 2) moving to the previous end position; 3) making a line to the current (offsetted) mouse position; and 4) drawing it on the screen (using stroke).

ctx.beginPath();

ctx.lineWidth = 30;
ctx.lineCap = 'round';
ctx.strokeStyle = '#000';

ctx.moveTo (pos.x, pos.y); // from
setPosition (e);
ctx.lineTo (pos.x, pos.y); // to

// draw it!
ctx.stroke();

Preparing the Data…

The data needs to be prepared/conditioned so that it is in the same format as what the model expects it to be. The main things involved are resizing the image data and adjusting the dimensions of the tensor (so its [1, 28, 28, 1] corresponding to the [batch size, image x, image y, and channel number]). Also, the image data needs to be normalized and the image needs to be taken (the inverse is required since the model is based on white numbers on a black background and the drawing is using a black numbers on a white background).

img = tf.image.resizeBilinear (img, [28, 28]).toFloat();
img  = img.mean (2)
          .toFloat ()
          .expandDims (0)
          .expandDims (-1);
img = img.div (tf.scalar (255.0));
img = tf.scalar (1.0).sub (img);

Using TensorFlow.js…

To use tensorflow.js, first the required script should be included at the top of the page:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>

When using the model to make a prediction, it needs to be placed in a separate function since it uses the Javascript await/async code pattern. The loadLayersModel function is provided by tensorflow which builds the pretrained model and assigns all the weights. Once it is built, it can be used to predict (or infer) what the image is (note that arraySync converts from a tensor to a Javascript array).

async function predict ()
{
    const model = await tf.loadLayersModel ('Model_CNN/model.json');

    model.summary ();

    results = model.predict(img).arraySync()[0];
}

Final Results…

Below is the small app that has all the elements listed above (click here to view it in a separate tab).

Overall things things behaved fairly well with numbers that didn’t have a huge amount of deviation to their original. Below are some examples of some correct (true positive) results.

Some interesting results were obtained when the drawings were more abstract.

Code…

All of the code along with the pretrained model is available here.

 

No Comments

Add your comment