1 00:00:00,000 --> 00:00:03,440 Hi everyone. Today we are continuing our implementation of MakeMore. 2 00:00:04,080 --> 00:00:07,440 Now in the last lecture we implemented the multilayer perceptron along the lines of 3 00:00:07,440 --> 00:00:12,080 Bengio et al. 2003 for character level language modeling. So we followed this paper, 4 00:00:12,080 --> 00:00:16,240 took in a few characters in the past, and used an MLP to predict the next character in a sequence. 5 00:00:17,120 --> 00:00:21,360 So what we'd like to do now is we'd like to move on to more complex and larger neural networks, 6 00:00:21,360 --> 00:00:25,360 like recurrent neural networks and their variations like the GRU, LSTM, and so on. 7 00:00:25,360 --> 00:00:30,400 Now before we do that though, we have to stick around the level of multilayer perceptron for a 8 00:00:30,400 --> 00:00:34,480 bit longer. And I'd like to do this because I would like us to have a very good intuitive 9 00:00:34,480 --> 00:00:39,440 understanding of the activations in the neural net during training, and especially the gradients 10 00:00:39,440 --> 00:00:44,240 that are flowing backwards, and how they behave, and what they look like. This is going to be very 11 00:00:44,240 --> 00:00:47,360 important to understand the history of the development of these architectures, 12 00:00:47,920 --> 00:00:51,840 because we'll see that recurrent neural networks, while they are very expressive 13 00:00:51,840 --> 00:00:55,280 in that they are a universal approximator and can in principle implement 14 00:00:55,280 --> 00:00:55,360 a very complex neural network, they are very expressive in that they are a universal approximator, 15 00:00:55,760 --> 00:01:01,120 all the algorithms, we'll see that they are not very easily optimisable with the first order of 16 00:01:01,120 --> 00:01:05,360 gradient based techniques that we have available to us and that we use all the time. And the key 17 00:01:05,360 --> 00:01:10,560 to understanding why they are not optimisable easily is to understand the activations and the 18 00:01:10,560 --> 00:01:14,720 gradients and how they behave during training. And we'll see that a lot of the variants since 19 00:01:14,720 --> 00:01:20,800 recurrent neural networks have tried to improve that situation. And so that's the path that we 20 00:01:20,800 --> 00:01:24,400 have to take, and let's get started. So the starting code for this lecture is 21 00:01:24,400 --> 00:01:29,040 largely the code from before, but I've cleaned it up a little bit. So you'll see that we are 22 00:01:29,040 --> 00:01:35,360 importing all the Torch and Mathplotlib utilities. We're reading in the words just like before. 23 00:01:35,360 --> 00:01:40,240 These are eight example words. There's a total of 32,000 of them. Here's a vocabulary of all 24 00:01:40,240 --> 00:01:47,600 the lowercase letters and the special dot token. Here we are reading the dataset and processing it 25 00:01:47,600 --> 00:01:54,320 and creating three splits, the train, dev, and the test split. Now in the MLP, 26 00:01:54,320 --> 00:01:58,960 this is the identical same MLP, except you see that I removed a bunch of magic numbers that we 27 00:01:58,960 --> 00:02:03,680 had here. And instead we have the dimensionality of the embedding space of the characters and the 28 00:02:03,680 --> 00:02:08,320 number of hidden units in the hidden layer. And so I've pulled them outside here so that we don't 29 00:02:08,320 --> 00:02:12,640 have to go in and change all these magic numbers all the time. We have the same neural net with 30 00:02:12,640 --> 00:02:18,480 11,000 parameters that we optimize now over 200,000 steps with a batch size of 32. And you'll 31 00:02:18,480 --> 00:02:23,760 see that I've refactored the code here a little bit, but there are no functional changes. I just 32 00:02:24,320 --> 00:02:30,240 have a few extra variables, a few more comments, and I removed all the magic numbers. And otherwise 33 00:02:30,240 --> 00:02:35,360 it's the exact same thing. Then when we optimize, we saw that our loss looked something like this. 34 00:02:36,160 --> 00:02:44,480 We saw that the train and val loss were about 2.16 and so on. Here I refactored the code a little bit 35 00:02:44,480 --> 00:02:49,360 for the evaluation of arbitrary splits. So you pass in a string of which split you'd like to 36 00:02:49,360 --> 00:02:54,160 evaluate. And then here, depending on train, val, or test, I index in, 37 00:02:54,160 --> 00:02:57,840 get the correct split. And then this is the forward pass of the network and evaluation 38 00:02:57,840 --> 00:03:04,400 of the loss and printing it. So just making it nicer. One thing that you'll notice here is 39 00:03:05,200 --> 00:03:11,280 I'm using a decorator torch.nograd, which you can also look up and read documentation of. 40 00:03:11,280 --> 00:03:16,560 Basically what this decorator does on top of a function is that whatever happens in this function 41 00:03:17,600 --> 00:03:23,840 is assumed by torch to never require any gradients. So it will not do any of the bookkeeping, 42 00:03:24,160 --> 00:03:28,720 that it does to keep track of all the gradients in anticipation of an eventual backward pass. 43 00:03:29,440 --> 00:03:33,040 It's almost as if all the tensors that get created here have a requires grad 44 00:03:33,040 --> 00:03:37,360 of false. And so it just makes everything much more efficient because you're telling torch that 45 00:03:37,360 --> 00:03:41,440 I will not call dot backward on any of this computation, and you don't need to maintain 46 00:03:41,440 --> 00:03:47,760 the graph under the hood. So that's what this does. And you can also use a context manager 47 00:03:48,400 --> 00:03:54,080 with torch.nograd, and you can look those up. Then here we have the sampling from 48 00:03:54,160 --> 00:03:59,440 a model just as before, just before passive and neural net, getting the distribution, 49 00:03:59,440 --> 00:04:04,160 sampling from it, adjusting the context window, and repeating until we get the special end token. 50 00:04:04,720 --> 00:04:09,680 And we see that we are starting to get much nicer looking words sampled from the model. 51 00:04:09,680 --> 00:04:12,640 It's still not amazing, and they're still not fully name-like, 52 00:04:13,200 --> 00:04:19,040 but it's much better than what we had to with the bigram model. So that's our starting point. 53 00:04:19,040 --> 00:04:23,920 Now, the first thing I would like to scrutinize is the initialization. I can tell that our 54 00:04:24,160 --> 00:04:28,240 work is very improperly configured at initialization, and there's multiple 55 00:04:28,240 --> 00:04:32,720 things wrong with it, but let's just start with the first one. Look here on the zeroth iteration, 56 00:04:32,720 --> 00:04:38,880 the very first iteration, we are recording a loss of 27, and this rapidly comes down to roughly one 57 00:04:38,880 --> 00:04:43,360 or two or so. So I can tell that the initialization is all messed up because this is way too high. 58 00:04:44,240 --> 00:04:48,320 In training of neural nets, it is almost always the case that you will have a rough idea for what 59 00:04:48,320 --> 00:04:53,760 loss to expect at initialization, and that just depends on the loss function and the problem setup. 60 00:04:54,640 --> 00:04:59,440 In this case, I do not expect 27. I expect a much lower number, and we can calculate it together. 61 00:05:00,560 --> 00:05:07,120 Basically, at initialization, what we'd like is that there's 27 characters that could come next 62 00:05:07,120 --> 00:05:11,840 for any one training example. At initialization, we have no reason to believe any characters to be 63 00:05:11,840 --> 00:05:16,640 much more likely than others. And so we'd expect that the probability distribution that comes out 64 00:05:16,640 --> 00:05:22,160 initially is a uniform distribution assigning about equal probability to all the 27 characters. 65 00:05:23,360 --> 00:05:24,080 So, basically, we're going to start with the initialization. We're going to start with the 66 00:05:24,080 --> 00:05:28,560 initialization. Basically, what we'd like is the probability for any character would be roughly 67 00:05:28,560 --> 00:05:35,440 1 over 27. That is the probability we should record, and then the loss is the negative log 68 00:05:35,440 --> 00:05:42,160 probability. So let's wrap this in a tensor, and then we can take the log of it. And then 69 00:05:42,160 --> 00:05:49,040 the negative log probability is the loss we would expect, which is 3.29, much, much lower than 27. 70 00:05:49,680 --> 00:05:53,680 And so what's happening right now is that at initialization, the neural net is creating 71 00:05:53,680 --> 00:05:57,920 probability distributions that are all messed up. Some characters are very confident, 72 00:05:57,920 --> 00:06:02,000 and some characters are very not confident. And then, basically, what's happening is that 73 00:06:02,000 --> 00:06:10,480 the network is very confidently wrong, and that's what makes it record very high loss. 74 00:06:10,480 --> 00:06:14,320 So here's a smaller four-dimensional example of the issue. Let's say we only have 75 00:06:14,320 --> 00:06:18,480 four characters, and then we have logits that come out of the neural net, 76 00:06:18,480 --> 00:06:23,200 and they are very, very close to 0. Then when we take the softmax of all 0s, 77 00:06:23,680 --> 00:06:29,600 we get probabilities there are a diffuse distribution so sums to one and is exactly 78 00:06:29,600 --> 00:06:35,840 uniform and then in this case if the label is say two it doesn't actually matter if this if the 79 00:06:35,840 --> 00:06:40,720 label is two or three or one or zero because it's a uniform distribution we're recording the exact 80 00:06:40,720 --> 00:06:45,520 same loss in this case 1.38 so this is the loss we would expect for a four-dimensional example 81 00:06:46,160 --> 00:06:50,960 and i can see of course that as we start to manipulate these logits we're going to be 82 00:06:50,960 --> 00:06:56,640 changing the loss here so it could be that we lock out and by chance this could be a very high 83 00:06:56,640 --> 00:07:00,960 number like you know five or something like that then in that case we'll record a very low loss 84 00:07:00,960 --> 00:07:05,760 because we're assigning the correct probability at initialization by chance to the correct label 85 00:07:06,640 --> 00:07:14,640 much more likely it is that some other dimension will have a high logit and then what will happen 86 00:07:14,640 --> 00:07:19,360 is we start to record much higher loss and what can come what can happen is basically the logits 87 00:07:19,360 --> 00:07:20,960 come out like something like this 88 00:07:21,680 --> 00:07:26,000 you know and they take on extreme values and we record really high loss 89 00:07:28,400 --> 00:07:35,680 for example if we have torch.random of four so these are uniform so these are normally distributed 90 00:07:36,800 --> 00:07:45,280 numbers four of them and here we can also print the logits probabilities that come out of it 91 00:07:45,280 --> 00:07:50,800 and loss and so because these logits are near zero for the most part 92 00:07:50,960 --> 00:07:56,240 the loss that comes out is is okay but suppose this is like times 10 now 93 00:07:58,960 --> 00:08:04,320 you see how because these are more extreme values it's very unlikely that you're going to be guessing 94 00:08:04,320 --> 00:08:10,400 the correct bucket and then you're confidently wrong and recording very high loss if your logits 95 00:08:10,400 --> 00:08:16,800 are coming up even more extreme you might get extremely insane losses like infinity even at 96 00:08:16,800 --> 00:08:17,600 initialization 97 00:08:18,400 --> 00:08:18,880 um 98 00:08:20,960 --> 00:08:27,760 this is not good and we want the logits to be roughly zero um when the network is initialized 99 00:08:27,760 --> 00:08:31,120 in fact the logits can don't have to be just zero they just have to be equal 100 00:08:31,120 --> 00:08:36,560 so for example if all the logits are one then because of the normalization inside the softmax 101 00:08:36,560 --> 00:08:41,120 this will actually come out okay but by symmetry we don't want it to be any arbitrary positive or 102 00:08:41,120 --> 00:08:46,240 negative number we just want it to be all zeros and record the loss that we expect at initialization 103 00:08:46,240 --> 00:08:50,720 so let's now concretely see where things go wrong in our example here we have the initialization 104 00:08:51,360 --> 00:08:56,160 let me reinitialize the neural net and here let me break after the very first iteration so we 105 00:08:56,160 --> 00:09:02,560 only see the initial loss which is 27. so that's way too high and intuitively now we can expect 106 00:09:02,560 --> 00:09:07,440 the variables involved and we see that the logits here if we just print some of these 107 00:09:09,360 --> 00:09:12,960 if we just print the first row we see that the logits take on quite extreme values 108 00:09:13,680 --> 00:09:20,960 and that's what's creating the fake confidence in incorrect answers and makes the loss um get the 109 00:09:20,960 --> 00:09:26,400 very very high so these logits should be much much closer to zero so now let's think through 110 00:09:26,400 --> 00:09:32,800 how we can achieve logits coming out of this neural net to be more closer to zero you see 111 00:09:32,800 --> 00:09:38,560 here that logits are calculated as the hidden states multiplied by w2 plus b2 so first of all 112 00:09:38,560 --> 00:09:46,480 currently we're initializing b2 as random values of the right size but because we want roughly zero 113 00:09:46,480 --> 00:09:50,720 we don't actually want to be adding a bias of random numbers so in fact i'm going to add it times 114 00:09:50,960 --> 00:09:58,960 zero here to make sure that b2 is just basically zero at initialization and second this is h 115 00:09:58,960 --> 00:10:04,720 multiplied by w2 so if we want logits to be very very small then we would be multiplying w2 and 116 00:10:04,720 --> 00:10:12,720 making that smaller so for example if we scale down w2 by 0.1 all the elements then if i do 117 00:10:12,720 --> 00:10:17,200 again just the very first iteration you see that we are getting much closer to what we expect 118 00:10:17,200 --> 00:10:20,720 so the rough roughly what we want is about 3.29 this is 119 00:10:20,960 --> 00:10:28,000 0.2 i can make this maybe even smaller 3.32 okay so we're getting closer and closer 120 00:10:28,560 --> 00:10:31,920 now you're probably wondering can we just set this to zero 121 00:10:33,040 --> 00:10:39,280 then we get of course exactly what we're looking for at initialization and the reason i don't 122 00:10:39,280 --> 00:10:44,560 usually do this is because i'm i'm very nervous and i'll show you in a second why you don't want 123 00:10:44,560 --> 00:10:50,480 to be setting w's or weights of a neural net exactly to zero you usually want it to be small 124 00:10:50,960 --> 00:10:57,360 instead of exactly zero for this output layer in this specific case i think it would be fine but 125 00:10:57,360 --> 00:11:01,200 i'll show you in a second where things go wrong very quickly if you do that so let's just go with 126 00:11:01,200 --> 00:11:08,320 0.01 in that case our loss is close enough but has some entropy it's not exactly zero 127 00:11:08,320 --> 00:11:12,560 it's got some little entropy and that's used for symmetry breaking as we'll see in a second 128 00:11:13,520 --> 00:11:20,400 logits are now coming out much closer to zero and everything is well and good so if i just erase these 129 00:11:21,040 --> 00:11:27,680 and i now take away the break statement we can run the optimization with this new initialization 130 00:11:28,240 --> 00:11:35,040 and let's just see what losses we record okay so i'll let it run and you see that we started off 131 00:11:35,040 --> 00:11:41,280 good and then we came down a bit the plot of the loss uh now doesn't have this hockey shape 132 00:11:41,280 --> 00:11:47,360 appearance um because basically what's happening in the hockey stick the very first few iterations 133 00:11:47,360 --> 00:11:50,960 of the loss what's happening during the optimization is the optimization of the loss is 134 00:11:50,960 --> 00:11:55,920 just squashing down the logits and then it's rearranging the logits so basically we took 135 00:11:55,920 --> 00:12:00,960 away this easy part of the loss function where just the the weights were just being shrunk down 136 00:12:01,760 --> 00:12:05,600 and so therefore we don't we don't get these easy gains in the beginning 137 00:12:05,600 --> 00:12:08,720 and we're just getting some of the hard games of training the actual neural net 138 00:12:08,720 --> 00:12:14,000 and so there's no hockey stick appearance so good things are happening in that both number 139 00:12:14,000 --> 00:12:20,000 one loss initialization is what we expect and the the loss doesn't look like a hockey stick 140 00:12:20,960 --> 00:12:26,880 is true for any neural net you might train and something to look out for and second the loss 141 00:12:26,880 --> 00:12:31,760 that came out is actually quite a bit improved unfortunately i erased what we had here before 142 00:12:31,760 --> 00:12:39,920 i believe this was 2.12 and this was this was 2.16 so we get a slightly improved result 143 00:12:39,920 --> 00:12:45,520 and the reason for that is because we're spending more cycles more time optimizing the neural net 144 00:12:45,520 --> 00:12:50,720 actually instead of just uh spending the first several thousand iterations probably just 145 00:12:50,960 --> 00:12:56,000 squashing down the weights because they are so way too high in the beginning in the initialization 146 00:12:56,720 --> 00:13:01,600 so something to look out for and uh that's number one now let's look at the second problem 147 00:13:01,600 --> 00:13:05,200 let me re-initialize our neural net and let me reintroduce the break statement 148 00:13:06,000 --> 00:13:10,240 so we have a reasonable initial loss so even though everything is looking good on the level 149 00:13:10,240 --> 00:13:14,800 of the loss and we get something that we expect there's still a deeper problem lurking inside 150 00:13:14,800 --> 00:13:20,720 this neural net and its initialization so the logits are now okay the problem now is 151 00:13:20,960 --> 00:13:27,440 the values of h the activations of the hidden states now if we just visualize this vector 152 00:13:27,440 --> 00:13:31,840 sorry this tensor h it's kind of hard to see but the problem here roughly speaking is you 153 00:13:31,840 --> 00:13:38,960 see how many of the elements are one or negative one now recall that torch.10h the 10-H function 154 00:13:38,960 --> 00:13:43,840 is a squashing function it takes arbitrary numbers and it squashes them into a range of negative 1 and 155 00:13:43,840 --> 00:13:49,040 1 and it does so smoothly so let's look at the histogram of h to get a better idea of the 156 00:13:49,040 --> 00:13:57,360 distribution of the values inside this tensor we can do this first well we can see that h is 32 157 00:13:57,360 --> 00:14:04,000 examples and 200 activations in each example we can view it as negative 1 to stretch it out into 158 00:14:04,000 --> 00:14:12,720 one large vector and we can then call to list to convert this into one large python list of floats 159 00:14:12,720 --> 00:14:20,320 and then we can pass this into plt.hist for histogram and we say we want 50 bins and a 160 00:14:20,320 --> 00:14:26,220 semicolon to suppress a bunch of output we don't want so we see this histogram and we see that most 161 00:14:26,220 --> 00:14:33,640 of the values by far take on value of negative 1 and 1 so this 10h is very very active and we can 162 00:14:33,640 --> 00:14:40,380 also look at basically why that is we can look at the pre-activations that feed into the 10h 163 00:14:40,380 --> 00:14:42,700 and we can also look at the pre-activations that feed into the 10h 164 00:14:42,700 --> 00:14:42,720 and we can also look at the pre-activations that feed into the 10h 165 00:14:42,720 --> 00:14:47,820 and we can see that the distribution of the pre-activations are is very very broad these take 166 00:14:47,820 --> 00:14:53,100 numbers between negative 15 and 15 and that's why in the torture 10h everything is being squashed 167 00:14:53,100 --> 00:14:57,560 and capped to be in the range of negative 1 and 1 and lots of numbers here take on very extreme 168 00:14:57,560 --> 00:15:02,940 values now if you are new to neural networks you might not actually see this as an issue 169 00:15:02,940 --> 00:15:08,100 but if you're well versed in the dark arts of back propagation and then have an intuitive sense of 170 00:15:08,100 --> 00:15:12,160 how these gradients flow through a neural net you are looking at your distribution of 10h 171 00:15:12,160 --> 00:15:17,440 activations here and you are sweating so let me show you why we have to keep in mind that during 172 00:15:17,440 --> 00:15:21,980 back propagation just like we saw in micrograd we are doing backward pass starting at the loss 173 00:15:21,980 --> 00:15:26,420 and flowing through the network backwards in particular we're going to back propagate through 174 00:15:26,420 --> 00:15:32,680 this torch.10h and this layer here is made up of 200 neurons for each one of these examples 175 00:15:32,680 --> 00:15:39,020 and it implements an elementwise 10h so let's look at what happens in 10h in the backward pass 176 00:15:39,020 --> 00:15:42,140 we can actually go back to our previous micrograd and we're going to look at the network backwards 177 00:15:42,160 --> 00:15:48,560 by the code in the very first lecture and see how we implement a 10h we saw that the input here was x 178 00:15:49,120 --> 00:15:54,880 and then we calculate t which is the 10h of x so that's t and t is between negative 1 and 1 it's 179 00:15:54,880 --> 00:15:59,120 the output of the 10h and then in the backward pass how do we back propagate through a 10h 180 00:16:00,000 --> 00:16:06,000 we take out that grad and then we multiply it this is the chain rule with the local gradient 181 00:16:06,000 --> 00:16:12,000 which took the form of 1 minus t squared so what happens if the outputs of your 10h are very close 182 00:16:12,160 --> 00:16:19,040 to negative 1 or 1 if you plug in t equals 1 here you're going to get 0 multiplying out that grad 183 00:16:19,600 --> 00:16:25,040 no matter what that grad is we are killing the gradient and we're stopping effectively the back 184 00:16:25,040 --> 00:16:30,480 propagation through this 10h unit similarly when t is negative 1 this will again become 0 185 00:16:30,480 --> 00:16:36,400 and out that grad just stops and intuitively this makes sense because this is a 10h neuron 186 00:16:37,520 --> 00:16:42,080 and what's happening is if its output is very close to one then we are in a tail 187 00:16:42,080 --> 00:16:51,360 of this tanh and so changing basically the input is not going to impact the output of the tanh too 188 00:16:51,360 --> 00:16:57,360 much because it's it's so it's in the flat region of the tanh and so therefore there's no impact on 189 00:16:57,360 --> 00:17:04,560 the loss and so indeed the the weights and the biases along with this tanh neuron do not impact 190 00:17:04,560 --> 00:17:08,880 the loss because the output of this tanh unit is in the flat region of the tanh and there's 191 00:17:08,880 --> 00:17:13,280 no influence we can we can be changing them whatever we want however we want and the loss 192 00:17:13,280 --> 00:17:18,160 is not impacted that's that's another way to justify that indeed the gradient would be 193 00:17:18,160 --> 00:17:28,320 basically zero it vanishes indeed when t equals zero we get one times out that grad so when the 194 00:17:28,320 --> 00:17:35,600 tanh takes on exactly value of zero then out that grad is just passed through so basically what this 195 00:17:35,600 --> 00:17:38,640 is doing right is if t is equal to zero then this 196 00:17:38,880 --> 00:17:45,760 the tanh unit is sort of inactive and gradient just passes through but the more you are in the 197 00:17:45,760 --> 00:17:52,000 flat tails the more degrading is squashed so in fact you'll see that the the gradient flowing 198 00:17:52,000 --> 00:17:59,840 through tanh can only ever decrease and the amount that it decreases is proportional through a square 199 00:17:59,840 --> 00:18:06,000 here depending on how far you are in the flat tails of this tanh and so that's kind of what's 200 00:18:09,120 --> 00:18:14,000 the concern here is that if all of these um outputs h are in the flat regions of negative 201 00:18:14,000 --> 00:18:19,040 one and one then the gradients that are flowing through the network will just get destroyed at 202 00:18:19,040 --> 00:18:25,520 this layer now there is some redeeming quality here and that we can actually get a sense of the 203 00:18:25,520 --> 00:18:31,040 problem here as follows i wrote some code here and basically what we want to do here is we want 204 00:18:31,040 --> 00:18:38,720 to take a look at h take the absolute value and see how often it is in the in a flat uh region so 205 00:18:39,300 --> 00:18:46,860 say greater than 0.99 and what you get is the following and this is a boolean tensor so uh 206 00:18:46,860 --> 00:18:53,180 in CSS- solid you get a white if this is true and black if this is false and so basically what we 207 00:18:53,180 --> 00:18:59,580 have here is the 32 examples and 200 hidden neurons and we see that a lot of this is white 208 00:19:00,180 --> 00:19:07,020 and what that's telling us is that all these tanh neurons were very very active and they're 209 00:19:07,020 --> 00:19:08,380 in the flat tail and we type a 0 to 210 00:19:08,880 --> 00:19:14,720 And so, in all these cases, the backward gradient would get destroyed. 211 00:19:16,240 --> 00:19:21,840 Now, we would be in a lot of trouble if, for any one of these 200 neurons, 212 00:19:22,460 --> 00:19:25,560 if it was the case that the entire column is white. 213 00:19:26,000 --> 00:19:28,260 Because in that case, we have what's called a dead neuron. 214 00:19:28,660 --> 00:19:31,780 And this could be a 10H neuron where the initialization of the weights and the biases 215 00:19:31,780 --> 00:19:39,380 could be such that no single example ever activates this 10H in the sort of active part of the 10H. 216 00:19:39,640 --> 00:19:44,500 If all the examples land in the tail, then this neuron will never learn. 217 00:19:44,780 --> 00:19:45,700 It is a dead neuron. 218 00:19:46,760 --> 00:19:51,360 And so, just scrutinizing this and looking for columns of completely white, 219 00:19:51,840 --> 00:19:53,620 we see that this is not the case. 220 00:19:54,080 --> 00:19:58,600 So, I don't see a single neuron that is all of, you know, white. 221 00:19:59,380 --> 00:20:01,760 And so, therefore, it is the case that for every one of these, 222 00:20:01,780 --> 00:20:08,720 these 10H neurons, we do have some examples that activate them in the active part of the 10H. 223 00:20:09,040 --> 00:20:11,720 And so, some gradients will flow through, and this neuron will learn. 224 00:20:12,280 --> 00:20:15,120 And the neuron will change, and it will move, and it will do something. 225 00:20:16,260 --> 00:20:19,800 But you can sometimes get yourself in cases where you have dead neurons. 226 00:20:20,300 --> 00:20:24,740 And the way this manifests is that for a 10H neuron, this would be when, 227 00:20:25,020 --> 00:20:27,060 no matter what inputs you plug in from your data set, 228 00:20:27,260 --> 00:20:31,100 this 10H neuron always fires completely one or completely negative one. 229 00:20:31,100 --> 00:20:31,760 And then it will. 230 00:20:31,780 --> 00:20:35,400 And then it will just not learn, because all the gradients will be just zeroed out. 231 00:20:36,480 --> 00:20:40,840 This is true not just for 10H, but for a lot of other nonlinearities that people use in neural networks. 232 00:20:41,160 --> 00:20:44,820 So, we certainly use 10H a lot, but sigmoid will have the exact same issue, 233 00:20:45,100 --> 00:20:46,920 because it is a squashing neuron. 234 00:20:47,580 --> 00:20:53,320 And so, the same will be true for sigmoid, but, you know, 235 00:20:54,540 --> 00:20:56,460 basically the same will actually apply to sigmoid. 236 00:20:57,000 --> 00:20:58,440 The same will also apply to ReLU. 237 00:20:58,920 --> 00:21:01,680 So, ReLU has a completely flat region here. 238 00:21:01,780 --> 00:21:08,320 So, if you have a ReLU neuron, then it is a pass-through, if it is positive. 239 00:21:08,680 --> 00:21:12,260 And if the pre-activation is negative, it will just shut it off. 240 00:21:12,660 --> 00:21:16,340 Since the region here is completely flat, then during backpropagation, 241 00:21:16,880 --> 00:21:19,660 this would be exactly zeroing out the gradient. 242 00:21:19,840 --> 00:21:23,900 Like, all of the gradient would be set exactly to zero, instead of just like a very, 243 00:21:23,900 --> 00:21:27,580 very small number, depending on how positive or negative T is. 244 00:21:28,540 --> 00:21:31,100 And so, you can get, for example, a dead ReLU neuron, 245 00:21:31,100 --> 00:21:34,420 and a dead ReLU neuron would basically look like... 246 00:21:35,320 --> 00:21:40,520 Basically, what it is, is if a neuron with a ReLU nonlinearity never activates, 247 00:21:41,160 --> 00:21:45,160 so, for any examples that you plug in in the dataset, it never turns on, 248 00:21:45,260 --> 00:21:49,220 it's always in this flat region, then this ReLU neuron is a dead neuron. 249 00:21:49,500 --> 00:21:52,040 Its weights and bias will never learn. 250 00:21:52,120 --> 00:21:54,680 They will never get a gradient, because the neuron never activated. 251 00:21:55,780 --> 00:21:58,680 And this can sometimes happen at initialization, because the weights 252 00:21:58,680 --> 00:22:00,660 and the biases just make it so that, by chance, 253 00:22:00,660 --> 00:22:01,060 some neuron will never activate. 254 00:22:01,100 --> 00:22:04,540 So, the neurons are just forever dead, but it can also happen during optimization. 255 00:22:04,920 --> 00:22:07,340 If you have, like, a too high of a learning rate, for example, 256 00:22:07,600 --> 00:22:10,100 sometimes you have these neurons that get too much of a gradient, 257 00:22:10,420 --> 00:22:12,540 and they get knocked out of the data manifold. 258 00:22:12,540 --> 00:22:17,660 And what happens is that, from then on, no example ever activates this neuron, 259 00:22:17,860 --> 00:22:19,340 so this neuron remains dead forever. 260 00:22:19,420 --> 00:22:23,200 So, it's kind of like a permanent brain damage in a mind of a network. 261 00:22:23,900 --> 00:22:27,220 And so, sometimes what can happen is, if your learning rate is very high, for example, 262 00:22:27,360 --> 00:22:30,480 and you have a neural net with ReLU neurons, you train the neural net, 263 00:22:30,480 --> 00:22:31,080 and you get some... 264 00:22:31,100 --> 00:22:32,120 last loss. 265 00:22:32,580 --> 00:22:35,860 But then, actually, what you do is, you go through the entire training set, 266 00:22:36,280 --> 00:22:41,600 and you forward your examples, and you can find neurons that never activate. 267 00:22:42,020 --> 00:22:43,560 They are dead neurons in your network. 268 00:22:44,080 --> 00:22:46,000 And so, those neurons will never turn on. 269 00:22:46,440 --> 00:22:48,040 And usually, what happens is that, during training, 270 00:22:48,400 --> 00:22:50,280 these ReLU neurons are changing, moving, etc. 271 00:22:50,440 --> 00:22:53,840 And then, because of a high gradient, somewhere, by chance, they get knocked off. 272 00:22:54,380 --> 00:22:56,180 And then, nothing ever activates them. 273 00:22:56,340 --> 00:22:57,740 And from then on, they are just dead. 274 00:22:58,740 --> 00:23:01,060 So, that's kind of like a permanent brain damage that can happen. 275 00:23:01,220 --> 00:23:04,180 So, that's kind of like a permanent brain damage that can happen to some of these neurons. 276 00:23:04,940 --> 00:23:08,520 These other nonlinearities, like Leaky ReLU, will not suffer from this issue as much, 277 00:23:08,780 --> 00:23:11,600 because you can see that it doesn't have flat tails. 278 00:23:11,860 --> 00:23:13,640 You'll almost always get gradients. 279 00:23:14,420 --> 00:23:17,220 And ReLU is also fairly frequently used. 280 00:23:18,000 --> 00:23:20,720 It also might suffer from this issue, because it has flat parts. 281 00:23:21,840 --> 00:23:25,420 So, that's just something to be aware of, and something to be concerned about. 282 00:23:25,680 --> 00:23:29,780 And in this case, we have way too many activations H, 283 00:23:29,780 --> 00:23:31,060 that take on extreme values. 284 00:23:31,260 --> 00:23:34,380 So, because there's no column of white, I think we will be okay. 285 00:23:34,380 --> 00:23:37,200 And indeed, the network optimizes, and gives us a pretty decent loss. 286 00:23:37,720 --> 00:23:40,280 But it's just not optimal, and this is not something you want, 287 00:23:40,540 --> 00:23:42,060 especially during initialization. 288 00:23:42,320 --> 00:23:44,620 And so, basically, what's happening is that 289 00:23:44,880 --> 00:23:47,700 this H pre-activation, that's flowing to 10H, 290 00:23:48,720 --> 00:23:50,000 it's too extreme. 291 00:23:50,260 --> 00:23:50,780 It's too large. 292 00:23:51,020 --> 00:23:56,920 It's creating a distribution that is too saturated in both sides of the 10H. 293 00:23:57,180 --> 00:24:00,500 And it's not something you want, because it means that there's less training 294 00:24:00,500 --> 00:24:01,020 because there's no column of white. 295 00:24:01,260 --> 00:24:05,420 And it's not something you want for these neurons, because they update less frequently. 296 00:24:05,660 --> 00:24:06,700 So, how do we fix this? 297 00:24:07,200 --> 00:24:12,320 Well, H pre-activation is MCAT, which comes from C. 298 00:24:12,580 --> 00:24:14,380 So, these are uniform Gaussian. 299 00:24:14,880 --> 00:24:16,940 But then it's multiplied by W1 plus B1. 300 00:24:17,440 --> 00:24:21,020 And H pre-act is too far off from 0, and that's causing the issue. 301 00:24:21,540 --> 00:24:26,140 So, we want this pre-activation to be closer to 0, very similar to what we had with logits. 302 00:24:27,180 --> 00:24:30,240 So, here, we want actually something very, very similar. 303 00:24:31,060 --> 00:24:34,900 Now, it's okay to set the biases to a very small number. 304 00:24:35,160 --> 00:24:38,220 We can either multiply by 001 to get a little bit of entropy. 305 00:24:39,500 --> 00:24:42,320 I sometimes like to do that, just so that 306 00:24:43,600 --> 00:24:46,680 there's a little bit of variation and diversity in the original 307 00:24:46,940 --> 00:24:49,240 initialization of these 10H neurons. 308 00:24:49,500 --> 00:24:52,560 And I find in practice that that can help optimization a little bit. 309 00:24:53,580 --> 00:24:56,140 And then the weights, we can also just squash. 310 00:24:56,400 --> 00:24:58,460 So, let's multiply everything by 0.1. 311 00:24:59,220 --> 00:25:00,500 Let's rerun the first batch. 312 00:25:01,580 --> 00:25:02,600 And now, let's look at this. 313 00:25:03,100 --> 00:25:05,160 And, well, first, let's look at here. 314 00:25:06,940 --> 00:25:09,260 You see now, because we multiplied W by 0.1, 315 00:25:09,500 --> 00:25:10,780 we have a much better histogram. 316 00:25:11,040 --> 00:25:14,620 And that's because the pre-activations are now between negative 1.5 and 1.5. 317 00:25:14,880 --> 00:25:16,680 And this, we expect much, much less 318 00:25:17,180 --> 00:25:17,700 white. 319 00:25:18,460 --> 00:25:20,260 Okay, there's no white. 320 00:25:20,780 --> 00:25:26,920 So, basically, that's because there are no neurons that saturated above 0.99 in either direction. 321 00:25:27,940 --> 00:25:29,980 So, it's actually a pretty decent place to be. 322 00:25:31,060 --> 00:25:34,380 Maybe we can go up a little bit. 323 00:25:36,700 --> 00:25:39,000 Sorry, am I changing W1 here? 324 00:25:39,260 --> 00:25:40,540 So, maybe we can go to 0.2. 325 00:25:42,060 --> 00:25:45,900 Okay, so maybe something like this is a nice distribution. 326 00:25:46,420 --> 00:25:48,720 So, maybe this is what our initialization should be. 327 00:25:49,240 --> 00:25:52,300 So, let me now erase these. 328 00:25:53,340 --> 00:25:56,140 And let me, starting with initialization, 329 00:25:56,400 --> 00:25:59,980 let me run the full optimization without the break. 330 00:26:00,500 --> 00:26:01,020 And 331 00:26:01,220 --> 00:26:02,820 let's see what we get. 332 00:26:03,060 --> 00:26:04,600 Okay, so the optimization finished. 333 00:26:04,860 --> 00:26:05,880 And I re-run the loss. 334 00:26:06,140 --> 00:26:07,420 And this is the result that we get. 335 00:26:07,940 --> 00:26:11,520 And then, just as a reminder, I put down all the losses that we saw previously in this lecture. 336 00:26:12,540 --> 00:26:14,840 So, we see that we actually do get an improvement here. 337 00:26:15,100 --> 00:26:16,120 And just as a reminder, 338 00:26:16,380 --> 00:26:19,460 we started off with a validation loss of 2.17 when we started. 339 00:26:19,960 --> 00:26:22,020 By fixing the softmax being confidently wrong, 340 00:26:22,520 --> 00:26:23,800 we came down to 2.13. 341 00:26:24,060 --> 00:26:26,360 And by fixing the 10-inch layer being way too saturated, 342 00:26:26,620 --> 00:26:28,160 we came down to 2.10. 343 00:26:28,920 --> 00:26:30,460 And the reason this is happening, of course, is because 344 00:26:30,460 --> 00:26:30,980 I don't initialize. 345 00:26:31,220 --> 00:26:31,780 Initialization is better. 346 00:26:32,040 --> 00:26:34,340 And so we're spending more time doing productive training 347 00:26:34,600 --> 00:26:35,120 instead of 348 00:26:36,660 --> 00:26:39,720 not very productive training because our gradients are set to zero. 349 00:26:39,980 --> 00:26:42,800 And we have to learn very simple things like 350 00:26:43,060 --> 00:26:45,100 the overconfidence of the softmax in the beginning. 351 00:26:45,360 --> 00:26:48,180 And we're spending cycles just like squashing down the weight matrix. 352 00:26:48,940 --> 00:26:49,460 So, 353 00:26:49,960 --> 00:26:50,980 this is illustrating 354 00:26:51,240 --> 00:26:54,060 basically initialization and its impacts 355 00:26:54,320 --> 00:26:55,340 on performance 356 00:26:55,600 --> 00:26:58,420 just by being aware of the internals of these neural nets 357 00:26:58,660 --> 00:26:59,700 and their activations and their gradients. 358 00:26:59,700 --> 00:27:00,980 Now, 359 00:27:01,240 --> 00:27:02,520 we're working with a very small network. 360 00:27:02,780 --> 00:27:05,080 This is just one-layer multi-layer perception. 361 00:27:05,580 --> 00:27:07,640 So because the network is so shallow, 362 00:27:07,900 --> 00:27:09,680 the optimization problem is actually quite easy 363 00:27:09,940 --> 00:27:10,960 and very forgiving. 364 00:27:11,480 --> 00:27:13,260 So even though our initialization was terrible, 365 00:27:13,520 --> 00:27:14,540 the network still learned 366 00:27:14,800 --> 00:27:16,860 eventually. It just got a bit worse result. 367 00:27:17,360 --> 00:27:19,160 This is not the case in general, though. 368 00:27:19,420 --> 00:27:20,440 Once we actually start 369 00:27:21,200 --> 00:27:22,480 working with much deeper networks 370 00:27:22,740 --> 00:27:24,020 that have, say, 50 layers, 371 00:27:24,540 --> 00:27:26,580 things can get much more complicated 372 00:27:27,100 --> 00:27:29,400 and these problems stack up 373 00:27:29,700 --> 00:27:30,220 over the years. 374 00:27:30,720 --> 00:27:34,820 And so you can actually get into a place where the network is basically not training at all 375 00:27:35,080 --> 00:27:36,860 if your initialization is bad enough. 376 00:27:37,900 --> 00:27:41,220 And the deeper your network is and the more complex it is, the less forgiving it is 377 00:27:41,480 --> 00:27:42,500 to some of these errors. 378 00:27:43,020 --> 00:27:44,300 And so 379 00:27:44,800 --> 00:27:45,820 it's something to definitely be aware of 380 00:27:46,340 --> 00:27:47,100 and 381 00:27:47,360 --> 00:27:49,160 something to scrutinize, something to plot 382 00:27:49,420 --> 00:27:50,440 and something to be careful with. 383 00:27:53,500 --> 00:27:55,300 Okay, so that's great that that worked for us. 384 00:27:55,560 --> 00:27:58,880 But what we have here now is all these magic numbers like .2. 385 00:27:59,140 --> 00:27:59,660 Like where do I come from? 386 00:27:59,700 --> 00:28:03,780 up with this and how am i supposed to set these if i have a large neural net with lots and lots 387 00:28:03,780 --> 00:28:09,140 of layers and so obviously no one does this by hand there's actually some relatively principled 388 00:28:09,140 --> 00:28:15,220 ways of setting these scales that i would like to introduce to you now so let me paste some code 389 00:28:15,220 --> 00:28:20,580 here that i prepared just to motivate the discussion of this so what i'm doing here is 390 00:28:20,580 --> 00:28:27,220 we have some random input here x that is drawn from a gaussian and there's 1000 examples that 391 00:28:27,220 --> 00:28:32,900 are 10 dimensional and then we have a weighting layer here that is also initialized using gaussian 392 00:28:32,900 --> 00:28:39,620 just like we did here and we these neurons in the hidden layer look at 10 inputs and there are 200 393 00:28:39,620 --> 00:28:46,260 neurons in this hidden layer and then we have here just like here in this case the multiplication x 394 00:28:46,260 --> 00:28:53,060 multiplied by w to get the preactivations of these neurons and basically the analysis here looks at 395 00:28:53,060 --> 00:28:57,060 okay suppose these are unit from gaussian and these weights are unit from gaussian 396 00:28:57,060 --> 00:28:57,220 and they kind of begin to self-deitaire and at the same time they are infected from this train 397 00:28:57,220 --> 00:29:04,500 I do x times w and we forget for now the bias and the non-linearity then what is the mean and the 398 00:29:04,500 --> 00:29:10,180 standard deviation of these gaussians so in the beginning here the input is just a normal gaussian 399 00:29:10,180 --> 00:29:14,900 distribution mean zero and the standard deviation is one and the standard deviation again is just 400 00:29:14,900 --> 00:29:20,420 the measure of a spread of the gaussian but then once we multiply here and we look at the 401 00:29:20,420 --> 00:29:27,380 histogram of y we see that the mean of course stays the same it's about zero because this is 402 00:29:27,380 --> 00:29:31,780 a symmetric operation but we see here that the standard deviation has expanded to three 403 00:29:32,420 --> 00:29:36,900 so the input standard deviation was one but now we've grown to three and so what you're seeing 404 00:29:36,900 --> 00:29:43,700 in the histogram is that this gaussian is expanding and so we're expanding this gaussian 405 00:29:44,580 --> 00:29:48,500 from the input and we don't want that we want most of the neural nets to have 406 00:29:48,500 --> 00:29:49,940 relatively similar activations 407 00:29:50,420 --> 00:29:55,380 so unit gaussian roughly throughout the neural net and so the question is how do we scale these 408 00:29:55,380 --> 00:30:04,180 w's to preserve the um to preserve this distribution to remain a gaussian and so 409 00:30:04,180 --> 00:30:10,180 intuitively if i multiply here these elements of w by a large number let's say by five 410 00:30:11,860 --> 00:30:15,780 then this gaussian grows and grows in standard deviation 411 00:30:15,780 --> 00:30:20,020 so now we're at 15. so basically these numbers here in the output y 412 00:30:20,020 --> 00:30:20,340 take 413 00:30:20,420 --> 00:30:27,300 more and more extreme values but if we scale it down like say 0.2 then conversely this gaussian 414 00:30:27,300 --> 00:30:33,140 is getting smaller and smaller and it's shrinking and you can see that the standard deviation is 0.6 415 00:30:33,780 --> 00:30:39,300 and so the question is what do i multiply by here to exactly preserve the standard deviation 416 00:30:39,300 --> 00:30:43,700 to be one and it turns out that the correct answer mathematically when you work out through 417 00:30:43,700 --> 00:30:50,420 the variance of uh this multiplication here is that you are supposed to divide by the square root 418 00:30:50,420 --> 00:30:56,980 of the fan in the fan in is the basically the uh number of input elements here 10. 419 00:30:57,620 --> 00:31:02,100 so we are supposed to divide by 10 square root and this is one way to do the square root you 420 00:31:02,100 --> 00:31:07,860 raise it to a power of 0.5 and that's the same as doing a square root so when you divide by the 421 00:31:08,740 --> 00:31:16,260 square root of 10 then we see that the output gaussian it has exactly standard deviation of 422 00:31:16,260 --> 00:31:20,260 1. now unsurprisingly a number of papers have looked into how this works so the number of papers 423 00:31:20,420 --> 00:31:25,680 but to best initialize neural networks and in the case of multi-layer perceptrons we can have 424 00:31:25,680 --> 00:31:30,400 fairly deep networks that have these non-linearities in between and we want to make sure that the 425 00:31:30,400 --> 00:31:34,720 activations are well behaved and they don't expand to infinity or shrink all the way to zero 426 00:31:34,720 --> 00:31:38,520 and the question is how do we initialize the weights so that these activations take on 427 00:31:38,520 --> 00:31:43,280 reasonable values throughout the network. Now one paper that has studied this in quite a bit of 428 00:31:43,280 --> 00:31:48,260 detail that is often referenced is this paper by Kaiming He et al called Delving Deep Interactive 429 00:31:48,260 --> 00:31:53,000 Fires. Now in this case they actually study convolutional neural networks and they study 430 00:31:53,000 --> 00:31:59,320 especially the ReLU non-linearity and the P-ReLU non-linearity instead of a 10H non-linearity 431 00:31:59,320 --> 00:32:07,400 but the analysis is very similar and basically what happens here is for them the ReLU non-linearity 432 00:32:07,400 --> 00:32:13,080 that they care about quite a bit here is a squashing function where all the negative numbers 433 00:32:13,080 --> 00:32:18,240 are simply clamped to zero. So the positive numbers are a pass-through but everything 434 00:32:18,240 --> 00:32:23,960 negative is just set to zero and because you are basically throwing away half of the distribution 435 00:32:23,960 --> 00:32:28,780 they find in their analysis of the forward activations in the neural net that you have 436 00:32:28,780 --> 00:32:36,700 to compensate for that with a gain and so here they find that basically when they initialize 437 00:32:36,700 --> 00:32:40,860 their weights they have to do it with a zero mean Gaussian whose standard deviation is square root 438 00:32:40,860 --> 00:32:47,380 of two over the fanon. What we have here is we are initializing the Gaussian with the square root 439 00:32:47,380 --> 00:32:48,120 of fanon. 440 00:32:48,240 --> 00:32:55,920 This NL here is the fanon. So what we have is square root of one over the fanon because we have 441 00:32:55,920 --> 00:33:01,900 a division here. Now they have to add this factor of two because of the ReLU which basically 442 00:33:01,900 --> 00:33:06,600 discards half of the distribution and clamps it at zero and so that's where you get an initial 443 00:33:06,600 --> 00:33:12,880 factor. Now in addition to that this paper also studies not just the sort of behavior of the 444 00:33:12,880 --> 00:33:17,140 activations in the forward pass of the neural net but it also studies the back propagation 445 00:33:17,140 --> 00:33:18,220 and we have to make sure that we have a good distribution of the NeuralNet and we have to 446 00:33:18,240 --> 00:33:23,920 make sure that the gradients also are well behaved and so because ultimately they end up 447 00:33:23,920 --> 00:33:28,900 updating our parameters and what they find here through a lot of the analysis that I invite you 448 00:33:28,900 --> 00:33:34,720 to read through but it's not exactly approachable what they find is basically if you properly 449 00:33:34,720 --> 00:33:40,360 initialize the forward pass the backward pass is also approximately initialized up to a constant 450 00:33:40,360 --> 00:33:47,580 factor that has to do with the size of the number of hidden neurons in an early and a late layer. 451 00:33:48,240 --> 00:33:54,400 But basically they find empirically that this is not a choice that matters too much. Now this 452 00:33:54,400 --> 00:34:00,480 kymene initialization is also implemented in pytorch so if you go to torch.nn.init documentation 453 00:34:00,480 --> 00:34:05,440 you'll find kymene normal and in my opinion this is probably the most common way of initializing 454 00:34:05,440 --> 00:34:11,040 neural networks now and it takes a few keyword arguments here. So number one it wants to know 455 00:34:11,760 --> 00:34:15,600 the mode. Would you like to normalize the activations or would you like to normalize 456 00:34:15,600 --> 00:34:17,840 the gradients to to be always the same or would you like to normalize the gradients to to be always 457 00:34:18,400 --> 00:34:23,840 gaussian with zero mean and a unit or one standard deviation and because they find in the paper that 458 00:34:23,840 --> 00:34:27,440 this doesn't matter too much most of the people just leave it as the default which is pan in 459 00:34:28,160 --> 00:34:32,240 and then second pass in the non-linearity that you are using because depending on the 460 00:34:32,240 --> 00:34:37,360 non-linearity we need to calculate a slightly different gain and so if your non-linearity is 461 00:34:37,360 --> 00:34:43,280 just linear so there's no non-linearity then the gain here will be one and we have the exact same 462 00:34:43,840 --> 00:34:48,080 kind of formula that we've got here but if the non-linearity is something else we're going to 463 00:34:48,240 --> 00:34:53,520 slightly different gain and so if we come up here to the top we see that for example in the case of 464 00:34:53,520 --> 00:34:58,800 ReLU this gain is a square root of 2 and the reason it's a square root because in this paper 465 00:35:02,960 --> 00:35:09,920 you see how the 2 is inside of the square root so the gain is a square root of 2. In a case of 466 00:35:09,920 --> 00:35:16,000 linear or identity we just get a gain of 1. In a case of 10h which is what we're using here 467 00:35:16,000 --> 00:35:17,920 the advised gain is a 5 over 3. 468 00:35:18,720 --> 00:35:24,640 And intuitively why do we need a gain on top of the initialization? It's because 10h just like ReLU 469 00:35:24,640 --> 00:35:29,200 is a contractive transformation so what that means is you're taking the output 470 00:35:29,200 --> 00:35:33,680 distribution from this matrix multiplication and then you are squashing it in some way now 471 00:35:33,680 --> 00:35:37,600 ReLU squashes it by taking everything below zero and clamping it to zero. 472 00:35:37,600 --> 00:35:41,920 10h also squashes it because it's a contractive operation it will take the tails and it will 473 00:35:42,960 --> 00:35:48,220 squeeze them in and so in order to fight the squeezing in we need to boost the weight of the 474 00:35:48,240 --> 00:35:54,020 a little bit so that we renormalize everything back to unit standard deviation. So that's why 475 00:35:54,020 --> 00:35:58,200 there's a little bit of a gain that comes out. Now, I'm skipping through this section a little 476 00:35:58,200 --> 00:36:01,840 bit quickly, and I'm doing that actually intentionally. And the reason for that is 477 00:36:01,840 --> 00:36:07,660 because about seven years ago when this paper was written, you had to actually be extremely careful 478 00:36:07,660 --> 00:36:12,260 with the activations and the gradients and their ranges and their histograms, and you had to be 479 00:36:12,260 --> 00:36:16,360 very careful with the precise setting of gains and the scrutinizing of the nonlinearities used and so 480 00:36:16,360 --> 00:36:21,120 on. And everything was very finicky and very fragile, and it had to be very properly arranged 481 00:36:21,120 --> 00:36:25,340 for the neural net to train, especially if your neural net was very deep. But there are a number 482 00:36:25,340 --> 00:36:29,260 of modern innovations that have made everything significantly more stable and more well-behaved, 483 00:36:29,360 --> 00:36:34,540 and it's become less important to initialize these networks exactly right. And some of those 484 00:36:34,540 --> 00:36:38,420 modern innovations, for example, are residual connections, which we will cover in the future, 485 00:36:39,080 --> 00:36:44,640 the use of a number of normalization layers, like for example, batch normalization, 486 00:36:44,640 --> 00:36:46,340 layer normalization, group normalization, and so on. 487 00:36:46,360 --> 00:36:52,860 And number three, much better optimizers, not just stochastic gradient descent, the simple 488 00:36:52,860 --> 00:36:58,620 optimizer we're basically using here, but slightly more complex optimizers like RMSPROP and especially 489 00:36:58,620 --> 00:37:03,980 ADAM. And so all of these modern innovations make it less important for you to precisely calibrate 490 00:37:03,980 --> 00:37:08,920 the initialization of the neural net. All that being said, in practice, what should we do? 491 00:37:09,420 --> 00:37:14,340 In practice, when I initialize these neural nets, I basically just normalize my weights by the square 492 00:37:14,340 --> 00:37:14,640 root of the fan-in. So basically, I'm going to normalize my weights by the square root of the 493 00:37:14,640 --> 00:37:16,340 fan-in. And so I'm going to normalize my weights by the square root of the fan-in. So basically, 494 00:37:16,360 --> 00:37:23,880 roughly what we did here is what I do. Now, if we want to be exactly accurate here, and go by 495 00:37:24,920 --> 00:37:31,000 init of kind of normal, this is how we would implement it. We want to set the standard deviation 496 00:37:31,000 --> 00:37:38,840 to be gain over the square root of fan-in, right? So to set the standard deviation of our weights, 497 00:37:38,840 --> 00:37:44,600 we will proceed as follows. Basically, when we have torch dot random, and let's say I just create 498 00:37:44,600 --> 00:37:46,340 a thousand numbers, we can look at the standard deviation of our weights, we can look at the 499 00:37:46,360 --> 00:37:50,200 standard deviation of this and, of course, that's one, that's the amount of spread. Let's make 500 00:37:50,200 --> 00:37:55,920 this a bit bigger, so it's closer to one. So that's the spread of the Gaussian of zero mean 501 00:37:55,920 --> 00:38:00,280 and unit standard deviation. Now, basically, when you take these, and you multiply by, 502 00:38:00,280 --> 00:38:05,780 say, point two, that basically scales down the Gaussian, and that makes its standard deviation 503 00:38:05,780 --> 00:38:09,780 point two. So basically, the number that you multiply by here ends up being the standard 504 00:38:09,780 --> 00:38:13,800 deviation of this Gaussian. So here, this is a standard deviation point two Gaussian. So this is 505 00:38:13,800 --> 00:38:14,780 a standard deviation point two Gaussian. And then you can right click the selection tool, and you 506 00:38:14,780 --> 00:38:15,680 can do something with the name of the operation. So in the aposemetics and the operation, that is going to 507 00:38:15,680 --> 00:38:15,840 give you the desired values, if you see an unnecessary deviation from one to the other. And itack will also 508 00:38:15,840 --> 00:38:23,040 Gaussian here when we sample rw1. But we want to set the standard deviation to gain over square 509 00:38:23,040 --> 00:38:30,660 root of fan mode, which is fanon. So in other words, we want to multiply by gain, which for 10h 510 00:38:30,660 --> 00:38:49,180 is 5 over 3. 5 over 3 is the gain. And then times, or I guess, sorry, divide square root of 511 00:38:49,180 --> 00:38:55,260 the fanon. And in this example here, the fanon was 10. And I just noticed that actually here, 512 00:38:55,480 --> 00:39:00,420 the fanon for w1 is actually n embed times block size, which as you will recall, 513 00:39:00,520 --> 00:39:00,640 is actually 10. And so we want to set the standard deviation to gain over square root of 514 00:39:00,640 --> 00:39:04,740 30. And that's because each character is 10 dimensional, but then we have three of them, 515 00:39:04,780 --> 00:39:09,120 and we concatenate them. So actually, the fanon here was 30. And I should have used 30 here, 516 00:39:09,220 --> 00:39:15,500 probably. But basically, we want 30 square root. So this is the number, this is what our standard 517 00:39:15,500 --> 00:39:21,020 deviation we want to be. And this number turns out to be 0.3. Whereas here, just by fiddling 518 00:39:21,020 --> 00:39:25,140 with it and looking at the distribution and making sure it looks okay, we came up with 0.2. 519 00:39:25,660 --> 00:39:30,120 And so instead, what we want to do here is we want to make the standard deviation be, 520 00:39:30,120 --> 00:39:42,420 5 over 3, which is our gain, divide this amount times 0.2 square root. And these brackets here 521 00:39:42,420 --> 00:39:47,280 are not that necessary, but I'll just put them here for clarity. This is basically what we want. 522 00:39:47,480 --> 00:39:53,100 This is the kymene in it, in our case, for a 10h non-linearity. And this is how we would 523 00:39:53,100 --> 00:39:59,660 initialize the neural net. And so we're multiplying by 0.3 instead of multiplying by 0.2. 524 00:40:00,120 --> 00:40:07,380 And so we can initialize this way. And then we can train the neural net and see what we get. 525 00:40:08,040 --> 00:40:12,760 Okay, so I trained the neural net, and we end up in roughly the same spot. So looking at the 526 00:40:12,760 --> 00:40:18,300 validation loss, we now get 2.10. And previously, we also had 2.10. And there's a little bit of a 527 00:40:18,300 --> 00:40:22,340 difference, but that's just the randomness of the process, I suspect. But the big deal, of course, 528 00:40:22,340 --> 00:40:29,840 is we get to the same spot. But we did not have to introduce any magic numbers that we got from just 529 00:40:30,120 --> 00:40:34,660 that histograms and guess and checking. We have something that is semi-principled and will scale 530 00:40:34,660 --> 00:40:40,740 us to much bigger networks and something that we can sort of use as a guide. So I mentioned that 531 00:40:40,740 --> 00:40:45,120 the precise setting of these initializations is not as important today due to some modern 532 00:40:45,120 --> 00:40:48,780 innovations. And I think now is a pretty good time to introduce one of those modern innovations, 533 00:40:49,340 --> 00:40:54,900 and that is batch normalization. So batch normalization came out in 2015 from a team at 534 00:40:54,900 --> 00:40:59,880 Google. And it was an extremely impactful paper because it made it possible to train 535 00:41:00,120 --> 00:41:05,800 very deep neural nets quite reliably. And it basically just worked. So here's what batch 536 00:41:05,800 --> 00:41:12,120 normalization does, and let's implement it. Basically, we have these hidden states, 537 00:41:12,120 --> 00:41:19,860 H preact, right? And we were talking about how we don't want these pre-activation states to be way 538 00:41:19,860 --> 00:41:25,200 too small, because then the 10H is not doing anything. But we don't want them to be too large, 539 00:41:25,200 --> 00:41:29,860 because then the 10H is saturated. In fact, we want them to be roughly, roughly gaussian. 540 00:41:30,120 --> 00:41:35,120 So zero mean and a unit or one standard deviation, at least at initialization. 541 00:41:35,120 --> 00:41:40,500 So the insight from the batch normalization paper is, okay, you have these hidden states, 542 00:41:40,500 --> 00:41:46,380 and you'd like them to be roughly gaussian, then why not take the hidden states and just 543 00:41:46,380 --> 00:41:51,360 normalize them to be gaussian? And it sounds kind of crazy, but you can just do that because 544 00:41:51,360 --> 00:41:58,320 standardizing hidden states so that they're unit gaussian is a perfectly differentiable operation, 545 00:41:58,320 --> 00:42:00,100 as we'll soon see. And so that was kind of the first step. And then the second step was to 546 00:42:00,120 --> 00:42:04,360 kind of like the big insight in this paper. And when I first read it, my mind was blown, 547 00:42:04,360 --> 00:42:08,120 because you can just normalize these hidden states. And if you'd like unit gaussian states 548 00:42:08,120 --> 00:42:13,480 in your network, at least initialization, you can just normalize them to be unit gaussian. 549 00:42:14,120 --> 00:42:19,080 So let's see how that works. So we're going to scroll to our pre activations here just before 550 00:42:19,080 --> 00:42:23,880 they enter into the 10H. Now, the idea again, is remember, we're trying to make these roughly 551 00:42:23,880 --> 00:42:29,640 gaussian. And that's because if these are way too small numbers, then the 10H here is kind of inactive. 552 00:42:30,120 --> 00:42:35,800 But if these are very large numbers, then the 10H is way too saturated and gradient in the flow. 553 00:42:36,320 --> 00:42:41,400 So we'd like this to be roughly gaussian. So the insight in batch normalization, again, 554 00:42:41,500 --> 00:42:48,320 is that we can just standardize these activations so they are exactly gaussian. So here, H preact 555 00:42:48,320 --> 00:42:55,000 has a shape of 32 by 200, 32 examples by 200 neurons in the hidden layer. 556 00:42:55,660 --> 00:42:59,900 So basically, what we can do is we can take H preact, and we can just calculate the mean, 557 00:43:01,080 --> 00:43:07,240 and the mean we want to calculate across the zero dimension. And we want to also keep them as true, 558 00:43:07,880 --> 00:43:15,400 so that we can easily broadcast this. So the shape of this is one by 200. In other words, 559 00:43:15,400 --> 00:43:22,520 we are doing the mean over all the elements in the batch. And similarly, we can calculate 560 00:43:22,520 --> 00:43:30,040 the standard deviation of these activations. And that will also be one by 200. Now in this paper, 561 00:43:30,120 --> 00:43:36,620 they have the sort of prescription here. And see, here we are calculating the mean, 562 00:43:36,880 --> 00:43:44,440 which is just taking the average value of any neuron's activation. And then their standard 563 00:43:44,440 --> 00:43:50,700 deviation is basically kind of like the measure of the spread that we've been using, which is 564 00:43:50,700 --> 00:43:57,440 the distance of every one of these values away from the mean, and that squared and averaged. 565 00:43:58,080 --> 00:43:59,040 So that's the... 566 00:44:00,120 --> 00:44:03,660 That's the variance. And then if you want to take the standard deviation you would 567 00:44:03,660 --> 00:44:09,100 square root the variance to get the standard deviation. So these are the two that we're 568 00:44:09,100 --> 00:44:14,300 calculating. And now we're going to normalize or standardize these x's by subtracting the mean, 569 00:44:14,300 --> 00:44:21,540 and dividing by the standard deviation. So basically, we're taking H preact, and we subtract 570 00:44:23,140 --> 00:44:23,640 the mean, 571 00:44:23,640 --> 00:44:36,640 and then we divide by the standard deviation this is exactly what these two std and mean 572 00:44:36,640 --> 00:44:44,380 are calculating oops sorry this is the mean and this is the variance you see how the sigma is 573 00:44:44,380 --> 00:44:48,280 the standard deviation usually so this is sigma square which is variance is the square of the 574 00:44:48,280 --> 00:44:54,560 standard deviation so this is how you standardize these values and what this will do is that every 575 00:44:54,560 --> 00:45:00,540 single neuron now and its firing rate will be exactly unit gaussian on these 32 examples at 576 00:45:00,540 --> 00:45:05,320 least of this batch that's why it's called batch normalization we are normalizing these batches 577 00:45:05,320 --> 00:45:11,460 and then we could in principle train this notice that calculating the mean and your standard 578 00:45:11,460 --> 00:45:15,660 deviation these are just mathematical formulas they're perfectly differentiable all this is 579 00:45:15,660 --> 00:45:17,760 perfectly differentiable and we can just train this 580 00:45:17,760 --> 00:45:24,420 the problem is you actually won't achieve a very good result with this and the reason for that is 581 00:45:24,420 --> 00:45:31,260 we want these to be roughly gaussian but only at initialization but we don't want these to be to 582 00:45:31,260 --> 00:45:37,640 be forced to be gaussian always we we'd like to allow the neural net to move this around to 583 00:45:37,640 --> 00:45:42,460 potentially make it more diffuse to make it more sharp to make some 10h neurons maybe be more 584 00:45:42,460 --> 00:45:47,280 trigger more trigger happy or less trigger happy so we'd like this distribution to move around 585 00:45:47,280 --> 00:45:47,740 and we'd like to allow the neural net to move around and we'd like to allow the distribution 586 00:45:47,740 --> 00:45:51,420 to move around and we'd like the back propagation to tell us how the distribution should move around 587 00:45:52,300 --> 00:45:58,220 and so in addition to this idea of standardizing the activations at any point in the network 588 00:45:59,180 --> 00:46:04,540 we have to also introduce this additional component in the paper here described as scale 589 00:46:04,540 --> 00:46:09,580 and shift and so basically what we're doing is we're taking these normalized inputs and we are 590 00:46:09,580 --> 00:46:15,420 additionally scaling them by some gain and offsetting them by some bias to get our final 591 00:46:15,420 --> 00:46:16,780 output from this layer 592 00:46:17,740 --> 00:46:22,860 and so what that amounts to is the following we are going to allow a batch normalization gain 593 00:46:23,740 --> 00:46:30,460 to be initialized at just a once and the once will be in the shape of one by n hidden 594 00:46:32,300 --> 00:46:36,620 and then we also will have a bn bias which will be torched at zeros 595 00:46:37,580 --> 00:46:45,900 and it will also be of the shape n by one by n hidden and then here the bn gain will multiply this 596 00:46:45,900 --> 00:46:51,900 and the bn bias will offset it here so because this is initialized to one and this to zero 597 00:46:51,900 --> 00:46:58,380 initialization each neuron's firing values in this batch will be exactly unit gaussian 598 00:46:58,860 --> 00:47:02,460 and we'll have nice numbers no matter what the distribution of the hp act is coming in 599 00:47:03,580 --> 00:47:07,180 coming out it will be unit gaussian for each neuron and that's roughly what we want at least 600 00:47:07,740 --> 00:47:13,820 at initialization and then during optimization we'll be able to break down gash data to a bunch 601 00:47:13,820 --> 00:47:15,740 of neurons or groups of neurons and we could 602 00:47:15,740 --> 00:47:21,740 backpropagate into bngain and bmbias and change them so the network is given the full ability to 603 00:47:21,740 --> 00:47:27,740 do with this whatever it wants internally. Here we just have to make sure that we 604 00:47:29,260 --> 00:47:33,260 include these in the parameters of the neural net because they will be trained with 605 00:47:33,260 --> 00:47:39,180 backpropagation. So let's initialize this and then we should be able to train. 606 00:47:39,180 --> 00:47:51,180 And then we're going to also copy this line which is the batch normalization layer 607 00:47:51,740 --> 00:47:55,660 here on a single line of code and we're going to swing down here and we're also going to 608 00:47:56,380 --> 00:47:58,060 do the exact same thing at test time here. 609 00:48:01,660 --> 00:48:04,140 So similar to train time we're going to normalize 610 00:48:04,860 --> 00:48:08,940 and then scale and that's going to give us our train and validation loss. 611 00:48:09,740 --> 00:48:12,620 And we'll see in a second that we're actually going to change this a little bit but for now 612 00:48:12,620 --> 00:48:17,660 I'm going to keep it this way. So I'm just going to wait for this to converge. Okay so I allowed 613 00:48:17,660 --> 00:48:21,900 the neural nets to converge here and when we scroll down we see that our validation loss here 614 00:48:21,900 --> 00:48:27,740 is 2.10 roughly which I wrote down here and we see that this is actually kind of comparable to some 615 00:48:27,740 --> 00:48:33,340 of the results that we've achieved previously. Now I'm not actually expecting an improvement in this 616 00:48:33,340 --> 00:48:37,740 case and that's because we are dealing with a very simple neural net that has just a single hidden 617 00:48:37,740 --> 00:48:38,780 layer. So we're going to go ahead and do that. 618 00:48:39,340 --> 00:48:43,500 So in fact in this very simple case of just one hidden layer we were able to 619 00:48:43,500 --> 00:48:48,300 actually calculate what the scale of w should be to make these pre-activations 620 00:48:48,300 --> 00:48:52,460 already have a roughly Gaussian shape. So the batch normalization is not doing much here 621 00:48:53,100 --> 00:48:57,020 but you might imagine that once you have a much deeper neural net that has lots of different 622 00:48:57,020 --> 00:49:02,220 types of operations and there's also for example residual connections which we'll cover and so on 623 00:49:02,780 --> 00:49:09,020 it will become basically very very difficult to tune the scales of your weight matrices such 624 00:49:09,020 --> 00:49:13,900 that all the activations throughout the neural net are roughly Gaussian and so that's going to 625 00:49:13,900 --> 00:49:19,340 become very quickly intractable but compared to that it's going to be much much easier to sprinkle 626 00:49:19,340 --> 00:49:25,100 batch normalization layers throughout the neural net so in particular it's common to look at every 627 00:49:25,100 --> 00:49:29,420 single linear layer like this one this is a linear layer multiplying by a weight matrix and adding a 628 00:49:29,420 --> 00:49:36,140 bias or for example convolutions which we'll cover later and also perform basically a multiplication 629 00:49:36,140 --> 00:49:41,420 with a weight matrix but in a more spatially structured format it's custom it's customary 630 00:49:41,420 --> 00:49:46,860 to take this linear layer or convolutional layer and append a batch normalization layer right after 631 00:49:46,860 --> 00:49:52,380 it to control the scale of these activations at every point in the neural net so we'd be adding 632 00:49:52,380 --> 00:49:56,380 these batch norm layers throughout the neural net and then this controls the scale of these 633 00:49:56,380 --> 00:50:02,060 activations throughout the neural net it doesn't require us to do perfect mathematics and care 634 00:50:02,060 --> 00:50:05,900 about the activation distributions for all these different types of neural network 635 00:50:06,140 --> 00:50:08,620 in general but it does require us to do perfect mathematics in order to be able to do this 636 00:50:08,620 --> 00:50:13,420 so what we're doing here is we're taking a bunch of basic lego building blocks that you might want 637 00:50:13,420 --> 00:50:17,660 to introduce into your neural net and it significantly stabilizes the training and 638 00:50:17,660 --> 00:50:21,820 that's why these layers are quite popular now the stability offered by batch normalization 639 00:50:21,820 --> 00:50:26,380 actually comes at a terrible cost and that cost is that if you think about what's happening here 640 00:50:27,020 --> 00:50:33,420 something something terribly strange and unnatural is happening it used to be that we have a single 641 00:50:33,420 --> 00:50:35,020 example feeding into a neural net and then we calculate its activations and its logits and this 642 00:50:35,020 --> 00:50:35,760 is a deterministic 643 00:50:36,140 --> 00:50:41,720 sort of process so you arrive at some logits for this example and then because of efficiency of 644 00:50:41,720 --> 00:50:46,120 training we suddenly started to use batches of examples but those batches of examples were 645 00:50:46,120 --> 00:50:50,980 processed independently and it was just an efficiency thing but now suddenly in batch 646 00:50:50,980 --> 00:50:55,140 normalization because of the normalization through the batch we are coupling these examples 647 00:50:55,140 --> 00:51:01,100 mathematically and in the forward pass and the backward pass of the neural net so now the hidden 648 00:51:01,100 --> 00:51:06,540 state activations h preact and your logits for any one input example are not just a function of 649 00:51:06,540 --> 00:51:11,340 that example and its input but they're also a function of all the other examples that happen 650 00:51:11,340 --> 00:51:16,960 to come for a ride in that batch and these examples are sampled randomly and so what's 651 00:51:16,960 --> 00:51:20,980 happening is for example when you look at h preact that's going to feed into h the hidden 652 00:51:20,980 --> 00:51:26,000 state activations for for example for for any one of these input examples is going to actually 653 00:51:26,000 --> 00:51:31,080 change slightly depending on what other examples there are in the batch and and 654 00:51:31,080 --> 00:51:36,540 depending on what other examples happen to come for a ride h is going to change suddenly and it's 655 00:51:36,540 --> 00:51:41,040 going to like jitter if you imagine sampling different examples because the statistics of 656 00:51:41,040 --> 00:51:45,860 the mean understanding deviation are going to be impacted and so you'll get a jitter for h and 657 00:51:45,860 --> 00:51:51,920 you'll get a jitter for logits and you think that this would be a bug or something undesirable 658 00:51:51,920 --> 00:51:58,320 but in a very strange way this actually turns out to be good in neural network training and 659 00:51:58,320 --> 00:52:00,900 as a side effect and the reason for that is that 660 00:52:01,080 --> 00:52:05,560 you can think of this as kind of like a regularizer because what's happening is you have your input 661 00:52:05,560 --> 00:52:10,600 and you get your h and then depending on the other examples this is jittering a bit and so what that 662 00:52:10,600 --> 00:52:15,160 does is that it's effectively padding out any one of these input examples and it's introducing a 663 00:52:15,160 --> 00:52:20,600 little bit of entropy and because of the padding out it's actually kind of like a form of data 664 00:52:20,600 --> 00:52:25,720 augmentation which we'll cover in the future and it's kind of like augmenting the input a little 665 00:52:25,720 --> 00:52:30,700 bit and it's jittering it and that makes it harder for the neural nets to overfit these concrete 666 00:52:31,080 --> 00:52:36,180 examples so by introducing all this noise it actually like pads out the examples and it 667 00:52:36,180 --> 00:52:41,700 regularizes the neural net and that's one of the reasons why deceivingly as a second order effect 668 00:52:41,700 --> 00:52:47,160 this is actually a regularizer and that has made it harder for us to remove the use of batch 669 00:52:47,160 --> 00:52:52,900 normalization because basically no one likes this property that the the examples in the batch are 670 00:52:52,900 --> 00:52:58,260 coupled mathematically and in the forward pass and at least all kinds of like strange results 671 00:52:58,260 --> 00:53:00,300 we'll go into some of that in a second as well 672 00:53:01,080 --> 00:53:06,180 um and it leads to a lot of bugs and um and so on and so no one likes this property 673 00:53:06,940 --> 00:53:11,700 and so people have tried to deprecate the use of batch normalization and move to other 674 00:53:11,700 --> 00:53:15,900 normalization techniques that do not couple the examples of a batch examples are layer 675 00:53:15,900 --> 00:53:20,620 normalization instance normalization group normalization and so on and we'll commerce 676 00:53:20,620 --> 00:53:26,260 we'll come or some of these uh later um but basically long story short batch normalization 677 00:53:26,260 --> 00:53:30,220 was the first kind of normalization layer to be introduced it worked extremely well 678 00:53:31,080 --> 00:53:37,140 it happened to have this regularizing effect it stabilized training and people have been trying 679 00:53:37,140 --> 00:53:42,300 to remove it and move to some of the other normalization techniques but it's been hard 680 00:53:42,300 --> 00:53:46,860 because it just works quite well and some of the reason that it works quite well is again because 681 00:53:46,860 --> 00:53:52,100 of this regularizing effect and because of the because it is quite effective at controlling the 682 00:53:52,100 --> 00:53:56,940 activations and their distributions so that's kind of like the brief story of batch normalization 683 00:53:57,540 --> 00:54:00,140 and i'd like to show you one of the other weird 684 00:54:01,080 --> 00:54:05,940 outcomes of this coupling so here's one of the strange outcomes that i only glossed over 685 00:54:05,940 --> 00:54:12,100 previously when i was evaluating the loss on the validation set basically once we've trained a 686 00:54:12,100 --> 00:54:17,280 neural net we'd like to deploy it in some kind of a setting and we'd like to be able to feed in a 687 00:54:17,280 --> 00:54:22,300 single individual example and get a prediction out from our neural net but how do we do that 688 00:54:22,300 --> 00:54:26,960 when our neural net now in a forward pass estimates the statistics of the mean understanding deviation 689 00:54:26,960 --> 00:54:31,060 of a batch the neural net expects batches as an input now so how do we feed in a batch 690 00:54:31,080 --> 00:54:36,760 in a single example and get sensible results out and so the proposal in the batch normalization 691 00:54:36,760 --> 00:54:42,360 paper is the following what we would like to do here is we would like to basically have a step 692 00:54:42,920 --> 00:54:50,440 after training that calculates and sets the bathroom mean and standard deviation a single time 693 00:54:50,440 --> 00:54:54,600 over the training set and so i wrote this code here in interest of time 694 00:54:55,160 --> 00:54:59,960 and we're going to call what's called calibrate the bathroom statistics and basically what we do 695 00:54:59,960 --> 00:55:00,280 is 696 00:55:01,080 --> 00:55:07,320 telling pytorch that none of this we will call a dot backward on and it's going to be a bit 697 00:55:07,320 --> 00:55:12,200 more efficient we're going to take the training set get the pre-activations for every single 698 00:55:12,200 --> 00:55:16,680 training example and then one single time estimate the mean and standard deviation over the entire 699 00:55:16,680 --> 00:55:21,400 training set and then we're going to get b and mean and b and standard deviation and now these 700 00:55:21,400 --> 00:55:27,000 are fixed numbers estimating over the entire training set and here instead of estimating it 701 00:55:27,720 --> 00:55:30,840 dynamically we are going to instead 702 00:55:31,080 --> 00:55:36,680 here use b and mean and here we're just going to use b and standard deviation 703 00:55:38,120 --> 00:55:43,560 so at test time we are going to fix these clamp them and use them during inference and now 704 00:55:45,480 --> 00:55:51,080 you see that we get basically identical result but the benefit that we've gained is that we 705 00:55:51,080 --> 00:55:55,720 can now also forward a single example because the mean and standard deviation are now fixed 706 00:55:55,720 --> 00:56:00,760 sort of tensors that said nobody actually wants to estimate this mean and standard deviation 707 00:56:01,080 --> 00:56:06,920 as a second stage after neural network training because everyone is lazy and so this batch 708 00:56:06,920 --> 00:56:11,560 normalization paper actually introduced one more idea which is that we can we can estimate the mean 709 00:56:11,560 --> 00:56:17,320 and standard deviation in a running matter running manner during training of the neural net and then 710 00:56:17,320 --> 00:56:22,360 we can simply just have a single stage of training and on the side of that training we are estimating 711 00:56:22,360 --> 00:56:27,720 the running mean and standard deviation so let's see what that would look like let me basically 712 00:56:27,720 --> 00:56:32,280 take the mean here that we are estimating on the batch and let me call this b and mean on the i 713 00:56:32,280 --> 00:56:50,680 iteration um and then here this is b and std um b and std i okay uh and the mean comes here and the 714 00:56:50,680 --> 00:56:56,280 std comes here so so far i've done nothing i've just moved around and i created these extra 715 00:56:56,280 --> 00:56:57,640 variables for the mean and standard deviation 716 00:56:57,720 --> 00:57:02,840 and i've put them here so so far nothing has changed but what we're going to do now is we're 717 00:57:02,840 --> 00:57:07,480 going to keep a running mean of both of these values during training so let me swing up here 718 00:57:07,480 --> 00:57:17,400 and let me create a bn mean underscore running and i'm going to initialize it at zeros and then bn std 719 00:57:17,400 --> 00:57:27,160 running which i'll initialize at once because in the beginning because of the way we initialized w1 720 00:57:27,720 --> 00:57:32,520 uh and b1 each preact will be roughly unit gaussian so the mean will be roughly zero and 721 00:57:32,520 --> 00:57:37,720 the standard deviation roughly one so i'm going to initialize these that way but then here i'm 722 00:57:37,720 --> 00:57:44,680 going to update these and in pytorch um these uh mean and standard deviation that are running 723 00:57:45,320 --> 00:57:48,440 they're not actually part of the gradient based optimization we're never going to derive 724 00:57:48,440 --> 00:57:52,520 gradients with respect to them they're they're updated on the side of training 725 00:57:53,480 --> 00:57:57,320 and so what we're going to do here is we're going to say with torch.nograd 726 00:57:57,960 --> 00:58:03,000 telling pytorch that the update here is not supposed to be building out a graph because 727 00:58:03,000 --> 00:58:09,000 there will be no dot backward but this running mean is basically going to be 0.99 728 00:58:10,120 --> 00:58:20,520 9 times the current value plus 0.001 times the this value this new mean and 729 00:58:21,160 --> 00:58:25,640 in the same way bn std running will be mostly what it used to be 730 00:58:28,520 --> 00:58:32,920 but it will receive a small update in the direction of what the current standard deviation is 731 00:58:34,920 --> 00:58:39,080 and as you're seeing here this update is outside and on the side of 732 00:58:39,080 --> 00:58:44,360 the gradient based optimization and it's simply being updated not using gradient descent it's just 733 00:58:44,360 --> 00:58:51,960 being updated using a janky like smooth sort of running mean manner 734 00:58:53,080 --> 00:58:57,640 and so while the network is training and these pre-activations are sort of changing your 735 00:58:57,720 --> 00:59:02,720 shifting around during back propagation, we are keeping track of the typical mean and standard 736 00:59:02,720 --> 00:59:10,420 deviation, and we're estimating them once. And when I run this, now I'm keeping track of this 737 00:59:10,420 --> 00:59:15,100 in a running manner. And what we're hoping for, of course, is that the bnmean underscore running 738 00:59:15,100 --> 00:59:20,840 and bnmean underscore std are going to be very similar to the ones that we've calculated here 739 00:59:20,840 --> 00:59:26,500 before. And that way, we don't need a second stage, because we've sort of combined the two stages, 740 00:59:26,500 --> 00:59:29,700 and we've put them on the side of each other, if you want to look at it that way. 741 00:59:30,680 --> 00:59:35,660 And this is how this is also implemented in the batch normalization layer in PyTorch. So during 742 00:59:35,660 --> 00:59:40,500 training, the exact same thing will happen. And then later, when you're using inference, 743 00:59:41,040 --> 00:59:46,380 it will use the estimated running mean of both the mean and standard deviation of those hidden 744 00:59:46,380 --> 00:59:51,660 states. So let's wait for the optimization to converge. And hopefully, the running mean and 745 00:59:51,660 --> 00:59:56,280 standard deviation are roughly equal to these two. And then we can simply use it here. And we don't 746 00:59:56,280 --> 00:59:56,480 need to do that. So let's wait for the optimization to converge. And hopefully, the running mean and 747 00:59:56,480 --> 00:59:56,540 standard deviation are roughly equal to these two. And then hopefully, the running mean and 748 00:59:56,540 --> 00:59:56,660 standard deviation are roughly equal to these two. And hopefully, the running mean and 749 00:59:56,660 --> 01:00:02,180 this stage of explicit calibration at the end. Okay, so the optimization finished. I'll rerun the 750 01:00:02,180 --> 01:00:09,340 explicit estimation. And then the bnmean from the explicit estimation is here. And bnmean from the 751 01:00:09,340 --> 01:00:16,980 running estimation during the optimization, you can see is very, very similar. It's not identical, 752 01:00:16,980 --> 01:00:25,180 but it's pretty close. And in the same way, bnstd is this. And bnstd running is this. 753 01:00:25,180 --> 01:00:30,880 As you can see that once again, they are fairly similar values, not identical, but pretty close. 754 01:00:31,720 --> 01:00:36,680 And so then here, instead of bnmean, we can use the bnmean running. Instead of bnstd, 755 01:00:36,680 --> 01:00:43,080 we can use bnstd running. And hopefully, the validation loss will not be impacted too much. 756 01:00:44,320 --> 01:00:50,200 Okay, so basically identical. And this way, we've eliminated the need for this explicit 757 01:00:50,200 --> 01:00:54,960 stage of calibration, because we are doing it inline over here. Okay, so we're almost done with 758 01:00:54,960 --> 01:00:58,960 batch normalization. There are only two more notes that I'd like to make. Number one, I've 759 01:00:58,960 --> 01:01:04,160 skipped a discussion over what is this plus epsilon doing here. This epsilon is usually like some small 760 01:01:04,160 --> 01:01:08,240 fixed number, for example, one e negative five by default. And what it's doing is that it's 761 01:01:08,240 --> 01:01:13,520 basically preventing a division by zero, in the case that the variance over your batch 762 01:01:14,320 --> 01:01:19,600 is exactly zero. In that case, here, we normally have a division by zero. But because of the plus 763 01:01:19,600 --> 01:01:24,240 epsilon, this is going to become a small number in the denominator instead, and things will be more 764 01:01:24,240 --> 01:01:24,940 well behaved. 765 01:01:24,960 --> 01:01:29,520 So feel free to also add a plus epsilon here of a very small number, it doesn't actually 766 01:01:29,520 --> 01:01:33,760 substantially change the result, I'm going to skip it in our case, just because this is unlikely to 767 01:01:33,760 --> 01:01:38,400 happen in our very simple example here. And the second thing I want you to notice is that we're 768 01:01:38,400 --> 01:01:43,520 being wasteful here. And it's very subtle. But right here, where we are adding the bias 769 01:01:43,520 --> 01:01:49,440 into H preact, these biases now are actually useless, because we're adding them to the H 770 01:01:49,440 --> 01:01:54,240 preact. But then we are calculating the mean for every one of these neurons, 771 01:01:54,240 --> 01:01:59,760 and subtracting it. So whatever bias you add here is going to get subtracted right here. 772 01:02:00,640 --> 01:02:04,480 And so these biases are not doing anything. In fact, they're being subtracted out, 773 01:02:04,480 --> 01:02:08,560 and they don't impact the rest of the calculation. So if you look at b1.grad, 774 01:02:08,560 --> 01:02:12,480 it's actually going to be zero, because it's being subtracted out and doesn't actually have any 775 01:02:12,480 --> 01:02:17,120 effect. And so whenever you're using batch normalization layers, then if you have any weight 776 01:02:17,120 --> 01:02:22,080 layers before, like a linear or a comb or something like that, you're better off coming here 777 01:02:22,080 --> 01:02:23,600 and just like not using bias. 778 01:02:24,240 --> 01:02:27,680 So you don't want to use bias. And then here, you don't want to 779 01:02:28,400 --> 01:02:33,600 add it because that's spurious. Instead, we have this batch normalization bias here. 780 01:02:33,600 --> 01:02:38,800 And that batch normalization bias is now in charge of the biasing of this distribution, 781 01:02:38,800 --> 01:02:44,720 instead of this b1 that we had here originally. And so basically, the batch normalization layer 782 01:02:44,720 --> 01:02:49,920 has its own bias. And there's no need to have a bias in the layer before it, because that bias 783 01:02:49,920 --> 01:02:54,160 is going to be subtracted out anyway. So that's the other small detail to be careful with sometimes. 784 01:02:54,240 --> 01:02:59,360 It's not going to do anything catastrophic. This b1 will just be useless. It will never get any 785 01:02:59,360 --> 01:03:03,760 gradient. It will not learn. It will stay constant. And it's just wasteful. But it doesn't actually 786 01:03:04,400 --> 01:03:09,680 really impact anything otherwise. Okay, so I rearranged the code a little bit with comments. 787 01:03:09,680 --> 01:03:12,800 And I just wanted to give a very quick summary of the batch normalization layer. 788 01:03:13,520 --> 01:03:18,800 We are using batch normalization to control the statistics of activations in the neural net. 789 01:03:19,520 --> 01:03:24,160 It is common to sprinkle batch normalization layer across the neural net. And usually, we will play 790 01:03:24,240 --> 01:03:30,160 it after layers that have multiplications, like for example, a linear layer or a convolutional 791 01:03:30,160 --> 01:03:37,520 layer, which we may cover in the future. Now, the batch normalization internally has parameters 792 01:03:37,520 --> 01:03:43,440 for the gain and the bias. And these are trained using backpropagation. It also has two buffers. 793 01:03:44,240 --> 01:03:48,720 The buffers are the mean and the standard deviation, the running mean and the running 794 01:03:48,720 --> 01:03:53,520 mean of the standard deviation. And these are not trained using backpropagation. These are trained 795 01:03:53,520 --> 01:04:02,800 using this janky update of kind of like a running mean update. So these are sort of the parameters 796 01:04:02,800 --> 01:04:07,760 and the buffers of batch normalization layer. And then really what it's doing is it's calculating the 797 01:04:07,760 --> 01:04:12,000 mean and standard deviation of the activations that are feeding into the batch normalization layer 798 01:04:12,880 --> 01:04:17,760 over that batch. Then it's centering that batch to be unit Gaussian. 799 01:04:18,400 --> 01:04:22,720 And then it's offsetting and scaling it by the learned bias and gain. 800 01:04:24,080 --> 01:04:28,000 And then on top of that, it's keeping track of the mean and standard deviation of the inputs. 801 01:04:28,880 --> 01:04:33,600 And it's maintaining this running mean and standard deviation. And this will later be 802 01:04:33,600 --> 01:04:37,520 used at inference so that we don't have to re-estimate the mean and standard deviation 803 01:04:37,520 --> 01:04:42,560 all the time. And in addition, that allows us to basically forward individual examples 804 01:04:42,560 --> 01:04:47,120 at test time. So that's the batch normalization layer. It's a fairly complicated layer, 805 01:04:48,400 --> 01:04:52,560 but this is what it's doing internally. Now, I wanted to show you a little bit of a real example. 806 01:04:53,680 --> 01:04:59,680 You can search ResNet, which is a residual neural network. And these are contacts of neural networks 807 01:04:59,680 --> 01:05:05,840 used for image classification. And of course, we haven't come to ResNets in detail. So I'm not going 808 01:05:05,840 --> 01:05:11,520 to explain all the pieces of it. But for now, just note that the image feeds into a ResNet on the top 809 01:05:11,520 --> 01:05:16,480 here. And there's many, many layers with repeating structure all the way to predictions of what's 810 01:05:16,480 --> 01:05:21,760 inside that image. This repeating structure is made up of these blocks. And these blocks are just 811 01:05:21,760 --> 01:05:23,360 sequentially stacked up in this 812 01:05:23,720 --> 01:05:30,480 deep neural network. Now, the code for this, the block basically that's used and repeated 813 01:05:30,480 --> 01:05:38,800 sequentially in series, is called this bottleneck block. And there's a lot here. This is all PyTorch. 814 01:05:38,800 --> 01:05:42,240 And of course, we haven't covered all of it. But I want to point out some small pieces of it. 815 01:05:43,120 --> 01:05:47,600 Here in the init is where we initialized the neural net. So this code of block here is basically 816 01:05:47,600 --> 01:05:52,160 the kind of stuff we're doing here. We're initializing all the layers. And in the forward, 817 01:05:53,520 --> 01:05:58,960 act once you actually have the input so this code here is along the lines of what we're doing here 818 01:06:01,520 --> 01:06:07,440 and now these blocks are replicated and stacked up serially and that's what a residual network 819 01:06:07,440 --> 01:06:14,000 would be and so notice what's happening here conv1 these are convolution layers 820 01:06:14,800 --> 01:06:20,320 and these convolution layers basically they're the same thing as a linear layer except convolution 821 01:06:20,320 --> 01:06:26,400 layers don't apply convolutional layers are used for images and so they have spatial structure 822 01:06:26,400 --> 01:06:33,040 and basically this linear multiplication and bias offset are done on patches instead of a map 823 01:06:33,040 --> 01:06:38,560 instead of the full input so because these images have structure spatial structure convolutions just 824 01:06:38,560 --> 01:06:44,320 basically do wx plus b but they do it on overlapping patches of the input but otherwise 825 01:06:44,320 --> 01:06:50,080 it's wx plus b then we have the norm layer which by default here is initialized to be a batch 826 01:06:50,080 --> 01:06:50,240 normal 827 01:06:50,320 --> 01:06:56,720 in 2d so two-dimensional bash normalization layer and then we have a non-linearity like relu so 828 01:06:56,720 --> 01:07:04,400 instead of uh here they use relu we are using tanh in this case but both both are just non-linearities 829 01:07:04,400 --> 01:07:08,400 and you can just use them relatively interchangeably for very deep networks 830 01:07:08,400 --> 01:07:14,320 relu's typically empirically work a bit better so see the motif that's being repeated here we have 831 01:07:14,320 --> 01:07:20,080 convolution batch normalization rather convolution batch normalization etc and then here this is 832 01:07:20,080 --> 01:07:20,240 residual 833 01:07:20,320 --> 01:07:24,560 connection that we haven't covered yet but basically that's the exact same pattern we have 834 01:07:24,560 --> 01:07:32,720 here we have a weight layer like a convolution or like a linear layer batch normalization and then 835 01:07:33,280 --> 01:07:39,520 tanh which is non-linearity but basically a weight layer a normalization layer and non-linearity and 836 01:07:39,520 --> 01:07:43,920 that's the motif that you would be stacking up when you create these deep neural networks exactly 837 01:07:43,920 --> 01:07:48,880 as it's done here and one more thing i'd like you to notice is that here when they are initializing 838 01:07:48,880 --> 01:07:50,160 the conf layers when they are initializing the conf layers when they are initializing the conf layers 839 01:07:50,160 --> 01:07:57,040 like conv one by one the depth for that is right here and so it's initializing an nn.conf2d which is 840 01:07:57,040 --> 01:08:01,200 a convolution layer in pytorch and there's a bunch of keyword arguments here that i'm not going to 841 01:08:01,200 --> 01:08:06,400 explain yet but you see how there's bias equals false the bias equals false is exactly for the 842 01:08:06,400 --> 01:08:12,560 same reason as bias is not used in our case you see how i erase the use of bias and the use of 843 01:08:12,560 --> 01:08:17,120 bias is spurious because after this weight layer there's a batch normalization and the batch normalization subtracts that bias and then has its own bias 844 01:08:17,120 --> 01:08:19,120 and the batch normalization subtracts that bias and then has its own bias and then has its own bias 845 01:08:20,160 --> 01:08:22,240 so there's no need to introduce these spurious 846 01:08:22,240 --> 01:08:28,000 parameters it wouldn't hurt performance it's just useless and so because they have this motif of 847 01:08:28,000 --> 01:08:33,600 conf they don't need a bias here because there's a bias inside here so 848 01:08:34,640 --> 01:08:38,000 by the way this example here is very easy to find just do resnet pytorch 849 01:08:39,280 --> 01:08:44,720 and uh it's this example here so this is kind of like the stock implementation of a residual neural 850 01:08:44,720 --> 01:08:49,920 network in pytorch and you can find that here but of course i haven't covered many of these parts yet 851 01:08:50,640 --> 01:08:55,360 and i would also like to briefly descend into the definitions of these pytorch layers and the 852 01:08:55,360 --> 01:08:59,200 parameters that they take now instead of a convolutional layer we're going to look at 853 01:08:59,200 --> 01:09:04,560 a linear layer because that's the one that we're using here this is a linear layer and i haven't 854 01:09:04,560 --> 01:09:09,120 covered convolutions yet but as i mentioned convolutions are basically linear layers except 855 01:09:09,120 --> 01:09:16,720 on patches so a linear layer performs a wx plus b except here they're calling the wa transpose 856 01:09:16,720 --> 01:09:21,920 so the calc is wx plus p very much like we did here to initialize this layer you need to know 857 01:09:21,920 --> 01:09:30,160 the fan in the fan out and that's so that they can initialize this w this is the fan in and the fan 858 01:09:30,160 --> 01:09:35,840 out so they know how how big the weight matrix should be you need to also pass in whether you 859 01:09:35,840 --> 01:09:42,320 whether or not you want a bias and if you set it to false then no bias will be inside this layer 860 01:09:43,200 --> 01:09:46,160 and you may want to do that exactly like in our case for instance 861 01:09:46,720 --> 01:09:50,340 if your layer is followed by a normalization layer such as batch norm. 862 01:09:51,480 --> 01:09:53,620 So this allows you to basically disable a bias. 863 01:09:54,600 --> 01:09:56,600 Now, in terms of the initialization, if we swing down here, 864 01:09:57,060 --> 01:10:00,400 this is reporting the variables used inside this linear layer. 865 01:10:00,920 --> 01:10:05,240 And our linear layer here has two parameters, the weight and the bias. 866 01:10:05,720 --> 01:10:07,620 In the same way, they have a weight and a bias. 867 01:10:08,400 --> 01:10:11,020 And they're talking about how they initialize it by default. 868 01:10:11,720 --> 01:10:15,420 So by default, PyTorch will initialize your weights by taking the fan in 869 01:10:15,420 --> 01:10:19,700 and then doing 1 over fan in square root. 870 01:10:20,560 --> 01:10:24,880 And then instead of a normal distribution, they are using a uniform distribution. 871 01:10:25,540 --> 01:10:30,380 So it's very much the same thing, but they are using a 1 instead of 5 over 3. 872 01:10:30,500 --> 01:10:32,300 So there's no gain being calculated here. 873 01:10:32,300 --> 01:10:33,260 The gain is just 1. 874 01:10:33,620 --> 01:10:38,800 But otherwise, it's exactly 1 over the square root of fan in, exactly as we have here. 875 01:10:40,260 --> 01:10:44,440 So 1 over the square root of k is the scale of the weights. 876 01:10:45,000 --> 01:10:45,400 But... 877 01:10:45,420 --> 01:10:48,500 But when they are drawing the numbers, they're not using a Gaussian by default. 878 01:10:48,760 --> 01:10:51,020 They're using a uniform distribution by default. 879 01:10:51,500 --> 01:10:55,300 And so they draw uniformly from negative square root of k to square root of k. 880 01:10:55,880 --> 01:11:02,500 But it's the exact same thing and the same motivation with respect to what we've seen in this lecture. 881 01:11:03,040 --> 01:11:06,180 And the reason they're doing this is if you have a roughly Gaussian input, 882 01:11:06,600 --> 01:11:11,240 this will ensure that out of this layer, you will have a roughly Gaussian output. 883 01:11:11,560 --> 01:11:15,240 And you basically achieve that by scaling the weights. 884 01:11:15,420 --> 01:11:18,740 So that's what this is doing. 885 01:11:19,880 --> 01:11:22,820 And then the second thing is the batch normalization layer. 886 01:11:23,200 --> 01:11:25,120 So let's look at what that looks like in PyTorch. 887 01:11:25,920 --> 01:11:30,040 So here we have a one-dimensional batch normalization layer, exactly as we are using here. 888 01:11:30,640 --> 01:11:32,920 And there are a number of keyword arguments going into it as well. 889 01:11:33,340 --> 01:11:34,960 So we need to know the number of features. 890 01:11:35,500 --> 01:11:36,780 For us, that is 200. 891 01:11:37,240 --> 01:11:40,300 And that is needed so that we can initialize these parameters here. 892 01:11:40,820 --> 01:11:45,400 The gain, the bias, and the buffers for the running mean and standard deviation. 893 01:11:45,940 --> 01:11:49,240 Then they need to know the value of epsilon here. 894 01:11:49,920 --> 01:11:51,620 And by default, this is 1, negative 5. 895 01:11:51,720 --> 01:11:53,120 You don't typically change this too much. 896 01:11:53,960 --> 01:11:55,180 Then they need to know the momentum. 897 01:11:55,920 --> 01:12:02,080 And the momentum here, as they explain, is basically used for these running mean and running standard deviation. 898 01:12:02,800 --> 01:12:04,620 So by default, the momentum here is 0.1. 899 01:12:05,080 --> 01:12:08,240 The momentum we are using here in this example is 0.001. 900 01:12:09,740 --> 01:12:13,220 And basically, you may want to change this sometimes. 901 01:12:13,680 --> 01:12:14,620 And roughly speaking, 902 01:12:14,620 --> 01:12:16,560 if you have a very large batch size, 903 01:12:17,080 --> 01:12:20,560 then typically what you'll see is that when you estimate the mean and standard deviation, 904 01:12:21,420 --> 01:12:23,520 for every single batch size, if it's large enough, 905 01:12:23,680 --> 01:12:25,360 you're going to get roughly the same result. 906 01:12:26,160 --> 01:12:30,100 And so therefore, you can use slightly higher momentum, like 0.1. 907 01:12:30,860 --> 01:12:33,780 But for a batch size as small as 32, 908 01:12:34,440 --> 01:12:37,720 the mean and standard deviation here might take on slightly different numbers 909 01:12:37,720 --> 01:12:41,360 because there's only 32 examples we are using to estimate the mean and standard deviation. 910 01:12:41,840 --> 01:12:43,620 So the value is changing around a lot. 911 01:12:43,860 --> 01:12:44,600 And if you have a very large batch size, 912 01:12:44,600 --> 01:12:45,840 if your momentum is 0.1, 913 01:12:46,200 --> 01:12:48,880 that might not be good enough for this value to settle 914 01:12:48,880 --> 01:12:54,300 and converge to the actual mean and standard deviation over the entire training set. 915 01:12:55,220 --> 01:12:56,980 And so basically, if your batch size is very small, 916 01:12:57,440 --> 01:12:59,460 momentum of 0.1 is potentially dangerous, 917 01:12:59,760 --> 01:13:02,840 and it might make it so that the running mean and standard deviation 918 01:13:02,840 --> 01:13:04,960 is thrashing too much during training, 919 01:13:05,140 --> 01:13:07,160 and it's not actually converging properly. 920 01:13:09,260 --> 01:13:12,860 affine equals true determines whether this batch normalization layer 921 01:13:12,860 --> 01:13:14,360 has these learnable affine parameters, 922 01:13:14,600 --> 01:13:17,940 the gain and the bias. 923 01:13:18,500 --> 01:13:20,640 And this is almost always kept to true. 924 01:13:20,760 --> 01:13:23,900 I'm not actually sure why you would want to change this to false. 925 01:13:26,540 --> 01:13:29,280 Then track running stats is determining whether or not 926 01:13:29,400 --> 01:13:31,580 batch normalization layer of PyTorch will be doing this. 927 01:13:32,840 --> 01:13:37,220 And one reason you may want to skip the running stats 928 01:13:37,660 --> 01:13:39,200 is because you may want to, for example, 929 01:13:39,200 --> 01:13:43,000 estimate them at the end as a stage two like this. 930 01:13:43,360 --> 01:13:44,320 And in that case, you don't want the batch normalization layer to be like this. 931 01:13:44,320 --> 01:13:44,420 And in that case, you don't want the batch normalization layer to be like this. 932 01:13:44,420 --> 01:13:44,560 And in that case, you don't want the batch normalization layer to be like this. 933 01:13:44,600 --> 01:13:45,160 And in that case, you don't want the batch normalization layer 934 01:13:45,160 --> 01:13:47,340 to be doing all this extra compute that you're not going to use. 935 01:13:48,720 --> 01:13:51,840 And finally, we need to know which device we're going to run 936 01:13:51,840 --> 01:13:54,280 this batch normalization on, a CPU or a GPU, 937 01:13:54,740 --> 01:13:56,320 and what the data type should be, 938 01:13:56,600 --> 01:13:59,360 half precision, single precision, double precision, and so on. 939 01:14:00,800 --> 01:14:02,320 So that's the batch normalization layer. 940 01:14:02,600 --> 01:14:03,760 Otherwise, they link to the paper. 941 01:14:03,940 --> 01:14:05,480 It's the same formula we've implemented, 942 01:14:05,920 --> 01:14:09,360 and everything is the same exactly as we've done here. 943 01:14:10,620 --> 01:14:13,040 Okay, so that's everything that I wanted to cover for this lecture. 944 01:14:13,620 --> 01:14:14,420 Really, what I wanted to talk about is the batch normalization layer. 945 01:14:14,420 --> 01:14:16,440 What I wanted to talk about is the importance of understanding 946 01:14:16,440 --> 01:14:20,120 the activations and the gradients and their statistics in neural networks. 947 01:14:20,540 --> 01:14:21,940 And this becomes increasingly important, 948 01:14:22,080 --> 01:14:24,680 especially as you make your neural networks bigger, larger, and deeper. 949 01:14:25,560 --> 01:14:28,000 We looked at the distributions basically at the output layer, 950 01:14:28,300 --> 01:14:31,500 and we saw that if you have two confident mispredictions 951 01:14:31,500 --> 01:14:34,480 because the activations are too messed up at the last layer, 952 01:14:34,860 --> 01:14:36,820 you can end up with these hockey stick losses. 953 01:14:37,520 --> 01:14:40,220 And if you fix this, you get a better loss at the end of training 954 01:14:40,220 --> 01:14:43,040 because your training is not doing wasteful work. 955 01:14:43,660 --> 01:14:44,400 Then we also saw that if you have two confident mispredictions, 956 01:14:44,400 --> 01:14:45,920 we saw that we need to control the activations. 957 01:14:46,060 --> 01:14:50,160 We don't want them to squash to zero or explode to infinity 958 01:14:50,160 --> 01:14:52,800 because that you can run into a lot of trouble 959 01:14:52,800 --> 01:14:55,360 with all of these nonlinearities in these neural nets. 960 01:14:55,980 --> 01:14:57,960 And basically, you want everything to be fairly homogeneous 961 01:14:57,960 --> 01:14:58,860 throughout the neural net. 962 01:14:58,960 --> 01:15:01,320 You want roughly Gaussian activations throughout the neural net. 963 01:15:02,480 --> 01:15:06,260 Then we talked about, okay, if we want roughly Gaussian activations, 964 01:15:06,500 --> 01:15:09,200 how do we scale these weight matrices and biases 965 01:15:09,200 --> 01:15:10,920 during initialization of the neural net 966 01:15:10,920 --> 01:15:13,300 so that we don't get, you know, 967 01:15:13,300 --> 01:15:14,300 so everything is S-controllable? 968 01:15:14,400 --> 01:15:15,040 S-controllable is possible. 969 01:15:16,940 --> 01:15:19,260 So that gave us a large boost in improvement. 970 01:15:19,860 --> 01:15:25,040 And then I talked about how that strategy is not actually possible 971 01:15:25,040 --> 01:15:26,760 for much, much deeper neural nets 972 01:15:26,760 --> 01:15:30,160 because when you have much deeper neural nets 973 01:15:30,160 --> 01:15:31,900 with lots of different types of layers, 974 01:15:32,360 --> 01:15:35,700 it becomes really, really hard to precisely set the weights 975 01:15:35,700 --> 01:15:37,040 and the biases in such a way 976 01:15:37,040 --> 01:15:39,400 that the activations are roughly uniform 977 01:15:39,400 --> 01:15:40,520 throughout the neural net. 978 01:15:40,980 --> 01:15:43,900 So then I introduced the notion of a normalization layer. 979 01:15:44,400 --> 01:15:45,820 Now, there are many normalization layers 980 01:15:45,820 --> 01:15:47,580 that people use in practice. 981 01:15:47,960 --> 01:15:49,860 Batch normalization, layer normalization, 982 01:15:50,320 --> 01:15:52,220 instance normalization, group normalization. 983 01:15:52,520 --> 01:15:53,880 We haven't covered most of them, 984 01:15:54,060 --> 01:15:55,240 but I've introduced the first one 985 01:15:55,240 --> 01:15:58,060 and also the one that I believe came out first, 986 01:15:58,260 --> 01:15:59,580 and that's called batch normalization. 987 01:16:00,660 --> 01:16:02,180 And we saw how batch normalization works. 988 01:16:02,900 --> 01:16:04,400 This is a layer that you can sprinkle 989 01:16:04,400 --> 01:16:05,720 throughout your deep neural net. 990 01:16:06,320 --> 01:16:08,080 And the basic idea is 991 01:16:08,080 --> 01:16:09,920 if you want roughly Gaussian activations, 992 01:16:10,360 --> 01:16:11,640 well, then take your activations 993 01:16:11,640 --> 01:16:13,940 and take the mean understanding deviation 994 01:16:13,940 --> 01:16:15,840 and standard deviation and center your data. 995 01:16:16,440 --> 01:16:17,460 And you can do that 996 01:16:17,460 --> 01:16:20,360 because the centering operation is differentiable. 997 01:16:21,400 --> 01:16:22,500 But on top of that, 998 01:16:22,560 --> 01:16:24,440 we actually had to add a lot of bells and whistles, 999 01:16:24,960 --> 01:16:26,760 and that gave you a sense of the complexities 1000 01:16:26,760 --> 01:16:28,060 of the batch normalization layer 1001 01:16:28,060 --> 01:16:30,140 because now we're centering the data. 1002 01:16:30,260 --> 01:16:30,640 That's great. 1003 01:16:30,920 --> 01:16:32,940 But suddenly, we need the gain and the bias, 1004 01:16:33,380 --> 01:16:34,380 and now those are trainable. 1005 01:16:35,500 --> 01:16:37,100 And then because we are coupling 1006 01:16:37,100 --> 01:16:38,260 all of the training examples, 1007 01:16:38,540 --> 01:16:39,620 now suddenly the question is, 1008 01:16:39,680 --> 01:16:40,500 how do you do the inference? 1009 01:16:41,160 --> 01:16:42,500 Well, to do the inference, 1010 01:16:42,500 --> 01:16:43,500 we need to now estimate 1011 01:16:43,940 --> 01:16:46,760 these mean and standard deviation 1012 01:16:46,760 --> 01:16:49,600 once over the entire training set 1013 01:16:49,600 --> 01:16:50,960 and then use those at inference. 1014 01:16:51,600 --> 01:16:53,380 But then no one likes to do stage two. 1015 01:16:53,760 --> 01:16:55,040 So instead, we fold everything 1016 01:16:55,040 --> 01:16:57,420 into the batch normalization layer during training 1017 01:16:57,420 --> 01:17:00,160 and try to estimate these in a running manner 1018 01:17:00,160 --> 01:17:01,560 so that everything is a bit simpler. 1019 01:17:02,420 --> 01:17:04,580 And that gives us the batch normalization layer. 1020 01:17:06,160 --> 01:17:07,460 And as I mentioned, 1021 01:17:07,640 --> 01:17:08,720 no one likes this layer. 1022 01:17:09,160 --> 01:17:10,920 It causes a huge amount of bugs. 1023 01:17:12,320 --> 01:17:13,560 And intuitively, 1024 01:17:13,560 --> 01:17:15,840 it's because it is coupling examples 1025 01:17:15,840 --> 01:17:18,060 in the forward-passive and neural net. 1026 01:17:18,800 --> 01:17:21,560 And I've shot myself in the foot 1027 01:17:21,560 --> 01:17:24,640 with this layer over and over again in my life, 1028 01:17:24,900 --> 01:17:27,160 and I don't want you to suffer the same. 1029 01:17:28,180 --> 01:17:30,420 So basically, try to avoid it as much as possible. 1030 01:17:31,700 --> 01:17:33,640 Some of the other alternatives to these layers 1031 01:17:33,640 --> 01:17:35,020 are, for example, group normalization 1032 01:17:35,020 --> 01:17:36,200 or layer normalization, 1033 01:17:36,540 --> 01:17:37,900 and those have become more common 1034 01:17:37,900 --> 01:17:40,220 in more recent deep learning, 1035 01:17:40,720 --> 01:17:42,200 but we haven't covered those yet. 1036 01:17:42,900 --> 01:17:43,540 But definitely, 1037 01:17:43,560 --> 01:17:45,800 batch normalization was very influential 1038 01:17:45,800 --> 01:17:48,380 at the time when it came out in roughly 2015 1039 01:17:48,740 --> 01:17:50,360 because it was kind of the first time 1040 01:17:50,360 --> 01:17:52,320 that you could train reliably 1041 01:17:53,660 --> 01:17:54,880 much deeper neural nets. 1042 01:17:55,380 --> 01:17:57,220 And fundamentally, the reason for that is because 1043 01:17:57,620 --> 01:18:00,520 this layer was very effective at controlling the statistics 1044 01:18:00,820 --> 01:18:02,080 of the activations in the neural net. 1045 01:18:03,180 --> 01:18:04,800 So that's the story so far. 1046 01:18:05,320 --> 01:18:07,600 And that's all I wanted to cover. 1047 01:18:07,800 --> 01:18:09,000 And in the future lectures, 1048 01:18:09,000 --> 01:18:11,000 hopefully we can start going into recurring neural nets. 1049 01:18:11,560 --> 01:18:12,900 And recurring neural nets, 1050 01:18:12,900 --> 01:18:15,860 as we'll see, are just very, very deep networks 1051 01:18:15,860 --> 01:18:18,340 because you unroll the loop 1052 01:18:18,340 --> 01:18:20,680 when you actually optimize these neural nets. 1053 01:18:21,320 --> 01:18:25,040 And that's where a lot of this analysis 1054 01:18:25,040 --> 01:18:26,700 around the activation statistics 1055 01:18:26,700 --> 01:18:28,860 and all these normalization layers 1056 01:18:28,860 --> 01:18:32,220 will become very, very important for good performance. 1057 01:18:32,600 --> 01:18:33,600 So we'll see that next time. 1058 01:18:34,060 --> 01:18:34,320 Bye. 1059 01:18:35,240 --> 01:18:35,940 Okay, so I lied. 1060 01:18:36,300 --> 01:18:38,620 I would like us to do one more summary here as a bonus. 1061 01:18:39,040 --> 01:18:41,760 And I think it's useful as to have one more summary 1062 01:18:41,760 --> 01:18:42,880 of everything I've presented today. 1063 01:18:42,900 --> 01:18:43,420 In this lecture. 1064 01:18:43,820 --> 01:18:45,200 But also, I would like us to start 1065 01:18:45,200 --> 01:18:47,120 by torchifying our code a little bit. 1066 01:18:47,260 --> 01:18:49,580 So it looks much more like what you would encounter in PyTorch. 1067 01:18:50,040 --> 01:18:52,040 So you'll see that I will structure our code 1068 01:18:52,040 --> 01:18:53,700 into these modules, 1069 01:18:54,060 --> 01:18:56,300 like a linear module 1070 01:18:56,300 --> 01:18:58,200 and a batch form module. 1071 01:18:58,600 --> 01:19:00,800 And I'm putting the code inside these modules 1072 01:19:00,800 --> 01:19:02,740 so that we can construct neural networks 1073 01:19:02,740 --> 01:19:04,640 very much like we would construct them in PyTorch. 1074 01:19:04,720 --> 01:19:05,940 And I will go through this in detail. 1075 01:19:06,400 --> 01:19:07,800 So we'll create our neural net. 1076 01:19:08,560 --> 01:19:10,880 Then we will do the optimization loop 1077 01:19:10,880 --> 01:19:11,780 as we did before. 1078 01:19:12,360 --> 01:19:14,200 And then the one more thing that I want to do here 1079 01:19:14,200 --> 01:19:15,940 is I want to look at the activation statistics 1080 01:19:15,940 --> 01:19:17,240 both in the forward pass 1081 01:19:17,240 --> 01:19:18,780 and in the backward pass. 1082 01:19:19,240 --> 01:19:20,720 And then here we have the evaluation 1083 01:19:20,720 --> 01:19:22,100 and sampling just like before. 1084 01:19:22,700 --> 01:19:24,600 So let me rewind all the way up here 1085 01:19:24,600 --> 01:19:26,020 and go a little bit slower. 1086 01:19:26,640 --> 01:19:28,600 So here I am creating a linear layer. 1087 01:19:29,180 --> 01:19:30,820 You'll notice that torch.nn 1088 01:19:30,820 --> 01:19:32,360 has lots of different types of layers. 1089 01:19:32,740 --> 01:19:34,380 And one of those layers is the linear layer. 1090 01:19:35,240 --> 01:19:37,260 Torch.nn.linear takes a number of input features, 1091 01:19:37,420 --> 01:19:38,040 output features, 1092 01:19:38,260 --> 01:19:39,440 whether or not we should have bias, 1093 01:19:39,740 --> 01:19:41,420 and then the device that we want to place 1094 01:19:41,420 --> 01:19:42,100 this layer on, 1095 01:19:42,440 --> 01:19:43,220 and the data type. 1096 01:19:43,800 --> 01:19:45,340 So I will omit these two, 1097 01:19:45,720 --> 01:19:47,660 but otherwise we have the exact same thing. 1098 01:19:48,140 --> 01:19:49,300 We have the fanIn, 1099 01:19:49,380 --> 01:19:50,420 which is the number of inputs, 1100 01:19:50,780 --> 01:19:52,820 fanOut, the number of outputs, 1101 01:19:53,220 --> 01:19:54,620 and whether or not we want to use a bias. 1102 01:19:55,240 --> 01:19:56,560 And internally inside this layer, 1103 01:19:56,820 --> 01:19:58,240 there's a weight and a bias, 1104 01:19:58,420 --> 01:19:59,080 if you'd like it. 1105 01:19:59,720 --> 01:20:02,100 It is typical to initialize the weight 1106 01:20:02,100 --> 01:20:05,380 using, say, random numbers drawn from a Gaussian. 1107 01:20:05,880 --> 01:20:07,540 And then here's the coming initialization 1108 01:20:07,540 --> 01:20:10,040 that we discussed already in this lecture. 1109 01:20:10,100 --> 01:20:11,380 And that's a good, 1110 01:20:11,420 --> 01:20:12,880 the default and also the default 1111 01:20:12,880 --> 01:20:14,100 that I believe PyTorch uses. 1112 01:20:14,760 --> 01:20:15,360 And by default, 1113 01:20:15,360 --> 01:20:17,660 the bias is usually initialized to zeros. 1114 01:20:18,380 --> 01:20:19,800 Now, when you call this module, 1115 01:20:20,800 --> 01:20:23,060 this will basically calculate w times x plus b, 1116 01:20:23,200 --> 01:20:24,200 if you have nb. 1117 01:20:24,900 --> 01:20:27,320 And then when you also call the parameters on this module, 1118 01:20:27,440 --> 01:20:29,600 it will return the tensors 1119 01:20:29,780 --> 01:20:31,460 that are the parameters of this layer. 1120 01:20:32,200 --> 01:20:34,300 Now, next we have the batch normalization layer. 1121 01:20:34,520 --> 01:20:36,540 So I've written that here. 1122 01:20:37,020 --> 01:20:40,460 And this is very similar to PyTorch.nn.batchNormalization.nl. 1123 01:20:40,460 --> 01:20:40,860 And this is very similar to PyTorch.nn.batchNormalization.nl. 1124 01:20:40,860 --> 01:20:41,400 And this is very similar to PyTorch.nn.batchNormalization.nl. 1125 01:20:41,420 --> 01:20:45,180 So PyTorch is showing is actually a normal one D layer as shown here. 1126 01:20:45,180 --> 01:20:48,020 So I'm kind of taking these three parameters here, 1127 01:20:48,020 --> 01:20:48,980 The dimensionality, 1128 01:20:48,980 --> 01:20:51,540 the epsilon that we will use in the division, 1129 01:20:51,540 --> 01:20:53,300 and the momentum that we will use 1130 01:20:53,300 --> 01:20:55,160 in keeping track of these running stats 1131 01:20:55,160 --> 01:20:57,120 the running mean and the running variance. 1132 01:20:58,220 --> 01:21:00,500 Now PyTorch actually takes quite a few more things, 1133 01:21:00,500 --> 01:21:02,340 but I'm assuming some of their settings. 1134 01:21:02,340 --> 01:21:03,960 So for us affine will be true. 1135 01:21:03,960 --> 01:21:06,180 That means that we will be using a Gamma and Beta 1136 01:21:06,180 --> 01:21:08,060 after the normalization. 1137 01:21:08,060 --> 01:21:09,620 The track running stats will be true. 1138 01:21:09,620 --> 01:21:10,860 So we will be keeping track 1139 01:21:10,860 --> 01:21:16,380 running mean and the running variance in the in the bastion our device by default is the cpu 1140 01:21:16,940 --> 01:21:23,740 and the data type by default is float float32 so those are the defaults otherwise 1141 01:21:24,380 --> 01:21:28,700 we are taking all the same parameters in this bastion layer so first i'm just saving them 1142 01:21:29,740 --> 01:21:34,620 now here's something new there's a dot training which by default is true and pytorch nn modules 1143 01:21:34,620 --> 01:21:40,700 also have this attribute dot training and that's because many modules and batch norm is included 1144 01:21:40,700 --> 01:21:45,500 in that have a different behavior whether you are training your neural net and or whether you 1145 01:21:45,500 --> 01:21:49,900 are running it in an evaluation mode and calculating your evaluation laws or using 1146 01:21:49,900 --> 01:21:55,020 it for inference on some test examples and bastion is an example of this because when 1147 01:21:55,020 --> 01:21:59,020 we are training we are going to be using the mean and the variance estimated from the current batch 1148 01:21:59,580 --> 01:22:03,020 but during inference we are using the running mean and running variance 1149 01:22:03,900 --> 01:22:04,300 and so 1150 01:22:04,780 --> 01:22:09,100 also if we are training we are updating mean and variance but if we are testing then these 1151 01:22:09,100 --> 01:22:14,140 are not being updated they're kept fixed and so this flag is necessary and by default true 1152 01:22:14,140 --> 01:22:19,820 just like in pytorch now the parameters of bastion1d are the gamma and the beta here 1153 01:22:21,660 --> 01:22:28,140 and then the running mean and running variance are called buffers in pytorch nomenclature and these 1154 01:22:28,140 --> 01:22:34,380 buffers are trained using exponential moving average here explicitly and they are not part of 1155 01:22:34,620 --> 01:22:38,460 the back propagation and stochastic gradient descent so they are not sort of like parameters 1156 01:22:38,460 --> 01:22:43,500 of this layer and that's why when we calculate when we have a parameters here we only return 1157 01:22:43,500 --> 01:22:48,700 gamma and beta we do not return the mean and the variance this is trained sort of like internally 1158 01:22:48,700 --> 01:22:55,580 here every forward pass using exponential moving average so that's the initialization 1159 01:22:56,700 --> 01:23:02,380 now in a forward pass if we are training then we use the mean and the variance estimated by the 1160 01:23:02,380 --> 01:23:04,300 batch let me pull up the paper here 1161 01:23:05,580 --> 01:23:11,900 we calculate the mean and the variance now up above i was estimating the standard deviation 1162 01:23:11,900 --> 01:23:16,380 and keeping track of the standard deviation here in the running standard deviation instead of 1163 01:23:16,380 --> 01:23:22,140 running variance but let's follow the paper exactly here they calculate the variance which 1164 01:23:22,140 --> 01:23:26,380 is the standard deviation squared and that's what's kept track of in the running variance 1165 01:23:26,380 --> 01:23:31,660 instead of a running standard deviation but those two would be very very similar i believe 1166 01:23:33,580 --> 01:23:34,460 if we are not training 1167 01:23:34,620 --> 01:23:40,780 then we use the running mean and variance we normalize and then here i am calculating the 1168 01:23:40,780 --> 01:23:46,540 output of this layer and i'm also assigning it to an attribute called dot out now dot out is 1169 01:23:46,540 --> 01:23:51,500 something that i'm using in our modules here this is not what you would find in pytorch we are 1170 01:23:51,500 --> 01:23:57,820 slightly deviating from it i'm creating a dot out because i would like to very easily maintain all 1171 01:23:57,820 --> 01:24:02,860 those variables so that we can create statistics of them and plot them but pytorch and modules will 1172 01:24:02,860 --> 01:24:04,380 not have a dot out attribute 1173 01:24:05,260 --> 01:24:09,500 and finally here we are updating the buffers using again as i mentioned exponential moving average 1174 01:24:10,700 --> 01:24:14,780 provide given the provided momentum and importantly you'll notice that i'm using 1175 01:24:14,780 --> 01:24:20,380 the torch.nograd context manager and i'm doing this because if we don't use this then pytorch 1176 01:24:20,380 --> 01:24:25,420 will start building out an entire computational graph out of these tensors because it is expecting 1177 01:24:25,420 --> 01:24:29,660 that we will eventually call a dot backward but we are never going to be calling dot backward 1178 01:24:29,660 --> 01:24:34,300 on anything that includes running mean and running variance so that's why we need to use this context 1179 01:24:35,020 --> 01:24:41,020 so that we are not sort of maintaining them using all this additional memory so this will make it 1180 01:24:41,020 --> 01:24:44,700 more efficient and it's just telling pytorch that they're rolling no backward we just have a bunch 1181 01:24:44,700 --> 01:24:51,660 of tensors we want to update them that's it and then we return okay now scrolling down we have the 1182 01:24:51,660 --> 01:24:59,020 10h layer this is very very similar to torch.10h and it doesn't do too much it just calculates 10h 1183 01:24:59,020 --> 01:25:04,380 as you might expect so that's torch.10h and there's no parameters in this layer 1184 01:25:05,020 --> 01:25:09,980 but because these are layers it now becomes very easy to sort of like stack them up into 1185 01:25:10,700 --> 01:25:16,860 basically just a list and we can do all the initializations that we're used to so we have the 1186 01:25:17,420 --> 01:25:21,260 initial sort of embedding matrix we have our layers and we can call them sequentially 1187 01:25:22,060 --> 01:25:25,980 and then again with torch.nograd there's some initializations here 1188 01:25:25,980 --> 01:25:31,100 so we want to make the output softmax a bit less confident like we saw and in addition to that 1189 01:25:31,100 --> 01:25:34,380 because we are using a six layer multilayer perceptron here 1190 01:25:34,620 --> 01:25:37,860 So you see how I'm stacking linear, 10H, linear, 10H, etc. 1191 01:25:39,100 --> 01:25:41,160 I'm going to be using the gain here. 1192 01:25:41,320 --> 01:25:42,700 And I'm going to play with this in a second. 1193 01:25:42,880 --> 01:25:45,860 So you'll see how, when we change this, what happens to the statistics. 1194 01:25:47,140 --> 01:25:51,860 Finally, the parameters are basically the embedding matrix and all the parameters in all the layers. 1195 01:25:52,400 --> 01:25:56,100 And notice here, I'm using a double list comprehension, if you want to call it that. 1196 01:25:56,100 --> 01:26:00,360 But for every layer in layers, and for every parameter in each of those layers, 1197 01:26:00,560 --> 01:26:03,980 we are just stacking up all those P's, all those parameters. 1198 01:26:04,840 --> 01:26:08,040 Now, in total, we have 46,000 parameters. 1199 01:26:09,060 --> 01:26:12,100 And I'm telling PyTorch that all of them require gradient. 1200 01:26:15,720 --> 01:26:19,900 Then here, we have everything here we are actually mostly used to. 1201 01:26:20,380 --> 01:26:21,620 We are sampling batch. 1202 01:26:21,940 --> 01:26:23,120 We are doing forward pass. 1203 01:26:23,240 --> 01:26:26,640 The forward pass now is just a linear application of all the layers in order, 1204 01:26:27,460 --> 01:26:28,480 followed by the cross entropy. 1205 01:26:29,460 --> 01:26:31,960 And then in the backward pass, you'll notice that for every single layer, 1206 01:26:32,220 --> 01:26:33,760 I now iterate over all the outputs. 1207 01:26:34,180 --> 01:26:34,600 And I'm telling you, 1208 01:26:34,820 --> 01:26:36,720 I'm telling PyTorch to retain the gradient of them. 1209 01:26:37,420 --> 01:26:41,320 And then here, we are already used to all the gradients set to none, 1210 01:26:41,720 --> 01:26:43,420 do the backward to fill in the gradients, 1211 01:26:43,920 --> 01:26:45,820 do an update using stochastic gradient send, 1212 01:26:46,320 --> 01:26:48,120 and then track some statistics. 1213 01:26:48,720 --> 01:26:51,620 And then I am going to break after a single iteration. 1214 01:26:52,020 --> 01:26:53,920 Now, here in this cell, in this diagram, 1215 01:26:54,120 --> 01:26:58,420 I'm visualizing the histograms of the forward pass activations, 1216 01:26:58,720 --> 01:27:01,220 and I'm specifically doing it at the 10-inch layers. 1217 01:27:01,820 --> 01:27:04,120 So iterating over all the layers, 1218 01:27:04,120 --> 01:27:05,520 except for the very last one, 1219 01:27:05,720 --> 01:27:08,220 which is basically just the softmax layer. 1220 01:27:10,220 --> 01:27:11,620 If it is a 10-inch layer, 1221 01:27:11,820 --> 01:27:14,520 and I'm using a 10-inch layer just because they have a finite output, 1222 01:27:14,720 --> 01:27:15,420 negative one to one. 1223 01:27:15,620 --> 01:27:17,220 And so it's very easy to visualize here. 1224 01:27:17,420 --> 01:27:18,820 So you see negative one to one, 1225 01:27:19,020 --> 01:27:20,820 and it's a finite range and easy to work with. 1226 01:27:21,820 --> 01:27:25,220 I take the out tensor from that layer into T, 1227 01:27:25,620 --> 01:27:26,920 and then I'm calculating the mean, 1228 01:27:27,120 --> 01:27:28,020 the standard deviation, 1229 01:27:28,220 --> 01:27:30,020 and the percent saturation of T. 1230 01:27:30,720 --> 01:27:32,720 And the way I define the percent saturation is that 1231 01:27:32,920 --> 01:27:33,920 T dot absolute value 1232 01:27:33,920 --> 01:27:35,020 is greater than 0.97. 1233 01:27:35,520 --> 01:27:38,320 So that means we are here at the tails of the 10-inch. 1234 01:27:38,720 --> 01:27:40,820 And remember that when we are in the tails of the 10-inch, 1235 01:27:41,020 --> 01:27:42,420 that will actually stop gradients. 1236 01:27:42,820 --> 01:27:44,420 So we don't want this to be too high. 1237 01:27:45,620 --> 01:27:48,520 Now, here I'm calling torch dot histogram, 1238 01:27:49,020 --> 01:27:50,720 and then I am plotting this histogram. 1239 01:27:51,320 --> 01:27:52,620 So basically what this is doing is that 1240 01:27:52,820 --> 01:27:53,920 every different type of layer, 1241 01:27:54,120 --> 01:27:55,120 and they all have a different color, 1242 01:27:55,420 --> 01:27:59,620 we are looking at how many values in these tensors 1243 01:27:59,820 --> 01:28:03,520 take on any of the values below on this axis here. 1244 01:28:03,920 --> 01:28:08,020 So the first layer is fairly saturated here at 20%. 1245 01:28:08,220 --> 01:28:10,120 So you can see that it's got tails here, 1246 01:28:10,320 --> 01:28:12,120 but then everything sort of stabilizes. 1247 01:28:12,320 --> 01:28:13,820 And if we had more layers here, 1248 01:28:14,020 --> 01:28:15,020 it would actually just stabilize 1249 01:28:15,220 --> 01:28:17,420 at around the standard deviation of about 0.65, 1250 01:28:17,620 --> 01:28:20,020 and the saturation would be roughly 5%. 1251 01:28:20,220 --> 01:28:22,320 And the reason that this stabilizes 1252 01:28:22,520 --> 01:28:24,120 and gives us a nice distribution here 1253 01:28:24,320 --> 01:28:26,820 is because gain is set to 5 over 3. 1254 01:28:27,020 --> 01:28:30,120 Now, here, this gain, 1255 01:28:30,320 --> 01:28:33,320 you see that by default we initialize with 1256 01:28:33,320 --> 01:28:34,920 1 over square root of fan in. 1257 01:28:35,120 --> 01:28:36,720 But then here during initialization, 1258 01:28:36,920 --> 01:28:38,620 I come in and I iterate over all the layers, 1259 01:28:38,820 --> 01:28:41,420 and if it's a linear layer, I boost that by the gain. 1260 01:28:41,620 --> 01:28:44,020 Now, we saw that 1, 1261 01:28:44,220 --> 01:28:46,720 so basically if we just do not use a gain, 1262 01:28:46,920 --> 01:28:47,920 then what happens? 1263 01:28:48,120 --> 01:28:49,720 If I redraw this, 1264 01:28:49,920 --> 01:28:53,720 you will see that the standard deviation is shrinking, 1265 01:28:53,920 --> 01:28:56,420 and the saturation is coming to 0. 1266 01:28:56,620 --> 01:28:58,120 And basically what's happening is 1267 01:28:58,320 --> 01:29:00,520 the first layer is pretty decent, 1268 01:29:00,720 --> 01:29:02,920 but then further layers are just kind of like, 1269 01:29:03,320 --> 01:29:04,820 shrinking down to 0. 1270 01:29:05,020 --> 01:29:07,520 And it's happening slowly, but it's shrinking to 0. 1271 01:29:07,720 --> 01:29:09,420 And the reason for that is 1272 01:29:09,620 --> 01:29:13,020 when you just have a sandwich of linear layers alone, 1273 01:29:13,220 --> 01:29:17,820 then initializing our weights in this manner, 1274 01:29:18,020 --> 01:29:19,020 we saw previously, 1275 01:29:19,220 --> 01:29:22,020 would have conserved the standard deviation of 1. 1276 01:29:22,220 --> 01:29:26,420 But because we have this interspersed 10H layers in there, 1277 01:29:26,620 --> 01:29:29,420 these 10H layers are squashing functions. 1278 01:29:29,620 --> 01:29:31,220 And so they take your distribution 1279 01:29:31,420 --> 01:29:32,820 and they slightly squash it. 1280 01:29:32,820 --> 01:29:36,920 And so some gain is necessary to keep expanding it 1281 01:29:37,120 --> 01:29:39,720 to fight the squashing. 1282 01:29:39,920 --> 01:29:43,220 So it just turns out that 5 over 3 is a good value. 1283 01:29:43,420 --> 01:29:45,520 So if we have something too small like 1, 1284 01:29:45,720 --> 01:29:48,820 we saw that things will come towards 0. 1285 01:29:49,020 --> 01:29:52,220 But if it's something too high, let's do 2. 1286 01:29:52,420 --> 01:29:56,420 Then here we see that, 1287 01:29:56,620 --> 01:29:59,020 well, let me do something a bit more extreme 1288 01:29:59,220 --> 01:30:00,320 so it's a bit more visible. 1289 01:30:00,520 --> 01:30:02,320 Let's try 3. 1290 01:30:02,320 --> 01:30:03,920 So we see here that the saturations 1291 01:30:04,120 --> 01:30:05,920 are trying to be way too large. 1292 01:30:06,120 --> 01:30:10,120 So 3 would create way too saturated activations. 1293 01:30:10,320 --> 01:30:12,920 So 5 over 3 is a good setting 1294 01:30:13,120 --> 01:30:17,120 for a sandwich of linear layers with 10H activations. 1295 01:30:17,320 --> 01:30:20,120 And it roughly stabilizes the standard deviation 1296 01:30:20,320 --> 01:30:22,120 at a reasonable point. 1297 01:30:22,320 --> 01:30:24,120 Now, honestly, I have no idea 1298 01:30:24,320 --> 01:30:26,120 where 5 over 3 came from in PyTorch 1299 01:30:26,320 --> 01:30:29,120 when we were looking at the coming initialization. 1300 01:30:29,320 --> 01:30:32,120 I see empirically that it stabilizes 1301 01:30:32,320 --> 01:30:34,120 a sandwich of linear and 10H 1302 01:30:34,320 --> 01:30:36,120 and that the saturation is in a good range. 1303 01:30:36,320 --> 01:30:39,120 But I don't actually know if this came out of some math formula. 1304 01:30:39,320 --> 01:30:42,120 I tried searching briefly for where this comes from, 1305 01:30:42,320 --> 01:30:44,120 but I wasn't able to find anything. 1306 01:30:44,320 --> 01:30:46,120 But certainly we see that empirically 1307 01:30:46,320 --> 01:30:47,120 these are very nice ranges. 1308 01:30:47,320 --> 01:30:49,120 Our saturation is roughly 5%, 1309 01:30:49,320 --> 01:30:51,120 which is a pretty good number. 1310 01:30:51,320 --> 01:30:55,120 And this is a good setting of the gain in this context. 1311 01:30:55,320 --> 01:30:58,120 Similarly, we can do the exact same thing with the gradients. 1312 01:30:58,320 --> 01:31:01,120 So here is a very same loop if it's a 10H, 1313 01:31:01,120 --> 01:31:02,920 but instead of taking the layer.out, 1314 01:31:03,120 --> 01:31:03,920 I'm taking the grad. 1315 01:31:04,120 --> 01:31:06,920 And then I'm also showing the mean and the standard deviation. 1316 01:31:07,120 --> 01:31:09,920 And I'm plotting the histogram of these values. 1317 01:31:10,120 --> 01:31:11,920 And so you'll see that the gradient distribution 1318 01:31:12,120 --> 01:31:12,920 is fairly reasonable. 1319 01:31:13,120 --> 01:31:14,920 And in particular, what we're looking for 1320 01:31:15,120 --> 01:31:17,920 is that all the different layers in this sandwich 1321 01:31:18,120 --> 01:31:19,920 has roughly the same gradient. 1322 01:31:20,120 --> 01:31:21,920 Things are not shrinking or exploding. 1323 01:31:22,120 --> 01:31:23,920 So we can, for example, come here 1324 01:31:24,120 --> 01:31:25,920 and we can take a look at what happens 1325 01:31:26,120 --> 01:31:27,920 if this gain was way too small. 1326 01:31:28,120 --> 01:31:29,920 So this was 0.5. 1327 01:31:29,920 --> 01:31:31,720 Then you see the... 1328 01:31:31,920 --> 01:31:33,720 First of all, the activations are shrinking to 0, 1329 01:31:33,920 --> 01:31:35,720 but also the gradients are doing something weird. 1330 01:31:35,920 --> 01:31:37,720 The gradients started out here 1331 01:31:37,920 --> 01:31:40,720 and then now they're like expanding out. 1332 01:31:40,920 --> 01:31:43,720 And similarly, if we, for example, have a 2 high of a gain, 1333 01:31:43,920 --> 01:31:45,720 so like 3, 1334 01:31:45,920 --> 01:31:47,720 then we see that also the gradients have... 1335 01:31:47,920 --> 01:31:49,720 There's some asymmetry going on where 1336 01:31:49,920 --> 01:31:51,720 as you go into deeper and deeper layers, 1337 01:31:51,920 --> 01:31:53,720 the activations are also changing. 1338 01:31:53,920 --> 01:31:55,720 And so that's not what we want. 1339 01:31:55,920 --> 01:31:57,720 And in this case, we saw that without the use of BatchNorm, 1340 01:31:57,920 --> 01:31:59,720 as we are going through right now, 1341 01:31:59,920 --> 01:32:02,720 we have to very carefully set those gains 1342 01:32:02,920 --> 01:32:04,720 to get nice activations 1343 01:32:04,920 --> 01:32:07,720 in both the forward pass and the backward pass. 1344 01:32:07,920 --> 01:32:09,720 Now, before we move on to BatchNormalization, 1345 01:32:09,920 --> 01:32:11,720 I would also like to take a look at what happens 1346 01:32:11,920 --> 01:32:13,720 when we have no 10H units here. 1347 01:32:13,920 --> 01:32:16,720 So erasing all the 10H nonlinearities, 1348 01:32:16,920 --> 01:32:18,720 but keeping the gain at 5 over 3, 1349 01:32:18,920 --> 01:32:21,720 we now have just a giant linear sandwich. 1350 01:32:21,920 --> 01:32:23,720 So let's see what happens to the activations. 1351 01:32:23,920 --> 01:32:26,720 As we saw before, the correct gain here is 1. 1352 01:32:26,920 --> 01:32:29,720 That is the standard deviation preserving gain. 1353 01:32:29,720 --> 01:32:33,520 So 1.667 is too high. 1354 01:32:33,720 --> 01:32:37,520 And so what's going to happen now is the following. 1355 01:32:37,720 --> 01:32:39,520 I have to change this to be linear, 1356 01:32:39,720 --> 01:32:42,520 because there's no more 10H layers. 1357 01:32:42,720 --> 01:32:45,520 And let me change this to linear as well. 1358 01:32:45,720 --> 01:32:47,520 So what we're seeing is 1359 01:32:47,720 --> 01:32:50,520 the activations started out on the blue 1360 01:32:50,720 --> 01:32:54,520 and have, by layer 4, become very diffuse. 1361 01:32:54,720 --> 01:32:57,520 So what's happening to the activations is this. 1362 01:32:57,720 --> 01:32:59,520 And with the gradients, 1363 01:32:59,520 --> 01:33:01,320 on the top layer, the activation, 1364 01:33:01,520 --> 01:33:04,320 the gradient statistics are the purple, 1365 01:33:04,520 --> 01:33:07,320 and then they diminish as you go down deeper in the layers. 1366 01:33:07,520 --> 01:33:10,320 And so basically you have an asymmetry in the neural net. 1367 01:33:10,520 --> 01:33:13,320 And you might imagine that if you have very deep neural networks, 1368 01:33:13,520 --> 01:33:15,320 say like 50 layers or something like that, 1369 01:33:15,520 --> 01:33:18,320 this is not a good place to be. 1370 01:33:18,520 --> 01:33:21,320 So that's why before BatchNormalization, 1371 01:33:21,520 --> 01:33:24,320 this was incredibly tricky to set. 1372 01:33:24,520 --> 01:33:27,320 In particular, if this is too large of a gain, this happens. 1373 01:33:27,520 --> 01:33:29,320 And if it's too little of a gain, 1374 01:33:29,520 --> 01:33:31,320 then this happens. 1375 01:33:31,520 --> 01:33:33,320 So the opposite of that basically happens. 1376 01:33:33,520 --> 01:33:39,320 Here we have a shrinking and a diffusion, 1377 01:33:39,520 --> 01:33:42,320 depending on which direction you look at it from. 1378 01:33:42,520 --> 01:33:44,320 And so certainly this is not what you want. 1379 01:33:44,520 --> 01:33:45,320 And in this case, 1380 01:33:45,520 --> 01:33:48,320 the correct setting of the gain is exactly 1, 1381 01:33:48,520 --> 01:33:50,320 just like we're doing at initialization. 1382 01:33:50,520 --> 01:33:53,320 And then we see that the statistics 1383 01:33:53,520 --> 01:33:56,320 for the forward and the backward pass are well behaved. 1384 01:33:56,520 --> 01:33:59,320 And so the reason I want to show you this 1385 01:33:59,320 --> 01:34:02,120 is that basically getting neural nets to train 1386 01:34:02,320 --> 01:34:04,120 before these normalization layers 1387 01:34:04,320 --> 01:34:07,120 and before the use of advanced optimizers like Atom, 1388 01:34:07,320 --> 01:34:09,120 which we still have to cover, 1389 01:34:09,320 --> 01:34:11,120 and residual connections and so on, 1390 01:34:11,320 --> 01:34:13,120 training neural nets basically looked like this. 1391 01:34:13,320 --> 01:34:15,120 It's like a total balancing act. 1392 01:34:15,320 --> 01:34:18,120 You have to make sure that everything is precisely orchestrated 1393 01:34:18,320 --> 01:34:20,120 and you have to care about the activations 1394 01:34:20,320 --> 01:34:22,120 and the gradients and their statistics. 1395 01:34:22,320 --> 01:34:24,120 And then maybe you can train something. 1396 01:34:24,320 --> 01:34:26,120 But it was basically impossible to train very deep networks. 1397 01:34:26,320 --> 01:34:28,120 And this is fundamentally the reason for that. 1398 01:34:28,120 --> 01:34:31,920 You'd have to be very, very careful with your initialization. 1399 01:34:32,120 --> 01:34:34,920 The other point here is you might be asking yourself, 1400 01:34:35,120 --> 01:34:36,920 by the way, I'm not sure if I covered this. 1401 01:34:37,120 --> 01:34:39,920 Why do we need these 10H layers at all? 1402 01:34:40,120 --> 01:34:42,920 Why do we include them and then have to worry about the gain? 1403 01:34:43,120 --> 01:34:44,920 And the reason for that, of course, 1404 01:34:45,120 --> 01:34:46,920 is that if you just have a stack of linear layers, 1405 01:34:47,120 --> 01:34:49,920 then certainly we're getting very easily 1406 01:34:50,120 --> 01:34:51,920 nice activations and so on. 1407 01:34:52,120 --> 01:34:53,920 But this is just a massive linear sandwich. 1408 01:34:54,120 --> 01:34:56,920 And it turns out that it collapses to a single linear layer 1409 01:34:56,920 --> 01:34:58,720 in terms of its representation power. 1410 01:34:58,920 --> 01:35:02,720 So if you were to plot the output as a function of the input, 1411 01:35:02,920 --> 01:35:04,720 you're just getting a linear function. 1412 01:35:04,920 --> 01:35:06,720 No matter how many linear layers you stack up, 1413 01:35:06,920 --> 01:35:08,720 you still just end up with a linear transformation. 1414 01:35:08,920 --> 01:35:13,720 All the WX plus Bs just collapse into a large WX plus B 1415 01:35:13,920 --> 01:35:16,720 with slightly different Ws and slightly different Bs. 1416 01:35:16,920 --> 01:35:19,720 But interestingly, even though the forward pass collapses 1417 01:35:19,920 --> 01:35:22,720 to just a linear layer, because of back propagation 1418 01:35:22,920 --> 01:35:25,720 and the dynamics of the backward pass, 1419 01:35:25,720 --> 01:35:28,520 the optimization actually is not identical. 1420 01:35:28,720 --> 01:35:32,520 You actually end up with all kinds of interesting dynamics 1421 01:35:32,720 --> 01:35:35,520 in the backward pass because of the way 1422 01:35:35,720 --> 01:35:37,520 the chain rule is calculating it. 1423 01:35:37,720 --> 01:35:40,520 And so optimizing a linear layer by itself 1424 01:35:40,720 --> 01:35:43,520 and optimizing a sandwich of 10 linear layers, 1425 01:35:43,720 --> 01:35:45,520 in both cases, those are just a linear transformation 1426 01:35:45,720 --> 01:35:47,520 in the forward pass, but the training dynamics 1427 01:35:47,720 --> 01:35:48,520 would be different. 1428 01:35:48,720 --> 01:35:50,520 And there's entire papers that analyze, in fact, 1429 01:35:50,720 --> 01:35:53,520 infinitely layered linear layers and so on. 1430 01:35:53,720 --> 01:35:55,520 And so there's a lot of things 1431 01:35:55,520 --> 01:35:57,320 that you can play with there. 1432 01:35:57,520 --> 01:36:00,320 But basically, the 10-inch nonlinearities 1433 01:36:00,520 --> 01:36:05,320 allow us to turn this sandwich 1434 01:36:05,520 --> 01:36:11,320 from just a linear function into a neural network 1435 01:36:11,520 --> 01:36:14,320 that can, in principle, approximate any arbitrary function. 1436 01:36:14,520 --> 01:36:17,320 Okay, so now I've reset the code to use 1437 01:36:17,520 --> 01:36:20,320 the linear 10-inch sandwich like before. 1438 01:36:20,520 --> 01:36:23,320 And I reset everything, so the gain is 5 over 3. 1439 01:36:23,520 --> 01:36:25,320 We can run a single step of optimization. 1440 01:36:25,520 --> 01:36:27,320 And we can look at the activation statistics 1441 01:36:27,520 --> 01:36:29,320 of the forward pass and the backward pass. 1442 01:36:29,520 --> 01:36:31,320 But I've added one more plot here 1443 01:36:31,520 --> 01:36:33,320 that I think is really important to look at 1444 01:36:33,520 --> 01:36:35,320 when you're training your neural nets and to consider. 1445 01:36:35,520 --> 01:36:37,320 And ultimately, what we're doing is 1446 01:36:37,520 --> 01:36:39,320 we're updating the parameters of the neural net. 1447 01:36:39,520 --> 01:36:41,320 So we care about the parameters 1448 01:36:41,520 --> 01:36:43,320 and their values and their gradients. 1449 01:36:43,520 --> 01:36:45,320 So here what I'm doing is I'm actually 1450 01:36:45,520 --> 01:36:47,320 iterating over all the parameters available 1451 01:36:47,520 --> 01:36:50,320 and then I'm only restricting it 1452 01:36:50,520 --> 01:36:52,320 to the two-dimensional parameters, 1453 01:36:52,520 --> 01:36:54,320 which are basically the weights of these linear layers. 1454 01:36:54,320 --> 01:36:56,120 And I'm skipping the biases 1455 01:36:56,320 --> 01:36:59,120 and I'm skipping the gammas and the betas 1456 01:36:59,320 --> 01:37:02,120 in the bash term just for simplicity. 1457 01:37:02,320 --> 01:37:04,120 But you can also take a look at those as well. 1458 01:37:04,320 --> 01:37:06,120 But what's happening with the weights 1459 01:37:06,320 --> 01:37:08,120 is instructive by itself. 1460 01:37:08,320 --> 01:37:12,120 So here we have all the different weights, their shapes. 1461 01:37:12,320 --> 01:37:14,120 So this is the embedding layer, 1462 01:37:14,320 --> 01:37:15,120 the first linear layer, 1463 01:37:15,320 --> 01:37:17,120 all the way to the very last linear layer. 1464 01:37:17,320 --> 01:37:18,120 And then we have the mean, 1465 01:37:18,320 --> 01:37:21,120 the standard deviation of all these parameters. 1466 01:37:21,320 --> 01:37:23,120 The histogram, and you can see that 1467 01:37:23,120 --> 01:37:24,920 it actually doesn't look that amazing. 1468 01:37:25,120 --> 01:37:26,920 So there's some trouble in paradise. 1469 01:37:27,120 --> 01:37:28,920 Even though these gradients looked okay, 1470 01:37:29,120 --> 01:37:30,920 there's something weird going on here. 1471 01:37:31,120 --> 01:37:32,920 I'll get to that in a second. 1472 01:37:33,120 --> 01:37:35,920 And the last thing here is the gradient to data ratio. 1473 01:37:36,120 --> 01:37:37,920 So sometimes I like to visualize this as well 1474 01:37:38,120 --> 01:37:39,920 because what this gives you a sense of is 1475 01:37:40,120 --> 01:37:41,920 what is the scale of the gradient 1476 01:37:42,120 --> 01:37:44,920 compared to the scale of the actual values? 1477 01:37:45,120 --> 01:37:47,920 And this is important because we're going to end up 1478 01:37:48,120 --> 01:37:49,920 taking a step update 1479 01:37:50,120 --> 01:37:52,920 that is the learning rate times the gradient 1480 01:37:52,920 --> 01:37:53,920 to the data. 1481 01:37:54,120 --> 01:37:56,520 And so if the gradient has too large of a magnitude, 1482 01:37:56,720 --> 01:37:58,120 if the numbers in there are too large 1483 01:37:58,320 --> 01:37:59,920 compared to the numbers in data, 1484 01:38:00,120 --> 01:38:01,520 then you'd be in trouble. 1485 01:38:01,720 --> 01:38:05,120 But in this case, the gradient to data is our low numbers. 1486 01:38:05,320 --> 01:38:09,120 So the values inside grad are 1000 times smaller 1487 01:38:09,320 --> 01:38:13,720 than the values inside data in these weights, most of them. 1488 01:38:13,920 --> 01:38:16,920 Now, notably, that is not true about the last layer. 1489 01:38:17,120 --> 01:38:18,320 And so the last layer actually here, 1490 01:38:18,520 --> 01:38:20,520 the output layer is a bit of a troublemaker 1491 01:38:20,720 --> 01:38:22,520 in the way that this is currently arranged, 1492 01:38:22,520 --> 01:38:28,320 because you can see that the last layer here in pink 1493 01:38:28,520 --> 01:38:30,320 takes on values that are much larger 1494 01:38:30,520 --> 01:38:35,320 than some of the values inside the neural net. 1495 01:38:35,520 --> 01:38:38,320 So the standard deviations are roughly 1 in negative 3 throughout, 1496 01:38:38,520 --> 01:38:41,320 except for the last layer, 1497 01:38:41,520 --> 01:38:43,320 which actually has roughly 1 in negative 2 1498 01:38:43,520 --> 01:38:45,320 standard deviation of gradients. 1499 01:38:45,520 --> 01:38:47,320 And so the gradients on the last layer 1500 01:38:47,520 --> 01:38:50,320 are currently about 100 times greater, 1501 01:38:50,520 --> 01:38:52,320 sorry, 10 times greater 1502 01:38:52,520 --> 01:38:55,720 than all the other weights inside the neural net. 1503 01:38:55,920 --> 01:38:56,720 And so that's problematic, 1504 01:38:56,920 --> 01:39:00,120 because in the simple stochastic gradient descent setup, 1505 01:39:00,320 --> 01:39:03,520 you would be training this last layer about 10 times faster 1506 01:39:03,720 --> 01:39:06,920 than you would be training the other layers at initialization. 1507 01:39:07,120 --> 01:39:09,920 Now, this actually kind of fixes itself a little bit 1508 01:39:10,120 --> 01:39:11,120 if you train for a bit longer. 1509 01:39:11,320 --> 01:39:13,920 So, for example, if I greater than 1000, 1510 01:39:14,120 --> 01:39:17,320 only then do a break, let me reinitialize, 1511 01:39:17,520 --> 01:39:19,920 and then let me do it 1000 steps. 1512 01:39:20,120 --> 01:39:22,320 And after 1000 steps, we can look at the 1513 01:39:22,720 --> 01:39:24,120 forward pass. 1514 01:39:24,320 --> 01:39:26,120 OK, so you see how the neurons are a bit, 1515 01:39:26,320 --> 01:39:27,720 are saturating a bit. 1516 01:39:27,920 --> 01:39:29,920 And we can also look at the backward pass, 1517 01:39:30,120 --> 01:39:31,120 but otherwise they look good. 1518 01:39:31,320 --> 01:39:33,920 They're about equal and there's no shrinking to zero 1519 01:39:34,120 --> 01:39:36,120 or exploding to infinities. 1520 01:39:36,320 --> 01:39:38,520 And you can see that here in the weights, 1521 01:39:38,720 --> 01:39:40,120 things are also stabilizing a little bit. 1522 01:39:40,320 --> 01:39:42,720 So the tails of the last pink layer 1523 01:39:42,920 --> 01:39:46,320 are actually coming in during the optimization. 1524 01:39:46,520 --> 01:39:48,720 But certainly this is like a little bit troubling, 1525 01:39:48,920 --> 01:39:50,920 especially if you are using a very simple update rule 1526 01:39:50,920 --> 01:39:52,520 like stochastic gradient descent 1527 01:39:52,720 --> 01:39:55,120 instead of a modern optimizer like Atom. 1528 01:39:55,320 --> 01:39:56,520 Now I'd like to show you one more plot 1529 01:39:56,720 --> 01:39:58,920 that I usually look at when I train neural networks. 1530 01:39:59,120 --> 01:40:01,720 And basically the gradient to data ratio 1531 01:40:01,920 --> 01:40:03,120 is not actually that informative 1532 01:40:03,320 --> 01:40:04,320 because what matters at the end 1533 01:40:04,520 --> 01:40:06,120 is not the gradient to data ratio, 1534 01:40:06,320 --> 01:40:08,320 but the update to the data ratio, 1535 01:40:08,520 --> 01:40:09,520 because that is the amount by which 1536 01:40:09,720 --> 01:40:12,920 we will actually change the data in these tensors. 1537 01:40:13,120 --> 01:40:14,320 So coming up here, 1538 01:40:14,520 --> 01:40:15,720 what I'd like to do is I'd like to introduce 1539 01:40:15,920 --> 01:40:19,720 a new update to data ratio. 1540 01:40:19,920 --> 01:40:20,820 It's going to be less than, 1541 01:40:21,020 --> 01:40:23,120 I'm going to build it out every single iteration. 1542 01:40:23,320 --> 01:40:25,720 And here I'd like to keep track of basically 1543 01:40:25,920 --> 01:40:29,920 the ratio every single iteration. 1544 01:40:30,120 --> 01:40:33,520 So without any gradients, 1545 01:40:33,720 --> 01:40:34,920 I'm comparing the update, 1546 01:40:35,120 --> 01:40:38,920 which is learning rate times the gradient. 1547 01:40:39,120 --> 01:40:40,320 That is the update that we're going to apply 1548 01:40:40,520 --> 01:40:42,520 to every parameter. 1549 01:40:42,720 --> 01:40:44,520 So see I'm iterating over all the parameters. 1550 01:40:44,720 --> 01:40:46,320 And then I'm taking the basically standard deviation 1551 01:40:46,520 --> 01:40:48,120 of the update we're going to apply 1552 01:40:48,320 --> 01:40:49,720 and divide it 1553 01:40:49,820 --> 01:40:52,420 by the actual content, 1554 01:40:52,620 --> 01:40:56,020 the data of that parameter and its standard deviation. 1555 01:40:56,220 --> 01:40:58,020 So this is the ratio of basically 1556 01:40:58,220 --> 01:41:02,220 how great are the updates to the values in these tensors. 1557 01:41:02,420 --> 01:41:03,420 Then we're going to take a log of it. 1558 01:41:03,620 --> 01:41:07,320 And actually, I'd like to take a log 10 1559 01:41:07,520 --> 01:41:10,220 just so it's a nicer visualization. 1560 01:41:10,420 --> 01:41:12,320 So we're going to be basically looking at the exponents 1561 01:41:12,520 --> 01:41:16,620 of this division here 1562 01:41:16,820 --> 01:41:19,620 and then that item to pop out the float. 1563 01:41:19,920 --> 01:41:21,020 I'm going to be keeping track of this 1564 01:41:21,220 --> 01:41:24,420 for all the parameters and adding it to this UD tensor. 1565 01:41:24,620 --> 01:41:27,920 So now let me re-initialize and run a thousand iterations. 1566 01:41:28,120 --> 01:41:30,920 We can look at the activations, 1567 01:41:31,120 --> 01:41:33,220 the gradients and the parameter gradients 1568 01:41:33,420 --> 01:41:34,520 as we did before. 1569 01:41:34,720 --> 01:41:37,820 But now I have one more plot here to introduce. 1570 01:41:38,020 --> 01:41:39,520 And what's happening here is we're iterating over 1571 01:41:39,720 --> 01:41:42,120 all the parameters and I'm constraining it again, 1572 01:41:42,320 --> 01:41:44,920 like I did here, to just the weights. 1573 01:41:45,120 --> 01:41:48,120 So the number of dimensions in these sensors is two. 1574 01:41:48,320 --> 01:41:49,420 And then I'm basically plotting 1575 01:41:49,420 --> 01:41:54,320 all of these update ratios over time. 1576 01:41:54,520 --> 01:41:56,520 So when I plot this, 1577 01:41:56,720 --> 01:41:59,320 I plot those ratios and you can see that they evolve over time 1578 01:41:59,520 --> 01:42:01,820 during initialization to take on certain values. 1579 01:42:02,020 --> 01:42:04,020 And then these updates sort of like start stabilizing 1580 01:42:04,220 --> 01:42:05,820 usually during training. 1581 01:42:06,020 --> 01:42:07,220 Then the other thing that I'm plotting here 1582 01:42:07,420 --> 01:42:09,220 is I'm plotting here like an approximate value 1583 01:42:09,420 --> 01:42:12,720 that is a rough guide for what it roughly should be. 1584 01:42:12,920 --> 01:42:15,320 And it should be like roughly one in negative three. 1585 01:42:15,520 --> 01:42:19,320 And so that means that basically there's some values in this tensor. 1586 01:42:19,520 --> 01:42:21,620 And they take on certain values. 1587 01:42:21,820 --> 01:42:24,020 And the updates to them at every single iteration 1588 01:42:24,220 --> 01:42:26,720 are no more than roughly one thousandth 1589 01:42:26,920 --> 01:42:30,720 of the actual like magnitude in those tensors. 1590 01:42:30,920 --> 01:42:33,220 If this was much larger, like for example, 1591 01:42:33,420 --> 01:42:37,520 if the log of this was like say negative one, 1592 01:42:37,720 --> 01:42:39,920 this is actually updating those values quite a lot. 1593 01:42:40,120 --> 01:42:42,020 They're undergoing a lot of change. 1594 01:42:42,220 --> 01:42:46,420 But the reason that the final layer here is an outlier 1595 01:42:46,620 --> 01:42:49,220 is because this layer was artificially 1596 01:42:49,520 --> 01:42:54,320 struck down to keep the softmax unconfident. 1597 01:42:54,520 --> 01:42:59,220 So here you see how we multiply the weight by 0.1 1598 01:42:59,420 --> 01:43:04,020 in the initialization to make the last layer prediction less confident. 1599 01:43:04,220 --> 01:43:09,220 That artificially made the values inside that tensor way too low. 1600 01:43:09,420 --> 01:43:12,020 And that's why we're getting temporarily a very high ratio. 1601 01:43:12,220 --> 01:43:14,020 But you see that that stabilizes over time 1602 01:43:14,220 --> 01:43:17,820 once that weight starts to learn. 1603 01:43:18,020 --> 01:43:19,320 But basically, I like to look at the evolution. 1604 01:43:19,520 --> 01:43:23,020 Of this update ratio for all my parameters usually. 1605 01:43:23,220 --> 01:43:29,420 And I like to make sure that it's not too much above one in negative three roughly. 1606 01:43:29,620 --> 01:43:32,820 So around negative three on this log plot. 1607 01:43:33,020 --> 01:43:34,020 If it's below negative three, 1608 01:43:34,220 --> 01:43:37,220 usually that means that the parameters are not training fast enough. 1609 01:43:37,420 --> 01:43:38,820 So if our learning rate was very low, 1610 01:43:39,020 --> 01:43:41,520 let's do that experiment. 1611 01:43:41,720 --> 01:43:43,020 Let's initialize. 1612 01:43:43,220 --> 01:43:47,320 And then let's actually do a learning rate of say one in negative three here. 1613 01:43:47,520 --> 01:43:49,320 So 0.001. 1614 01:43:49,620 --> 01:43:53,620 If your learning rate is way too low, 1615 01:43:53,820 --> 01:43:56,220 this plot will typically reveal it. 1616 01:43:56,420 --> 01:44:00,120 So you see how all of these updates are way too small. 1617 01:44:00,320 --> 01:44:06,220 So the size of the update is basically 10,000 times 1618 01:44:06,420 --> 01:44:10,520 in magnitude to the size of the numbers in that tensor in the first place. 1619 01:44:10,720 --> 01:44:14,420 So this is a symptom of training way too slow. 1620 01:44:14,620 --> 01:44:16,820 So this is another way to sometimes set the learning rate 1621 01:44:17,020 --> 01:44:19,320 and to get a sense of what that learning rate should be. 1622 01:44:19,520 --> 01:44:23,820 So ultimately, this is something that you would keep track of. 1623 01:44:25,020 --> 01:44:29,320 If anything, the learning rate here is a little bit on the higher side 1624 01:44:29,520 --> 01:44:33,920 because you see that we're above the black line of negative three. 1625 01:44:34,120 --> 01:44:35,720 We're somewhere around negative 2.5. 1626 01:44:35,920 --> 01:44:36,820 It's like, OK. 1627 01:44:37,020 --> 01:44:39,620 And but everything is like somewhat stabilizing. 1628 01:44:39,820 --> 01:44:43,820 And so this looks like a pretty decent setting of learning rates and so on. 1629 01:44:44,020 --> 01:44:45,220 But this is something to look at. 1630 01:44:45,420 --> 01:44:48,220 And when things are miscalibrated, you will see very quickly. 1631 01:44:48,420 --> 01:44:49,320 So for example, 1632 01:44:49,520 --> 01:44:52,020 everything looks pretty well behaved, right? 1633 01:44:52,220 --> 01:44:55,020 But just as a comparison, when things are not properly calibrated, 1634 01:44:55,220 --> 01:44:56,220 what does that look like? 1635 01:44:56,420 --> 01:45:01,520 Let me come up here and let's say that, for example, what do we do? 1636 01:45:01,720 --> 01:45:05,720 Let's say that we forgot to apply this fan in normalization. 1637 01:45:05,920 --> 01:45:07,220 So the weights inside the linear layers 1638 01:45:07,420 --> 01:45:10,520 are just a sample from a Gaussian in all the stages. 1639 01:45:10,720 --> 01:45:14,220 What happens to our how do we notice that something's off? 1640 01:45:14,420 --> 01:45:18,520 Well, the activation plot will tell you, whoa, your neurons are way too saturated. 1641 01:45:18,620 --> 01:45:21,220 The gradients are going to be all messed up. 1642 01:45:21,420 --> 01:45:25,020 The histogram for these weights are going to be all messed up as well. 1643 01:45:25,220 --> 01:45:26,820 And there's a lot of asymmetry. 1644 01:45:27,020 --> 01:45:30,420 And then if we look here, I suspect it's all going to be also pretty messed up. 1645 01:45:30,620 --> 01:45:36,220 So you see, there's a lot of discrepancy in how fast these layers are learning. 1646 01:45:36,420 --> 01:45:38,320 And some of them are learning way too fast. 1647 01:45:38,520 --> 01:45:41,320 So negative one, negative 1.5. 1648 01:45:41,520 --> 01:45:44,020 Those are very large numbers in terms of this ratio. 1649 01:45:44,220 --> 01:45:48,420 Again, you should be somewhere around negative three and not much more about that. 1650 01:45:48,620 --> 01:45:52,720 So this is how miscalibrations of your neural nets are going to manifest. 1651 01:45:52,920 --> 01:45:58,220 And these kinds of plots here are a good way of sort of bringing 1652 01:45:58,420 --> 01:46:02,520 those miscalibrations sort of to your attention. 1653 01:46:02,720 --> 01:46:04,220 And so you can address them. 1654 01:46:04,420 --> 01:46:08,020 OK, so far we've seen that when we have this linear 10-H sandwich, 1655 01:46:08,220 --> 01:46:11,420 we can actually precisely calibrate the gains and make the activations, 1656 01:46:11,620 --> 01:46:15,620 the gradients and the parameters and the updates all look pretty decent. 1657 01:46:15,820 --> 01:46:18,420 But it definitely feels a little bit like balancing 1658 01:46:18,620 --> 01:46:20,920 of a pencil on your finger. 1659 01:46:21,120 --> 01:46:25,620 And that's because this gain has to be very precisely calibrated. 1660 01:46:25,820 --> 01:46:29,720 So now let's introduce batch normalization layers into the fix, into the mix. 1661 01:46:29,920 --> 01:46:33,620 And let's see how that helps fix the problem. 1662 01:46:33,820 --> 01:46:38,120 So here I'm going to take the BatchNorm1D class 1663 01:46:38,320 --> 01:46:40,820 and I'm going to start placing it inside. 1664 01:46:41,020 --> 01:46:42,620 And as I mentioned before, 1665 01:46:42,820 --> 01:46:46,920 the standard typical place you would place it is between the linear layer. 1666 01:46:47,120 --> 01:46:48,420 So right after it. 1667 01:46:48,620 --> 01:46:50,120 So this is the non-linearity. 1668 01:46:50,320 --> 01:46:52,320 But people have definitely played with that. 1669 01:46:52,520 --> 01:46:55,320 And in fact, you can get very similar results, 1670 01:46:55,520 --> 01:46:58,320 even if you place it after the non-linearity. 1671 01:46:58,520 --> 01:47:02,320 And the other thing that I wanted to mention is it's totally fine to also place it at the end 1672 01:47:02,520 --> 01:47:05,320 after the last linear layer and before the last function. 1673 01:47:05,520 --> 01:47:09,320 So this is potentially fine as well. 1674 01:47:09,520 --> 01:47:14,320 And in this case, this would be output, would be vocab size. 1675 01:47:14,520 --> 01:47:17,120 Now, because the last layer is BatchNorm, 1676 01:47:17,220 --> 01:47:20,720 we would not be changing the weight to make the softmax less confident. 1677 01:47:20,920 --> 01:47:23,120 We'd be changing the gamma. 1678 01:47:23,320 --> 01:47:25,620 Because gamma, remember, in the BatchNorm, 1679 01:47:25,820 --> 01:47:32,720 is the variable that multiplicatively interacts with the output of that normalization. 1680 01:47:32,920 --> 01:47:35,920 So we can initialize this sandwich now. 1681 01:47:36,120 --> 01:47:37,220 We can train. 1682 01:47:37,420 --> 01:47:41,720 And we can see that the activations are going to, of course, look very good. 1683 01:47:41,920 --> 01:47:46,820 And they are going to necessarily look good because now before every single 10H layer, 1684 01:47:46,820 --> 01:47:49,020 there is a normalization in the BatchNorm. 1685 01:47:49,220 --> 01:47:52,820 So this is, unsurprisingly, all looks pretty good. 1686 01:47:53,020 --> 01:47:56,420 It's going to be standard deviation of roughly 0.65, 2%, 1687 01:47:56,620 --> 01:47:59,620 and roughly equal standard deviation throughout the entire layers. 1688 01:47:59,820 --> 01:48:02,520 So everything looks very homogeneous. 1689 01:48:02,720 --> 01:48:04,520 The gradients look good. 1690 01:48:04,720 --> 01:48:09,020 The weights look good in their distributions. 1691 01:48:09,220 --> 01:48:14,020 And then the updates also look pretty reasonable. 1692 01:48:14,220 --> 01:48:16,720 We're going above negative three a little bit, 1693 01:48:16,820 --> 01:48:18,020 but not by too much. 1694 01:48:18,220 --> 01:48:24,820 So all the parameters are training at roughly the same rate here. 1695 01:48:25,020 --> 01:48:32,120 But now what we've gained is we are going to be slightly less brittle 1696 01:48:32,320 --> 01:48:34,320 with respect to the gain of these. 1697 01:48:34,520 --> 01:48:39,320 So, for example, I can make the gain be, say, 0.2 here, 1698 01:48:39,520 --> 01:48:43,120 which is much, much slower than what we had with the 10H. 1699 01:48:43,320 --> 01:48:44,920 But as we'll see, the activations will actually be exactly unaffected. 1700 01:48:45,120 --> 01:48:46,520 But as we'll see, the activations will actually be exactly unaffected. 1701 01:48:46,820 --> 01:48:49,520 And that's because of, again, this explicit normalization. 1702 01:48:49,720 --> 01:48:51,220 The gradients are going to look okay. 1703 01:48:51,420 --> 01:48:53,720 The weight gradients are going to look okay. 1704 01:48:53,920 --> 01:48:56,720 But actually, the updates will change. 1705 01:48:56,920 --> 01:49:00,620 And so even though the forward and backward pass, to a very large extent, 1706 01:49:00,820 --> 01:49:03,520 look okay because of the backward pass of the BatchNorm 1707 01:49:03,720 --> 01:49:05,820 and how the scale of the incoming activations 1708 01:49:06,020 --> 01:49:10,120 interacts in the BatchNorm and its backward pass, 1709 01:49:10,320 --> 01:49:15,820 this is actually changing the scale of the updates on these parameters. 1710 01:49:15,820 --> 01:49:18,720 So the gradients of these weights are affected. 1711 01:49:18,920 --> 01:49:24,220 So we still don't get a completely free pass to pass in arbitrary weights here, 1712 01:49:24,420 --> 01:49:28,020 but everything else is significantly more robust 1713 01:49:28,220 --> 01:49:32,620 in terms of the forward, backward, and the weight gradients. 1714 01:49:32,820 --> 01:49:35,020 It's just that you may have to retune your learning rate 1715 01:49:35,220 --> 01:49:39,420 if you are changing sufficiently the scale of the activations 1716 01:49:39,620 --> 01:49:40,820 that are coming into the BatchNorms. 1717 01:49:41,020 --> 01:49:44,720 So here, for example, we changed the gains 1718 01:49:44,720 --> 01:49:47,020 of these linear layers to be greater, 1719 01:49:47,220 --> 01:49:51,520 and we're seeing that the updates are coming out lower as a result. 1720 01:49:51,720 --> 01:49:54,420 And then finally, we can also, if we are using BatchNorms, 1721 01:49:54,620 --> 01:49:56,520 we don't actually need to necessarily... 1722 01:49:56,720 --> 01:49:59,020 Let me reset this to 1 so there's no gain. 1723 01:49:59,220 --> 01:50:03,320 We don't necessarily even have to normalize by fan-in sometimes. 1724 01:50:03,520 --> 01:50:08,020 So if I take out the fan-in, so these are just now random Gaussian, 1725 01:50:08,220 --> 01:50:11,620 we'll see that because of BatchNorm, this will actually be relatively well-behaved. 1726 01:50:11,820 --> 01:50:13,620 So... 1727 01:50:14,720 --> 01:50:17,320 The statistics look, of course, in the forward pass look good. 1728 01:50:17,520 --> 01:50:19,620 The gradients look good. 1729 01:50:19,820 --> 01:50:23,520 The backward weight updates look okay. 1730 01:50:23,720 --> 01:50:26,420 A little bit of fat tails on some of the layers. 1731 01:50:26,620 --> 01:50:29,020 And this looks okay as well. 1732 01:50:29,220 --> 01:50:33,320 But as you can see, we're significantly below negative 3, 1733 01:50:33,520 --> 01:50:36,320 so we'd have to bump up the learning rate of this BatchNorm 1734 01:50:36,520 --> 01:50:38,820 so that we are training more properly. 1735 01:50:39,020 --> 01:50:40,420 And in particular, looking at this, 1736 01:50:40,620 --> 01:50:43,020 roughly looks like we have to 10x the learning rate 1737 01:50:43,220 --> 01:50:44,620 to get to about 1e-3. 1738 01:50:44,820 --> 01:50:50,620 So we'd come here and we would change this to be update of 1.0. 1739 01:50:50,820 --> 01:50:52,820 And if I re-initialize... 1740 01:50:53,020 --> 01:51:01,020 Then we'll see that everything still, of course, looks good. 1741 01:51:01,220 --> 01:51:04,020 And now we are roughly here. 1742 01:51:04,220 --> 01:51:06,220 And we expect this to be an okay training run. 1743 01:51:06,420 --> 01:51:09,320 So long story short, we are significantly more robust 1744 01:51:09,520 --> 01:51:11,320 to the gain of these linear layers, 1745 01:51:11,520 --> 01:51:13,320 whether or not we have to apply the fan-in. 1746 01:51:13,520 --> 01:51:14,520 And then... 1747 01:51:14,820 --> 01:51:16,020 We can change the gain, 1748 01:51:16,220 --> 01:51:20,320 but we actually do have to worry a little bit about the update scales 1749 01:51:20,520 --> 01:51:23,920 and making sure that the learning rate is properly calibrated here. 1750 01:51:24,120 --> 01:51:27,420 But the activations of the forward-backward pass and the updates 1751 01:51:27,620 --> 01:51:30,120 are looking significantly more well-behaved, 1752 01:51:30,320 --> 01:51:34,520 except for the global scale that is potentially being adjusted here. 1753 01:51:34,720 --> 01:51:36,320 Okay, so now let me summarize. 1754 01:51:36,520 --> 01:51:39,320 There are three things I was hoping to achieve with this section. 1755 01:51:39,520 --> 01:51:42,120 Number one, I wanted to introduce you to BatchNormalization, 1756 01:51:42,320 --> 01:51:44,520 which is one of the first modern innovations 1757 01:51:44,520 --> 01:51:48,320 that we're looking into that helped stabilize very deep neural networks 1758 01:51:48,520 --> 01:51:49,520 and their training. 1759 01:51:49,720 --> 01:51:52,320 And I hope you understand how the BatchNormalization works 1760 01:51:52,520 --> 01:51:55,920 and how it would be used in a neural network. 1761 01:51:56,120 --> 01:51:59,120 Number two, I was hoping to PyTorchify some of our code 1762 01:51:59,320 --> 01:52:01,720 and wrap it up into these modules. 1763 01:52:01,920 --> 01:52:04,720 So like Linear, BatchNorm1D, 10H, etc. 1764 01:52:04,920 --> 01:52:06,620 These are layers or modules, 1765 01:52:06,820 --> 01:52:10,720 and they can be stacked up into neural nets like Lego building blocks. 1766 01:52:10,920 --> 01:52:14,420 And these layers actually exist in PyTorch, 1767 01:52:14,620 --> 01:52:16,520 and if you import Torch NN, 1768 01:52:16,720 --> 01:52:19,120 then you can actually, the way I've constructed it, 1769 01:52:19,320 --> 01:52:22,520 you can simply just use PyTorch by prepending NN. 1770 01:52:22,720 --> 01:52:24,720 to all these different layers. 1771 01:52:24,920 --> 01:52:27,520 And actually everything will just work 1772 01:52:27,720 --> 01:52:29,720 because the API that I've developed here 1773 01:52:29,920 --> 01:52:32,320 is identical to the API that PyTorch uses. 1774 01:52:32,520 --> 01:52:34,720 And the implementation also is basically, 1775 01:52:34,920 --> 01:52:37,920 as far as I'm aware, identical to the one in PyTorch. 1776 01:52:38,120 --> 01:52:41,120 And number three, I tried to introduce you to the diagnostic tools 1777 01:52:41,320 --> 01:52:44,320 that you would use to understand whether your neural network 1778 01:52:44,320 --> 01:52:46,120 is in a good state dynamically. 1779 01:52:46,320 --> 01:52:48,920 So we are looking at the statistics and histograms 1780 01:52:49,120 --> 01:52:52,120 and activation of the forward pass activations, 1781 01:52:52,320 --> 01:52:53,920 the backward pass gradients, 1782 01:52:54,120 --> 01:52:56,920 and then also we're looking at the weights that are going to be updated 1783 01:52:57,120 --> 01:52:58,920 as part of stochastic gradient ascent, 1784 01:52:59,120 --> 01:53:01,120 and we're looking at their means, standard deviations, 1785 01:53:01,320 --> 01:53:04,520 and also the ratio of gradients to data, 1786 01:53:04,720 --> 01:53:07,720 or even better, the updates to data. 1787 01:53:07,920 --> 01:53:10,320 And we saw that typically we don't actually look at it 1788 01:53:10,520 --> 01:53:13,720 as a single snapshot frozen in time at some particular iteration. 1789 01:53:13,720 --> 01:53:17,520 Typically, people look at this as over time, just like I've done here. 1790 01:53:17,720 --> 01:53:19,520 And they look at these update to data ratios 1791 01:53:19,720 --> 01:53:21,320 and they make sure everything looks OK. 1792 01:53:21,520 --> 01:53:25,120 And in particular, I said that one in negative three 1793 01:53:25,320 --> 01:53:27,320 or basically negative three on the log scale 1794 01:53:27,520 --> 01:53:31,520 is a good rough heuristic for what you want this ratio to be. 1795 01:53:31,720 --> 01:53:34,120 And if it's way too high, then probably the learning rate 1796 01:53:34,320 --> 01:53:36,320 or the updates are a little too big. 1797 01:53:36,520 --> 01:53:39,520 And if it's way too small, then the learning rate is probably too small. 1798 01:53:39,720 --> 01:53:42,320 So that's just some of the things that you may want to play with 1799 01:53:42,320 --> 01:53:46,720 when you try to get your neural network to work very well. 1800 01:53:46,920 --> 01:53:49,120 Now, there's a number of things I did not try to achieve. 1801 01:53:49,320 --> 01:53:51,120 I did not try to beat our previous performance, 1802 01:53:51,320 --> 01:53:53,920 as an example, by introducing the BatchNorm layer. 1803 01:53:54,120 --> 01:53:56,920 Actually, I did try and I found that I used 1804 01:53:57,120 --> 01:53:59,720 the learning rate finding mechanism that I've described before. 1805 01:53:59,920 --> 01:54:03,120 I tried to train the BatchNorm layer, a BatchNorm neural net. 1806 01:54:03,320 --> 01:54:05,720 And I actually ended up with results that are very, 1807 01:54:05,920 --> 01:54:08,120 very similar to what we've obtained before. 1808 01:54:08,320 --> 01:54:11,920 And that's because our performance now is not bottlenecked by 1809 01:54:12,320 --> 01:54:14,920 optimization, which is what BatchNorm is helping with. 1810 01:54:15,120 --> 01:54:18,520 The performance at this stage is bottlenecked by what I suspect is 1811 01:54:18,720 --> 01:54:21,720 the context length of our context. 1812 01:54:21,920 --> 01:54:24,520 So currently we are taking three characters to predict the fourth one. 1813 01:54:24,720 --> 01:54:26,120 And I think we need to go beyond that. 1814 01:54:26,320 --> 01:54:29,520 And we need to look at more powerful architectures like recurring neural 1815 01:54:29,720 --> 01:54:32,800 networks and transformers in order to further push 1816 01:54:33,000 --> 01:54:36,200 the log probabilities that we're achieving on this dataset. 1817 01:54:36,400 --> 01:54:41,920 And I also did not try to have a full explanation of all of these activations, 1818 01:54:42,320 --> 01:54:44,920 and the backward pass, and the statistics of all these gradients. 1819 01:54:45,120 --> 01:54:47,720 And so you may have found some of the parts here unintuitive. 1820 01:54:47,920 --> 01:54:51,720 And maybe you were slightly confused about, okay, if I change the gain here, 1821 01:54:51,920 --> 01:54:53,720 how come that we need a different learning rate? 1822 01:54:53,920 --> 01:54:56,520 And I didn't go into the full detail because you'd have to actually look 1823 01:54:56,720 --> 01:54:59,320 at the backward pass of all these different layers and get an intuitive 1824 01:54:59,520 --> 01:55:00,920 understanding of how that works. 1825 01:55:01,120 --> 01:55:03,520 And I did not go into that in this lecture. 1826 01:55:03,720 --> 01:55:07,320 The purpose really was just to introduce you to the diagnostic tools and what 1827 01:55:07,520 --> 01:55:10,720 they look like, but there's still a lot of work remaining on the intuitive level 1828 01:55:10,920 --> 01:55:12,120 to understand the initialization. 1829 01:55:12,320 --> 01:55:14,920 The backward pass and how all of that interacts. 1830 01:55:15,120 --> 01:55:18,120 But you shouldn't feel too bad because honestly, 1831 01:55:18,320 --> 01:55:22,320 we are getting to the cutting edge of where the field is. 1832 01:55:22,520 --> 01:55:25,520 We certainly haven't, I would say, solved initialization. 1833 01:55:25,720 --> 01:55:27,720 And we haven't solved back propagation. 1834 01:55:27,920 --> 01:55:30,520 And these are still very much an active area of research. 1835 01:55:30,520 --> 01:55:33,120 People are still trying to figure out what is the best way to initialize these 1836 01:55:33,320 --> 01:55:37,320 networks, what is the best update rule to use, and so on. 1837 01:55:37,520 --> 01:55:40,920 So none of this is really solved, and we don't really have all the answers to all 1838 01:55:41,120 --> 01:55:42,120 the... 1839 01:55:42,320 --> 01:55:46,320 You know, all these cases, but at least, you know, we're making progress and at 1840 01:55:46,520 --> 01:55:49,320 least we have some tools to tell us whether or not things are on the right 1841 01:55:49,520 --> 01:55:51,520 track for now. 1842 01:55:51,720 --> 01:55:55,720 So I think we've made positive progress in this lecture, and I hope you enjoyed 1843 01:55:55,720 --> 01:55:56,920 that, and I will see you next time.