The goal of this article is to answer three simple question,
- What is batch normalization ?
- Why we need batch normalization ?
- How to use it in a neural network ?
What is batch normalization ?
Batch normalization refers to the normalization done over a batch of data in between of hidden layers. Consider the diagram below for complete understanding.
The normalization is performed on the activation values of previous hidden layer then it is passed to next hidden layer. The normalization performed in this step is same as the normalization performed before input layer, subtract the mean then divide it by its standard-deviation.
Why do we need batch normalization ?
Before reading the answer of this question you must be familiar with working of neural network.
So before explaining batch normalization let’s first understand why we do normalization before input layer, consider a situation where we have a dataset of numeric features the scale of the features given below,
f1 : 0-1
f2 : 200-1000
f3 -4 – 200
now if we try to train the network on this data without normalizing there will be following implications,
- The features which corresponds to high values will affects the output most irrespective how important it is,
- Trained coefficients over this data will not be able to generalize very well
- If any query comes which have different scale for these features then chances are the output for that query will be wrong.
By observing the implications mentioned above we can say following about normalization,
- normalization brings the scale of each features within a range of -1 to 1
- After normalization it very unlikely that the output is affected by the scale of features
- From second point we can also infer that it makes the prediction more dependent on feature importance rather than on its scale
- This also enhances the ability of the network to generalize well
Now let’s dive into batch normalization, From the demo of above given neural network where it contain 3 hidden layers, one input layer, and one output layer with a batch normalization layer in between of hidden layer 2nd and 3rd. So why to we need to normalize the output of the hidden layer 2, consider this explanation lets say we are training this network over some data then the amount of the variance experienced by each layer increases as we built the neural network deeper and deeper so the variance experienced by the input layer is less than what is experienced by 3rd hidden layer, due to this fact the training of neural network tends to become very difficult. There are some others factors which are also affected due to this,
- Due to this high variance we have to set the learning rate low, as if we choose a large learning rate chances are the neural network will skip patters during training
- This high variance also affect the performance of the network over test data.
In order to minimize the effect of this variance which is generated when the network grows deeper we use batch normalization. This means we can use batch normalization multiple times ? the answer is yes.
How to use it in a neural network ?
Any deep-learning framework can be used to implement a batch normalization layer, we will use tensorflow for this explanation.
Calculate the mean and variance of current batch
batch_mean, batch_var = tf.nn.moments(hidden_layer_2,) # here  indicate the axis on which we are calculating mean and std-dev
gamma = tf.Variable(tf.ones([hidden_layer_2])) x = tf.variable(tf.ones([hidden_layer_2])) beta = tf.Variable(tf.zeros([hidden_layer_2])) hidden_layer_2 = tf.nn.batch_normalization(batch_mean, batch_var, beta, gamma, variance_epsilon=1e-05)
You must be thinking what are all these parameters, this all is done because the batch normalization function in tensorflow does some thing different, take a look at the equation below,
This is what the function does internally so make make normalization simple I am multiplying it by one and adding zero to make the equation something like this,
which is simple normalization equation. Now you completely understand batch normalization and how to implement it.