#008 TF An implementation of a Shallow Neural Network in tf.keras – digits dataset
In this post we will see how we can classify handwritten digits using shallow neural network implemented with tf.keras.
Table of Contents:
1. Load the digit dataset
First, let us import all necessary libraries.
After imports, we can use imported module to load digits data. The load_digits() function will just download data and we need to split it into train and test sets.
We can also plot some digits to see how they look.
Many machine learning algorithms cannot operate on label data directly. They require all input variables and output variables to be numeric. This means that categorical data must be converted to a numerical form. For that reason, we will perform one hot encoding. (1)
2. Implementing a Neural Network
When all data is loaded and prepared, it is time to create a model. We will use as simple Sequential API in order to do this. Our model will have 2 layers, with 64(height x width) neurons in the input layer, 64 in the hidden layer and 10 neurons in the output layer.
We will use normal initializer that generates tensors with a normal distribution.
The optimizer we’ll use is Adam .It is an optimization algorithm that can be used instead of the classical stochastic gradient descent procedure to update network weights iterative based on training data. Adam is a popular algorithm in the field of deep learning because it achieves good results fast.
To make this work, we need to compile a model. An important choice to make is the loss function. We use the categorical_crossentropy loss because it measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class).
3. Visualization and Testing
Let’s now visualize the outputs of our neural network.
In the next post we will learn how to perform classification with a convolutional neural network on the MNIST Dataset using tensorflow.
More resources on the topic:
- Why One-Hot Encode Data in Machine Learning? (1)
- Practical Machine Learning with Python and Keras
- An implementation of a Convolutional Neural Network in tf.keras – MNIST dataset