1 00:00:00,000 --> 00:00:04,200 Hi everyone. So today we are once again continuing our implementation of MakeMore. 2 00:00:04,980 --> 00:00:10,560 Now so far we've come up to here, multilayer perceptrons, and our neural net looked like this, 3 00:00:10,800 --> 00:00:14,880 and we were implementing this over the last few lectures. Now I'm sure everyone is very excited 4 00:00:14,880 --> 00:00:19,340 to go into recurrent neural networks and all of their variants and how they work, and the diagrams 5 00:00:19,340 --> 00:00:22,180 look cool and it's very exciting and interesting, and we're going to get a better result. 6 00:00:22,760 --> 00:00:28,400 But unfortunately I think we have to remain here for one more lecture. And the reason for that is 7 00:00:28,400 --> 00:00:32,300 we've already trained this multilayer perceptron, right, and we are getting pretty good loss, 8 00:00:32,520 --> 00:00:35,880 and I think we have a pretty decent understanding of the architecture and how it works. 9 00:00:36,360 --> 00:00:42,240 But the line of code here that I take an issue with is here, loss.backward. That is, we are 10 00:00:42,240 --> 00:00:48,080 taking PyTorch autograd and using it to calculate all of our gradients along the way. And I would 11 00:00:48,080 --> 00:00:52,580 like to remove the use of loss.backward, and I would like us to write our backward pass manually 12 00:00:52,580 --> 00:00:57,940 on the level of tensors. And I think that this is a very useful exercise for the following reasons. 13 00:00:58,400 --> 00:01:04,040 I actually have an entire blog post on this topic, but I'd like to call backpropagation a leaky 14 00:01:04,040 --> 00:01:09,240 abstraction. And what I mean by that is backpropagation doesn't just make your neural 15 00:01:09,240 --> 00:01:13,560 networks just work magically. It's not the case that you can just stack up arbitrary Lego blocks 16 00:01:13,560 --> 00:01:17,820 of differentiable functions and just cross your fingers and backpropagate and everything is great. 17 00:01:18,760 --> 00:01:22,660 Things don't just work automatically. It is a leaky abstraction in the sense that 18 00:01:22,660 --> 00:01:28,140 you can shoot yourself in the foot if you do not understand its internals. It will magically not 19 00:01:28,400 --> 00:01:33,380 work or not work optimally. And you will need to understand how it works under the hood if you're 20 00:01:33,380 --> 00:01:39,200 hoping to debug it and if you are hoping to address it in your neural net. So this blog post 21 00:01:39,200 --> 00:01:43,580 here from a while ago goes into some of those examples. So for example, we've already covered 22 00:01:43,580 --> 00:01:49,520 them, some of them already. For example, the flat tails of these functions and how you do not want 23 00:01:49,520 --> 00:01:54,980 to saturate them too much because your gradients will die. The case of dead neurons, which I've 24 00:01:54,980 --> 00:01:58,160 already covered as well. The case of exploding or 25 00:01:58,400 --> 00:02:01,880 exploding gradients in the case of recurring neural networks, which we are about to cover. 26 00:02:02,840 --> 00:02:08,780 And then also you will often come across some examples in the wild. This is a snippet that I 27 00:02:08,780 --> 00:02:14,240 found in a random code base on the internet where they actually have like a very subtle but pretty 28 00:02:14,240 --> 00:02:19,880 major bug in their implementation. And the bug points at the fact that the author of this code 29 00:02:19,880 --> 00:02:23,420 does not actually understand backpropagation. So what they're trying to do here is they're trying 30 00:02:23,420 --> 00:02:28,040 to clip the loss at a certain maximum value. But actually what they're trying to do is they're 31 00:02:28,400 --> 00:02:32,360 trying to clip the gradients to have a maximum value instead of trying to clip the loss at a 32 00:02:32,360 --> 00:02:38,240 maximum value. And indirectly, they're basically causing some of the outliers to be actually 33 00:02:38,240 --> 00:02:44,300 ignored. Because when you clip the loss of an outlier, you are setting its gradient to 0. 34 00:02:44,840 --> 00:02:49,760 And so have a look through this and read through it. But there's basically a bunch of subtle 35 00:02:49,760 --> 00:02:53,600 issues that you're going to avoid if you actually know what you're doing. And that's why I don't 36 00:02:53,600 --> 00:02:58,380 think it's the case that because PyTorch or other frameworks offer autograd, it is okay for us to do. 37 00:02:58,400 --> 00:03:05,480 ignore how it works. Now, we've actually already covered autograd and we wrote micrograd, but 38 00:03:05,480 --> 00:03:10,500 micrograd was an autograd engine only on the level of individual scalars. So the atoms were single 39 00:03:10,500 --> 00:03:15,000 individual numbers. And, you know, I don't think it's enough. And I'd like us to basically think 40 00:03:15,000 --> 00:03:19,420 about backpropagation on the level of tensors as well. And so in a summary, I think it's a good 41 00:03:19,420 --> 00:03:24,760 exercise. I think it is very, very valuable. You're going to become better at debugging neural 42 00:03:24,760 --> 00:03:29,180 networks and making sure that you understand what you're doing. It is going to make everything 43 00:03:29,180 --> 00:03:33,460 fully explicit. So you're not going to be nervous about what is hidden away from you. And basically 44 00:03:33,460 --> 00:03:39,100 in general, we're going to emerge stronger. And so let's get into it. A bit of a fun historical note 45 00:03:39,100 --> 00:03:44,020 here is that today writing your backward pass by hand and manually is not recommended and no one 46 00:03:44,020 --> 00:03:49,000 does it except for the purposes of exercise. But about 10 years ago in deep learning, this was 47 00:03:49,000 --> 00:03:53,740 fairly standard and in fact pervasive. So at the time, everyone used to write their backward pass 48 00:03:53,740 --> 00:03:54,600 by hand manually. 49 00:03:54,760 --> 00:03:59,940 Including myself. And it's just what you would do. So we used to write backward pass by hand. And now 50 00:03:59,940 --> 00:04:06,100 everyone just calls lost that backward. We've lost something. I want to give you a few examples of 51 00:04:06,100 --> 00:04:13,480 this. So here's a 2006 paper from Jeff Hinton and Ruslan Slakhtinov in science that was 52 00:04:13,480 --> 00:04:18,800 influential at the time. And this was training some architectures called restricted Boltzmann 53 00:04:18,800 --> 00:04:24,580 machines. And basically, it's an autoencoder trained here. And this is from roughly 54 00:04:24,580 --> 00:04:24,740 2000. 55 00:04:24,760 --> 00:04:30,100 In 2010, I had a library for training restricted Boltzmann machines. And this was at the time 56 00:04:30,100 --> 00:04:34,980 written in Matlab. So Python was not used for deep learning pervasively. It was all Matlab. And 57 00:04:34,980 --> 00:04:41,060 Matlab was this scientific computing package that everyone would use. So we would write Matlab, 58 00:04:41,060 --> 00:04:46,740 which is barely a programming language as well. But it had a very convenient tensor class. 59 00:04:46,740 --> 00:04:50,420 And it was this computing environment and you would run here. It would all run on the CPU, 60 00:04:50,420 --> 00:04:54,580 of course. But you would have very nice plots to go with it and a built-in debugger. And it was 61 00:04:54,580 --> 00:05:01,080 pretty nice. Now, the code in this package in 2010 that I wrote for fitting restricted Boltzmann 62 00:05:01,080 --> 00:05:06,280 machines to a large extent is recognizable. But I wanted to show you how you would... Well, 63 00:05:06,280 --> 00:05:11,880 I'm creating the data in the XY batches. I'm initializing the neural net. So it's got weights 64 00:05:11,880 --> 00:05:16,440 and biases just like we're used to. And then this is the training loop where we actually do the 65 00:05:16,440 --> 00:05:21,200 forward pass. And then here, at this time, they didn't even necessarily use back propagation to 66 00:05:21,200 --> 00:05:24,400 train neural networks. So this, in particular, implements a lot of the training that we're doing. 67 00:05:24,400 --> 00:05:29,920 It implements contrastive divergence, which estimates a gradient. And then here, we take 68 00:05:29,920 --> 00:05:35,860 that gradient and use it for a parameter update along the lines that we're used to. Yeah, here. 69 00:05:35,860 --> 00:05:41,200 But you can see that basically people are meddling with these gradients directly and inline and 70 00:05:41,200 --> 00:05:45,640 themselves. It wasn't that common to use an autograd engine. Here's one more example from a 71 00:05:45,640 --> 00:05:51,640 paper of mine from 2014 called Deep Fragment Embeddings. And here, what I was doing is I was 72 00:05:51,640 --> 00:05:53,020 aligning images and text. 73 00:05:54,400 --> 00:06:19,520 And here, I'm implementing the cost function. And it was standard to implement not just the cost, 74 00:06:19,520 --> 00:06:24,200 but also the backward pass manually. So here, I'm calculating the image embeddings, 75 00:06:24,200 --> 00:06:24,380 and I'm implementing the cost function. And here, I'm implementing the backward pass manually. 76 00:06:24,400 --> 00:06:31,260 Sentence embeddings, I calculate the scores. This is the loss function. And then once I have 77 00:06:31,260 --> 00:06:35,820 the loss function, I do the backward pass right here. So I backward through the loss function 78 00:06:35,820 --> 00:06:41,480 and through the neural net, and I append regularization. So everything was done by hand 79 00:06:41,480 --> 00:06:45,000 manually, and you would just write out the backward pass. And then you would use a gradient 80 00:06:45,000 --> 00:06:49,340 checker to make sure that your numerical estimate of the gradient agrees with the one you calculated 81 00:06:49,340 --> 00:06:53,820 during back propagation. So this was very standard for a long time. But today, of course, it is 82 00:06:53,820 --> 00:06:59,180 standard to use an autograd engine. But it was definitely useful, and I think people sort of 83 00:06:59,180 --> 00:07:03,180 understood how these neural networks work on a very intuitive level. And so I think it's a good 84 00:07:03,180 --> 00:07:07,020 exercise again, and this is where we want to be. Okay, so just as a reminder from our previous 85 00:07:07,020 --> 00:07:12,440 lecture, this is the Jupyter notebook that we implemented at the time. And we're going to keep 86 00:07:12,440 --> 00:07:16,540 everything the same. So we're still going to have a two-layer multi-layer perceptron with a batch 87 00:07:16,540 --> 00:07:21,260 normalization layer. So the forward pass will be basically identical to this lecture. But here, 88 00:07:21,260 --> 00:07:23,780 we're going to get rid of loss.backward. And instead, we're going to 89 00:07:23,820 --> 00:07:28,900 write the backward pass manually. Now, here's the starter code for this lecture. We are becoming 90 00:07:28,900 --> 00:07:35,020 a backprop ninja in this notebook. And the first few cells here are identical to what we are used 91 00:07:35,020 --> 00:07:40,040 to. So we are doing some imports, loading in the data set, and processing the data set. None of 92 00:07:40,040 --> 00:07:45,100 this changed. Now, here, I'm introducing a utility function that we're going to use later to compare 93 00:07:45,100 --> 00:07:49,020 the gradients. So in particular, we are going to have the gradients that we estimate manually 94 00:07:49,020 --> 00:07:53,780 ourselves. And we're going to have gradients that PyTorch calculates. And we're going to be 95 00:07:53,820 --> 00:07:56,720 checking for correctness, assuming, of course, that PyTorch is correct. 96 00:07:58,660 --> 00:08:03,900 Then here, we have the initialization that we are quite used to. So we have our embedding table for 97 00:08:03,900 --> 00:08:08,860 the characters, the first layer, second layer, and a batch normalization in between. And here's 98 00:08:08,860 --> 00:08:13,200 where we create all the parameters. Now, you will note that I changed the initialization a little 99 00:08:13,200 --> 00:08:18,600 bit to be small numbers. So normally, you would set the biases to be all zero. Here, I'm setting 100 00:08:18,600 --> 00:08:23,800 them to be small random numbers. And I'm doing this because if your variables are all zero, 101 00:08:23,820 --> 00:08:27,900 or initialized to exactly zero, sometimes what can happen is that can mask an incorrect 102 00:08:27,900 --> 00:08:33,200 implementation of a gradient. Because when everything is zero, it sort of like simplifies 103 00:08:33,200 --> 00:08:37,320 and gives you a much simpler expression of the gradient than you would otherwise get. And so by 104 00:08:37,320 --> 00:08:42,120 making it small numbers, I'm trying to unmask those potential errors in these calculations. 105 00:08:42,920 --> 00:08:48,820 You also notice that I'm using b1 in the first layer. I'm using a bias despite batch 106 00:08:48,820 --> 00:08:53,580 normalization right afterwards. So this would typically not be what you'd do because we talked 107 00:08:53,580 --> 00:08:53,800 about the bias. So I'm going to mask the bias. And I'm going to mask the bias. And I'm going to 108 00:08:53,820 --> 00:08:58,460 fact that you don't need a bias but i'm doing this here just for fun because we're going to 109 00:08:58,460 --> 00:09:02,220 have a gradient with respect to it and we can check that we are still calculating it correctly 110 00:09:02,220 --> 00:09:08,140 even though this bias is spurious so here i'm calculating a single batch and then here i am 111 00:09:08,140 --> 00:09:13,260 doing a forward pass now you'll notice that the forward pass is significantly expanded from what 112 00:09:13,260 --> 00:09:19,420 we are used to here the forward pass was just um here now the reason that the forward pass is 113 00:09:19,420 --> 00:09:25,020 longer is for two reasons number one here we just had an f dot cross entropy but here i am bringing 114 00:09:25,020 --> 00:09:30,940 back a explicit implementation of the loss function and number two i've broken up the 115 00:09:30,940 --> 00:09:37,020 implementation into manageable chunks so we have a lot a lot more intermediate tensors along the way 116 00:09:37,020 --> 00:09:41,420 in the forward pass and that's because we are about to go backwards and calculate the gradients 117 00:09:42,220 --> 00:09:48,380 in this back propagation from the bottom to the top so we're going to go upwards and just like we 118 00:09:48,380 --> 00:09:49,340 have for example the lockpick 119 00:09:49,420 --> 00:09:53,980 props tensor in a forward pass in a backward pass we're going to have a d lock props which is going 120 00:09:53,980 --> 00:09:58,300 to store the derivative of the loss with respect to the lock props tensor and so we're going to 121 00:09:58,300 --> 00:10:03,420 be prepending d to every one of these tensors and calculating it along the way of this back 122 00:10:03,420 --> 00:10:09,420 propagation so as an example we have a b in raw here we're going to be calculating a db in raw 123 00:10:10,140 --> 00:10:16,220 so here i'm telling pytorch that we want to retain the grad of all these intermediate values because 124 00:10:16,220 --> 00:10:19,260 here in exercise one we're going to calculate the backward pass 125 00:10:19,420 --> 00:10:24,860 so we're going to calculate all these d variable d variables and use the cmp function i've introduced 126 00:10:24,860 --> 00:10:29,900 above to check our correctness with respect to what pytorch is telling us this is going to be 127 00:10:29,900 --> 00:10:35,500 exercise one where we sort of back propagate through this entire graph now just to give you 128 00:10:35,500 --> 00:10:38,380 a very quick preview of what's going to happen in exercise two and below 129 00:10:39,580 --> 00:10:45,340 here we have fully broken up the loss and back propagated through it manually in all 130 00:10:45,340 --> 00:10:49,340 the little atomic pieces that make it up but here we're going to collapse the loss into 131 00:10:49,420 --> 00:10:54,300 a single cross entropy call and instead we're going to analytically derive using 132 00:10:54,940 --> 00:11:00,780 math and paper and pencil the gradient of the loss with respect to the logits and instead of 133 00:11:00,780 --> 00:11:04,780 back propagating through all of its little chunks one at a time we're just going to analytically 134 00:11:04,780 --> 00:11:08,780 derive what that gradient is and we're going to implement that which is much more efficient as 135 00:11:08,780 --> 00:11:14,060 we'll see in a bit then we're going to do the exact same thing for batch normalization so 136 00:11:14,060 --> 00:11:19,180 instead of breaking up bastion arm into all the little tiny components we're going to use pen and 137 00:11:19,420 --> 00:11:25,100 paper and mathematics and calculus to derive the gradient through the bachelor bathroom layer so 138 00:11:25,100 --> 00:11:29,660 we're going to calculate the backward pass through bathroom layer in a much more efficient expression 139 00:11:29,660 --> 00:11:32,780 instead of backward propagating through all of its little pieces independently 140 00:11:33,500 --> 00:11:38,780 so it's going to be exercise three and then in exercise four we're going to put it all together 141 00:11:38,780 --> 00:11:43,980 and this is the full code of training this two layer mlp and we're going to basically insert 142 00:11:43,980 --> 00:11:49,260 our manual backdrop and we're going to take out lost up backward and you will basically see 143 00:11:49,420 --> 00:11:55,980 that you can get all the same results using fully your own code and the only thing we're using from 144 00:11:55,980 --> 00:12:02,460 pytorch is the torch.tensor to make the calculations efficient but otherwise you will understand fully 145 00:12:02,460 --> 00:12:06,780 what it means to forward and backward the neural net and train it and i think that'll be awesome so 146 00:12:06,780 --> 00:12:13,100 let's get to it okay so i ran all the cells of this notebook all the way up to here and i'm going 147 00:12:13,100 --> 00:12:18,380 to erase this and i'm going to start implementing backward pass starting with d lock probes so we 148 00:12:19,420 --> 00:12:23,980 go here to calculate the gradient of the loss with respect to all the elements of the lock props 149 00:12:23,980 --> 00:12:29,260 tensor now i'm going to give away the answer here but i wanted to put a quick note here that 150 00:12:29,260 --> 00:12:34,380 i think would be most pedagogically useful for you is to actually go into the description of this 151 00:12:34,380 --> 00:12:38,940 video and find the link to this jupyter notebook you can find it both on github but you can also 152 00:12:38,940 --> 00:12:42,780 find google collab with it so you don't have to install anything you'll just go to a website on 153 00:12:42,780 --> 00:12:49,260 google collab and you can try to implement these derivatives or gradients yourself and then if you 154 00:12:49,420 --> 00:12:55,180 are not able to come to my video and see me do it and so work in tandem and try it first yourself 155 00:12:55,180 --> 00:12:59,660 and then see me give away the answer and i think that'll be most valuable to you and that's how i 156 00:12:59,660 --> 00:13:05,900 recommend you go through this lecture so we are starting here with d log props now d log props 157 00:13:05,900 --> 00:13:12,220 will hold the derivative of the loss with respect to all the elements of log props what is inside 158 00:13:12,220 --> 00:13:19,420 log blobs the shape of this is 32 by 27. so it's not going to surprise you that d log props should 159 00:13:19,420 --> 00:13:24,380 also be an array of size 32 by 27 because we want the derivative loss with respect to all of its 160 00:13:24,380 --> 00:13:32,620 elements so the sizes of those are always going to be equal now how how does log probes influence 161 00:13:32,620 --> 00:13:41,180 the loss okay loss is negative log probes indexed with range of n and yb and then the mean of that 162 00:13:41,740 --> 00:13:47,660 now just as a reminder yb is just basically an array of all the 163 00:13:48,780 --> 00:13:49,420 correct indexes of all the indexes of all the indexes of all the indexes of all the indexes of 164 00:13:49,420 --> 00:13:55,100 all the indexes so what we're doing here is we're taking the log props array of size 32 by 27 165 00:13:57,340 --> 00:14:04,220 right and then we are going in every single row and in each row we are plugging plucking out 166 00:14:04,220 --> 00:14:08,620 the index 8 and then 14 and 15 and so on so we're going down the rows 167 00:14:08,620 --> 00:14:13,340 that's the iterator range of n and then we are always plucking out the index and the 168 00:14:13,340 --> 00:14:19,020 column specified by this tensor yb so in the zeroth row we are taking the eighth column 169 00:14:19,580 --> 00:14:27,100 in the first row we're taking the 14th column etc and so log props at this plucks out all those 170 00:14:29,020 --> 00:14:34,300 log probabilities of the correct next character in a sequence so that's what that does and the 171 00:14:34,300 --> 00:14:41,500 shape of this or the size of it is of course 32 because our batch size is 32. so these elements 172 00:14:41,500 --> 00:14:48,620 get plucked out and then their mean and the negative of that becomes loss so i always like to 173 00:14:49,580 --> 00:14:55,740 examples to understand the numerical form of derivative what's going on here is once we've 174 00:14:55,740 --> 00:15:02,060 plucked out these examples um we're taking the mean and then the negative so the loss basically 175 00:15:02,780 --> 00:15:06,540 if i can write it this way is the negative of say a plus b plus c 176 00:15:07,900 --> 00:15:11,980 and the mean of those three numbers would be say negative would divide three that would be how we 177 00:15:11,980 --> 00:15:18,060 achieve the mean of three numbers a b c although we actually have 32 numbers here and so what is 178 00:15:19,420 --> 00:15:25,900 loss by say like da right well if we simplify this expression mathematically this is negative 179 00:15:25,900 --> 00:15:29,020 one over three of a and negative one plus negative one over three of b 180 00:15:30,940 --> 00:15:36,140 plus negative one over three of c and so what is d loss by d a it's just negative one over three 181 00:15:37,020 --> 00:15:40,620 and so you can see that if we don't just have a b and c but we have 32 numbers 182 00:15:41,340 --> 00:15:48,060 then d loss by d um you know every one of those numbers is going to be one over n more generally 183 00:15:49,420 --> 00:15:52,140 the size of the batch, 32 in this case. 184 00:15:53,140 --> 00:15:59,060 So DLoss by DLockProbs is negative one over N 185 00:15:59,060 --> 00:16:01,200 in all these places. 186 00:16:01,980 --> 00:16:04,620 Now, what about the other elements inside LockProbs? 187 00:16:04,760 --> 00:16:06,260 Because LockProbs is a large array. 188 00:16:06,260 --> 00:16:09,300 You see that LockProbs.shape is 32 by 27, 189 00:16:09,720 --> 00:16:13,640 but only 32 of them participate in the loss calculation. 190 00:16:14,240 --> 00:16:16,160 So what's the derivative of all the other, 191 00:16:16,460 --> 00:16:19,380 most of the elements that do not get blocked out here? 192 00:16:20,380 --> 00:16:21,880 Well, their loss intuitively is zero. 193 00:16:22,200 --> 00:16:24,400 Sorry, their gradient intuitively is zero. 194 00:16:24,780 --> 00:16:26,740 And that's because they did not participate in the loss. 195 00:16:27,260 --> 00:16:29,760 So most of these numbers inside this tensor 196 00:16:29,760 --> 00:16:31,560 does not feed into the loss. 197 00:16:31,940 --> 00:16:33,560 And so if we were to change these numbers, 198 00:16:34,020 --> 00:16:35,140 then the loss doesn't change, 199 00:16:35,200 --> 00:16:37,320 which is the equivalent of what I was saying, 200 00:16:37,820 --> 00:16:40,440 that the derivative of the loss with respect to them is zero. 201 00:16:40,740 --> 00:16:41,640 They don't impact it. 202 00:16:43,080 --> 00:16:46,100 So here's a way to implement this derivative then. 203 00:16:46,440 --> 00:16:49,380 We start out with Torch.zeros of shape 32. 204 00:16:49,580 --> 00:16:51,420 So we're going to set it to 32 by 27, 205 00:16:51,420 --> 00:16:55,180 or let's just say instead of doing this because we don't want to hard-code numbers, 206 00:16:55,180 --> 00:16:59,180 let's do Torch.zeros like LockProbs. 207 00:16:59,180 --> 00:17:03,180 So basically this is going to create an array of zeros exactly in the shape of LockProbs. 208 00:17:04,180 --> 00:17:09,180 And then we need to set the derivative of negative one over n inside exactly these locations. 209 00:17:09,180 --> 00:17:11,180 So here's what we can do. 210 00:17:11,180 --> 00:17:14,180 The LockProbs indexed in the identical way 211 00:17:15,180 --> 00:17:19,180 will be just set to negative one over zero divide n. 212 00:17:19,920 --> 00:17:21,420 Right, just like we derived here. 213 00:17:22,660 --> 00:17:25,420 So now let me erase all of these reasoning. 214 00:17:25,920 --> 00:17:29,420 And then this is the candidate derivative for DLockProbs. 215 00:17:29,660 --> 00:17:32,420 Let's uncomment the first line and check that this is correct. 216 00:17:34,180 --> 00:17:39,180 Okay, so CMP ran, and let's go back to CMP. 217 00:17:39,920 --> 00:17:42,180 And you see that what it's doing is it's calculating if 218 00:17:42,920 --> 00:17:46,180 the calculated value by us, which is dt, 219 00:17:46,420 --> 00:17:49,180 is exactly equal to t.grad as calculated by PyTorch. 220 00:17:49,420 --> 00:17:53,420 And then this is making sure that all of the elements are exactly equal. 221 00:17:53,420 --> 00:17:56,420 And then converting this to a single Boolean value 222 00:17:56,460 --> 00:17:59,420 because we don't want a Boolean tensor, we just want a Boolean value. 223 00:17:59,920 --> 00:18:03,920 And then here we are making sure that, okay, if they're not exactly equal, 224 00:18:03,920 --> 00:18:06,920 maybe they are approximately equal because of some floating point issues. 225 00:18:06,920 --> 00:18:08,920 But they're very, very close. 226 00:18:08,920 --> 00:18:13,420 So here we are using Torch.allClose, which has a little bit of a wiggle available 227 00:18:13,420 --> 00:18:15,920 because sometimes you can get very, very close. 228 00:18:16,920 --> 00:18:18,420 But if you use a slightly different calculation, 229 00:18:18,420 --> 00:18:18,920 because of floating point, you can't get very, very close. 230 00:18:18,920 --> 00:18:23,340 because of floating point arithmetic, you can get a slightly different result. 231 00:18:23,860 --> 00:18:26,640 So this is checking if you get an approximately close result. 232 00:18:27,340 --> 00:18:32,500 And then here we are checking the maximum, basically the value that has the highest difference, 233 00:18:33,080 --> 00:18:36,740 and what is the difference, and the absolute value difference between those two. 234 00:18:37,340 --> 00:18:41,100 And so we are printing whether we have an exact equality, an approximate equality, 235 00:18:41,520 --> 00:18:43,500 and what is the largest difference. 236 00:18:45,180 --> 00:18:48,700 And so here we see that we actually have exact equality. 237 00:18:49,180 --> 00:18:51,380 And so therefore, of course, we also have an approximate equality, 238 00:18:51,680 --> 00:18:54,300 and the maximum difference is exactly zero. 239 00:18:54,840 --> 00:19:00,360 So basically, our DLOGPROPS is exactly equal to what PyTorch calculated to be 240 00:19:00,360 --> 00:19:02,980 logPROPS.grad in its backpropagation. 241 00:19:03,720 --> 00:19:05,540 So, so far, we're doing pretty well. 242 00:19:06,200 --> 00:19:07,900 Okay, so let's now continue our backpropagation. 243 00:19:08,660 --> 00:19:11,820 We have that logPROPS depends on PROPS through a log. 244 00:19:12,180 --> 00:19:16,280 So all the elements of PROPS are being element-wise applied log to. 245 00:19:17,440 --> 00:19:18,820 Now, if we want DPROPS... 246 00:19:18,920 --> 00:19:21,420 then, then remember your micrograph training. 247 00:19:22,160 --> 00:19:23,680 We have like a log node. 248 00:19:23,840 --> 00:19:26,240 It takes in PROPS and creates logPROPS. 249 00:19:26,800 --> 00:19:31,580 And DPROPS will be the local derivative of that individual operation, log, 250 00:19:32,000 --> 00:19:36,980 times the derivative loss with respect to its output, which in this case is DLOGPROPS. 251 00:19:37,600 --> 00:19:39,960 So what is the local derivative of this operation? 252 00:19:40,320 --> 00:19:43,860 Well, we are taking log element-wise, and we can come here and we can see, 253 00:19:43,880 --> 00:19:48,140 well, from alpha is your friend, that d by dx of log of x is just simply 1 over x. 254 00:19:48,920 --> 00:19:51,680 So therefore, in this case, x is PROPS. 255 00:19:51,700 --> 00:19:55,840 So we have d by dx is 1 over x, which is 1 over PROPS. 256 00:19:56,220 --> 00:19:57,680 And then this is the local derivative. 257 00:19:57,700 --> 00:19:59,540 And then times, we want to chain it. 258 00:20:00,120 --> 00:20:01,040 So this is chain rule. 259 00:20:01,620 --> 00:20:02,720 Times DLOGPROPS. 260 00:20:03,520 --> 00:20:07,120 Then let me uncomment this and let me run the cell in place. 261 00:20:07,220 --> 00:20:11,820 And we see that the derivative of PROPS as we calculated here is exactly correct. 262 00:20:12,920 --> 00:20:14,520 And so notice here how this works. 263 00:20:14,800 --> 00:20:16,400 PROPS that are... 264 00:20:16,600 --> 00:20:18,740 PROPS is going to be inverted and then element-wise, 265 00:20:18,740 --> 00:20:20,240 and then element-wise, multiplied here. 266 00:20:20,840 --> 00:20:23,440 So if your PROPS is very, very close to 1, 267 00:20:23,540 --> 00:20:26,740 that means your network is currently predicting the character correctly, 268 00:20:27,240 --> 00:20:30,840 then this will become 1 over 1, and DLOGPROPS just gets passed through. 269 00:20:31,740 --> 00:20:33,940 But if your probabilities are incorrectly assigned, 270 00:20:34,040 --> 00:20:38,140 so if the correct character here is getting a very low probability, 271 00:20:38,640 --> 00:20:45,040 then 1.0 dividing by it will boost this and then multiply by DLOGPROPS. 272 00:20:45,340 --> 00:20:48,140 So basically, what this line is doing intuitively is it's taking 273 00:20:48,140 --> 00:20:48,440 the... 274 00:20:48,440 --> 00:20:51,740 the examples that have a very low probability currently assigned, 275 00:20:51,940 --> 00:20:53,340 and it's boosting their gradient. 276 00:20:54,140 --> 00:20:55,440 You can look at it that way. 277 00:20:56,140 --> 00:20:58,240 Next up is COUNTSUMINV. 278 00:20:59,340 --> 00:21:01,740 So we want the derivative of this. 279 00:21:02,240 --> 00:21:06,540 Now, let me just pause here and kind of introduce what's happening here in general, 280 00:21:06,640 --> 00:21:07,940 because I know it's a little bit confusing. 281 00:21:08,440 --> 00:21:10,240 We have the logits that come out of the neural net. 282 00:21:10,840 --> 00:21:14,040 Here, what I'm doing is I'm finding the maximum in each row, 283 00:21:14,540 --> 00:21:17,340 and I'm subtracting it for the purpose of numerical stability. 284 00:21:17,540 --> 00:21:18,240 And we talked about how... 285 00:21:18,440 --> 00:21:19,640 if you do not do this, 286 00:21:20,040 --> 00:21:23,540 you run into numerical issues if some of the logits take on two large values 287 00:21:23,840 --> 00:21:25,640 because we end up exponentiating them. 288 00:21:26,440 --> 00:21:29,240 So this is done just for safety, numerically. 289 00:21:29,840 --> 00:21:34,640 Then here's the exponentiation of all the sort of logits to create our counts. 290 00:21:35,140 --> 00:21:38,740 And then we want to take the sum of these counts and normalize 291 00:21:38,840 --> 00:21:41,040 so that all of the probes sum to 1. 292 00:21:41,740 --> 00:21:43,840 Now here, instead of using 1 over COUNTSUM, 293 00:21:43,940 --> 00:21:46,440 I use raised to the power of negative 1. 294 00:21:46,640 --> 00:21:48,040 Mathematically, they are identical. 295 00:21:48,040 --> 00:21:50,840 I just found that there's something wrong with the PyTorch implementation 296 00:21:50,940 --> 00:21:55,540 of the backward pass of division, and it gives like a weird result, 297 00:21:55,640 --> 00:21:58,240 but that doesn't happen for star star negative 1, 298 00:21:58,340 --> 00:21:59,940 so I'm using this formula instead. 299 00:22:00,040 --> 00:22:03,940 But basically, all that's happening here is we got the logits, 300 00:22:04,040 --> 00:22:05,340 we want to exponentiate all of them, 301 00:22:05,440 --> 00:22:08,540 and we want to normalize the counts to create our probabilities. 302 00:22:08,640 --> 00:22:11,240 It's just that it's happening across multiple lines. 303 00:22:11,340 --> 00:22:15,540 So now, here, 304 00:22:15,640 --> 00:22:17,440 we want to normalize the counts to create our probabilities. 305 00:22:18,040 --> 00:22:20,440 We want to first take the derivative, 306 00:22:20,540 --> 00:22:24,540 we want to backpropagate into COUNTSUM and then into COUNTS as well. 307 00:22:24,640 --> 00:22:28,140 So what should be the COUNTSUM? 308 00:22:28,240 --> 00:22:29,640 Now, we actually have to be careful here 309 00:22:29,740 --> 00:22:33,340 because we have to scrutinize and be careful with the shapes. 310 00:22:33,440 --> 00:22:40,740 So COUNTS.shape and then COUNTSUM.inv.shape are different. 311 00:22:40,840 --> 00:22:43,140 So in particular, COUNTS is 32 by 27, 312 00:22:43,240 --> 00:22:45,940 but this COUNTSUM.inv is 32 by 1. 313 00:22:46,040 --> 00:22:47,840 And so in this multiplication here, 314 00:22:48,040 --> 00:22:52,140 we also have an implicit broadcasting that PyTorch will do 315 00:22:52,240 --> 00:22:54,540 because it needs to take this column tensor of 32 numbers 316 00:22:54,640 --> 00:22:58,640 and replicate it horizontally 27 times to align these two tensors 317 00:22:58,740 --> 00:23:01,440 so it can do an element-wise multiply. 318 00:23:01,540 --> 00:23:03,640 So really what this looks like is the following, 319 00:23:03,740 --> 00:23:06,140 using a toy example again. 320 00:23:06,240 --> 00:23:09,240 What we really have here is just props is COUNTS times COUNTSUM.inv. 321 00:23:09,340 --> 00:23:11,640 So it's C equals A times B. 322 00:23:11,740 --> 00:23:16,340 But A is 3 by 3, and B is just 3 by 1, a column tensor. 323 00:23:16,440 --> 00:23:17,640 And so PyTorch internally 324 00:23:17,640 --> 00:23:19,940 replicated these elements of B, 325 00:23:20,040 --> 00:23:22,240 and it did that across all the columns. 326 00:23:22,340 --> 00:23:25,140 So for example, B1, which is the first element of B, 327 00:23:25,240 --> 00:23:29,240 would be replicated here across all the columns in this multiplication. 328 00:23:29,340 --> 00:23:34,040 And now we're trying to backpropagate through this operation to COUNTSUM.inv. 329 00:23:34,140 --> 00:23:37,240 So when we are calculating this derivative, 330 00:23:37,340 --> 00:23:41,240 it's important to realize that this looks like a single operation, 331 00:23:41,340 --> 00:23:45,040 but actually is two operations applied sequentially. 332 00:23:45,140 --> 00:23:47,340 The first operation that PyTorch did is it took 333 00:23:47,340 --> 00:23:53,140 this column tensor and replicated it across all the columns, 334 00:23:53,240 --> 00:23:54,740 basically 27 times. 335 00:23:54,840 --> 00:23:56,940 So that's the first operation, it's a replication. 336 00:23:57,040 --> 00:23:59,440 And then the second operation is the multiplication. 337 00:23:59,540 --> 00:24:02,740 So let's first backprop through the multiplication. 338 00:24:02,840 --> 00:24:05,840 If these two arrays were of the same size, 339 00:24:05,940 --> 00:24:09,440 and we just have A and B, both of them 3 by 3, 340 00:24:09,540 --> 00:24:13,240 then how do we backpropagate through a multiplication? 341 00:24:13,340 --> 00:24:15,640 So if we just have scalars and not tensors, 342 00:24:15,740 --> 00:24:17,140 then if you have C equals A times B, 343 00:24:17,340 --> 00:24:21,240 then what is the derivative of C with respect to B? 344 00:24:21,340 --> 00:24:22,640 Well, it's just A. 345 00:24:22,740 --> 00:24:24,640 And so that's the local derivative. 346 00:24:24,740 --> 00:24:27,740 So here in our case, undoing the multiplication 347 00:24:27,840 --> 00:24:30,740 and backpropagating through just the multiplication itself, 348 00:24:30,840 --> 00:24:33,840 which is element-wise, is going to be the local derivative, 349 00:24:33,940 --> 00:24:37,640 which in this case is simply COUNTS, 350 00:24:37,740 --> 00:24:40,240 because COUNTS is the A. 351 00:24:40,340 --> 00:24:42,340 So this is the local derivative, and then times, 352 00:24:42,440 --> 00:24:46,140 because of the chain rule, dprops. 353 00:24:46,240 --> 00:24:47,140 So this here is the dprops. 354 00:24:47,340 --> 00:24:50,240 So this is the local derivative, or the gradient, 355 00:24:50,340 --> 00:24:53,240 but with respect to replicated B. 356 00:24:53,340 --> 00:24:55,240 But we don't have a replicated B, 357 00:24:55,340 --> 00:24:57,240 we just have a single B column. 358 00:24:57,340 --> 00:25:00,240 So how do we now backpropagate through the replication? 359 00:25:00,340 --> 00:25:04,240 And intuitively, this B1 is the same variable, 360 00:25:04,340 --> 00:25:06,240 and it's just reused multiple times. 361 00:25:06,340 --> 00:25:09,240 And so you can look at it as being equivalent 362 00:25:09,340 --> 00:25:12,240 to a case we've encountered in micrograd. 363 00:25:12,340 --> 00:25:14,240 And so here I'm just pulling out a random graph 364 00:25:14,340 --> 00:25:15,240 we used in micrograd. 365 00:25:15,340 --> 00:25:17,240 We had an example where a single node, 366 00:25:17,440 --> 00:25:19,240 has its output feeding into two branches 367 00:25:19,340 --> 00:25:23,240 of basically the graph until the last function. 368 00:25:23,340 --> 00:25:25,240 And we're talking about how the correct thing to do 369 00:25:25,340 --> 00:25:28,240 in the backward pass is we need to sum all the gradients 370 00:25:28,340 --> 00:25:30,240 that arrive at any one node. 371 00:25:30,340 --> 00:25:32,240 So across these different branches, 372 00:25:32,340 --> 00:25:34,240 the gradients would sum. 373 00:25:34,340 --> 00:25:36,240 So if a node is used multiple times, 374 00:25:36,340 --> 00:25:41,240 the gradients for all of its uses sum during backpropagation. 375 00:25:41,340 --> 00:25:44,240 So here, B1 is used multiple times in all these columns, 376 00:25:44,340 --> 00:25:47,240 and therefore the right thing to do here is to sum 377 00:25:47,340 --> 00:25:50,240 horizontally across all the rows. 378 00:25:50,340 --> 00:25:54,240 So we want to sum in dimension 1, 379 00:25:54,340 --> 00:25:56,240 but we want to retain this dimension 380 00:25:56,340 --> 00:25:59,240 so that the countSumInv and its gradient 381 00:25:59,340 --> 00:26:01,240 are going to be exactly the same shape. 382 00:26:01,340 --> 00:26:04,240 So we want to make sure that we keep them as true 383 00:26:04,340 --> 00:26:06,240 so we don't lose this dimension. 384 00:26:06,340 --> 00:26:08,240 And this will make the countSumInv 385 00:26:08,340 --> 00:26:11,240 be exactly shaped 32 by 1. 386 00:26:11,340 --> 00:26:15,240 So revealing this comparison as well and running this, 387 00:26:15,340 --> 00:26:17,240 we see that we get an exact match. 388 00:26:17,340 --> 00:26:20,240 So this derivative is exactly correct. 389 00:26:20,340 --> 00:26:24,240 And let me erase this. 390 00:26:24,340 --> 00:26:27,240 Now let's also backpropagate into counts, 391 00:26:27,340 --> 00:26:30,240 which is the other variable here to create props. 392 00:26:30,340 --> 00:26:32,240 So from props to countSumInv, 393 00:26:32,340 --> 00:26:33,240 we just did that. 394 00:26:33,340 --> 00:26:35,240 Let's go into counts as well. 395 00:26:35,340 --> 00:26:40,240 So dCounts is our A. 396 00:26:40,340 --> 00:26:43,240 So dC by dA is just B. 397 00:26:43,340 --> 00:26:46,240 So therefore it's countSumInv. 398 00:26:46,240 --> 00:26:50,140 And then times, chain rule, dProps. 399 00:26:50,240 --> 00:26:53,140 Now countSumInv is 32 by 1. 400 00:26:53,240 --> 00:26:57,140 dProps is 32 by 27. 401 00:26:57,240 --> 00:27:01,140 So those will broadcast fine 402 00:27:01,240 --> 00:27:03,140 and will give us dCounts. 403 00:27:03,240 --> 00:27:06,140 There's no additional summation required here. 404 00:27:06,240 --> 00:27:08,140 There will be a broadcasting 405 00:27:08,240 --> 00:27:10,140 that happens in this multiply here 406 00:27:10,240 --> 00:27:13,140 because countSumInv needs to be replicated again 407 00:27:13,240 --> 00:27:15,140 to correctly multiply dProps. 408 00:27:15,140 --> 00:27:18,040 But that's going to give the correct result. 409 00:27:18,140 --> 00:27:21,040 So as far as this single operation is concerned. 410 00:27:21,140 --> 00:27:24,040 So we've backpropagated from props to counts, 411 00:27:24,140 --> 00:27:28,040 but we can't actually check the derivative of counts. 412 00:27:28,140 --> 00:27:30,040 I have it much later on. 413 00:27:30,140 --> 00:27:32,040 And the reason for that is because 414 00:27:32,140 --> 00:27:34,040 countSumInv depends on counts. 415 00:27:34,140 --> 00:27:36,040 And so there's a second branch here 416 00:27:36,140 --> 00:27:37,040 that we have to finish. 417 00:27:37,140 --> 00:27:40,040 Because countSumInv backpropagates into countSum, 418 00:27:40,140 --> 00:27:42,040 and countSum will backpropagate into counts. 419 00:27:42,140 --> 00:27:45,040 And so counts is a node that is being used twice. 420 00:27:45,040 --> 00:27:46,940 It's used right here in two props, 421 00:27:47,040 --> 00:27:48,940 and it goes through this other branch 422 00:27:49,040 --> 00:27:50,940 through countSumInv. 423 00:27:51,040 --> 00:27:52,940 So even though we've calculated 424 00:27:53,040 --> 00:27:54,940 the first contribution of it, 425 00:27:55,040 --> 00:27:56,940 we still have to calculate the second contribution of it later. 426 00:27:57,040 --> 00:27:58,940 Okay, so we're continuing with this branch. 427 00:27:59,040 --> 00:28:00,940 We have the derivative for countSumInv. 428 00:28:01,040 --> 00:28:02,940 Now we want the derivative for countSum. 429 00:28:03,040 --> 00:28:04,940 So dCountSum equals 430 00:28:05,040 --> 00:28:06,940 what is the local derivative of this operation? 431 00:28:07,040 --> 00:28:08,940 So this is basically an element-wise 432 00:28:09,040 --> 00:28:10,940 1 over countsSum. 433 00:28:11,040 --> 00:28:13,940 So countSum raised to the power of negative 1 434 00:28:13,940 --> 00:28:15,840 is the same as 1 over countsSum. 435 00:28:15,940 --> 00:28:17,840 If we go to wall from alpha, 436 00:28:17,940 --> 00:28:19,840 we see that x to the negative 1, 437 00:28:19,940 --> 00:28:21,840 d by dx of it, 438 00:28:21,940 --> 00:28:23,840 is basically negative x to the negative 2. 439 00:28:23,940 --> 00:28:25,840 Negative 1 over s squared 440 00:28:25,940 --> 00:28:27,840 is the same as negative x to the negative 2. 441 00:28:27,940 --> 00:28:31,840 So dCountSum here will be 442 00:28:31,940 --> 00:28:33,840 local derivative is going to be 443 00:28:33,940 --> 00:28:38,840 negative countsSum to the negative 2, 444 00:28:38,940 --> 00:28:40,840 that's the local derivative, 445 00:28:40,940 --> 00:28:43,840 times chain rule, which is 446 00:28:43,840 --> 00:28:45,740 countSumInv. 447 00:28:45,840 --> 00:28:47,740 So that's dCountSum. 448 00:28:47,840 --> 00:28:49,740 Let's uncomment this 449 00:28:49,840 --> 00:28:51,740 and check that I am correct. 450 00:28:51,840 --> 00:28:53,740 Okay, so we have perfect equality. 451 00:28:53,840 --> 00:28:57,740 And there's no sketchiness going on here 452 00:28:57,840 --> 00:28:59,740 with any shapes because these are of the same shape. 453 00:28:59,840 --> 00:29:01,740 Okay, next up we want to 454 00:29:01,840 --> 00:29:03,740 backpropagate through this line. 455 00:29:03,840 --> 00:29:05,740 We have that countSum is counts.sum 456 00:29:05,840 --> 00:29:07,740 along the rows. 457 00:29:07,840 --> 00:29:09,740 So I wrote out some help here. 458 00:29:09,840 --> 00:29:11,740 We have to keep in mind that 459 00:29:11,840 --> 00:29:13,740 counts, of course, is 32 by 27. 460 00:29:13,840 --> 00:29:15,740 And countsSum is 32 by 1. 461 00:29:15,840 --> 00:29:17,740 So in this backpropagation, 462 00:29:17,840 --> 00:29:21,740 we need to take this column of derivatives 463 00:29:21,840 --> 00:29:24,740 and transform it into an array of derivatives, 464 00:29:24,840 --> 00:29:26,740 two-dimensional array. 465 00:29:26,840 --> 00:29:28,740 So what is this operation doing? 466 00:29:28,840 --> 00:29:30,740 We're taking some kind of an input, 467 00:29:30,840 --> 00:29:32,740 like, say, a 3x3 matrix A, 468 00:29:32,840 --> 00:29:34,740 and we are summing up the rows 469 00:29:34,840 --> 00:29:36,740 into a column tensor B. 470 00:29:36,840 --> 00:29:39,740 B1, B2, B3, that is basically this. 471 00:29:39,840 --> 00:29:41,740 So now we have the derivatives 472 00:29:41,840 --> 00:29:43,740 of the loss with respect to B. 473 00:29:43,740 --> 00:29:45,640 And now we have the elements of B. 474 00:29:45,740 --> 00:29:47,640 And now we want the derivative of the loss 475 00:29:47,740 --> 00:29:49,640 with respect to all these little a's. 476 00:29:49,740 --> 00:29:51,640 So how do the b's depend on the a's, 477 00:29:51,740 --> 00:29:53,640 is basically what we're after. 478 00:29:53,740 --> 00:29:55,640 What is the local derivative of this operation? 479 00:29:55,740 --> 00:29:57,640 Well, we can see here that B1 480 00:29:57,740 --> 00:29:59,640 only depends on these elements here. 481 00:29:59,740 --> 00:30:01,640 The derivative of B1 482 00:30:01,740 --> 00:30:03,640 with respect to all of these elements down here 483 00:30:03,740 --> 00:30:05,640 is 0. 484 00:30:05,740 --> 00:30:07,640 But for these elements here, 485 00:30:07,740 --> 00:30:09,640 like A11, A12, etc., 486 00:30:09,740 --> 00:30:11,640 the local derivative is 1, right? 487 00:30:11,640 --> 00:30:15,540 So dB1 by dA11, for example, is 1. 488 00:30:15,640 --> 00:30:17,540 So it's 1, 1, and 1. 489 00:30:17,640 --> 00:30:19,540 So when we have the derivative of the loss 490 00:30:19,640 --> 00:30:21,540 with respect to B1, 491 00:30:21,640 --> 00:30:23,540 the local derivative of B1 492 00:30:23,640 --> 00:30:25,540 with respect to these inputs is 0 here, 493 00:30:25,640 --> 00:30:27,540 but it's 1 on these guys. 494 00:30:27,640 --> 00:30:29,540 So in the chain rule, 495 00:30:29,640 --> 00:30:31,540 we have the local derivative 496 00:30:31,640 --> 00:30:34,540 times the derivative of B1. 497 00:30:34,640 --> 00:30:36,540 And so because the local derivative 498 00:30:36,640 --> 00:30:38,540 is 1 on these three elements, 499 00:30:38,640 --> 00:30:40,540 the local derivative of multiplying 500 00:30:40,540 --> 00:30:42,440 the derivative of B1 501 00:30:42,540 --> 00:30:44,440 will just be the derivative of B1. 502 00:30:44,540 --> 00:30:46,440 And so you can look at it as a router. 503 00:30:46,540 --> 00:30:48,440 Basically, an addition 504 00:30:48,540 --> 00:30:50,440 is a router of gradient. 505 00:30:50,540 --> 00:30:52,440 Whatever gradient comes from above, 506 00:30:52,540 --> 00:30:54,440 it just gets routed equally 507 00:30:54,540 --> 00:30:56,440 to all the elements that participate 508 00:30:56,540 --> 00:30:58,440 in that addition. 509 00:30:58,540 --> 00:31:00,440 So in this case, the derivative of B1 510 00:31:00,540 --> 00:31:02,440 will just flow equally to the derivative 511 00:31:02,540 --> 00:31:04,440 of A11, A12, and A13. 512 00:31:04,540 --> 00:31:06,440 So if we have a derivative 513 00:31:06,540 --> 00:31:08,440 of all the elements of B 514 00:31:08,540 --> 00:31:10,440 in this column tensor, 515 00:31:10,440 --> 00:31:12,340 which we calculated just now, 516 00:31:12,440 --> 00:31:14,340 we basically see that what that amounts to 517 00:31:14,440 --> 00:31:16,340 is all of these are now flowing 518 00:31:16,440 --> 00:31:18,340 to all these elements of A, 519 00:31:18,440 --> 00:31:20,340 and they're doing that horizontally. 520 00:31:20,440 --> 00:31:22,340 So basically what we want 521 00:31:22,440 --> 00:31:24,340 is we want to take the decount sum 522 00:31:24,440 --> 00:31:26,340 of size 32 by 1, 523 00:31:26,440 --> 00:31:28,340 and we just want to replicate it 524 00:31:28,440 --> 00:31:30,340 27 times horizontally 525 00:31:30,440 --> 00:31:32,340 to create 32 by 27 array. 526 00:31:32,440 --> 00:31:34,340 So there's many ways to implement this operation. 527 00:31:34,440 --> 00:31:36,340 You could, of course, just replicate the tensor, 528 00:31:36,440 --> 00:31:38,340 but I think maybe one clean one 529 00:31:38,440 --> 00:31:40,340 is that decounts is simply 530 00:31:40,340 --> 00:31:42,240 torch.once-like, 531 00:31:42,340 --> 00:31:44,240 so just two-dimensional arrays 532 00:31:44,340 --> 00:31:46,240 of once in the shape of counts, 533 00:31:46,340 --> 00:31:48,240 so 32 by 27, 534 00:31:48,340 --> 00:31:50,240 times decounts sum. 535 00:31:50,340 --> 00:31:52,240 So this way we're letting 536 00:31:52,340 --> 00:31:54,240 the broadcasting here 537 00:31:54,340 --> 00:31:56,240 basically implement the replication. 538 00:31:56,340 --> 00:31:58,240 You can look at it that way. 539 00:31:58,340 --> 00:32:00,240 But then we have to also be careful 540 00:32:00,340 --> 00:32:02,240 because decounts 541 00:32:02,340 --> 00:32:04,240 was all already calculated. 542 00:32:04,340 --> 00:32:06,240 We calculated earlier here, 543 00:32:06,340 --> 00:32:08,240 and that was just the first branch, 544 00:32:08,340 --> 00:32:10,240 and we're now finishing the second branch. 545 00:32:10,240 --> 00:32:12,140 So we need to make sure that these gradients add, 546 00:32:12,240 --> 00:32:14,140 so plus equals. 547 00:32:14,240 --> 00:32:16,140 And then here, 548 00:32:16,240 --> 00:32:18,140 let's comment out 549 00:32:18,240 --> 00:32:20,140 the comparison, 550 00:32:20,240 --> 00:32:22,140 and let's make sure, crossing fingers, 551 00:32:22,240 --> 00:32:24,140 that we have the correct result. 552 00:32:24,240 --> 00:32:26,140 So PyTorch agrees with us 553 00:32:26,240 --> 00:32:28,140 on this gradient as well. 554 00:32:28,240 --> 00:32:30,140 Okay, hopefully we're getting a hang of this now. 555 00:32:30,240 --> 00:32:32,140 Counts is an element-wise exp 556 00:32:32,240 --> 00:32:34,140 of norm logits. 557 00:32:34,240 --> 00:32:36,140 So now we want dNormLogits, 558 00:32:36,240 --> 00:32:38,140 and because it's an element-wise operation, 559 00:32:38,240 --> 00:32:40,140 everything is very simple. 560 00:32:40,140 --> 00:32:42,040 It's the local derivative of e to the x. 561 00:32:42,140 --> 00:32:44,040 It's famously just e to the x. 562 00:32:44,140 --> 00:32:46,040 So this is the local derivative. 563 00:32:48,140 --> 00:32:50,040 That is the local derivative. 564 00:32:50,140 --> 00:32:52,040 Now we already calculated it, 565 00:32:52,140 --> 00:32:54,040 and it's inside counts. 566 00:32:54,140 --> 00:32:56,040 So we may as well potentially just reuse counts. 567 00:32:56,140 --> 00:32:58,040 That is the local derivative. 568 00:32:58,140 --> 00:33:00,040 Times dCounts. 569 00:33:02,140 --> 00:33:04,040 Funny as that looks. 570 00:33:04,140 --> 00:33:06,040 Counts times dCounts is dNormLogits. 571 00:33:06,140 --> 00:33:08,040 And now let's erase this, 572 00:33:08,140 --> 00:33:10,040 and let's verify, 573 00:33:10,040 --> 00:33:11,940 and let's go. 574 00:33:12,040 --> 00:33:13,940 So that's dNormLogits. 575 00:33:14,040 --> 00:33:15,940 Okay, so we are here 576 00:33:16,040 --> 00:33:17,940 on this line now, dNormLogits. 577 00:33:18,040 --> 00:33:19,940 We have that, 578 00:33:20,040 --> 00:33:21,940 and we're trying to calculate dLogits 579 00:33:22,040 --> 00:33:23,940 and dLogitMaxes. 580 00:33:24,040 --> 00:33:25,940 So back-propagating through this line. 581 00:33:26,040 --> 00:33:27,940 Now we have to be careful here, 582 00:33:28,040 --> 00:33:29,940 because the shapes, again, are not the same, 583 00:33:30,040 --> 00:33:31,940 and so there's an implicit broadcasting happening here. 584 00:33:32,040 --> 00:33:33,940 So dNormLogits has the shape 32x27. 585 00:33:34,040 --> 00:33:35,940 dLogits does as well, 586 00:33:36,040 --> 00:33:37,940 but dLogitMaxes is only 32x1. 587 00:33:38,040 --> 00:33:39,940 So there's a broadcast 588 00:33:39,940 --> 00:33:41,840 here in the minus. 589 00:33:41,940 --> 00:33:43,840 Now here I tried to 590 00:33:43,940 --> 00:33:45,840 sort of write out a toy example again. 591 00:33:45,940 --> 00:33:47,840 We basically have that 592 00:33:47,940 --> 00:33:49,840 this is our c equals a minus b, 593 00:33:49,940 --> 00:33:51,840 and we see that because of the shape, 594 00:33:51,940 --> 00:33:53,840 these are 3x3, but this one is just a column. 595 00:33:53,940 --> 00:33:55,840 And so for example, 596 00:33:55,940 --> 00:33:57,840 every element of c, 597 00:33:57,940 --> 00:33:59,840 we have to look at how it came to be. 598 00:33:59,940 --> 00:34:01,840 And every element of c is just 599 00:34:01,940 --> 00:34:03,840 the corresponding element of a minus 600 00:34:03,940 --> 00:34:05,840 basically that associated b. 601 00:34:07,940 --> 00:34:09,840 So it's very clear now 602 00:34:09,840 --> 00:34:11,740 that the derivatives of 603 00:34:11,840 --> 00:34:13,740 every one of these c's with respect to their inputs 604 00:34:13,840 --> 00:34:15,740 are 1 605 00:34:15,840 --> 00:34:17,740 for the corresponding a, 606 00:34:17,840 --> 00:34:19,740 and it's a negative 1 607 00:34:19,840 --> 00:34:21,740 for the corresponding b. 608 00:34:21,840 --> 00:34:23,740 And so therefore, 609 00:34:23,840 --> 00:34:25,740 the derivatives 610 00:34:25,840 --> 00:34:27,740 on the c will flow 611 00:34:27,840 --> 00:34:29,740 equally to the corresponding a's, 612 00:34:29,840 --> 00:34:31,740 and then also to the corresponding b's. 613 00:34:31,840 --> 00:34:33,740 But then in addition to that, 614 00:34:33,840 --> 00:34:35,740 the b's are broadcast, 615 00:34:35,840 --> 00:34:37,740 so we'll have to do the additional sum 616 00:34:37,840 --> 00:34:39,740 just like we did before. 617 00:34:39,740 --> 00:34:41,640 And of course, the derivatives for b's 618 00:34:41,740 --> 00:34:43,640 will undergo a minus, 619 00:34:43,740 --> 00:34:45,640 because the local derivative here 620 00:34:45,740 --> 00:34:47,640 is negative 1. 621 00:34:47,740 --> 00:34:49,640 So dc32 by db3 is negative 1. 622 00:34:49,740 --> 00:34:51,640 So let's just implement that. 623 00:34:51,740 --> 00:34:53,640 Basically, dlogits will be 624 00:34:53,740 --> 00:34:55,640 exactly copying 625 00:34:55,740 --> 00:34:57,640 the derivative on normlogits. 626 00:34:57,740 --> 00:34:59,640 So 627 00:34:59,740 --> 00:35:01,640 dlogits equals 628 00:35:01,740 --> 00:35:03,640 dnormlogits, and I'll do a .clone 629 00:35:03,740 --> 00:35:05,640 for safety, so we're just making a copy. 630 00:35:05,740 --> 00:35:07,640 And then we have that 631 00:35:07,740 --> 00:35:09,640 dlogitmaxis, 632 00:35:09,640 --> 00:35:11,540 will be the negative 633 00:35:11,640 --> 00:35:13,540 of dnormlogits, 634 00:35:13,640 --> 00:35:15,540 because of the negative sign. 635 00:35:15,640 --> 00:35:17,540 And then we have to be careful because 636 00:35:17,640 --> 00:35:19,540 dlogitmaxis is 637 00:35:19,640 --> 00:35:21,540 a column, and so 638 00:35:21,640 --> 00:35:23,540 just like we saw before, 639 00:35:23,640 --> 00:35:25,540 because we keep replicating the same 640 00:35:25,640 --> 00:35:27,540 elements across all the 641 00:35:27,640 --> 00:35:29,540 columns, then in the 642 00:35:29,640 --> 00:35:31,540 backward pass, because we keep reusing 643 00:35:31,640 --> 00:35:33,540 this, these are all just like separate 644 00:35:33,640 --> 00:35:35,540 branches of use of that one variable. 645 00:35:35,640 --> 00:35:37,540 And so therefore, we have to do a 646 00:35:37,640 --> 00:35:39,540 sum along 1, we'd keep 647 00:35:39,540 --> 00:35:41,440 them equals true, so that we don't 648 00:35:41,540 --> 00:35:43,440 destroy this dimension. 649 00:35:43,540 --> 00:35:45,440 And then dlogitmaxis will be the same 650 00:35:45,540 --> 00:35:47,440 shape. Now we have to be careful because 651 00:35:47,540 --> 00:35:49,440 this dlogits is not the final dlogits, 652 00:35:49,540 --> 00:35:51,440 and that's because 653 00:35:51,540 --> 00:35:53,440 not only do we get gradient signal 654 00:35:53,540 --> 00:35:55,440 into logits through here, but 655 00:35:55,540 --> 00:35:57,440 logitmaxis is a function of logits, 656 00:35:57,540 --> 00:35:59,440 and that's a second branch into 657 00:35:59,540 --> 00:36:01,440 logits. So this is not yet our final 658 00:36:01,540 --> 00:36:03,440 derivative for logits, we will come back 659 00:36:03,540 --> 00:36:05,440 later for the second branch. 660 00:36:05,540 --> 00:36:07,440 For now, dlogitmaxis is the final derivative, 661 00:36:07,540 --> 00:36:09,440 so let me uncomment this 662 00:36:09,440 --> 00:36:11,340 cmp here, and let's just run this. 663 00:36:11,440 --> 00:36:13,340 And logitmaxis, 664 00:36:13,440 --> 00:36:15,340 if pytorch, agrees 665 00:36:15,440 --> 00:36:17,340 with us. So that was the derivative 666 00:36:17,440 --> 00:36:19,340 into, through 667 00:36:19,440 --> 00:36:21,340 this line. Now before 668 00:36:21,440 --> 00:36:23,340 we move on, I want to pause here briefly, 669 00:36:23,440 --> 00:36:25,340 and I want to look at these logitmaxis, and 670 00:36:25,440 --> 00:36:27,340 especially their gradients. We've 671 00:36:27,440 --> 00:36:29,340 talked previously in the previous lecture 672 00:36:29,440 --> 00:36:31,340 that the only reason we're doing this is 673 00:36:31,440 --> 00:36:33,340 for the numerical stability of the softmax 674 00:36:33,440 --> 00:36:35,340 that we are implementing here. 675 00:36:35,440 --> 00:36:37,340 And we talked about how if you take 676 00:36:37,440 --> 00:36:39,340 these logits for any one of these examples, 677 00:36:39,440 --> 00:36:41,340 so one row of this logits tensor, 678 00:36:41,440 --> 00:36:43,340 if you add or subtract 679 00:36:43,440 --> 00:36:45,340 any value equally to all the elements, 680 00:36:45,440 --> 00:36:47,340 then the value 681 00:36:47,440 --> 00:36:49,340 of the probes will be unchanged. 682 00:36:49,440 --> 00:36:51,340 You're not changing the softmax. The only thing 683 00:36:51,440 --> 00:36:53,340 that this is doing is it's making sure that 684 00:36:53,440 --> 00:36:55,340 exp doesn't overflow. And the 685 00:36:55,440 --> 00:36:57,340 reason we're using a max is because then we 686 00:36:57,440 --> 00:36:59,340 are guaranteed that each row of logits, 687 00:36:59,440 --> 00:37:01,340 the highest number, is zero. 688 00:37:01,440 --> 00:37:03,340 And so this will be safe. 689 00:37:03,440 --> 00:37:05,340 And so 690 00:37:05,440 --> 00:37:07,340 basically 691 00:37:07,440 --> 00:37:09,340 that has repercussions. 692 00:37:09,440 --> 00:37:11,340 If it is the case that changing 693 00:37:11,440 --> 00:37:13,340 logitmaxis does not change the 694 00:37:13,440 --> 00:37:15,340 probes, and therefore does not change the loss, 695 00:37:15,440 --> 00:37:17,340 then the gradient on logitmaxis 696 00:37:17,440 --> 00:37:19,340 should be zero. 697 00:37:19,440 --> 00:37:21,340 Because saying those two things is the same. 698 00:37:21,440 --> 00:37:23,340 So indeed we hope that this is 699 00:37:23,440 --> 00:37:25,340 very, very small numbers. Indeed we hope this is zero. 700 00:37:25,440 --> 00:37:27,340 Now because of floating 701 00:37:27,440 --> 00:37:29,340 point sort of wonkiness, 702 00:37:29,440 --> 00:37:31,340 this doesn't come out exactly zero. 703 00:37:31,440 --> 00:37:33,340 Only in some of the rows it does. 704 00:37:33,440 --> 00:37:35,340 But we get extremely small values, like 705 00:37:35,440 --> 00:37:37,340 1e-9 or 10. 706 00:37:37,440 --> 00:37:39,340 And so this is telling us that the values of 707 00:37:39,340 --> 00:37:41,240 logitmaxis are not impacting the loss 708 00:37:41,340 --> 00:37:43,240 as they shouldn't. 709 00:37:43,340 --> 00:37:45,240 It feels kind of weird to backpropagate through this branch, 710 00:37:45,340 --> 00:37:47,240 honestly, because 711 00:37:47,340 --> 00:37:49,240 if you have any 712 00:37:49,340 --> 00:37:51,240 implementation of f.crossentropy and 713 00:37:51,340 --> 00:37:53,240 pytorch, and you block together 714 00:37:53,340 --> 00:37:55,240 all of these elements, and you're not doing backpropagation 715 00:37:55,340 --> 00:37:57,240 piece by piece, then you would 716 00:37:57,340 --> 00:37:59,240 probably assume that the derivative 717 00:37:59,340 --> 00:38:01,240 through here is exactly zero. 718 00:38:01,340 --> 00:38:03,240 So you would be sort of 719 00:38:03,340 --> 00:38:05,240 skipping 720 00:38:05,340 --> 00:38:07,240 this branch. Because 721 00:38:07,340 --> 00:38:09,240 it's only done for numerical stability. 722 00:38:09,340 --> 00:38:11,240 But it's interesting to see that even if you break up 723 00:38:11,340 --> 00:38:13,240 everything into the full atoms, and you 724 00:38:13,340 --> 00:38:15,240 still do the computation as you'd like 725 00:38:15,340 --> 00:38:17,240 with respect to numerical stability, the correct thing 726 00:38:17,340 --> 00:38:19,240 happens. And you still get 727 00:38:19,340 --> 00:38:21,240 very, very small gradients here. 728 00:38:21,340 --> 00:38:23,240 Basically reflecting the fact that 729 00:38:23,340 --> 00:38:25,240 the values of these do not matter 730 00:38:25,340 --> 00:38:27,240 with respect to the final loss. 731 00:38:27,340 --> 00:38:29,240 Okay, so let's now continue backpropagation 732 00:38:29,340 --> 00:38:31,240 through this line here. 733 00:38:31,340 --> 00:38:33,240 We've just calculated the logitmaxis, and now 734 00:38:33,340 --> 00:38:35,240 we want to backprop into logits through this 735 00:38:35,340 --> 00:38:37,240 second branch. Now here of course 736 00:38:37,340 --> 00:38:39,240 we took logits, and we took the max 737 00:38:39,240 --> 00:38:41,140 along all the rows, and then 738 00:38:41,240 --> 00:38:43,140 we looked at its values here. 739 00:38:43,240 --> 00:38:45,140 Now the way this works is that in pytorch, 740 00:38:47,240 --> 00:38:49,140 this thing here, 741 00:38:49,240 --> 00:38:51,140 the max returns both the values, 742 00:38:51,240 --> 00:38:53,140 and it returns the indices at which those 743 00:38:53,240 --> 00:38:55,140 values to count the maximum value. 744 00:38:55,240 --> 00:38:57,140 Now in the forward pass, we only 745 00:38:57,240 --> 00:38:59,140 used values, because that's all we needed. 746 00:38:59,240 --> 00:39:01,140 But in the backward pass, it's extremely 747 00:39:01,240 --> 00:39:03,140 useful to know about where those 748 00:39:03,240 --> 00:39:05,140 maximum values occurred. 749 00:39:05,240 --> 00:39:07,140 And we have the indices at which they occurred. 750 00:39:07,240 --> 00:39:09,140 And this will of course help us do 751 00:39:09,240 --> 00:39:11,140 the backpropagation. 752 00:39:11,240 --> 00:39:13,140 Because what should the backward pass be 753 00:39:13,240 --> 00:39:15,140 here in this case? We have the logit tensor, 754 00:39:15,240 --> 00:39:17,140 which is 32 by 27, and 755 00:39:17,240 --> 00:39:19,140 in each row we find the maximum value, 756 00:39:19,240 --> 00:39:21,140 and then that value gets plucked out into 757 00:39:21,240 --> 00:39:23,140 logitmaxis. And so intuitively, 758 00:39:23,240 --> 00:39:25,140 basically 759 00:39:25,240 --> 00:39:27,140 the derivative 760 00:39:27,240 --> 00:39:29,140 flowing through here then 761 00:39:29,240 --> 00:39:31,140 should be 1 762 00:39:31,240 --> 00:39:33,140 times the local derivative 763 00:39:33,240 --> 00:39:35,140 is 1 for the appropriate entry that was 764 00:39:35,240 --> 00:39:37,140 plucked out, and 765 00:39:37,240 --> 00:39:39,140 then times the global derivative, 766 00:39:39,140 --> 00:39:41,040 of the logitmaxis. 767 00:39:41,140 --> 00:39:43,040 So really what we're doing here, if you think through it, 768 00:39:43,140 --> 00:39:45,040 is we need to take the delogitmaxis, 769 00:39:45,140 --> 00:39:47,040 and we need to scatter it to 770 00:39:47,140 --> 00:39:49,040 the correct positions 771 00:39:49,140 --> 00:39:51,040 in these logits, 772 00:39:51,140 --> 00:39:53,040 from where the maximum values came. 773 00:39:53,140 --> 00:39:55,040 And so, 774 00:39:55,140 --> 00:39:57,040 I came up with 775 00:39:57,140 --> 00:39:59,040 one line of code that does that. 776 00:39:59,140 --> 00:40:01,040 Let me just erase a bunch of stuff here. 777 00:40:01,140 --> 00:40:03,040 You could do it kind of very similar 778 00:40:03,140 --> 00:40:05,040 to what we've done here, where we create 779 00:40:05,140 --> 00:40:07,040 a zeros, and then we populate 780 00:40:07,140 --> 00:40:09,040 the correct elements. 781 00:40:09,140 --> 00:40:11,040 So we use the indices here, and we would 782 00:40:11,140 --> 00:40:13,040 set them to be 1. But you can 783 00:40:13,140 --> 00:40:15,040 also use one hot. 784 00:40:15,140 --> 00:40:17,040 So f dot one hot, 785 00:40:17,140 --> 00:40:19,040 and then I'm taking the logits of max 786 00:40:19,140 --> 00:40:21,040 over the first dimension 787 00:40:21,140 --> 00:40:23,040 dot indices, and I'm telling 788 00:40:23,140 --> 00:40:25,040 PyTorch that 789 00:40:25,140 --> 00:40:27,040 the dimension of 790 00:40:27,140 --> 00:40:29,040 every one of these tensors should be 791 00:40:29,140 --> 00:40:31,040 27. 792 00:40:31,140 --> 00:40:33,040 And so what this is going to do 793 00:40:33,140 --> 00:40:35,040 is... 794 00:40:35,140 --> 00:40:37,040 Okay, I apologize, this is crazy. 795 00:40:37,140 --> 00:40:39,040 PLT dot imchev of this. 796 00:40:39,140 --> 00:40:41,040 It's really just an array 797 00:40:41,140 --> 00:40:43,040 of where the maxes came from 798 00:40:43,140 --> 00:40:45,040 in each row, and that element is 1, 799 00:40:45,140 --> 00:40:47,040 and all the other elements are 0. 800 00:40:47,140 --> 00:40:49,040 So it's one hot vector in each row, 801 00:40:49,140 --> 00:40:51,040 and these indices are now populating 802 00:40:51,140 --> 00:40:53,040 a single 1 in the proper 803 00:40:53,140 --> 00:40:55,040 place. And then what I'm doing 804 00:40:55,140 --> 00:40:57,040 here is I'm multiplying by the logit 805 00:40:57,140 --> 00:40:59,040 maxes. And keep in mind that 806 00:40:59,140 --> 00:41:01,040 this is a column 807 00:41:01,140 --> 00:41:03,040 of 32 by 1. 808 00:41:03,140 --> 00:41:05,040 And so when I'm doing this 809 00:41:05,140 --> 00:41:07,040 times the logit maxes, 810 00:41:07,140 --> 00:41:09,040 the logit maxes will broadcast 811 00:41:09,040 --> 00:41:10,940 and that column will get replicated, 812 00:41:11,040 --> 00:41:12,940 and then element-wise multiply 813 00:41:13,040 --> 00:41:14,940 will ensure that each of these 814 00:41:15,040 --> 00:41:16,940 just gets routed to whichever 815 00:41:17,040 --> 00:41:18,940 one of these bits is turned on. 816 00:41:19,040 --> 00:41:20,940 And so that's another way to implement 817 00:41:21,040 --> 00:41:22,940 this kind of 818 00:41:23,040 --> 00:41:24,940 operation, and 819 00:41:25,040 --> 00:41:26,940 both of these can be used. I just 820 00:41:27,040 --> 00:41:28,940 thought I would show an equivalent way to do it. 821 00:41:29,040 --> 00:41:30,940 And I'm using plus equals because 822 00:41:31,040 --> 00:41:32,940 we already calculated the logits here, 823 00:41:33,040 --> 00:41:34,940 and this is now the second branch. 824 00:41:35,040 --> 00:41:36,940 So let's 825 00:41:37,040 --> 00:41:38,940 look at logits and make sure that 826 00:41:38,940 --> 00:41:40,840 this is correct. 827 00:41:40,940 --> 00:41:42,840 And we see that we have exactly the correct answer. 828 00:41:42,940 --> 00:41:44,840 Next up, 829 00:41:44,940 --> 00:41:46,840 we want to continue with logits here. 830 00:41:46,940 --> 00:41:48,840 That is an outcome of a matrix 831 00:41:48,940 --> 00:41:50,840 multiplication and a bias offset 832 00:41:50,940 --> 00:41:52,840 in this linear layer. 833 00:41:52,940 --> 00:41:54,840 So I've 834 00:41:54,940 --> 00:41:56,840 printed out the shapes of all these intermediate 835 00:41:56,940 --> 00:41:58,840 tensors. We see that logits 836 00:41:58,940 --> 00:42:00,840 is of course 32 by 27, as we've just 837 00:42:00,940 --> 00:42:02,840 seen. Then the 838 00:42:02,940 --> 00:42:04,840 h here is 32 by 64. 839 00:42:04,940 --> 00:42:06,840 So these are 64-dimensional hidden states. 840 00:42:06,940 --> 00:42:08,840 And then this w 841 00:42:08,840 --> 00:42:10,740 matrix projects those 64-dimensional 842 00:42:10,840 --> 00:42:12,740 vectors into 27 dimensions. 843 00:42:12,840 --> 00:42:14,740 And then there's a 27-dimensional 844 00:42:14,840 --> 00:42:16,740 offset, which is a 845 00:42:16,840 --> 00:42:18,740 one-dimensional vector. Now we 846 00:42:18,840 --> 00:42:20,740 should note that this plus here actually 847 00:42:20,840 --> 00:42:22,740 broadcasts, because h multiplied 848 00:42:22,840 --> 00:42:24,740 by w2 849 00:42:24,840 --> 00:42:26,740 will give us a 32 by 27. 850 00:42:26,840 --> 00:42:28,740 And so then this plus 851 00:42:28,840 --> 00:42:30,740 b2 is a 27-dimensional 852 00:42:30,840 --> 00:42:32,740 vector here. Now in the 853 00:42:32,840 --> 00:42:34,740 rules of broadcasting, what's going to happen with this bias 854 00:42:34,840 --> 00:42:36,740 vector is that this one-dimensional 855 00:42:36,840 --> 00:42:38,740 vector of 27 will get a lot 856 00:42:38,840 --> 00:42:40,740 aligned with an padded dimension 857 00:42:40,840 --> 00:42:42,740 of 1 on the left. 858 00:42:42,840 --> 00:42:44,740 And it will basically become a row vector, 859 00:42:44,840 --> 00:42:46,740 and then it will get replicated 860 00:42:46,840 --> 00:42:48,740 vertically 32 times to make it 861 00:42:48,840 --> 00:42:50,740 32 by 27, and then there's an element-wise 862 00:42:50,840 --> 00:42:52,740 multiply. 863 00:42:52,840 --> 00:42:54,740 Now the question 864 00:42:54,840 --> 00:42:56,740 is how do we backpropagate from 865 00:42:56,840 --> 00:42:58,740 logits to the hidden states, 866 00:42:58,840 --> 00:43:00,740 the weight matrix w2, and the bias 867 00:43:00,840 --> 00:43:02,740 b2? And you might 868 00:43:02,840 --> 00:43:04,740 think that we need to go to some 869 00:43:04,840 --> 00:43:06,740 matrix calculus, 870 00:43:06,740 --> 00:43:08,640 and then we have to look up the derivative 871 00:43:08,740 --> 00:43:10,640 for matrix multiplication, 872 00:43:10,740 --> 00:43:12,640 but actually you don't have to do any of that, and you can go 873 00:43:12,740 --> 00:43:14,640 back to first principles and derive this yourself 874 00:43:14,740 --> 00:43:16,640 on a piece of paper. 875 00:43:16,740 --> 00:43:18,640 And specifically what I like to do, and what 876 00:43:18,740 --> 00:43:20,640 I find works well for me, is you find 877 00:43:20,740 --> 00:43:22,640 a specific small example 878 00:43:22,740 --> 00:43:24,640 that you then fully write out, and then 879 00:43:24,740 --> 00:43:26,640 in the process of analyzing how that individual 880 00:43:26,740 --> 00:43:28,640 small example works, you will understand 881 00:43:28,740 --> 00:43:30,640 the broader pattern, and you'll be able to generalize 882 00:43:30,740 --> 00:43:32,640 and write out the full 883 00:43:32,740 --> 00:43:34,640 general formula for 884 00:43:34,640 --> 00:43:36,540 how these derivatives flow in an expression 885 00:43:36,640 --> 00:43:38,540 like this. So let's try that out. 886 00:43:38,640 --> 00:43:40,540 So pardon the low-budget production 887 00:43:40,640 --> 00:43:42,540 here, but what I've done here 888 00:43:42,640 --> 00:43:44,540 is I'm writing it out on a piece of paper. 889 00:43:44,640 --> 00:43:46,540 Really what we are interested in is we have 890 00:43:46,640 --> 00:43:48,540 a multiply b plus c, 891 00:43:48,640 --> 00:43:50,540 and that creates a d. 892 00:43:50,640 --> 00:43:52,540 And we have the derivative 893 00:43:52,640 --> 00:43:54,540 of the loss with respect to d, and we'd like to 894 00:43:54,640 --> 00:43:56,540 know what the derivative of the loss is with respect to a, b, 895 00:43:56,640 --> 00:43:58,540 and c. Now these 896 00:43:58,640 --> 00:44:00,540 here are little two-dimensional examples 897 00:44:00,640 --> 00:44:02,540 of matrix multiplication. 898 00:44:02,640 --> 00:44:04,540 2 by 2 times a 2 by 2, 899 00:44:04,540 --> 00:44:06,440 plus a 2, 900 00:44:06,540 --> 00:44:08,440 a vector of just two elements, c1 and c2, 901 00:44:08,540 --> 00:44:10,440 gives me a 2 by 2. 902 00:44:10,540 --> 00:44:12,440 Now notice here that 903 00:44:12,540 --> 00:44:14,440 I have a bias vector 904 00:44:14,540 --> 00:44:16,440 here called c, and the 905 00:44:16,540 --> 00:44:18,440 bias vector is c1 and c2, but 906 00:44:18,540 --> 00:44:20,440 as I described over here, that bias 907 00:44:20,540 --> 00:44:22,440 vector will become a row vector in the broadcasting, 908 00:44:22,540 --> 00:44:24,440 and will replicate vertically. 909 00:44:24,540 --> 00:44:26,440 So that's what's happening here as well. c1, c2 910 00:44:26,540 --> 00:44:28,440 is replicated vertically, 911 00:44:28,540 --> 00:44:30,440 and we see how we have two rows of c1, 912 00:44:30,540 --> 00:44:32,440 c2 as a result. 913 00:44:32,540 --> 00:44:34,440 So now when I say write it out, 914 00:44:34,540 --> 00:44:36,440 I just mean like this. 915 00:44:36,540 --> 00:44:38,440 Basically break up this matrix multiplication 916 00:44:38,540 --> 00:44:40,440 into the actual thing that's 917 00:44:40,540 --> 00:44:42,440 going on under the hood. 918 00:44:42,540 --> 00:44:44,440 So as a result of matrix multiplication 919 00:44:44,540 --> 00:44:46,440 and how it works, d11 920 00:44:46,540 --> 00:44:48,440 is the result of a dot product between the 921 00:44:48,540 --> 00:44:50,440 first row of a and the first column 922 00:44:50,540 --> 00:44:52,440 of b. So a11, b11, 923 00:44:52,540 --> 00:44:54,440 plus a12, b21, 924 00:44:54,540 --> 00:44:56,440 plus c1. 925 00:44:56,540 --> 00:44:58,440 And so on 926 00:44:58,540 --> 00:45:00,440 and so forth for all the other elements of d. 927 00:45:00,540 --> 00:45:02,440 And once you actually write 928 00:45:02,540 --> 00:45:04,440 it out, it becomes obvious that it's just a bunch of 929 00:45:04,440 --> 00:45:06,340 multiplies and adds. 930 00:45:06,440 --> 00:45:08,340 And we know from micrograd 931 00:45:08,440 --> 00:45:10,340 how to differentiate multiplies and adds. 932 00:45:10,440 --> 00:45:12,340 And so this is not scary anymore. 933 00:45:12,440 --> 00:45:14,340 It's not just matrix multiplication. 934 00:45:14,440 --> 00:45:16,340 It's just tedious, unfortunately. 935 00:45:16,440 --> 00:45:18,340 But this is completely tractable. 936 00:45:18,440 --> 00:45:20,340 We have dl by d for all of these, 937 00:45:20,440 --> 00:45:22,340 and we want dl by 938 00:45:22,440 --> 00:45:24,340 all these little other variables. 939 00:45:24,440 --> 00:45:26,340 So how do we achieve that, and how do we 940 00:45:26,440 --> 00:45:28,340 actually get the gradients? 941 00:45:28,440 --> 00:45:30,340 Okay, so the low-budget production continues here. 942 00:45:30,440 --> 00:45:32,340 So let's, for example, derive 943 00:45:32,440 --> 00:45:34,340 the derivative of the loss with respect to 944 00:45:34,340 --> 00:45:36,240 a11. 945 00:45:36,340 --> 00:45:38,240 We see here that a11 occurs twice 946 00:45:38,340 --> 00:45:40,240 in our simple expression, right here, right here, 947 00:45:40,340 --> 00:45:42,240 and influences d11 and d12. 948 00:45:42,340 --> 00:45:44,240 So this is, so what 949 00:45:44,340 --> 00:45:46,240 is dl by d a11? 950 00:45:46,340 --> 00:45:48,240 Well, it's dl by d11 951 00:45:48,340 --> 00:45:50,240 times 952 00:45:50,340 --> 00:45:52,240 the local derivative of d11, 953 00:45:52,340 --> 00:45:54,240 which in this case is just b11, 954 00:45:54,340 --> 00:45:56,240 because that's what's multiplying 955 00:45:56,340 --> 00:45:58,240 a11 here. 956 00:45:58,340 --> 00:46:00,240 And likewise here, the local 957 00:46:00,340 --> 00:46:02,240 derivative of d12 with respect to a11 958 00:46:02,340 --> 00:46:04,240 is just b12. 959 00:46:04,240 --> 00:46:06,140 And so b12 will, in the chain rule, therefore, 960 00:46:06,240 --> 00:46:08,140 multiply dl by d12. 961 00:46:08,240 --> 00:46:10,140 And then, because a11 962 00:46:10,240 --> 00:46:12,140 is used both to produce 963 00:46:12,240 --> 00:46:14,140 d11 and d12, we need 964 00:46:14,240 --> 00:46:16,140 to add up the contributions 965 00:46:16,240 --> 00:46:18,140 of both of those sort of 966 00:46:18,240 --> 00:46:20,140 chains that are running in parallel. 967 00:46:20,240 --> 00:46:22,140 And that's why we get a plus, just 968 00:46:22,240 --> 00:46:24,140 adding up those two, 969 00:46:24,240 --> 00:46:26,140 those two contributions. And that gives 970 00:46:26,240 --> 00:46:28,140 us dl by d a11. 971 00:46:28,240 --> 00:46:30,140 We can do the exact same analysis for 972 00:46:30,240 --> 00:46:32,140 the other one, for all the other 973 00:46:32,240 --> 00:46:34,140 elements of a. And when you 974 00:46:34,140 --> 00:46:36,040 simply write it out, it's just super 975 00:46:36,140 --> 00:46:38,040 simple taking of 976 00:46:38,140 --> 00:46:40,040 gradients on, you know, 977 00:46:40,140 --> 00:46:42,040 expressions like this. 978 00:46:42,140 --> 00:46:44,040 You find that 979 00:46:44,140 --> 00:46:46,040 this matrix, dl by d a, 980 00:46:46,140 --> 00:46:48,040 that we're after, right, if we 981 00:46:48,140 --> 00:46:50,040 just arrange all of them in the 982 00:46:50,140 --> 00:46:52,040 same shape as a takes, so 983 00:46:52,140 --> 00:46:54,040 a is just a 2x2 matrix, so 984 00:46:54,140 --> 00:46:56,040 dl by d a here will be 985 00:46:56,140 --> 00:46:58,040 also just the same 986 00:46:58,140 --> 00:47:00,040 shape 987 00:47:00,140 --> 00:47:02,040 tensor with the derivatives 988 00:47:02,140 --> 00:47:04,040 now. So dl by d a11 989 00:47:04,040 --> 00:47:05,940 etc. And we see that actually 990 00:47:06,040 --> 00:47:07,940 we can express what we've written out here 991 00:47:08,040 --> 00:47:09,940 as a matrix multiply. 992 00:47:10,040 --> 00:47:11,940 And so it just so 993 00:47:12,040 --> 00:47:13,940 happens that dl by, that all 994 00:47:14,040 --> 00:47:15,940 of these formulas that we've derived here 995 00:47:16,040 --> 00:47:17,940 by taking gradients can actually 996 00:47:18,040 --> 00:47:19,940 be expressed as a matrix multiplication. 997 00:47:20,040 --> 00:47:21,940 And in particular, we see that it is the matrix 998 00:47:22,040 --> 00:47:23,940 multiplication of these two 999 00:47:24,040 --> 00:47:25,940 matrices. So it 1000 00:47:26,040 --> 00:47:27,940 is the dl 1001 00:47:28,040 --> 00:47:29,940 by d and then matrix 1002 00:47:30,040 --> 00:47:31,940 multiplying b, but b 1003 00:47:32,040 --> 00:47:33,840 transpose, actually. So you see that 1004 00:47:33,840 --> 00:47:35,740 b21 and b12 1005 00:47:35,840 --> 00:47:37,740 have changed place, 1006 00:47:37,840 --> 00:47:39,740 whereas before we had, of course, 1007 00:47:39,840 --> 00:47:41,740 b11, b12, b21, 1008 00:47:41,840 --> 00:47:43,740 b22. So you see that 1009 00:47:43,840 --> 00:47:45,740 this other matrix, b, 1010 00:47:45,840 --> 00:47:47,740 is transposed. And so 1011 00:47:47,840 --> 00:47:49,740 basically what we have, long story short, just by 1012 00:47:49,840 --> 00:47:51,740 doing very simple reasoning here, 1013 00:47:51,840 --> 00:47:53,740 by breaking up the expression in the case of 1014 00:47:53,840 --> 00:47:55,740 a very simple example, is that 1015 00:47:55,840 --> 00:47:57,740 dl by d a is 1016 00:47:57,840 --> 00:47:59,740 which is this, is simply 1017 00:47:59,840 --> 00:48:01,740 equal to dl by dd matrix 1018 00:48:01,840 --> 00:48:03,740 multiplied with b transpose. 1019 00:48:03,840 --> 00:48:05,740 So that 1020 00:48:05,840 --> 00:48:07,740 is what we have so far. 1021 00:48:07,840 --> 00:48:09,740 Now, we also want the derivative with respect to 1022 00:48:09,840 --> 00:48:11,740 b and c. 1023 00:48:11,840 --> 00:48:13,740 Now, for 1024 00:48:13,840 --> 00:48:15,740 b, I'm not actually doing the full derivation 1025 00:48:15,840 --> 00:48:17,740 because, honestly, it's 1026 00:48:17,840 --> 00:48:19,740 not deep. It's just 1027 00:48:19,840 --> 00:48:21,740 annoying. It's exhausting. You can 1028 00:48:21,840 --> 00:48:23,740 actually do this analysis yourself. You'll 1029 00:48:23,840 --> 00:48:25,740 also find that if you take these expressions 1030 00:48:25,840 --> 00:48:27,740 and you differentiate with respect to b 1031 00:48:27,840 --> 00:48:29,740 instead of a, you will find that 1032 00:48:29,840 --> 00:48:31,740 dl by db is also a matrix 1033 00:48:31,840 --> 00:48:33,740 multiplication. In this case, you have to take 1034 00:48:33,740 --> 00:48:35,640 the matrix a and transpose 1035 00:48:35,740 --> 00:48:37,640 it and matrix multiply that with 1036 00:48:37,740 --> 00:48:39,640 dl by dd. 1037 00:48:39,740 --> 00:48:41,640 And that's what gives you dl by db. 1038 00:48:41,740 --> 00:48:43,640 And then here for 1039 00:48:43,740 --> 00:48:45,640 the offsets, c1 and 1040 00:48:45,740 --> 00:48:47,640 c2, if you again just differentiate 1041 00:48:47,740 --> 00:48:49,640 with respect to c1, you will find 1042 00:48:49,740 --> 00:48:51,640 an expression like this. 1043 00:48:51,740 --> 00:48:53,640 And c2, an expression 1044 00:48:53,740 --> 00:48:55,640 like this. And basically 1045 00:48:55,740 --> 00:48:57,640 you'll find that dl by dc is 1046 00:48:57,740 --> 00:48:59,640 simply, because they're just offsetting 1047 00:48:59,740 --> 00:49:01,640 these expressions, you just have to take 1048 00:49:01,740 --> 00:49:03,640 the dl by dd matrix 1049 00:49:03,640 --> 00:49:05,540 of the derivatives of d 1050 00:49:05,640 --> 00:49:07,540 and you just have to 1051 00:49:07,640 --> 00:49:09,540 sum across the columns. 1052 00:49:09,640 --> 00:49:11,540 And that gives you the derivatives 1053 00:49:11,640 --> 00:49:13,540 for c. 1054 00:49:13,640 --> 00:49:15,540 So, long story short, 1055 00:49:15,640 --> 00:49:17,540 the backward pass of a matrix 1056 00:49:17,640 --> 00:49:19,540 multiply is a matrix multiply. 1057 00:49:19,640 --> 00:49:21,540 And instead of, just like we had 1058 00:49:21,640 --> 00:49:23,540 d equals a times b plus c, 1059 00:49:23,640 --> 00:49:25,540 in a scalar case, 1060 00:49:25,640 --> 00:49:27,540 we sort of arrive at something very, very similar 1061 00:49:27,640 --> 00:49:29,540 but now with a matrix multiplication 1062 00:49:29,640 --> 00:49:31,540 instead of a scalar multiplication. 1063 00:49:31,640 --> 00:49:33,540 So, the derivative 1064 00:49:33,540 --> 00:49:35,440 of d with respect 1065 00:49:35,540 --> 00:49:37,440 to a is 1066 00:49:37,540 --> 00:49:39,440 dl by dd 1067 00:49:39,540 --> 00:49:41,440 matrix multiply b transpose 1068 00:49:41,540 --> 00:49:43,440 and here it's a transpose 1069 00:49:43,540 --> 00:49:45,440 multiply dl by dd. But in both 1070 00:49:45,540 --> 00:49:47,440 cases it's a matrix multiplication with 1071 00:49:47,540 --> 00:49:49,440 the derivative and 1072 00:49:49,540 --> 00:49:51,440 the other term in the 1073 00:49:51,540 --> 00:49:53,440 multiplication. And 1074 00:49:53,540 --> 00:49:55,440 for c it is a sum. 1075 00:49:55,540 --> 00:49:57,440 Now, I'll tell you a secret. 1076 00:49:57,540 --> 00:49:59,440 I can never remember the formulas 1077 00:49:59,540 --> 00:50:01,440 that we just derived for backpropagating 1078 00:50:01,440 --> 00:50:03,340 a matrix multiplication and I can backpropagate 1079 00:50:03,440 --> 00:50:05,340 through these expressions just fine. 1080 00:50:05,440 --> 00:50:07,340 And the reason this works is because 1081 00:50:07,440 --> 00:50:09,340 the dimensions have to work out. 1082 00:50:09,440 --> 00:50:11,340 So, let me give you an example. 1083 00:50:11,440 --> 00:50:13,340 Say I want to create dh. 1084 00:50:13,440 --> 00:50:15,340 Then what should dh be? 1085 00:50:15,440 --> 00:50:17,340 Number one, I have to know that 1086 00:50:17,440 --> 00:50:19,340 the shape of dh must be the same 1087 00:50:19,440 --> 00:50:21,340 as the shape of h. 1088 00:50:21,440 --> 00:50:23,340 And the shape of h is 32 by 64. 1089 00:50:23,440 --> 00:50:25,340 And then the other piece of information I know 1090 00:50:25,440 --> 00:50:27,340 is that dh 1091 00:50:27,440 --> 00:50:29,340 must be some kind of matrix multiplication 1092 00:50:29,340 --> 00:50:31,240 of dlogits with w2. 1093 00:50:31,340 --> 00:50:33,240 And dlogits 1094 00:50:33,340 --> 00:50:35,240 is 32 by 27 1095 00:50:35,340 --> 00:50:37,240 and w2 is 1096 00:50:37,340 --> 00:50:39,240 64 by 27. There is only 1097 00:50:39,340 --> 00:50:41,240 a single way to make the shape work out 1098 00:50:41,340 --> 00:50:43,240 in this case 1099 00:50:43,340 --> 00:50:45,240 and it is indeed the correct result. 1100 00:50:45,340 --> 00:50:47,240 In particular here, h 1101 00:50:47,340 --> 00:50:49,240 needs to be 32 by 64. The only 1102 00:50:49,340 --> 00:50:51,240 way to achieve that is to take dlogits 1103 00:50:51,340 --> 00:50:53,240 and matrix 1104 00:50:53,340 --> 00:50:55,240 multiply it with 1105 00:50:55,340 --> 00:50:57,240 you see how I have to take w2 but I have to 1106 00:50:57,340 --> 00:50:59,240 transpose it to make the dimensions work out. 1107 00:50:59,340 --> 00:51:01,240 So w2 transpose. 1108 00:51:01,340 --> 00:51:03,240 And it is the only way to make these 1109 00:51:03,340 --> 00:51:05,240 to matrix multiply those two pieces 1110 00:51:05,340 --> 00:51:07,240 to make the shapes work out. And that turns out 1111 00:51:07,340 --> 00:51:09,240 to be the correct formula. So if we come 1112 00:51:09,340 --> 00:51:11,240 here, we want 1113 00:51:11,340 --> 00:51:13,240 dh which is da and we see 1114 00:51:13,340 --> 00:51:15,240 that da is dl by 1115 00:51:15,340 --> 00:51:17,240 dd matrix multiply b transpose. 1116 00:51:17,340 --> 00:51:19,240 So that is dlogits multiply 1117 00:51:19,340 --> 00:51:21,240 and b is w2. 1118 00:51:21,340 --> 00:51:23,240 So w2 transpose which is exactly 1119 00:51:23,340 --> 00:51:25,240 what we have here. So 1120 00:51:25,340 --> 00:51:27,240 there is no need to remember these formulas. 1121 00:51:27,340 --> 00:51:29,240 Similarly, now if I 1122 00:51:29,240 --> 00:51:31,140 want dw2 1123 00:51:31,240 --> 00:51:33,140 well I know that it must be a matrix 1124 00:51:33,240 --> 00:51:35,140 multiplication of 1125 00:51:35,240 --> 00:51:37,140 dlogits and h 1126 00:51:37,240 --> 00:51:39,140 and maybe there is a few transpose 1127 00:51:39,240 --> 00:51:41,140 like there is one transpose in there as well. 1128 00:51:41,240 --> 00:51:43,140 And I don't know which way it is so I have to come to w2 1129 00:51:43,240 --> 00:51:45,140 and I see that its shape is 1130 00:51:45,240 --> 00:51:47,140 64 by 27 1131 00:51:47,240 --> 00:51:49,140 and that has to come from some matrix 1132 00:51:49,240 --> 00:51:51,140 multiplication of these two. 1133 00:51:51,240 --> 00:51:53,140 And so to get a 64 by 27 1134 00:51:53,240 --> 00:51:55,140 I need to take 1135 00:51:55,240 --> 00:51:57,140 h 1136 00:51:57,240 --> 00:51:59,140 I need to transpose it 1137 00:51:59,240 --> 00:52:01,140 and then I need to matrix multiply it 1138 00:52:01,240 --> 00:52:03,140 so that will become 64 by 32 1139 00:52:03,240 --> 00:52:05,140 and then I need to matrix multiply it with 1140 00:52:05,240 --> 00:52:07,140 32 by 27 and that's going to give me 1141 00:52:07,240 --> 00:52:09,140 a 64 by 27. So I need 1142 00:52:09,240 --> 00:52:11,140 to matrix multiply this with dlogits.shape 1143 00:52:11,240 --> 00:52:13,140 just like that. That's the only way 1144 00:52:13,240 --> 00:52:15,140 to make the dimensions work out and 1145 00:52:15,240 --> 00:52:17,140 just use matrix multiplication. 1146 00:52:17,240 --> 00:52:19,140 And if we come here, we see that 1147 00:52:19,240 --> 00:52:21,140 that's exactly what's here. So a transpose 1148 00:52:21,240 --> 00:52:23,140 a for us is h 1149 00:52:23,240 --> 00:52:25,140 multiplied with dlogits. 1150 00:52:25,240 --> 00:52:27,140 So that's w2 1151 00:52:27,240 --> 00:52:29,140 and then db2 1152 00:52:29,140 --> 00:52:31,040 is just 1153 00:52:31,140 --> 00:52:33,040 the 1154 00:52:33,140 --> 00:52:35,040 vertical sum and actually 1155 00:52:35,140 --> 00:52:37,040 in the same way, there's only one way to make 1156 00:52:37,140 --> 00:52:39,040 the shapes work out. I don't have to remember that 1157 00:52:39,140 --> 00:52:41,040 it's a vertical sum along the 0th axis 1158 00:52:41,140 --> 00:52:43,040 because that's the only way that this makes sense. 1159 00:52:43,140 --> 00:52:45,040 Because b2's shape is 27 1160 00:52:45,140 --> 00:52:47,040 so in order to get 1161 00:52:47,140 --> 00:52:49,040 a dlogits 1162 00:52:49,140 --> 00:52:51,040 here 1163 00:52:51,140 --> 00:52:53,040 it's 32 by 27 so 1164 00:52:53,140 --> 00:52:55,040 knowing that it's just sum over dlogits 1165 00:52:55,140 --> 00:52:57,040 in some direction 1166 00:52:57,040 --> 00:53:00,940 that direction must be 0 1167 00:53:01,040 --> 00:53:02,940 because I need to eliminate this dimension. 1168 00:53:03,040 --> 00:53:04,940 So it's this. 1169 00:53:05,040 --> 00:53:06,940 So this is 1170 00:53:07,040 --> 00:53:08,940 kind of like the hacky way. 1171 00:53:09,040 --> 00:53:10,940 Let me copy paste and delete that 1172 00:53:11,040 --> 00:53:12,940 and let me swing over here 1173 00:53:13,040 --> 00:53:14,940 and this is our backward pass for the linear layer. 1174 00:53:15,040 --> 00:53:16,940 Hopefully. 1175 00:53:17,040 --> 00:53:18,940 So now let's uncomment 1176 00:53:19,040 --> 00:53:20,940 these three and we're checking that 1177 00:53:21,040 --> 00:53:22,940 we got all the three 1178 00:53:23,040 --> 00:53:24,940 derivatives correct and 1179 00:53:25,040 --> 00:53:26,940 run 1180 00:53:26,940 --> 00:53:28,840 and we see that h, 1181 00:53:28,940 --> 00:53:30,840 w2 and b2 are all exactly correct. 1182 00:53:30,940 --> 00:53:32,840 So we backpropagate it through 1183 00:53:32,940 --> 00:53:34,840 a linear layer. 1184 00:53:34,940 --> 00:53:36,840 Now next up 1185 00:53:36,940 --> 00:53:38,840 we have derivative for the h 1186 00:53:38,940 --> 00:53:40,840 already and we need to backpropagate 1187 00:53:40,940 --> 00:53:42,840 through tanh into h preact. 1188 00:53:42,940 --> 00:53:44,840 So we want to derive 1189 00:53:44,940 --> 00:53:46,840 dh preact 1190 00:53:46,940 --> 00:53:48,840 and here we have to backpropagate through a tanh 1191 00:53:48,940 --> 00:53:50,840 and we've already done this in micrograd 1192 00:53:50,940 --> 00:53:52,840 and we remember that tanh is a very simple 1193 00:53:52,940 --> 00:53:54,840 backward formula. Now unfortunately 1194 00:53:54,940 --> 00:53:56,840 if I just put in d by dx of f 1195 00:53:56,940 --> 00:53:58,840 tanh of x into volt from alpha 1196 00:53:58,940 --> 00:54:00,840 it lets us down. It tells us that it's a 1197 00:54:00,940 --> 00:54:02,840 hyperbolic secant function squared 1198 00:54:02,940 --> 00:54:04,840 of x. It's not exactly helpful 1199 00:54:04,940 --> 00:54:06,840 but luckily google image 1200 00:54:06,940 --> 00:54:08,840 search does not let us down and it gives 1201 00:54:08,940 --> 00:54:10,840 us the simpler formula. In particular 1202 00:54:10,940 --> 00:54:12,840 if you have that a is equal to tanh 1203 00:54:12,940 --> 00:54:14,840 of z then da by 1204 00:54:14,940 --> 00:54:16,840 dz backpropagating through tanh 1205 00:54:16,940 --> 00:54:18,840 is just 1 minus a square 1206 00:54:18,940 --> 00:54:20,840 and take note that 1 1207 00:54:20,940 --> 00:54:22,840 minus a square a here is the 1208 00:54:22,940 --> 00:54:24,840 output of the tanh not the input to 1209 00:54:24,940 --> 00:54:26,840 the tanh z. So 1210 00:54:26,840 --> 00:54:28,740 the da by dz is here 1211 00:54:28,840 --> 00:54:30,740 formulated in terms of the output of that tanh 1212 00:54:30,840 --> 00:54:32,740 and here also 1213 00:54:32,840 --> 00:54:34,740 in google image search we have the full derivation 1214 00:54:34,840 --> 00:54:36,740 if you want to actually take the 1215 00:54:36,840 --> 00:54:38,740 actual definition of tanh and work 1216 00:54:38,840 --> 00:54:40,740 through the math to figure out 1 minus 1217 00:54:40,840 --> 00:54:42,740 tanh square of z. So 1218 00:54:42,840 --> 00:54:44,740 1 minus a square is 1219 00:54:44,840 --> 00:54:46,740 the local derivative. In our case 1220 00:54:46,840 --> 00:54:48,740 that is 1 minus 1221 00:54:48,840 --> 00:54:50,740 the output of tanh 1222 00:54:50,840 --> 00:54:52,740 square which here is h 1223 00:54:52,840 --> 00:54:54,740 so it's h square 1224 00:54:54,840 --> 00:54:56,740 and that is the local derivative 1225 00:54:56,840 --> 00:54:58,740 and then times the chain rule 1226 00:54:58,840 --> 00:55:00,740 dh. So 1227 00:55:00,840 --> 00:55:02,740 that is going to be our candidate implementation 1228 00:55:02,840 --> 00:55:04,740 so if we come here 1229 00:55:04,840 --> 00:55:06,740 and then uncomment 1230 00:55:06,840 --> 00:55:08,740 this let's hope for the best 1231 00:55:08,840 --> 00:55:10,740 and we have 1232 00:55:10,840 --> 00:55:12,740 the right answer. Okay next 1233 00:55:12,840 --> 00:55:14,740 up we have dh preact and 1234 00:55:14,840 --> 00:55:16,740 we want to backpropagate into the gain 1235 00:55:16,840 --> 00:55:18,740 the b in raw and the b in bias. 1236 00:55:18,840 --> 00:55:20,740 So here this is the bash norm 1237 00:55:20,840 --> 00:55:22,740 parameters b in gain and bias inside 1238 00:55:22,840 --> 00:55:24,740 the bash norm that take the b in raw 1239 00:55:24,840 --> 00:55:26,740 that is exact unit Gaussian 1240 00:55:26,740 --> 00:55:28,640 and they scale it and shift it 1241 00:55:28,740 --> 00:55:30,640 and these are the parameters of the 1242 00:55:30,740 --> 00:55:32,640 bash norm. Now here 1243 00:55:32,740 --> 00:55:34,640 we have a multiplication but 1244 00:55:34,740 --> 00:55:36,640 it's worth noting that this multiply is very very 1245 00:55:36,740 --> 00:55:38,640 different from this matrix multiply here 1246 00:55:38,740 --> 00:55:40,640 matrix multiply are dot products 1247 00:55:40,740 --> 00:55:42,640 between rows and columns of these 1248 00:55:42,740 --> 00:55:44,640 matrices involved. This is an 1249 00:55:44,740 --> 00:55:46,640 element wise multiply so things are quite a bit 1250 00:55:46,740 --> 00:55:48,640 simpler. Now we do have to 1251 00:55:48,740 --> 00:55:50,640 be careful with some of the broadcasting happening 1252 00:55:50,740 --> 00:55:52,640 in this line of code though. So you 1253 00:55:52,740 --> 00:55:54,640 see how b in gain and b in bias 1254 00:55:54,740 --> 00:55:56,640 are 1 by 64 1255 00:55:56,740 --> 00:55:58,640 but h preact and 1256 00:55:58,740 --> 00:56:00,640 b in raw are 32 by 64. 1257 00:56:00,740 --> 00:56:02,640 So 1258 00:56:02,740 --> 00:56:04,640 we have to be careful with that and make sure that all the shapes 1259 00:56:04,740 --> 00:56:06,640 work out fine and that the broadcasting is 1260 00:56:06,740 --> 00:56:08,640 correctly backpropagated. So 1261 00:56:08,740 --> 00:56:10,640 in particular let's start with db in gain 1262 00:56:10,740 --> 00:56:12,640 so db in gain 1263 00:56:12,740 --> 00:56:14,640 should be 1264 00:56:14,740 --> 00:56:16,640 and here this is again element wise 1265 00:56:16,740 --> 00:56:18,640 multiply and whenever we have a times 1266 00:56:18,740 --> 00:56:20,640 b equals c we saw that 1267 00:56:20,740 --> 00:56:22,640 the local derivative here is just if this 1268 00:56:22,740 --> 00:56:24,640 is a the local derivative is just the 1269 00:56:24,740 --> 00:56:26,640 b the other one. So this 1270 00:56:26,740 --> 00:56:28,640 local derivative is just b in raw 1271 00:56:28,740 --> 00:56:30,640 and then times chain rule 1272 00:56:30,740 --> 00:56:32,640 so dh preact. 1273 00:56:32,740 --> 00:56:34,640 So 1274 00:56:34,740 --> 00:56:36,640 this is the candidate 1275 00:56:36,740 --> 00:56:38,640 gradient. Now again 1276 00:56:38,740 --> 00:56:40,640 we have to be careful because b in gain 1277 00:56:40,740 --> 00:56:42,640 is of size 1 by 64 1278 00:56:42,740 --> 00:56:44,640 but this here 1279 00:56:44,740 --> 00:56:46,640 would be 32 by 64 1280 00:56:46,740 --> 00:56:48,640 and so 1281 00:56:48,740 --> 00:56:50,640 the correct thing to do 1282 00:56:50,740 --> 00:56:52,640 in this case of course is that b in gain 1283 00:56:52,740 --> 00:56:54,640 here is a rule vector of 64 numbers 1284 00:56:54,740 --> 00:56:56,640 it gets replicated vertically 1285 00:56:56,640 --> 00:56:58,540 in this operation 1286 00:56:58,640 --> 00:57:00,540 and so therefore the correct thing to do 1287 00:57:00,640 --> 00:57:02,540 is to sum because it's being replicated 1288 00:57:02,640 --> 00:57:04,540 and therefore 1289 00:57:04,640 --> 00:57:06,540 all the gradients in each of the rows 1290 00:57:06,640 --> 00:57:08,540 that are now flowing backwards 1291 00:57:08,640 --> 00:57:10,540 need to sum up to that same 1292 00:57:10,640 --> 00:57:12,540 tensor db in gain. 1293 00:57:12,640 --> 00:57:14,540 So we have to sum across all the zero 1294 00:57:14,640 --> 00:57:16,540 all the examples 1295 00:57:16,640 --> 00:57:18,540 basically which is the direction 1296 00:57:18,640 --> 00:57:20,540 in which this gets replicated 1297 00:57:20,640 --> 00:57:22,540 and now we have to be also careful because 1298 00:57:22,640 --> 00:57:24,540 b in gain is of shape 1299 00:57:24,640 --> 00:57:26,540 1 by 64. So in fact 1300 00:57:26,640 --> 00:57:28,540 I need to keep them as true 1301 00:57:28,640 --> 00:57:30,540 otherwise I would just get 64. 1302 00:57:30,640 --> 00:57:32,540 Now I don't actually 1303 00:57:32,640 --> 00:57:34,540 really remember why 1304 00:57:34,640 --> 00:57:36,540 the b in gain and the b in bias 1305 00:57:36,640 --> 00:57:38,540 I made them be 1 by 64 1306 00:57:38,640 --> 00:57:40,540 but the biases 1307 00:57:40,640 --> 00:57:42,540 b1 and b2 1308 00:57:42,640 --> 00:57:44,540 I just made them be one-dimensional vectors 1309 00:57:44,640 --> 00:57:46,540 they're not two-dimensional tensors 1310 00:57:46,640 --> 00:57:48,540 so I can't recall exactly why 1311 00:57:48,640 --> 00:57:50,540 I left the gain 1312 00:57:50,640 --> 00:57:52,540 and the bias as two-dimensional 1313 00:57:52,640 --> 00:57:54,540 but it doesn't really matter as long as you are consistent 1314 00:57:54,640 --> 00:57:56,540 and you're keeping it the same. 1315 00:57:56,640 --> 00:57:58,540 So in this case we want to keep the dimension 1316 00:57:58,640 --> 00:58:00,540 so that the tensor shapes work. 1317 00:58:00,640 --> 00:58:02,540 Next up we have 1318 00:58:02,640 --> 00:58:04,540 b in raw 1319 00:58:04,640 --> 00:58:06,540 so db in raw will be 1320 00:58:06,640 --> 00:58:08,540 b in gain 1321 00:58:08,640 --> 00:58:10,540 multiplying 1322 00:58:10,640 --> 00:58:12,540 dh preact 1323 00:58:12,640 --> 00:58:14,540 that's our chain rule. 1324 00:58:14,640 --> 00:58:16,540 Now what about the 1325 00:58:16,640 --> 00:58:18,540 dimensions of this? 1326 00:58:18,640 --> 00:58:20,540 We have to be careful, right? 1327 00:58:20,640 --> 00:58:22,540 So dh preact is 1328 00:58:22,640 --> 00:58:24,540 32 by 64 1329 00:58:24,640 --> 00:58:26,540 b in gain is 1 by 64 1330 00:58:26,540 --> 00:58:28,440 so it will just get replicated 1331 00:58:28,540 --> 00:58:30,440 to create this multiplication 1332 00:58:30,540 --> 00:58:32,440 which is the correct thing 1333 00:58:32,540 --> 00:58:34,440 because in a forward pass it also gets replicated 1334 00:58:34,540 --> 00:58:36,440 in just the same way. 1335 00:58:36,540 --> 00:58:38,440 So in fact we don't need the brackets here, we're done. 1336 00:58:38,540 --> 00:58:40,440 And the shapes are already correct. 1337 00:58:40,540 --> 00:58:42,440 And finally for the bias 1338 00:58:42,540 --> 00:58:44,440 very similar 1339 00:58:44,540 --> 00:58:46,440 this bias here is very very similar 1340 00:58:46,540 --> 00:58:48,440 to the bias we saw in the linear layer 1341 00:58:48,540 --> 00:58:50,440 and we see that the gradients 1342 00:58:50,540 --> 00:58:52,440 from h preact will simply flow 1343 00:58:52,540 --> 00:58:54,440 into the biases and add up 1344 00:58:54,540 --> 00:58:56,440 because these are just offsets. 1345 00:58:56,540 --> 00:58:58,440 And so basically we want this to be 1346 00:58:58,540 --> 00:59:00,440 dh preact but it needs 1347 00:59:00,540 --> 00:59:02,440 to sum along the right dimension 1348 00:59:02,540 --> 00:59:04,440 and in this case similar to the gain 1349 00:59:04,540 --> 00:59:06,440 we need to sum across the 0th 1350 00:59:06,540 --> 00:59:08,440 dimension, the examples 1351 00:59:08,540 --> 00:59:10,440 because of the way that the bias gets replicated 1352 00:59:10,540 --> 00:59:12,440 vertically and we also 1353 00:59:12,540 --> 00:59:14,440 want to have keep them as true. 1354 00:59:14,540 --> 00:59:16,440 And so this will basically take 1355 00:59:16,540 --> 00:59:18,440 this and sum it up and give us 1356 00:59:18,540 --> 00:59:20,440 a 1 by 64. 1357 00:59:20,540 --> 00:59:22,440 So this is the candidate implementation 1358 00:59:22,540 --> 00:59:24,440 it makes all the shapes work 1359 00:59:24,440 --> 00:59:26,340 let me bring it up 1360 00:59:26,440 --> 00:59:28,340 down here 1361 00:59:28,440 --> 00:59:30,340 and then let me uncomment these 3 lines 1362 00:59:30,440 --> 00:59:32,340 to check that 1363 00:59:32,440 --> 00:59:34,340 we are getting the correct result 1364 00:59:34,440 --> 00:59:36,340 for all the 3 tensors 1365 00:59:36,440 --> 00:59:38,340 and indeed we see that all of that 1366 00:59:38,440 --> 00:59:40,340 got backpropagated correctly. 1367 00:59:40,440 --> 00:59:42,340 So now we get to the batchnorm layer 1368 00:59:42,440 --> 00:59:44,340 we see how here bngain and bmbias 1369 00:59:44,440 --> 00:59:46,340 are the primers so the backpropagation ends 1370 00:59:46,440 --> 00:59:48,340 but bnraw now 1371 00:59:48,440 --> 00:59:50,340 is the output of the standardization 1372 00:59:50,440 --> 00:59:52,340 so here what I'm doing 1373 00:59:52,440 --> 00:59:54,340 of course is I'm breaking up the batchnorm 1374 00:59:54,340 --> 00:59:56,240 into manageable pieces so we can backpropagate 1375 00:59:56,340 --> 00:59:58,240 through each line individually 1376 00:59:58,340 --> 01:00:00,240 but basically what's happening is 1377 01:00:00,340 --> 01:00:02,240 bnmeani is the sum 1378 01:00:02,340 --> 01:00:04,240 so this is the 1379 01:00:04,340 --> 01:00:06,240 bnmeani 1380 01:00:06,340 --> 01:00:08,240 I apologize for the variable naming 1381 01:00:08,340 --> 01:00:10,240 bndiff is x minus mu 1382 01:00:10,340 --> 01:00:12,240 bndiff2 1383 01:00:12,340 --> 01:00:14,240 is x minus mu squared 1384 01:00:14,340 --> 01:00:16,240 here inside the variance 1385 01:00:16,340 --> 01:00:18,240 bnvar is the variance 1386 01:00:18,340 --> 01:00:20,240 so sigma square 1387 01:00:20,340 --> 01:00:22,240 this is bnvar 1388 01:00:22,340 --> 01:00:24,240 and it's basically the sum of squares 1389 01:00:24,340 --> 01:00:26,240 so this is the x minus mu 1390 01:00:26,340 --> 01:00:28,240 squared 1391 01:00:28,340 --> 01:00:30,240 and then the sum 1392 01:00:30,340 --> 01:00:32,240 now you'll notice one departure here 1393 01:00:32,340 --> 01:00:34,240 here it is normalized as 1 over m 1394 01:00:34,340 --> 01:00:36,240 which is the number of examples 1395 01:00:36,340 --> 01:00:38,240 here I'm normalizing 1396 01:00:38,340 --> 01:00:40,240 as 1 over n minus 1 instead of m 1397 01:00:40,340 --> 01:00:42,240 and this is deliberate and I'll come back to that 1398 01:00:42,340 --> 01:00:44,240 in a bit when we are at this line 1399 01:00:44,340 --> 01:00:46,240 it is something called the Bessel's correction 1400 01:00:46,340 --> 01:00:48,240 but this is how I want it 1401 01:00:48,340 --> 01:00:50,240 in our case 1402 01:00:50,340 --> 01:00:52,240 bnvar inv 1403 01:00:52,340 --> 01:00:54,240 then becomes basically bnvar 1404 01:00:54,340 --> 01:00:56,240 plus epsilon 1405 01:00:56,340 --> 01:00:58,240 epsilon is 1 negative 5 1406 01:00:58,340 --> 01:01:00,240 and then it's 1 over square root 1407 01:01:00,340 --> 01:01:02,240 is the same as raising to the power of 1408 01:01:02,340 --> 01:01:04,240 negative 0.5 1409 01:01:04,340 --> 01:01:06,240 because 0.5 is square root 1410 01:01:06,340 --> 01:01:08,240 and then negative makes it 1 over square root 1411 01:01:08,340 --> 01:01:10,240 so bnvar inv 1412 01:01:10,340 --> 01:01:12,240 is 1 over this denominator here 1413 01:01:12,340 --> 01:01:14,240 and then we can see that 1414 01:01:14,340 --> 01:01:16,240 bnraw which is the x hat here 1415 01:01:16,340 --> 01:01:18,240 is equal to the 1416 01:01:18,340 --> 01:01:20,240 bndiff the numerator 1417 01:01:20,340 --> 01:01:22,240 multiplied by the 1418 01:01:22,340 --> 01:01:24,240 bnvar inv 1419 01:01:24,240 --> 01:01:26,140 and this line here 1420 01:01:26,240 --> 01:01:28,140 that creates hpreact was the last piece 1421 01:01:28,240 --> 01:01:30,140 we've already backpropagated through it 1422 01:01:30,240 --> 01:01:32,140 so now what we want to do 1423 01:01:32,240 --> 01:01:34,140 is we are here 1424 01:01:34,240 --> 01:01:36,140 and we have bnraw 1425 01:01:36,240 --> 01:01:38,140 and we have to first backpropagate 1426 01:01:38,240 --> 01:01:40,140 into bndiff and bnvar inv 1427 01:01:40,240 --> 01:01:42,140 so now we are here 1428 01:01:42,240 --> 01:01:44,140 and we have dbnraw 1429 01:01:44,240 --> 01:01:46,140 and we need to backpropagate through this line 1430 01:01:46,240 --> 01:01:48,140 now I've written out the shapes here 1431 01:01:48,240 --> 01:01:50,140 and indeed 1432 01:01:50,240 --> 01:01:52,140 bnvar inv is a shape 1 by 64 1433 01:01:52,240 --> 01:01:54,140 so there is a 1434 01:01:54,140 --> 01:01:56,040 little bit of broadcasting happening here 1435 01:01:56,140 --> 01:01:58,040 that we have to be careful with 1436 01:01:58,140 --> 01:02:00,040 but it is just an elementwise simple multiplication 1437 01:02:00,140 --> 01:02:02,040 by now we should be pretty comfortable with that 1438 01:02:02,140 --> 01:02:04,040 to get dbndiff 1439 01:02:04,140 --> 01:02:06,040 we know that this is just 1440 01:02:06,140 --> 01:02:08,040 bnvar inv multiplied with 1441 01:02:08,140 --> 01:02:10,040 dbnraw 1442 01:02:10,140 --> 01:02:12,040 and conversely 1443 01:02:12,140 --> 01:02:14,040 to get dbnvar inv 1444 01:02:14,140 --> 01:02:16,040 we need to take 1445 01:02:16,140 --> 01:02:18,040 bndiff 1446 01:02:18,140 --> 01:02:20,040 and multiply that by dbnraw 1447 01:02:22,140 --> 01:02:24,040 so this is the candidate 1448 01:02:24,040 --> 01:02:25,940 but of course 1449 01:02:26,040 --> 01:02:27,940 we need to make sure that broadcasting is obeyed 1450 01:02:28,040 --> 01:02:29,940 so in particular 1451 01:02:30,040 --> 01:02:31,940 bnvar inv multiplying with dbnraw 1452 01:02:32,040 --> 01:02:33,940 will be okay 1453 01:02:34,040 --> 01:02:35,940 and give us 32 by 64 as we expect 1454 01:02:36,040 --> 01:02:37,940 but dbnvar inv 1455 01:02:38,040 --> 01:02:39,940 would be taking 1456 01:02:40,040 --> 01:02:41,940 a 32 by 64 1457 01:02:42,040 --> 01:02:43,940 multiplying it by 1458 01:02:44,040 --> 01:02:45,940 32 by 64 1459 01:02:46,040 --> 01:02:47,940 so this is a 32 by 64 1460 01:02:48,040 --> 01:02:49,940 but of course this bnvar inv 1461 01:02:50,040 --> 01:02:51,940 is only 1 by 64 1462 01:02:52,040 --> 01:02:53,940 so this second line here 1463 01:02:53,940 --> 01:02:55,840 needs a sum across the examples 1464 01:02:55,940 --> 01:02:57,840 and because there's this 1465 01:02:57,940 --> 01:02:59,840 dimension here we need to make sure that 1466 01:02:59,940 --> 01:03:01,840 keep them is true 1467 01:03:01,940 --> 01:03:03,840 so this is the candidate 1468 01:03:03,940 --> 01:03:05,840 let's erase this 1469 01:03:05,940 --> 01:03:07,840 and let's swing down here 1470 01:03:07,940 --> 01:03:09,840 and implement it 1471 01:03:09,940 --> 01:03:11,840 and then let's comment out 1472 01:03:11,940 --> 01:03:13,840 dbnvar inv 1473 01:03:13,940 --> 01:03:15,840 and dbndiff 1474 01:03:15,940 --> 01:03:17,840 now we'll actually notice 1475 01:03:17,940 --> 01:03:19,840 that dbndiff by the way 1476 01:03:19,940 --> 01:03:21,840 is going to be incorrect 1477 01:03:21,940 --> 01:03:23,840 so when I run this 1478 01:03:23,840 --> 01:03:25,740 bnvar inv is correct 1479 01:03:25,840 --> 01:03:27,740 bndiff is not correct 1480 01:03:27,840 --> 01:03:29,740 and this is actually expected 1481 01:03:29,840 --> 01:03:31,740 because we're not done 1482 01:03:31,840 --> 01:03:33,740 with bndiff 1483 01:03:33,840 --> 01:03:35,740 so in particular when we slide here 1484 01:03:35,840 --> 01:03:37,740 we see here that bnraw is a function of bndiff 1485 01:03:37,840 --> 01:03:39,740 but actually bnvar inv 1486 01:03:39,840 --> 01:03:41,740 is a function of bnvar 1487 01:03:41,840 --> 01:03:43,740 which is a function of bndiff too 1488 01:03:43,840 --> 01:03:45,740 which is a function of bndiff 1489 01:03:45,840 --> 01:03:47,740 so it comes here 1490 01:03:47,840 --> 01:03:49,740 so bdndiff 1491 01:03:49,840 --> 01:03:51,740 these variable names are crazy I'm sorry 1492 01:03:51,840 --> 01:03:53,740 it branches out into two branches 1493 01:03:53,740 --> 01:03:55,640 we've only done one branch of it 1494 01:03:55,740 --> 01:03:57,640 we have to continue our backpropagation 1495 01:03:57,740 --> 01:03:59,640 and eventually come back to bndiff 1496 01:03:59,740 --> 01:04:01,640 and then we'll be able to do a plus equals 1497 01:04:01,740 --> 01:04:03,640 and get the actual correct gradient 1498 01:04:03,740 --> 01:04:05,640 for now it is good to verify that cmp also works 1499 01:04:05,740 --> 01:04:07,640 it doesn't just lie to us 1500 01:04:07,740 --> 01:04:09,640 and tell us that everything is always correct 1501 01:04:09,740 --> 01:04:11,640 it can in fact detect when your 1502 01:04:11,740 --> 01:04:13,640 gradient is not correct 1503 01:04:13,740 --> 01:04:15,640 so that's good to see as well 1504 01:04:15,740 --> 01:04:17,640 okay so now we have the derivative here 1505 01:04:17,740 --> 01:04:19,640 and we're trying to backpropagate through this line 1506 01:04:19,740 --> 01:04:21,640 and because we're raising to a power of negative 0.5 1507 01:04:21,740 --> 01:04:23,640 I brought up the power rule 1508 01:04:23,640 --> 01:04:25,540 and we see that basically we have that 1509 01:04:25,640 --> 01:04:27,540 the bnvar will now be 1510 01:04:27,640 --> 01:04:29,540 we bring down the exponent 1511 01:04:29,640 --> 01:04:31,540 so negative 0.5 times x 1512 01:04:31,640 --> 01:04:33,540 which is this 1513 01:04:33,640 --> 01:04:35,540 and now raised to the power of 1514 01:04:35,640 --> 01:04:37,540 negative 0.5 minus 1 1515 01:04:37,640 --> 01:04:39,540 which is negative 1.5 1516 01:04:39,640 --> 01:04:41,540 now we would have to also apply 1517 01:04:41,640 --> 01:04:43,540 a small chain rule here in our head 1518 01:04:43,640 --> 01:04:45,540 because we need to take further 1519 01:04:45,640 --> 01:04:47,540 the derivative of bnvar 1520 01:04:47,640 --> 01:04:49,540 with respect to this expression here 1521 01:04:49,640 --> 01:04:51,540 inside the bracket 1522 01:04:51,640 --> 01:04:53,540 but because this is an element-wise operation 1523 01:04:53,540 --> 01:04:55,440 everything is fairly simple 1524 01:04:55,540 --> 01:04:57,440 that's just one 1525 01:04:57,540 --> 01:04:59,440 and so there's nothing to do there 1526 01:04:59,540 --> 01:05:01,440 so this is the local derivative 1527 01:05:01,540 --> 01:05:03,440 and then times the global derivative 1528 01:05:03,540 --> 01:05:05,440 to create the chain rule 1529 01:05:05,540 --> 01:05:07,440 this is just times the bnvar 1530 01:05:07,540 --> 01:05:09,440 so this is our candidate 1531 01:05:09,540 --> 01:05:11,440 let me bring this down 1532 01:05:11,540 --> 01:05:13,440 and uncomment the check 1533 01:05:13,540 --> 01:05:15,440 and we see that 1534 01:05:15,540 --> 01:05:17,440 we have the correct result 1535 01:05:17,540 --> 01:05:19,440 now before we backpropagate through the next line 1536 01:05:19,540 --> 01:05:21,440 I want to briefly talk about the node here 1537 01:05:21,540 --> 01:05:23,440 where I'm using the Bessel's correction 1538 01:05:23,440 --> 01:05:25,340 which is 1 over n minus 1 1539 01:05:25,440 --> 01:05:27,340 instead of dividing by n 1540 01:05:27,440 --> 01:05:29,340 when I normalize here 1541 01:05:29,440 --> 01:05:31,340 the sum of squares 1542 01:05:31,440 --> 01:05:33,340 now you'll notice that this is a departure from the paper 1543 01:05:33,440 --> 01:05:35,340 which uses 1 over n instead 1544 01:05:35,440 --> 01:05:37,340 not 1 over n minus 1 1545 01:05:37,440 --> 01:05:39,340 there m is rn 1546 01:05:39,440 --> 01:05:41,340 so it turns out that there are two ways 1547 01:05:41,440 --> 01:05:43,340 of estimating variance of an array 1548 01:05:43,440 --> 01:05:45,340 one is the biased estimate 1549 01:05:45,440 --> 01:05:47,340 which is 1 over n 1550 01:05:47,440 --> 01:05:49,340 and the other one is the unbiased estimate 1551 01:05:49,440 --> 01:05:51,340 which is 1 over n minus 1 1552 01:05:51,440 --> 01:05:53,340 now confusingly in the paper 1553 01:05:53,340 --> 01:05:55,240 it's not very clearly described 1554 01:05:55,340 --> 01:05:57,240 and also it's a detail that kind of matters 1555 01:05:57,340 --> 01:05:59,240 I think 1556 01:05:59,340 --> 01:06:01,240 we are using the biased version at training time 1557 01:06:01,340 --> 01:06:03,240 but later when they are talking about the inference 1558 01:06:03,340 --> 01:06:05,240 they are mentioning that 1559 01:06:05,340 --> 01:06:07,240 when they do the inference 1560 01:06:07,340 --> 01:06:09,240 they are using the unbiased estimate 1561 01:06:09,340 --> 01:06:11,240 which is the n minus 1 version 1562 01:06:11,340 --> 01:06:13,240 in basically 1563 01:06:13,340 --> 01:06:15,240 for inference 1564 01:06:15,340 --> 01:06:17,240 and to calibrate the running mean 1565 01:06:17,340 --> 01:06:19,240 and the running variance basically 1566 01:06:19,340 --> 01:06:21,240 and so they actually introduce 1567 01:06:21,340 --> 01:06:23,240 a train test mismatch 1568 01:06:23,340 --> 01:06:25,240 where in training they use the biased version 1569 01:06:25,340 --> 01:06:27,240 and in test time they use the unbiased version 1570 01:06:27,340 --> 01:06:29,240 I find this extremely confusing 1571 01:06:29,340 --> 01:06:31,240 you can read more about 1572 01:06:31,340 --> 01:06:33,240 the Bessel's correction 1573 01:06:33,340 --> 01:06:35,240 and why dividing by n minus 1 1574 01:06:35,340 --> 01:06:37,240 gives you a better estimate of the variance 1575 01:06:37,340 --> 01:06:39,240 in the case where you have population sizes 1576 01:06:39,340 --> 01:06:41,240 or samples from a population 1577 01:06:41,340 --> 01:06:43,240 that are very small 1578 01:06:43,340 --> 01:06:45,240 and that is indeed the case for us 1579 01:06:45,340 --> 01:06:47,240 because we are dealing with mini-matches 1580 01:06:47,340 --> 01:06:49,240 and these mini-matches are a small sample 1581 01:06:49,340 --> 01:06:51,240 of a larger population 1582 01:06:51,340 --> 01:06:53,240 which is the entire training set 1583 01:06:53,240 --> 01:06:55,140 and it turns out that 1584 01:06:55,240 --> 01:06:57,140 if you just estimate it using 1 over n 1585 01:06:57,240 --> 01:06:59,140 that actually almost always 1586 01:06:59,240 --> 01:07:01,140 underestimates the variance 1587 01:07:01,240 --> 01:07:03,140 and it is a biased estimator 1588 01:07:03,240 --> 01:07:05,140 and it is advised that you use the unbiased version 1589 01:07:05,240 --> 01:07:07,140 and divide by n minus 1 1590 01:07:07,240 --> 01:07:09,140 and you can go through this article here 1591 01:07:09,240 --> 01:07:11,140 that I liked that actually describes 1592 01:07:11,240 --> 01:07:13,140 the fall of reasoning 1593 01:07:13,240 --> 01:07:15,140 and I'll link it in the video description 1594 01:07:15,240 --> 01:07:17,140 now when you calculate the torshta variance 1595 01:07:17,240 --> 01:07:19,140 you'll notice that they take the unbiased flag 1596 01:07:19,240 --> 01:07:21,140 whether or not you want to divide by n 1597 01:07:21,240 --> 01:07:23,140 or n minus 1 1598 01:07:23,140 --> 01:07:25,040 so the default is for unbiased 1599 01:07:25,140 --> 01:07:27,040 but I believe unbiased by default 1600 01:07:27,140 --> 01:07:29,040 is true 1601 01:07:29,140 --> 01:07:31,040 I'm not sure why the docs here don't cite that 1602 01:07:31,140 --> 01:07:33,040 now in the batch norm 1603 01:07:33,140 --> 01:07:35,040 1 , the documentation again 1604 01:07:35,140 --> 01:07:37,040 is kind of wrong and confusing 1605 01:07:37,140 --> 01:07:39,040 it says that the standard deviation is calculated 1606 01:07:39,140 --> 01:07:41,040 via the biased estimator 1607 01:07:41,140 --> 01:07:43,040 but this is actually not exactly right 1608 01:07:43,140 --> 01:07:45,040 and people have pointed out that it is not right 1609 01:07:45,140 --> 01:07:47,040 in a number of issues since then 1610 01:07:47,140 --> 01:07:49,040 because actually the rabbit hole is deeper 1611 01:07:49,140 --> 01:07:51,040 and they follow the paper exactly 1612 01:07:51,140 --> 01:07:53,040 and they use the biased 1613 01:07:53,040 --> 01:07:54,940 version for training 1614 01:07:55,040 --> 01:07:56,940 but when they're estimating the running standard deviation 1615 01:07:57,040 --> 01:07:58,940 they are using the unbiased version 1616 01:07:59,040 --> 01:08:00,940 so again there's the train test mismatch 1617 01:08:01,040 --> 01:08:02,940 so long story short 1618 01:08:03,040 --> 01:08:04,940 I'm not a fan of train test discrepancies 1619 01:08:05,040 --> 01:08:06,940 I basically kind of consider 1620 01:08:07,040 --> 01:08:08,940 the fact that we use the biased version 1621 01:08:09,040 --> 01:08:10,940 the training time 1622 01:08:11,040 --> 01:08:12,940 and the unbiased test time 1623 01:08:13,040 --> 01:08:14,940 I basically consider this to be a bug 1624 01:08:15,040 --> 01:08:16,940 and I don't think that there's a good reason for that 1625 01:08:17,040 --> 01:08:18,940 it's not really 1626 01:08:19,040 --> 01:08:20,940 they don't really go into the detail 1627 01:08:21,040 --> 01:08:22,940 of the reasoning behind it in this paper 1628 01:08:22,940 --> 01:08:24,840 I basically prefer to use the Bessel's correction 1629 01:08:24,940 --> 01:08:26,840 in my own work 1630 01:08:26,940 --> 01:08:28,840 unfortunately BatchNorm does not take 1631 01:08:28,940 --> 01:08:30,840 a keyword argument that tells you whether or not 1632 01:08:30,940 --> 01:08:32,840 you want to use the unbiased version 1633 01:08:32,940 --> 01:08:34,840 or the biased version in both train and test 1634 01:08:34,940 --> 01:08:36,840 and so therefore anyone using BatchNormalization 1635 01:08:36,940 --> 01:08:38,840 basically in my view 1636 01:08:38,940 --> 01:08:40,840 has a bit of a bug in the code 1637 01:08:40,940 --> 01:08:42,840 and this turns out to be much less of a problem 1638 01:08:42,940 --> 01:08:44,840 if your batch 1639 01:08:44,940 --> 01:08:46,840 many batch sizes are a bit larger 1640 01:08:46,940 --> 01:08:48,840 but still I just find it kind of 1641 01:08:48,940 --> 01:08:50,840 unpalatable 1642 01:08:50,940 --> 01:08:52,840 so maybe someone can explain why this is okay 1643 01:08:52,840 --> 01:08:54,740 but for now I prefer to use the unbiased version 1644 01:08:54,840 --> 01:08:56,740 consistently both during training 1645 01:08:56,840 --> 01:08:58,740 and at test time 1646 01:08:58,840 --> 01:09:00,740 and that's why I'm using 1 over n minus 1 here 1647 01:09:00,840 --> 01:09:02,740 okay so let's now actually backpropagate 1648 01:09:02,840 --> 01:09:04,740 through this line 1649 01:09:04,840 --> 01:09:06,740 so 1650 01:09:06,840 --> 01:09:08,740 the first thing that I always like to do 1651 01:09:08,840 --> 01:09:10,740 is I like to scrutinize the shapes first 1652 01:09:10,840 --> 01:09:12,740 so in particular here looking at the shapes 1653 01:09:12,840 --> 01:09:14,740 of what's involved 1654 01:09:14,840 --> 01:09:16,740 I see that bnvar shape is 1 by 64 1655 01:09:16,840 --> 01:09:18,740 so it's a row vector 1656 01:09:18,840 --> 01:09:20,740 and bndiff2.shape is 32 by 64 1657 01:09:20,840 --> 01:09:22,740 so I can see that 1658 01:09:22,840 --> 01:09:24,740 so clearly here we're doing a sum 1659 01:09:24,840 --> 01:09:26,740 over the 0th axis 1660 01:09:26,840 --> 01:09:28,740 to squash the first dimension 1661 01:09:28,840 --> 01:09:30,740 of the shapes here 1662 01:09:30,840 --> 01:09:32,740 using a sum 1663 01:09:32,840 --> 01:09:34,740 so that right away actually hints to me 1664 01:09:34,840 --> 01:09:36,740 that there will be some kind of a replication 1665 01:09:36,840 --> 01:09:38,740 or broadcasting in the backward pass 1666 01:09:38,840 --> 01:09:40,740 and maybe you're noticing the pattern here 1667 01:09:40,840 --> 01:09:42,740 but basically any time you have a sum 1668 01:09:42,840 --> 01:09:44,740 in the forward pass 1669 01:09:44,840 --> 01:09:46,740 that turns into a replication 1670 01:09:46,840 --> 01:09:48,740 or broadcasting in the backward pass 1671 01:09:48,840 --> 01:09:50,740 along the same dimension 1672 01:09:50,840 --> 01:09:52,740 and conversely when we have a replication 1673 01:09:52,740 --> 01:09:54,640 or a broadcasting in the forward pass 1674 01:09:54,740 --> 01:09:56,640 that indicates a variable reuse 1675 01:09:56,740 --> 01:09:58,640 and so in the backward pass 1676 01:09:58,740 --> 01:10:00,640 that turns into a sum 1677 01:10:00,740 --> 01:10:02,640 over the exact same dimension 1678 01:10:02,740 --> 01:10:04,640 and so hopefully you're noticing that duality 1679 01:10:04,740 --> 01:10:06,640 that those two are kind of like the opposites 1680 01:10:06,740 --> 01:10:08,640 of each other in the forward and backward pass 1681 01:10:08,740 --> 01:10:10,640 now once we understand the shapes 1682 01:10:10,740 --> 01:10:12,640 the next thing I like to do always 1683 01:10:12,740 --> 01:10:14,640 is I like to look at a toy example in my head 1684 01:10:14,740 --> 01:10:16,640 to sort of just like understand roughly how 1685 01:10:16,740 --> 01:10:18,640 the variable dependencies go 1686 01:10:18,740 --> 01:10:20,640 in the mathematical formula 1687 01:10:20,740 --> 01:10:22,640 so here we have 1688 01:10:22,640 --> 01:10:24,540 a two-dimensional array 1689 01:10:24,640 --> 01:10:26,540 b and div 2 which we are scaling 1690 01:10:26,640 --> 01:10:28,540 by a constant and then we are summing 1691 01:10:28,640 --> 01:10:30,540 vertically over the columns 1692 01:10:30,640 --> 01:10:32,540 so if we have a 2x2 matrix a 1693 01:10:32,640 --> 01:10:34,540 and then we sum over the columns 1694 01:10:34,640 --> 01:10:36,540 and scale we would get a 1695 01:10:36,640 --> 01:10:38,540 row vector b1 b2 and 1696 01:10:38,640 --> 01:10:40,540 b1 depends on a in this way 1697 01:10:40,640 --> 01:10:42,540 where it's just sum that is scaled 1698 01:10:42,640 --> 01:10:44,540 of a and b2 1699 01:10:44,640 --> 01:10:46,540 in this way where it's the second column 1700 01:10:46,640 --> 01:10:48,540 summed and scaled 1701 01:10:48,640 --> 01:10:50,540 and so looking at this basically 1702 01:10:50,640 --> 01:10:52,540 what we want to do is 1703 01:10:52,540 --> 01:10:54,440 we have the derivatives on b1 and b2 1704 01:10:54,540 --> 01:10:56,440 and we want to back propagate them into a's 1705 01:10:56,540 --> 01:10:58,440 and so it's clear that just 1706 01:10:58,540 --> 01:11:00,440 differentiating in your head 1707 01:11:00,540 --> 01:11:02,440 the local derivative here is 1 over n-1 1708 01:11:02,540 --> 01:11:04,440 times 1 1709 01:11:04,540 --> 01:11:06,440 for each one of these a's 1710 01:11:06,540 --> 01:11:08,440 and 1711 01:11:08,540 --> 01:11:10,440 basically the derivative of b1 1712 01:11:10,540 --> 01:11:12,440 has to flow through the columns of a 1713 01:11:12,540 --> 01:11:14,440 scaled by 1 over n-1 1714 01:11:14,540 --> 01:11:16,440 and that's roughly 1715 01:11:16,540 --> 01:11:18,440 what's happening here 1716 01:11:18,540 --> 01:11:20,440 so intuitively the derivative flow 1717 01:11:20,540 --> 01:11:22,440 tells us that 1718 01:11:22,440 --> 01:11:24,340 db and df2 1719 01:11:24,440 --> 01:11:26,340 will be 1720 01:11:26,440 --> 01:11:28,340 the local derivative of this operation 1721 01:11:28,440 --> 01:11:30,340 and there are many ways to do this by the way 1722 01:11:30,440 --> 01:11:32,340 but I like to do something like this 1723 01:11:32,440 --> 01:11:34,340 torch dot ones like 1724 01:11:34,440 --> 01:11:36,340 of b and df2 1725 01:11:36,440 --> 01:11:38,340 so I'll create a large array 1726 01:11:38,440 --> 01:11:40,340 two dimensional of ones 1727 01:11:40,440 --> 01:11:42,340 and then I will scale it 1728 01:11:42,440 --> 01:11:44,340 so 1.0 divided by n-1 1729 01:11:44,440 --> 01:11:46,340 so this is an array of 1730 01:11:46,440 --> 01:11:48,340 1 over n-1 1731 01:11:48,440 --> 01:11:50,340 and that's sort of like the local derivative 1732 01:11:50,440 --> 01:11:52,340 and now for the chain rule 1733 01:11:52,340 --> 01:11:54,240 I will simply just multiply it by 1734 01:11:54,340 --> 01:11:56,240 db and var 1735 01:11:58,340 --> 01:12:00,240 and notice here what's going to happen 1736 01:12:00,340 --> 01:12:02,240 this is 32 by 64 1737 01:12:02,340 --> 01:12:04,240 and this is just 1 by 64 1738 01:12:04,340 --> 01:12:06,240 so I'm letting the broadcasting 1739 01:12:06,340 --> 01:12:08,240 do the replication 1740 01:12:08,340 --> 01:12:10,240 because internally in pytorch 1741 01:12:10,340 --> 01:12:12,240 basically db and var 1742 01:12:12,340 --> 01:12:14,240 which is 1 by 64 row vector 1743 01:12:14,340 --> 01:12:16,240 will in this multiplication get 1744 01:12:16,340 --> 01:12:18,240 copied vertically 1745 01:12:18,340 --> 01:12:20,240 until the two are of the same shape 1746 01:12:20,340 --> 01:12:22,240 and then there will be an elementwise multiply 1747 01:12:22,240 --> 01:12:24,140 so the broadcasting is basically doing the replication 1748 01:12:24,240 --> 01:12:26,140 and I will end up 1749 01:12:26,240 --> 01:12:28,140 with the derivatives of db and df2 1750 01:12:28,240 --> 01:12:30,140 here 1751 01:12:30,240 --> 01:12:32,140 so this is the candidate solution 1752 01:12:32,240 --> 01:12:34,140 let's bring it down here 1753 01:12:34,240 --> 01:12:36,140 let's uncomment this line 1754 01:12:36,240 --> 01:12:38,140 where we check it 1755 01:12:38,240 --> 01:12:40,140 and let's hope for the best 1756 01:12:40,240 --> 01:12:42,140 and indeed we see that this is the correct formula 1757 01:12:42,240 --> 01:12:44,140 next up let's differentiate here 1758 01:12:44,240 --> 01:12:46,140 into b and df 1759 01:12:46,240 --> 01:12:48,140 so here we have that b and df 1760 01:12:48,240 --> 01:12:50,140 is elementwise squared to create b and df2 1761 01:12:50,240 --> 01:12:52,140 so this is a 1762 01:12:52,140 --> 01:12:54,040 relatively simple derivative 1763 01:12:54,140 --> 01:12:56,040 because it's a simple elementwise operation 1764 01:12:56,140 --> 01:12:58,040 so it's kind of like the scalar case 1765 01:12:58,140 --> 01:13:00,040 and we have that db and df 1766 01:13:00,140 --> 01:13:02,040 should be 1767 01:13:02,140 --> 01:13:04,040 if this is x squared 1768 01:13:04,140 --> 01:13:06,040 then the derivative of this is 2x 1769 01:13:06,140 --> 01:13:08,040 so it's simply 2 times b and df 1770 01:13:08,140 --> 01:13:10,040 that's the local derivative 1771 01:13:10,140 --> 01:13:12,040 and then times chain rule 1772 01:13:12,140 --> 01:13:14,040 and the shape of these is the same 1773 01:13:14,140 --> 01:13:16,040 they are of the same shape 1774 01:13:16,140 --> 01:13:18,040 so times this 1775 01:13:18,140 --> 01:13:20,040 so that's the backward pass for this variable 1776 01:13:20,140 --> 01:13:22,040 let me bring it down here 1777 01:13:22,040 --> 01:13:23,940 I've already calculated db and df 1778 01:13:24,040 --> 01:13:25,940 so this is just the end of the other 1779 01:13:26,040 --> 01:13:27,940 branch coming back to b and df 1780 01:13:28,040 --> 01:13:29,940 because b and df 1781 01:13:30,040 --> 01:13:31,940 were already back propagated to 1782 01:13:32,040 --> 01:13:33,940 way over here 1783 01:13:34,040 --> 01:13:35,940 from b and raw 1784 01:13:36,040 --> 01:13:37,940 so we now completed the second branch 1785 01:13:38,040 --> 01:13:39,940 and so that's why I have to do plus equals 1786 01:13:40,040 --> 01:13:41,940 and if you recall 1787 01:13:42,040 --> 01:13:43,940 we had an incorrect derivative for b and df before 1788 01:13:44,040 --> 01:13:45,940 and I'm hoping that once we append 1789 01:13:46,040 --> 01:13:47,940 this last missing piece 1790 01:13:48,040 --> 01:13:49,940 we have the exact correctness 1791 01:13:50,040 --> 01:13:51,940 so let's run 1792 01:13:52,040 --> 01:13:53,940 and b and df now actually shows 1793 01:13:54,040 --> 01:13:55,940 the exact correct derivative 1794 01:13:56,040 --> 01:13:57,940 so that's comforting 1795 01:13:58,040 --> 01:13:59,940 okay so let's now back propagate 1796 01:14:00,040 --> 01:14:01,940 through this line here 1797 01:14:02,040 --> 01:14:03,940 the first thing we do of course 1798 01:14:04,040 --> 01:14:05,940 is we check the shapes 1799 01:14:06,040 --> 01:14:07,940 and I wrote them out here 1800 01:14:08,040 --> 01:14:09,940 and basically the shape of this 1801 01:14:10,040 --> 01:14:11,940 is 32 by 64 1802 01:14:12,040 --> 01:14:13,940 h pre bn is the same shape 1803 01:14:14,040 --> 01:14:15,940 but b and mean i is a row vector 1804 01:14:16,040 --> 01:14:17,940 1 by 64 1805 01:14:18,040 --> 01:14:19,940 so this minus here will actually do broadcasting 1806 01:14:20,040 --> 01:14:21,940 and so we have to be careful with that 1807 01:14:21,940 --> 01:14:23,840 again because of the duality 1808 01:14:23,940 --> 01:14:25,840 a broadcasting in the forward pass 1809 01:14:25,940 --> 01:14:27,840 means a variable reuse 1810 01:14:27,940 --> 01:14:29,840 and therefore there will be a sum 1811 01:14:29,940 --> 01:14:31,840 in the backward pass 1812 01:14:31,940 --> 01:14:33,840 so let's write out the backward pass here now 1813 01:14:33,940 --> 01:14:35,840 back propagate into the h pre bn 1814 01:14:35,940 --> 01:14:37,840 because these are the same shape 1815 01:14:37,940 --> 01:14:39,840 then the local derivative 1816 01:14:39,940 --> 01:14:41,840 for each one of the elements here 1817 01:14:41,940 --> 01:14:43,840 is just 1 for the corresponding element 1818 01:14:43,940 --> 01:14:45,840 in here 1819 01:14:45,940 --> 01:14:47,840 so basically what this means is that 1820 01:14:47,940 --> 01:14:49,840 the gradient just simply copies 1821 01:14:49,940 --> 01:14:51,840 it's just a variable assignment 1822 01:14:51,840 --> 01:14:53,740 so I'm just going to clone this tensor 1823 01:14:53,840 --> 01:14:55,740 just for safety to create an exact copy 1824 01:14:55,840 --> 01:14:57,740 of db and diff 1825 01:14:57,840 --> 01:14:59,740 and then here 1826 01:14:59,840 --> 01:15:01,740 to back propagate into this one 1827 01:15:01,840 --> 01:15:03,740 what I'm inclined to do here 1828 01:15:03,840 --> 01:15:05,740 is 1829 01:15:05,840 --> 01:15:07,740 d bn mean i 1830 01:15:07,840 --> 01:15:09,740 will basically be 1831 01:15:09,840 --> 01:15:11,740 what is the local derivative 1832 01:15:11,840 --> 01:15:13,740 well it's negative torch.once like 1833 01:15:13,840 --> 01:15:15,740 of the shape of 1834 01:15:15,840 --> 01:15:17,740 b and diff 1835 01:15:17,840 --> 01:15:19,740 right 1836 01:15:19,840 --> 01:15:21,740 so 1837 01:15:22,240 --> 01:15:23,740 and then times 1838 01:15:23,840 --> 01:15:25,740 the 1839 01:15:25,840 --> 01:15:27,740 derivative here 1840 01:15:27,840 --> 01:15:29,740 db and diff 1841 01:15:29,840 --> 01:15:31,740 and this here 1842 01:15:31,840 --> 01:15:33,740 is the back propagation 1843 01:15:33,840 --> 01:15:35,740 for the replicated 1844 01:15:35,840 --> 01:15:37,740 b and mean i 1845 01:15:37,840 --> 01:15:39,740 so I still have to back propagate 1846 01:15:39,840 --> 01:15:41,740 through the replication 1847 01:15:41,840 --> 01:15:43,740 in the broadcasting 1848 01:15:43,840 --> 01:15:45,740 and I do that by doing a sum 1849 01:15:45,840 --> 01:15:47,740 so I'm going to take this whole thing 1850 01:15:47,840 --> 01:15:49,740 and I'm going to do a sum 1851 01:15:49,740 --> 01:15:51,640 and I'm going to do a replication 1852 01:15:51,740 --> 01:15:53,640 so if you scrutinize this by the way 1853 01:15:53,740 --> 01:15:55,640 you'll notice that 1854 01:15:55,740 --> 01:15:57,640 this is the same shape as that 1855 01:15:57,740 --> 01:15:59,640 and so what I'm doing 1856 01:15:59,740 --> 01:16:01,640 what I'm doing here doesn't actually make that much sense 1857 01:16:01,740 --> 01:16:03,640 because it's just a 1858 01:16:03,740 --> 01:16:05,640 array of ones multiplying db and diff 1859 01:16:05,740 --> 01:16:07,640 so in fact I can just do 1860 01:16:07,740 --> 01:16:09,640 this 1861 01:16:09,740 --> 01:16:11,640 and that is equivalent 1862 01:16:11,740 --> 01:16:13,640 so this is the candidate 1863 01:16:13,740 --> 01:16:15,640 backward pass 1864 01:16:15,740 --> 01:16:17,640 let me copy it here 1865 01:16:17,640 --> 01:16:19,540 let me comment out this one 1866 01:16:19,640 --> 01:16:21,540 and this one 1867 01:16:21,640 --> 01:16:23,540 enter 1868 01:16:23,640 --> 01:16:25,540 and it's wrong 1869 01:16:25,640 --> 01:16:27,540 damn 1870 01:16:27,640 --> 01:16:29,540 actually sorry 1871 01:16:29,640 --> 01:16:31,540 this is supposed to be wrong 1872 01:16:31,640 --> 01:16:33,540 and it's supposed to be wrong because 1873 01:16:33,640 --> 01:16:35,540 we are back propagating 1874 01:16:35,640 --> 01:16:37,540 from b and diff into h pre bn 1875 01:16:37,640 --> 01:16:39,540 but we're not done 1876 01:16:39,640 --> 01:16:41,540 because b and mean i depends 1877 01:16:41,640 --> 01:16:43,540 on h pre bn and there will be 1878 01:16:43,640 --> 01:16:45,540 a second portion of that derivative coming from 1879 01:16:45,640 --> 01:16:47,540 this second branch 1880 01:16:47,540 --> 01:16:49,440 but we're not done yet and we expect it to be incorrect 1881 01:16:49,540 --> 01:16:51,440 so there you go 1882 01:16:51,540 --> 01:16:53,440 so let's now back propagate from b and mean i 1883 01:16:53,540 --> 01:16:55,440 into h pre bn 1884 01:16:57,540 --> 01:16:59,440 and so here again we have to be careful 1885 01:16:59,540 --> 01:17:01,440 because there's a broadcasting along 1886 01:17:01,540 --> 01:17:03,440 or there's a sum along the 0th dimension 1887 01:17:03,540 --> 01:17:05,440 so this will turn into broadcasting 1888 01:17:05,540 --> 01:17:07,440 in the backward pass now 1889 01:17:07,540 --> 01:17:09,440 and I'm going to go a little bit faster on this line 1890 01:17:09,540 --> 01:17:11,440 because it is very similar to the line 1891 01:17:11,540 --> 01:17:13,440 that we had before 1892 01:17:13,540 --> 01:17:15,440 multiple lines in the past in fact 1893 01:17:15,540 --> 01:17:17,440 so d h pre bn 1894 01:17:17,540 --> 01:17:19,440 will be 1895 01:17:19,540 --> 01:17:21,440 the gradient will be scaled 1896 01:17:21,540 --> 01:17:23,440 by 1 over n and then 1897 01:17:23,540 --> 01:17:25,440 basically this gradient here on d bn 1898 01:17:25,540 --> 01:17:27,440 mean i 1899 01:17:27,540 --> 01:17:29,440 is going to be scaled by 1 over n 1900 01:17:29,540 --> 01:17:31,440 and then it's going to flow across 1901 01:17:31,540 --> 01:17:33,440 all the columns and deposit itself 1902 01:17:33,540 --> 01:17:35,440 into d h pre bn 1903 01:17:35,540 --> 01:17:37,440 so what we want is this thing 1904 01:17:37,540 --> 01:17:39,440 scaled by 1 over n 1905 01:17:39,540 --> 01:17:41,440 let me put the constant up front here 1906 01:17:45,540 --> 01:17:47,440 so scale down the gradient 1907 01:17:47,440 --> 01:17:49,340 and we need to replicate it 1908 01:17:49,440 --> 01:17:51,340 across all the 1909 01:17:51,440 --> 01:17:53,340 across all the rows here 1910 01:17:53,440 --> 01:17:55,340 so I like to do that 1911 01:17:55,440 --> 01:17:57,340 by torch dot once like 1912 01:17:57,440 --> 01:17:59,340 of basically 1913 01:17:59,440 --> 01:18:01,340 h pre bn 1914 01:18:03,440 --> 01:18:05,340 and I will let broadcasting 1915 01:18:05,440 --> 01:18:07,340 do the work of 1916 01:18:07,440 --> 01:18:09,340 replication 1917 01:18:09,440 --> 01:18:11,340 so 1918 01:18:11,340 --> 01:18:15,340 A 1919 01:18:15,340 --> 01:18:17,240 like that 1920 01:18:17,340 --> 01:18:19,240 so this is 1921 01:18:19,340 --> 01:18:21,240 d h pre bn 1922 01:18:21,340 --> 01:18:23,240 and hopefully 1923 01:18:23,340 --> 01:18:25,240 we can plus equals that 1924 01:18:29,340 --> 01:18:31,240 so this here is broadcasting 1925 01:18:31,340 --> 01:18:33,240 and then this is the scaling 1926 01:18:33,340 --> 01:18:35,240 so this should be correct 1927 01:18:35,340 --> 01:18:37,240 okay 1928 01:18:37,340 --> 01:18:39,240 so that completes the backpropagation 1929 01:18:39,240 --> 01:18:41,140 let's backpropagate through the linear layer 1 1930 01:18:41,240 --> 01:18:43,140 here now because 1931 01:18:43,240 --> 01:18:45,140 everything is getting a little vertically crazy 1932 01:18:45,240 --> 01:18:47,140 I copy pasted the line here 1933 01:18:47,240 --> 01:18:49,140 and let's just backpropagate through this one line 1934 01:18:49,240 --> 01:18:51,140 so first of course 1935 01:18:51,240 --> 01:18:53,140 we inspect the shapes and we see that 1936 01:18:53,240 --> 01:18:55,140 this is 32 by 64 1937 01:18:55,240 --> 01:18:57,140 mcat is 32 1938 01:18:57,240 --> 01:18:59,140 by 30 1939 01:18:59,240 --> 01:19:01,140 w1 is 30 by 64 1940 01:19:01,240 --> 01:19:03,140 and b1 is just 64 1941 01:19:03,240 --> 01:19:05,140 so as I mentioned 1942 01:19:05,240 --> 01:19:07,140 backpropagating through linear layers 1943 01:19:07,240 --> 01:19:09,140 is fairly easy just by matching the shapes 1944 01:19:09,240 --> 01:19:11,140 so let's do that 1945 01:19:11,240 --> 01:19:13,140 we have that d mcat 1946 01:19:13,240 --> 01:19:15,140 should be 1947 01:19:15,240 --> 01:19:17,140 some matrix multiplication 1948 01:19:17,240 --> 01:19:19,140 of d h pre bn with 1949 01:19:19,240 --> 01:19:21,140 w1 and 1 transpose 1950 01:19:21,240 --> 01:19:23,140 thrown in there 1951 01:19:23,240 --> 01:19:25,140 so to make mcat 1952 01:19:25,240 --> 01:19:27,140 be 32 by 30 1953 01:19:27,240 --> 01:19:29,140 I need to take 1954 01:19:29,240 --> 01:19:31,140 d h pre bn 1955 01:19:31,240 --> 01:19:33,140 32 by 64 1956 01:19:33,240 --> 01:19:35,140 and multiply it by 1957 01:19:35,240 --> 01:19:37,140 w1 dot transpose 1958 01:19:37,240 --> 01:19:39,140 ... 1959 01:19:39,240 --> 01:19:41,140 to get d w1 1960 01:19:41,240 --> 01:19:43,140 I need to end up with 1961 01:19:43,240 --> 01:19:45,140 30 by 64 1962 01:19:45,240 --> 01:19:47,140 so to get that I need to take 1963 01:19:47,240 --> 01:19:49,140 mcat transpose 1964 01:19:49,240 --> 01:19:51,140 ... 1965 01:19:51,240 --> 01:19:53,140 and multiply that by 1966 01:19:53,240 --> 01:19:55,140 d h pre bn 1967 01:19:55,240 --> 01:19:57,140 ... 1968 01:19:57,240 --> 01:19:59,140 and finally to get 1969 01:19:59,240 --> 01:20:01,140 d b1 1970 01:20:01,240 --> 01:20:03,140 this is an addition 1971 01:20:03,240 --> 01:20:05,140 and we saw that basically 1972 01:20:05,240 --> 01:20:07,140 I need to just sum the elements 1973 01:20:07,240 --> 01:20:09,140 in d h pre bn along some dimensions 1974 01:20:09,240 --> 01:20:11,140 and to make the dimensions work out 1975 01:20:11,240 --> 01:20:13,140 I need to sum along the 0th axis 1976 01:20:13,240 --> 01:20:15,140 here to eliminate 1977 01:20:15,240 --> 01:20:17,140 this dimension 1978 01:20:17,240 --> 01:20:19,140 and we do not keep dims 1979 01:20:19,240 --> 01:20:21,140 so that we want to just get a single 1980 01:20:21,240 --> 01:20:23,140 one-dimensional vector of 64 1981 01:20:23,240 --> 01:20:25,140 so these are the claimed derivatives 1982 01:20:25,240 --> 01:20:27,140 let me put that here 1983 01:20:27,240 --> 01:20:29,140 and let me 1984 01:20:29,240 --> 01:20:31,140 uncomment three lines 1985 01:20:31,240 --> 01:20:33,140 and cross our fingers 1986 01:20:33,240 --> 01:20:35,140 everything is great 1987 01:20:35,240 --> 01:20:37,140 okay so we now continue almost there 1988 01:20:37,240 --> 01:20:39,140 we have the derivative of mcat 1989 01:20:39,140 --> 01:20:41,040 and we want to backpropagate 1990 01:20:41,140 --> 01:20:43,040 into mb 1991 01:20:43,140 --> 01:20:45,040 so I again copied this line over here 1992 01:20:45,140 --> 01:20:47,040 so this is the forward pass 1993 01:20:47,140 --> 01:20:49,040 and then this is the shapes 1994 01:20:49,140 --> 01:20:51,040 so remember that the shape here 1995 01:20:51,140 --> 01:20:53,040 was 32 by 30 1996 01:20:53,140 --> 01:20:55,040 and the original shape of mb 1997 01:20:55,140 --> 01:20:57,040 was 32 by 3 by 10 1998 01:20:57,140 --> 01:20:59,040 so this layer in the forward pass 1999 01:20:59,140 --> 01:21:01,040 as you recall did the concatenation 2000 01:21:01,140 --> 01:21:03,040 of these three 10-dimensional 2001 01:21:03,140 --> 01:21:05,040 character vectors 2002 01:21:05,140 --> 01:21:07,040 and so now we just want to undo that 2003 01:21:07,140 --> 01:21:09,040 so this is actually a relatively 2004 01:21:09,040 --> 01:21:10,940 simple iteration because 2005 01:21:11,040 --> 01:21:12,940 the backward pass of the 2006 01:21:13,040 --> 01:21:14,940 what is the view? view is just a 2007 01:21:15,040 --> 01:21:16,940 representation of the array 2008 01:21:17,040 --> 01:21:18,940 it's just a logical form of how 2009 01:21:19,040 --> 01:21:20,940 you interpret the array 2010 01:21:21,040 --> 01:21:22,940 so let's just reinterpret it 2011 01:21:23,040 --> 01:21:24,940 to be what it was before 2012 01:21:25,040 --> 01:21:26,940 so in other words dmb is not 32 by 30 2013 01:21:27,040 --> 01:21:28,940 it is basically dmpcat 2014 01:21:29,040 --> 01:21:30,940 but if you view it 2015 01:21:31,040 --> 01:21:32,940 as the original shape 2016 01:21:33,040 --> 01:21:34,940 so just m.shape 2017 01:21:37,040 --> 01:21:38,940 you can pass and tuple 2018 01:21:38,940 --> 01:21:40,840 into view 2019 01:21:40,940 --> 01:21:42,840 and so this should just be 2020 01:21:42,940 --> 01:21:44,840 okay 2021 01:21:44,940 --> 01:21:46,840 we just re-represent that view 2022 01:21:46,940 --> 01:21:48,840 and then we uncomment this line here 2023 01:21:48,940 --> 01:21:50,840 and hopefully 2024 01:21:50,940 --> 01:21:52,840 yeah, so the derivative of m 2025 01:21:52,940 --> 01:21:54,840 is correct 2026 01:21:54,940 --> 01:21:56,840 so in this case we just have to re-represent 2027 01:21:56,940 --> 01:21:58,840 the shape of those derivatives 2028 01:21:58,940 --> 01:22:00,840 into the original view 2029 01:22:00,940 --> 01:22:02,840 so now we are at the final line 2030 01:22:02,940 --> 01:22:04,840 and the only thing that's left to backpropagate through 2031 01:22:04,940 --> 01:22:06,840 is this indexing operation here 2032 01:22:06,940 --> 01:22:08,840 m is c at xb 2033 01:22:08,840 --> 01:22:10,740 or I copy pasted this line here 2034 01:22:10,840 --> 01:22:12,740 and let's look at the shapes of everything that's involved 2035 01:22:12,840 --> 01:22:14,740 and remind ourselves how this worked 2036 01:22:14,840 --> 01:22:16,740 so m.shape 2037 01:22:16,840 --> 01:22:18,740 was 32 by 3 by 10 2038 01:22:18,840 --> 01:22:20,740 so it's 32 examples 2039 01:22:20,840 --> 01:22:22,740 and then we have 3 characters 2040 01:22:22,840 --> 01:22:24,740 each one of them has a 10 dimensional 2041 01:22:24,840 --> 01:22:26,740 embedding 2042 01:22:26,840 --> 01:22:28,740 and this was achieved by taking the 2043 01:22:28,840 --> 01:22:30,740 lookup table c which have 27 2044 01:22:30,840 --> 01:22:32,740 possible characters 2045 01:22:32,840 --> 01:22:34,740 each of them 10 dimensional 2046 01:22:34,840 --> 01:22:36,740 and we looked up at the rows 2047 01:22:36,840 --> 01:22:38,740 that were specified 2048 01:22:38,740 --> 01:22:40,640 inside this tensor xb 2049 01:22:40,740 --> 01:22:42,640 so xb is 32 by 3 2050 01:22:42,740 --> 01:22:44,640 and it's basically giving us for each example 2051 01:22:44,740 --> 01:22:46,640 the identity or the index 2052 01:22:46,740 --> 01:22:48,640 of which character 2053 01:22:48,740 --> 01:22:50,640 is part of that example 2054 01:22:50,740 --> 01:22:52,640 and so here I'm showing the first 5 rows 2055 01:22:52,740 --> 01:22:56,640 of this tensor xb 2056 01:22:56,740 --> 01:22:58,640 and so we can see that for example here 2057 01:22:58,740 --> 01:23:00,640 it was the first example in this batch 2058 01:23:00,740 --> 01:23:02,640 is that the first character 2059 01:23:02,740 --> 01:23:04,640 and the first character and the fourth character 2060 01:23:04,740 --> 01:23:06,640 comes into the neural net 2061 01:23:06,740 --> 01:23:08,640 and then we want to predict the next character 2062 01:23:08,640 --> 01:23:10,540 in the sequence after the character is 114 2063 01:23:10,640 --> 01:23:12,540 so basically 2064 01:23:12,640 --> 01:23:14,540 what's happening here is 2065 01:23:14,640 --> 01:23:16,540 there are integers inside xb 2066 01:23:16,640 --> 01:23:18,540 and each one of these integers 2067 01:23:18,640 --> 01:23:20,540 is specifying which row of c 2068 01:23:20,640 --> 01:23:22,540 we want to pluck out 2069 01:23:22,640 --> 01:23:24,540 right and then we arrange 2070 01:23:24,640 --> 01:23:26,540 those rows that we've plucked out 2071 01:23:26,640 --> 01:23:28,540 into 32 by 3 by 10 tensor 2072 01:23:28,640 --> 01:23:30,540 and we just package them in 2073 01:23:30,640 --> 01:23:32,540 we just package them into this tensor 2074 01:23:32,640 --> 01:23:34,540 and now what's happening 2075 01:23:34,640 --> 01:23:36,540 is that we have dimp 2076 01:23:36,640 --> 01:23:38,540 so for every one of these 2077 01:23:38,540 --> 01:23:40,440 basically plucked out rows 2078 01:23:40,540 --> 01:23:42,440 we have their gradients now 2079 01:23:42,540 --> 01:23:46,440 but they're arranged inside this 32 by 3 by 10 tensor 2080 01:23:46,540 --> 01:23:48,440 so all we have to do now 2081 01:23:48,540 --> 01:23:50,440 is we just need to route this gradient 2082 01:23:50,540 --> 01:23:52,440 backwards through this assignment 2083 01:23:52,540 --> 01:23:54,440 so we need to find which row of c 2084 01:23:54,540 --> 01:23:56,440 that every one of these 2085 01:23:56,540 --> 01:23:58,440 10 dimensional embeddings come from 2086 01:23:58,540 --> 01:24:00,440 and then we need to deposit them 2087 01:24:00,540 --> 01:24:02,440 into dc 2088 01:24:02,540 --> 01:24:04,440 so we just need to undo the indexing 2089 01:24:04,540 --> 01:24:06,440 and of course 2090 01:24:06,540 --> 01:24:08,440 if any of these rows of c 2091 01:24:08,440 --> 01:24:10,340 were used multiple times 2092 01:24:10,440 --> 01:24:12,340 which almost certainly is the case 2093 01:24:12,440 --> 01:24:14,340 like the row 1 and 1 was used multiple times 2094 01:24:14,440 --> 01:24:16,340 then we have to remember that the gradients 2095 01:24:16,440 --> 01:24:18,340 that arrive there have to add 2096 01:24:18,440 --> 01:24:20,340 so for each occurrence 2097 01:24:20,440 --> 01:24:22,340 we have to have an addition 2098 01:24:22,440 --> 01:24:24,340 so let's now write this out 2099 01:24:24,440 --> 01:24:26,340 and I don't actually know of like 2100 01:24:26,440 --> 01:24:28,340 a much better way to do this 2101 01:24:28,440 --> 01:24:30,340 than a for loop unfortunately in python 2102 01:24:30,440 --> 01:24:32,340 so maybe someone can come up with 2103 01:24:32,440 --> 01:24:34,340 a vectorized efficient operation 2104 01:24:34,440 --> 01:24:36,340 but for now let's just use for loops 2105 01:24:36,440 --> 01:24:38,340 so let me create torch.zeros like c 2106 01:24:38,340 --> 01:24:40,240 and I'm going to utilize just 2107 01:24:40,340 --> 01:24:42,240 a 27 by 10 tensor of all zeros 2108 01:24:42,340 --> 01:24:44,240 and then honestly 2109 01:24:44,340 --> 01:24:46,240 for k in range xb.shape at 0 2110 01:24:46,340 --> 01:24:48,240 maybe someone has a better way to do this 2111 01:24:48,340 --> 01:24:50,240 but for j in range xb.shape at 1 2112 01:24:50,340 --> 01:24:52,240 this is going to iterate over 2113 01:24:52,340 --> 01:24:54,240 all the elements of xb 2114 01:24:54,340 --> 01:24:56,240 all these integers 2115 01:24:56,340 --> 01:24:58,240 and then let's get the index 2116 01:24:58,340 --> 01:25:00,240 at this position 2117 01:25:00,340 --> 01:25:02,240 so the index is basically 2118 01:25:02,340 --> 01:25:04,240 the value of xb 2119 01:25:04,340 --> 01:25:06,240 and then let's get the index 2120 01:25:06,340 --> 01:25:08,240 at this position 2121 01:25:08,240 --> 01:25:10,140 which is basically xb at kj 2122 01:25:10,240 --> 01:25:12,140 so an example of that 2123 01:25:12,240 --> 01:25:14,140 is 11 or 14 and so on 2124 01:25:14,240 --> 01:25:16,140 and now in a forward pass 2125 01:25:16,240 --> 01:25:18,140 we took 2126 01:25:18,240 --> 01:25:20,140 we basically took 2127 01:25:20,240 --> 01:25:22,140 um 2128 01:25:22,240 --> 01:25:24,140 the row of c at index 2129 01:25:24,240 --> 01:25:26,140 and we deposited it 2130 01:25:26,240 --> 01:25:28,140 into emb at k at j 2131 01:25:28,240 --> 01:25:30,140 that's what happened 2132 01:25:30,240 --> 01:25:32,140 that's where they are packaged 2133 01:25:32,240 --> 01:25:34,140 so now we need to go backwards 2134 01:25:34,240 --> 01:25:36,140 and we just need to route 2135 01:25:36,240 --> 01:25:38,140 deemb at the position 2136 01:25:38,140 --> 01:25:40,040 kj 2137 01:25:40,140 --> 01:25:42,040 we now have these derivatives 2138 01:25:42,140 --> 01:25:44,040 for each position 2139 01:25:44,140 --> 01:25:46,040 and it's 10 dimensional 2140 01:25:46,140 --> 01:25:48,040 and you just need to go into the correct 2141 01:25:48,140 --> 01:25:50,040 row of c 2142 01:25:50,140 --> 01:25:52,040 so dc rather 2143 01:25:52,140 --> 01:25:54,040 at ix is this 2144 01:25:54,140 --> 01:25:56,040 but plus equals 2145 01:25:56,140 --> 01:25:58,040 because there could be multiple occurrences 2146 01:25:58,140 --> 01:26:00,040 like the same row could have been used 2147 01:26:00,140 --> 01:26:02,040 many many times 2148 01:26:02,140 --> 01:26:04,040 and so all those derivatives will 2149 01:26:04,140 --> 01:26:06,040 just go backwards through the indexing 2150 01:26:06,140 --> 01:26:08,040 and they will add 2151 01:26:08,040 --> 01:26:09,940 so this is my candidate 2152 01:26:10,040 --> 01:26:11,940 solution 2153 01:26:12,040 --> 01:26:13,940 let's copy it here 2154 01:26:16,040 --> 01:26:17,940 let's uncomment this 2155 01:26:18,040 --> 01:26:19,940 and cross our fingers 2156 01:26:20,040 --> 01:26:21,940 yay 2157 01:26:22,040 --> 01:26:23,940 so that's it 2158 01:26:24,040 --> 01:26:25,940 we've backpropagated through 2159 01:26:26,040 --> 01:26:27,940 this entire beast 2160 01:26:28,040 --> 01:26:29,940 so there we go 2161 01:26:30,040 --> 01:26:31,940 totally makes sense 2162 01:26:32,040 --> 01:26:33,940 so now we come to exercise 2 2163 01:26:34,040 --> 01:26:35,940 it basically turns out that in this first exercise 2164 01:26:36,040 --> 01:26:37,940 we were doing way too much work 2165 01:26:37,940 --> 01:26:39,840 we were backpropagating way too much 2166 01:26:39,940 --> 01:26:41,840 and it was all good practice and so on 2167 01:26:41,940 --> 01:26:43,840 but it's not what you would do in practice 2168 01:26:43,940 --> 01:26:45,840 and the reason for that is for example 2169 01:26:45,940 --> 01:26:47,840 here I separated out this loss calculation 2170 01:26:47,940 --> 01:26:49,840 over multiple lines 2171 01:26:49,940 --> 01:26:51,840 and I broke it up all to like 2172 01:26:51,940 --> 01:26:53,840 its smallest atomic pieces 2173 01:26:53,940 --> 01:26:55,840 and we backpropagated through all of those individually 2174 01:26:55,940 --> 01:26:57,840 but it turns out that if you just look at 2175 01:26:57,940 --> 01:26:59,840 the mathematical expression for the loss 2176 01:26:59,940 --> 01:27:01,840 then actually you can do 2177 01:27:01,940 --> 01:27:03,840 the differentiation on pen and paper 2178 01:27:03,940 --> 01:27:05,840 and a lot of terms cancel and simplify 2179 01:27:05,940 --> 01:27:07,840 and the mathematical expression you end up with 2180 01:27:07,940 --> 01:27:09,740 is significantly shorter 2181 01:27:09,840 --> 01:27:11,740 and easier to implement 2182 01:27:11,840 --> 01:27:13,740 than backpropagating through all the little pieces 2183 01:27:13,840 --> 01:27:15,740 of everything you've done 2184 01:27:15,840 --> 01:27:17,740 so before we had this complicated forward pass 2185 01:27:17,840 --> 01:27:19,740 going from logits to the loss 2186 01:27:19,840 --> 01:27:21,740 but in pytorch everything can just be 2187 01:27:21,840 --> 01:27:23,740 glued together into a single call 2188 01:27:23,840 --> 01:27:25,740 at that cross entropy 2189 01:27:25,840 --> 01:27:27,740 you just pass in logits and the labels 2190 01:27:27,840 --> 01:27:29,740 and you get the exact same loss 2191 01:27:29,840 --> 01:27:31,740 as I verify here 2192 01:27:31,840 --> 01:27:33,740 so our previous loss and the fast loss 2193 01:27:33,840 --> 01:27:35,740 coming from the chunk of operations 2194 01:27:35,840 --> 01:27:37,740 as a single mathematical expression 2195 01:27:37,840 --> 01:27:39,640 is much faster than the backward pass 2196 01:27:39,740 --> 01:27:41,640 it's also much much faster in backward pass 2197 01:27:41,740 --> 01:27:43,640 and the reason for that is if you just look at 2198 01:27:43,740 --> 01:27:45,640 the mathematical form of this and differentiate again 2199 01:27:45,740 --> 01:27:47,640 you will end up with a very small and short expression 2200 01:27:47,740 --> 01:27:49,640 so that's what we want to do here 2201 01:27:49,740 --> 01:27:51,640 we want to in a single operation 2202 01:27:51,740 --> 01:27:53,640 or in a single go or like very quickly 2203 01:27:53,740 --> 01:27:55,640 go directly into dlogits 2204 01:27:55,740 --> 01:27:57,640 and we need to implement dlogits 2205 01:27:57,740 --> 01:27:59,640 as a function of logits 2206 01:27:59,740 --> 01:28:01,640 and yb's 2207 01:28:01,740 --> 01:28:03,640 but it will be significantly shorter 2208 01:28:03,740 --> 01:28:05,640 than whatever we did here 2209 01:28:05,740 --> 01:28:07,640 where to get to dlogits 2210 01:28:07,640 --> 01:28:09,540 we need to go all the way here 2211 01:28:09,640 --> 01:28:11,540 so all of this work can be skipped 2212 01:28:11,640 --> 01:28:13,540 in a much much simpler mathematical expression 2213 01:28:13,640 --> 01:28:15,540 that you can implement here 2214 01:28:15,640 --> 01:28:17,540 so you can 2215 01:28:17,640 --> 01:28:19,540 give it a shot yourself 2216 01:28:19,640 --> 01:28:21,540 basically look at what exactly 2217 01:28:21,640 --> 01:28:23,540 is the mathematical expression of loss 2218 01:28:23,640 --> 01:28:25,540 and differentiate with respect to the logits 2219 01:28:25,640 --> 01:28:27,540 so let me show you 2220 01:28:27,640 --> 01:28:29,540 a hint 2221 01:28:29,640 --> 01:28:31,540 you can of course try it fully yourself 2222 01:28:31,640 --> 01:28:33,540 but if not I can give you some hint 2223 01:28:33,640 --> 01:28:35,540 of how to get started mathematically 2224 01:28:35,640 --> 01:28:37,540 so basically what's happening here 2225 01:28:37,640 --> 01:28:39,540 is we have logits 2226 01:28:39,640 --> 01:28:41,540 then there's the softmax 2227 01:28:41,640 --> 01:28:43,540 that takes the logits and gives you probabilities 2228 01:28:43,640 --> 01:28:45,540 then we are using the identity 2229 01:28:45,640 --> 01:28:47,540 of the correct next character 2230 01:28:47,640 --> 01:28:49,540 to pluck out a row of probabilities 2231 01:28:49,640 --> 01:28:51,540 take the negative log of it 2232 01:28:51,640 --> 01:28:53,540 to get our negative log probability 2233 01:28:53,640 --> 01:28:55,540 and then we average up 2234 01:28:55,640 --> 01:28:57,540 all the log probabilities 2235 01:28:57,640 --> 01:28:59,540 or negative log probabilities 2236 01:28:59,640 --> 01:29:01,540 to get our loss 2237 01:29:01,640 --> 01:29:03,540 so basically what we have 2238 01:29:03,640 --> 01:29:05,540 is for a single individual example 2239 01:29:05,640 --> 01:29:07,540 we have that loss is equal to 2240 01:29:07,540 --> 01:29:09,440 where p here is kind of like 2241 01:29:09,540 --> 01:29:11,440 thought of as a vector 2242 01:29:11,540 --> 01:29:13,440 of all the probabilities 2243 01:29:13,540 --> 01:29:15,440 so at the yth position 2244 01:29:15,540 --> 01:29:17,440 where y is the label 2245 01:29:17,540 --> 01:29:19,440 and we have that p here of course 2246 01:29:19,540 --> 01:29:21,440 is the softmax 2247 01:29:21,540 --> 01:29:23,440 so the ith component of p 2248 01:29:23,540 --> 01:29:25,440 of this probability vector 2249 01:29:25,540 --> 01:29:27,440 is just the softmax function 2250 01:29:27,540 --> 01:29:29,440 so raising all the logits 2251 01:29:29,540 --> 01:29:31,440 basically to the power of e 2252 01:29:31,540 --> 01:29:33,440 and normalizing 2253 01:29:33,540 --> 01:29:35,440 so everything sums to one 2254 01:29:35,540 --> 01:29:37,440 now if you write out 2255 01:29:37,440 --> 01:29:39,340 this expression here 2256 01:29:39,440 --> 01:29:41,340 you can just write out the softmax 2257 01:29:41,440 --> 01:29:43,340 and then basically what we're interested in 2258 01:29:43,440 --> 01:29:45,340 is we're interested in the derivative of the loss 2259 01:29:45,440 --> 01:29:47,340 with respect to the ith logit 2260 01:29:47,440 --> 01:29:49,340 and so basically it's a d by dLi 2261 01:29:49,440 --> 01:29:51,340 of this expression here 2262 01:29:51,440 --> 01:29:53,340 where we have l indexed 2263 01:29:53,440 --> 01:29:55,340 with the specific label y 2264 01:29:55,440 --> 01:29:57,340 and on the bottom we have a sum over j 2265 01:29:57,440 --> 01:29:59,340 of e to the lj 2266 01:29:59,440 --> 01:30:01,340 and the negative log of all that 2267 01:30:01,440 --> 01:30:03,340 so potentially give it a shot 2268 01:30:03,440 --> 01:30:05,340 pen and paper and see if you can actually 2269 01:30:05,440 --> 01:30:07,340 derive the expression for the loss by dLi 2270 01:30:07,340 --> 01:30:09,240 and to implement it here 2271 01:30:09,340 --> 01:30:11,240 okay so I'm going to give away the result here 2272 01:30:11,340 --> 01:30:13,240 so this is some of the math I did 2273 01:30:13,340 --> 01:30:15,240 to derive the gradients 2274 01:30:15,340 --> 01:30:17,240 analytically 2275 01:30:17,340 --> 01:30:19,240 and so we see here that I'm just applying 2276 01:30:19,340 --> 01:30:21,240 the rules of calculus from your first or second year 2277 01:30:21,340 --> 01:30:23,240 of bachelor's degree if you took it 2278 01:30:23,340 --> 01:30:25,240 and we see that the expressions 2279 01:30:25,340 --> 01:30:27,240 actually simplify quite a bit 2280 01:30:27,340 --> 01:30:29,240 you have to separate out the analysis 2281 01:30:29,340 --> 01:30:31,240 in the case where the ith index 2282 01:30:31,340 --> 01:30:33,240 that you're interested in inside logits 2283 01:30:33,340 --> 01:30:35,240 is either equal to the label 2284 01:30:35,340 --> 01:30:37,240 or it's not equal to the label 2285 01:30:37,240 --> 01:30:39,140 in a slightly different way 2286 01:30:39,240 --> 01:30:41,140 and what we end up with is something 2287 01:30:41,240 --> 01:30:43,140 very very simple 2288 01:30:43,240 --> 01:30:45,140 we either end up with basically p at i 2289 01:30:45,240 --> 01:30:47,140 where p is again this vector of 2290 01:30:47,240 --> 01:30:49,140 probabilities after a softmax 2291 01:30:49,240 --> 01:30:51,140 or p at i minus one 2292 01:30:51,240 --> 01:30:53,140 where we just simply subtract a one 2293 01:30:53,240 --> 01:30:55,140 but in any case we just need to calculate 2294 01:30:55,240 --> 01:30:57,140 the softmax p 2295 01:30:57,240 --> 01:30:59,140 and then in the correct dimension 2296 01:30:59,240 --> 01:31:01,140 we need to subtract a one 2297 01:31:01,240 --> 01:31:03,140 and that's the gradient 2298 01:31:03,240 --> 01:31:05,140 the form that it takes analytically 2299 01:31:05,240 --> 01:31:07,140 so let's implement this basically 2300 01:31:07,140 --> 01:31:09,040 but here we are working with batches of examples 2301 01:31:09,140 --> 01:31:11,040 so we have to be careful of that 2302 01:31:11,140 --> 01:31:13,040 and then the loss for a batch 2303 01:31:13,140 --> 01:31:15,040 is the average loss over all the examples 2304 01:31:15,140 --> 01:31:17,040 so in other words 2305 01:31:17,140 --> 01:31:19,040 is the example for all the individual examples 2306 01:31:19,140 --> 01:31:21,040 is the loss for each individual example 2307 01:31:21,140 --> 01:31:23,040 summed up and then divided by n 2308 01:31:23,140 --> 01:31:25,040 and we have to backpropagate through that 2309 01:31:25,140 --> 01:31:27,040 as well and be careful with it 2310 01:31:27,140 --> 01:31:29,040 so dlogits 2311 01:31:29,140 --> 01:31:31,040 is going to be f dot softmax 2312 01:31:31,140 --> 01:31:33,040 pytorch has a softmax function 2313 01:31:33,140 --> 01:31:35,040 that you can call 2314 01:31:35,140 --> 01:31:37,040 and we want to apply the softmax 2315 01:31:37,040 --> 01:31:38,940 on the logits and we want to go 2316 01:31:39,040 --> 01:31:40,940 in the dimension 2317 01:31:41,040 --> 01:31:42,940 that is one 2318 01:31:43,040 --> 01:31:44,940 so basically we want to do the softmax 2319 01:31:45,040 --> 01:31:46,940 along the rows of these logits 2320 01:31:47,040 --> 01:31:48,940 then at the correct positions 2321 01:31:49,040 --> 01:31:50,940 we need to subtract a one 2322 01:31:51,040 --> 01:31:52,940 so dlogits at iterating over all the rows 2323 01:31:53,040 --> 01:31:54,940 and indexing 2324 01:31:55,040 --> 01:31:56,940 into the columns 2325 01:31:57,040 --> 01:31:58,940 provided by the correct labels 2326 01:31:59,040 --> 01:32:00,940 inside yb 2327 01:32:01,040 --> 01:32:02,940 we need to subtract one 2328 01:32:03,040 --> 01:32:04,940 and then finally it's the average loss 2329 01:32:05,040 --> 01:32:06,940 that is the loss 2330 01:32:06,940 --> 01:32:08,840 so in average there's a one over n 2331 01:32:08,940 --> 01:32:10,840 of all the losses added up 2332 01:32:10,940 --> 01:32:12,840 and so we need to also backpropagate 2333 01:32:12,940 --> 01:32:14,840 through that division 2334 01:32:14,940 --> 01:32:16,840 so the gradient has to be scaled down 2335 01:32:16,940 --> 01:32:18,840 by n as well 2336 01:32:18,940 --> 01:32:20,840 because of the mean 2337 01:32:20,940 --> 01:32:22,840 but this otherwise should be the result 2338 01:32:22,940 --> 01:32:24,840 so now if we verify this 2339 01:32:24,940 --> 01:32:26,840 we see that we don't get an exact match 2340 01:32:26,940 --> 01:32:28,840 but at the same time 2341 01:32:28,940 --> 01:32:30,840 the maximum difference from 2342 01:32:30,940 --> 01:32:32,840 logits from pytorch 2343 01:32:32,940 --> 01:32:34,840 and rdlogits here 2344 01:32:34,840 --> 01:32:36,740 is on the order of 5e-9 2345 01:32:36,840 --> 01:32:38,740 so it's a tiny tiny number 2346 01:32:38,840 --> 01:32:40,740 so because of floating point wonkiness 2347 01:32:40,840 --> 01:32:42,740 we don't get the exact bitwise result 2348 01:32:42,840 --> 01:32:44,740 but we basically get 2349 01:32:44,840 --> 01:32:46,740 the correct answer 2350 01:32:46,840 --> 01:32:48,740 approximately 2351 01:32:48,840 --> 01:32:50,740 now I'd like to pause here briefly 2352 01:32:50,840 --> 01:32:52,740 before we move on to the next exercise 2353 01:32:52,840 --> 01:32:54,740 because I'd like us to get an intuitive sense 2354 01:32:54,840 --> 01:32:56,740 of what dlogits is 2355 01:32:56,840 --> 01:32:58,740 because it has a beautiful and very simple 2356 01:32:58,840 --> 01:33:00,740 explanation honestly 2357 01:33:00,840 --> 01:33:02,740 so here I'm taking dlogits 2358 01:33:02,840 --> 01:33:04,740 and I'm visualizing it 2359 01:33:04,740 --> 01:33:06,640 and I see that we have a batch of 32 examples 2360 01:33:06,740 --> 01:33:08,640 of 27 characters 2361 01:33:08,740 --> 01:33:10,640 and what is dlogits intuitively? 2362 01:33:10,740 --> 01:33:12,640 dlogits is the probabilities 2363 01:33:12,740 --> 01:33:14,640 that the probabilities matrix 2364 01:33:14,740 --> 01:33:16,640 in the forward pass 2365 01:33:16,740 --> 01:33:18,640 but then here these black squares 2366 01:33:18,740 --> 01:33:20,640 are the positions of the correct indices 2367 01:33:20,740 --> 01:33:22,640 where we subtracted a 1 2368 01:33:22,740 --> 01:33:24,640 and so what is this doing? 2369 01:33:24,740 --> 01:33:26,640 these are the derivatives on dlogits 2370 01:33:26,740 --> 01:33:28,640 and so let's look at 2371 01:33:28,740 --> 01:33:30,640 just the first row here 2372 01:33:30,740 --> 01:33:32,640 so that's what I'm doing here 2373 01:33:32,740 --> 01:33:34,640 I'm calculating the probabilities 2374 01:33:34,640 --> 01:33:36,540 and then I'm taking just the first row 2375 01:33:36,640 --> 01:33:38,540 and this is the probability row 2376 01:33:38,640 --> 01:33:40,540 and then dlogits of the first row 2377 01:33:40,640 --> 01:33:42,540 and multiplying by n 2378 01:33:42,640 --> 01:33:44,540 just for us so that 2379 01:33:44,640 --> 01:33:46,540 we don't have the scaling by n in here 2380 01:33:46,640 --> 01:33:48,540 and everything is more interpretable 2381 01:33:48,640 --> 01:33:50,540 we see that it's exactly equal to the probability 2382 01:33:50,640 --> 01:33:52,540 of course but then the position 2383 01:33:52,640 --> 01:33:54,540 of the correct index has a minus equals 1 2384 01:33:54,640 --> 01:33:56,540 so minus 1 on that position 2385 01:33:56,640 --> 01:33:58,540 and so notice that 2386 01:33:58,640 --> 01:34:00,540 if you take dlogits at 0 2387 01:34:00,640 --> 01:34:02,540 and you sum it 2388 01:34:02,640 --> 01:34:04,540 it actually sums to 0 2389 01:34:04,640 --> 01:34:06,540 and so you should think of these 2390 01:34:06,640 --> 01:34:08,540 gradients here 2391 01:34:08,640 --> 01:34:10,540 at each cell 2392 01:34:10,640 --> 01:34:12,540 as like a force 2393 01:34:12,640 --> 01:34:14,540 we are going to be basically 2394 01:34:14,640 --> 01:34:16,540 pulling down on the probabilities 2395 01:34:16,640 --> 01:34:18,540 of the incorrect characters 2396 01:34:18,640 --> 01:34:20,540 and we're going to be pulling up 2397 01:34:20,640 --> 01:34:22,540 on the probability 2398 01:34:22,640 --> 01:34:24,540 at the correct index 2399 01:34:24,640 --> 01:34:26,540 and that's what's basically happening 2400 01:34:26,640 --> 01:34:28,540 in each row 2401 01:34:28,640 --> 01:34:30,540 and the amount of push and pull 2402 01:34:30,640 --> 01:34:32,540 is exactly equalized 2403 01:34:32,640 --> 01:34:34,540 because the sum is 0 2404 01:34:34,540 --> 01:34:36,440 and the amount to which we pull down 2405 01:34:36,540 --> 01:34:38,440 on the probabilities 2406 01:34:38,540 --> 01:34:40,440 and the amount that we push up 2407 01:34:40,540 --> 01:34:42,440 on the probability of the correct character 2408 01:34:42,540 --> 01:34:44,440 is equal 2409 01:34:44,540 --> 01:34:46,440 so the repulsion and the attraction are equal 2410 01:34:46,540 --> 01:34:48,440 and think of the neural net now 2411 01:34:48,540 --> 01:34:50,440 as a massive pulley system 2412 01:34:50,540 --> 01:34:52,440 or something like that 2413 01:34:52,540 --> 01:34:54,440 we're up here on top of dlogits 2414 01:34:54,540 --> 01:34:56,440 and we're pulling up 2415 01:34:56,540 --> 01:34:58,440 we're pulling down the probabilities of incorrect 2416 01:34:58,540 --> 01:35:00,440 and pulling up the probability of the correct 2417 01:35:00,540 --> 01:35:02,440 and in this complicated pulley system 2418 01:35:02,440 --> 01:35:04,340 we think of it as sort of like 2419 01:35:04,440 --> 01:35:06,340 this tension translating to this 2420 01:35:06,440 --> 01:35:08,340 complicating pulley mechanism 2421 01:35:08,440 --> 01:35:10,340 and then eventually we get a tug 2422 01:35:10,440 --> 01:35:12,340 on the weights and the biases 2423 01:35:12,440 --> 01:35:14,340 and basically in each update 2424 01:35:14,440 --> 01:35:16,340 we just kind of like tug in the direction 2425 01:35:16,440 --> 01:35:18,340 that we'd like for each of these elements 2426 01:35:18,440 --> 01:35:20,340 and the parameters are slowly given in 2427 01:35:20,440 --> 01:35:22,340 to the tug and that's what training in neural net 2428 01:35:22,440 --> 01:35:24,340 kind of like looks like on a high level 2429 01:35:24,440 --> 01:35:26,340 and so I think the forces of push and pull 2430 01:35:26,440 --> 01:35:28,340 in these gradients are actually 2431 01:35:28,440 --> 01:35:30,340 very intuitive here 2432 01:35:30,440 --> 01:35:32,340 we're pushing and pulling on the correct answer 2433 01:35:32,340 --> 01:35:34,240 and the amount of force that we're applying 2434 01:35:34,340 --> 01:35:36,240 is actually proportional to 2435 01:35:36,340 --> 01:35:38,240 the probabilities that came out 2436 01:35:38,340 --> 01:35:40,240 in the forward pass 2437 01:35:40,340 --> 01:35:42,240 and so for example if our probabilities came out 2438 01:35:42,340 --> 01:35:44,240 exactly correct so they would have had 2439 01:35:44,340 --> 01:35:46,240 zero everywhere except for one 2440 01:35:46,340 --> 01:35:48,240 at the correct position 2441 01:35:48,340 --> 01:35:50,240 then the dlogits would be all 2442 01:35:50,340 --> 01:35:52,240 a row of zeros for that example 2443 01:35:52,340 --> 01:35:54,240 there would be no push and pull 2444 01:35:54,340 --> 01:35:56,240 so the amount to which your prediction is incorrect 2445 01:35:56,340 --> 01:35:58,240 is exactly the amount 2446 01:35:58,340 --> 01:36:00,240 by which you're going to get a pull 2447 01:36:00,340 --> 01:36:02,240 or a push in that dimension 2448 01:36:02,240 --> 01:36:04,140 so if you have for example 2449 01:36:04,240 --> 01:36:06,140 a very confidently mispredicted element here 2450 01:36:06,240 --> 01:36:08,140 then what's going to happen is 2451 01:36:08,240 --> 01:36:10,140 that element is going to be pulled down 2452 01:36:10,240 --> 01:36:12,140 very heavily and the correct answer 2453 01:36:12,240 --> 01:36:14,140 is going to be pulled up to the same amount 2454 01:36:14,240 --> 01:36:16,140 and the other characters 2455 01:36:16,240 --> 01:36:18,140 are not going to be influenced too much 2456 01:36:18,240 --> 01:36:20,140 so the amount to which 2457 01:36:20,240 --> 01:36:22,140 you mispredict is then proportional 2458 01:36:22,240 --> 01:36:24,140 to the strength of the pull 2459 01:36:24,240 --> 01:36:26,140 and that's happening independently 2460 01:36:26,240 --> 01:36:28,140 in all the dimensions of this tensor 2461 01:36:28,240 --> 01:36:30,140 and it's sort of very intuitive 2462 01:36:30,240 --> 01:36:32,140 and very easy to think through 2463 01:36:32,140 --> 01:36:34,040 and that's basically the magic of the cross entropy loss 2464 01:36:34,140 --> 01:36:36,040 and what it's doing dynamically 2465 01:36:36,140 --> 01:36:38,040 in the backward pass of the neural net 2466 01:36:38,140 --> 01:36:40,040 so now we get to exercise number three 2467 01:36:40,140 --> 01:36:42,040 which is a very fun exercise 2468 01:36:42,140 --> 01:36:44,040 depending on your definition of fun 2469 01:36:44,140 --> 01:36:46,040 and we are going to do for batch normalization 2470 01:36:46,140 --> 01:36:48,040 exactly what we did for cross entropy loss 2471 01:36:48,140 --> 01:36:50,040 in exercise number two 2472 01:36:50,140 --> 01:36:52,040 that is we are going to consider it as a glued 2473 01:36:52,140 --> 01:36:54,040 single mathematical expression 2474 01:36:54,140 --> 01:36:56,040 and back propagate through it in a very efficient manner 2475 01:36:56,140 --> 01:36:58,040 because we are going to derive a much simpler formula 2476 01:36:58,140 --> 01:37:00,040 for the backward pass of batch normalization 2477 01:37:00,140 --> 01:37:02,040 and we're going to do that 2478 01:37:02,040 --> 01:37:03,940 using pen and paper 2479 01:37:04,040 --> 01:37:05,940 so previously we've broken up batch normalization 2480 01:37:06,040 --> 01:37:07,940 into all of the little intermediate pieces 2481 01:37:08,040 --> 01:37:09,940 and all the atomic operations inside it 2482 01:37:10,040 --> 01:37:11,940 and then we back propagated through it 2483 01:37:12,040 --> 01:37:13,940 one by one 2484 01:37:14,040 --> 01:37:15,940 now we just have a single sort of forward pass 2485 01:37:16,040 --> 01:37:17,940 of a batch form 2486 01:37:18,040 --> 01:37:19,940 and it's all glued together 2487 01:37:20,040 --> 01:37:21,940 and we see that we get the exact same result as before 2488 01:37:22,040 --> 01:37:23,940 now for the backward pass 2489 01:37:24,040 --> 01:37:25,940 we'd like to also implement 2490 01:37:26,040 --> 01:37:27,940 a single formula basically 2491 01:37:28,040 --> 01:37:29,940 for back propagating through this entire operation 2492 01:37:30,040 --> 01:37:31,940 that is the batch normalization 2493 01:37:32,040 --> 01:37:33,940 so in the forward pass previously 2494 01:37:34,040 --> 01:37:35,940 we took h pre bn 2495 01:37:36,040 --> 01:37:37,940 the hidden states of the pre batch normalization 2496 01:37:38,040 --> 01:37:39,940 and created h preact 2497 01:37:40,040 --> 01:37:41,940 which is the hidden states 2498 01:37:42,040 --> 01:37:43,940 just before the activation 2499 01:37:44,040 --> 01:37:45,940 in the batch normalization paper 2500 01:37:46,040 --> 01:37:47,940 h pre bn is x 2501 01:37:48,040 --> 01:37:49,940 and h preact is y 2502 01:37:50,040 --> 01:37:51,940 so in the backward pass what we'd like to do now 2503 01:37:52,040 --> 01:37:53,940 is we have dh preact 2504 01:37:54,040 --> 01:37:55,940 and we'd like to produce dh pre bn 2505 01:37:56,040 --> 01:37:57,940 and we'd like to do that in a very efficient manner 2506 01:37:58,040 --> 01:37:59,940 so that's the name of the game 2507 01:38:00,040 --> 01:38:01,940 calculate dh pre bn 2508 01:38:02,040 --> 01:38:03,940 given dh preact 2509 01:38:04,040 --> 01:38:05,940 and for the purposes of this exercise 2510 01:38:06,040 --> 01:38:07,940 we're going to ignore gamma and beta 2511 01:38:08,040 --> 01:38:09,940 and their derivatives 2512 01:38:10,040 --> 01:38:11,940 because they take on a very simple form 2513 01:38:12,040 --> 01:38:13,940 in a very similar way to what we did up above 2514 01:38:14,040 --> 01:38:15,940 so let's calculate this 2515 01:38:16,040 --> 01:38:17,940 given that right here 2516 01:38:18,040 --> 01:38:19,940 so to help you a little bit 2517 01:38:20,040 --> 01:38:21,940 like I did before 2518 01:38:22,040 --> 01:38:23,940 I started off the implementation here 2519 01:38:24,040 --> 01:38:25,940 on pen and paper 2520 01:38:26,040 --> 01:38:27,940 and I took two sheets of paper 2521 01:38:28,040 --> 01:38:29,940 to derive the mathematical formulas 2522 01:38:30,040 --> 01:38:31,940 for the backward pass 2523 01:38:31,940 --> 01:38:33,840 so to solve the problem 2524 01:38:33,940 --> 01:38:35,840 just write out the mu sigma square variance 2525 01:38:35,940 --> 01:38:37,840 xi hat and yi 2526 01:38:37,940 --> 01:38:39,840 exactly as in the paper 2527 01:38:39,940 --> 01:38:41,840 except for the Bessel correction 2528 01:38:41,940 --> 01:38:43,840 and then in the backward pass 2529 01:38:43,940 --> 01:38:45,840 we have the derivative of the laws 2530 01:38:45,940 --> 01:38:47,840 with respect to all the elements of y 2531 01:38:47,940 --> 01:38:49,840 and remember that y is a vector 2532 01:38:49,940 --> 01:38:51,840 there's multiple numbers here 2533 01:38:51,940 --> 01:38:53,840 so we have all the derivatives 2534 01:38:53,940 --> 01:38:55,840 with respect to all the y's 2535 01:38:55,940 --> 01:38:57,840 and then there's a gamma and a beta 2536 01:38:57,940 --> 01:38:59,840 and this is kind of like the compute graph 2537 01:38:59,940 --> 01:39:01,840 the gamma and the beta 2538 01:39:01,840 --> 01:39:03,740 there's the x hat 2539 01:39:03,840 --> 01:39:05,740 and then the mu and the sigma square 2540 01:39:05,840 --> 01:39:07,740 and the x 2541 01:39:07,840 --> 01:39:09,740 so we have dl by dyi 2542 01:39:09,840 --> 01:39:11,740 and we want dl by dxi 2543 01:39:11,840 --> 01:39:13,740 for all the i's in these vectors 2544 01:39:13,840 --> 01:39:15,740 so 2545 01:39:15,840 --> 01:39:17,740 this is the compute graph 2546 01:39:17,840 --> 01:39:19,740 and you have to be careful because 2547 01:39:19,840 --> 01:39:21,740 I'm trying to note here that 2548 01:39:21,840 --> 01:39:23,740 these are vectors 2549 01:39:23,840 --> 01:39:25,740 there's many nodes here inside x 2550 01:39:25,840 --> 01:39:27,740 x hat and y 2551 01:39:27,840 --> 01:39:29,740 but mu and sigma 2552 01:39:29,840 --> 01:39:31,740 sorry sigma square 2553 01:39:31,840 --> 01:39:33,740 so you have to be careful with that 2554 01:39:33,840 --> 01:39:35,740 you have to imagine there's multiple nodes here 2555 01:39:35,840 --> 01:39:37,740 or you're going to get your math wrong 2556 01:39:37,840 --> 01:39:39,740 so as an example 2557 01:39:39,840 --> 01:39:41,740 I would suggest that you go in the following order 2558 01:39:41,840 --> 01:39:43,740 one, two, three, four 2559 01:39:43,840 --> 01:39:45,740 in terms of the back propagation 2560 01:39:45,840 --> 01:39:47,740 so back propagate into x hat 2561 01:39:47,840 --> 01:39:49,740 then into sigma square 2562 01:39:49,840 --> 01:39:51,740 then into mu and then into x 2563 01:39:51,840 --> 01:39:53,740 just like in a topological sort 2564 01:39:53,840 --> 01:39:55,740 in micrograd we would go from right to left 2565 01:39:55,840 --> 01:39:57,740 you're doing the exact same thing 2566 01:39:57,840 --> 01:39:59,740 except you're doing it with symbols 2567 01:39:59,840 --> 01:40:01,740 and on a piece of paper 2568 01:40:01,740 --> 01:40:03,640 so for number one 2569 01:40:03,740 --> 01:40:05,640 I'm not giving away too much 2570 01:40:05,740 --> 01:40:07,640 if you want dl of 2571 01:40:07,740 --> 01:40:09,640 dxi hat 2572 01:40:09,740 --> 01:40:11,640 then we just take dl by dyi 2573 01:40:11,740 --> 01:40:13,640 and multiply it by gamma 2574 01:40:13,740 --> 01:40:15,640 because of this expression here 2575 01:40:15,740 --> 01:40:17,640 where any individual yi is just gamma 2576 01:40:17,740 --> 01:40:19,640 times xi hat plus beta 2577 01:40:19,740 --> 01:40:21,640 so it didn't help you 2578 01:40:21,740 --> 01:40:23,640 too much there 2579 01:40:23,740 --> 01:40:25,640 but this gives you basically the derivatives 2580 01:40:25,740 --> 01:40:27,640 for all the x hats 2581 01:40:27,740 --> 01:40:29,640 and so now try to go through this computational graph 2582 01:40:29,740 --> 01:40:31,640 and derive 2583 01:40:31,640 --> 01:40:33,540 what is dl by d sigma square 2584 01:40:33,640 --> 01:40:35,540 and then what is dl by d mu 2585 01:40:35,640 --> 01:40:37,540 and then what is dl by dx 2586 01:40:37,640 --> 01:40:39,540 eventually 2587 01:40:39,640 --> 01:40:41,540 so give it a go 2588 01:40:41,640 --> 01:40:43,540 and I'm going to be revealing the answer 2589 01:40:43,640 --> 01:40:45,540 one piece at a time 2590 01:40:45,640 --> 01:40:47,540 okay, so to get dl by d sigma square 2591 01:40:47,640 --> 01:40:49,540 we have to remember again, like I mentioned 2592 01:40:49,640 --> 01:40:51,540 that there are many x hats here 2593 01:40:51,640 --> 01:40:53,540 and remember that sigma square 2594 01:40:53,640 --> 01:40:55,540 is just a single individual number here 2595 01:40:55,640 --> 01:40:57,540 so when we look at the expression 2596 01:40:57,640 --> 01:40:59,540 for dl by d sigma square 2597 01:40:59,540 --> 01:41:01,440 for dl by d sigma square 2598 01:41:01,540 --> 01:41:03,440 we have that we have to actually 2599 01:41:03,540 --> 01:41:05,440 consider all the possible paths 2600 01:41:05,540 --> 01:41:07,440 that 2601 01:41:07,540 --> 01:41:09,440 we basically have that 2602 01:41:09,540 --> 01:41:11,440 there's many x hats 2603 01:41:11,540 --> 01:41:13,440 and they all feed off from 2604 01:41:13,540 --> 01:41:15,440 they all depend on sigma square 2605 01:41:15,540 --> 01:41:17,440 so sigma square has a large fan out 2606 01:41:17,540 --> 01:41:19,440 there's lots of arrows coming out from sigma square 2607 01:41:19,540 --> 01:41:21,440 into all the x hats 2608 01:41:21,540 --> 01:41:23,440 and then there's a back-replicating signal 2609 01:41:23,540 --> 01:41:25,440 from each x hat into sigma square 2610 01:41:25,540 --> 01:41:27,440 and that's why we actually need to sum over 2611 01:41:27,540 --> 01:41:29,440 all those i's 2612 01:41:29,440 --> 01:41:31,340 into 1 to m 2613 01:41:31,440 --> 01:41:33,340 of the dl by dx hat 2614 01:41:33,440 --> 01:41:35,340 which is the global gradient 2615 01:41:35,440 --> 01:41:37,340 times 2616 01:41:37,440 --> 01:41:39,340 the xi hat by d sigma square 2617 01:41:39,440 --> 01:41:41,340 which is the local gradient 2618 01:41:41,440 --> 01:41:43,340 of this operation here 2619 01:41:43,440 --> 01:41:45,340 and then mathematically 2620 01:41:45,440 --> 01:41:47,340 I'm just working it out here 2621 01:41:47,440 --> 01:41:49,340 and I'm simplifying and you get a certain expression 2622 01:41:49,440 --> 01:41:51,340 for dl by d sigma square 2623 01:41:51,440 --> 01:41:53,340 and we're going to be using this expression 2624 01:41:53,440 --> 01:41:55,340 when we back-propagate into mu 2625 01:41:55,440 --> 01:41:57,340 and then eventually into x 2626 01:41:57,440 --> 01:41:59,340 so now let's continue our back-propagation into mu 2627 01:41:59,340 --> 01:42:01,240 which is dl by d mu 2628 01:42:01,340 --> 01:42:03,240 now again be careful 2629 01:42:03,340 --> 01:42:05,240 that mu influences x hat 2630 01:42:05,340 --> 01:42:07,240 and x hat is actually lots of values 2631 01:42:07,340 --> 01:42:09,240 so for example if our mini-batch size is 32 2632 01:42:09,340 --> 01:42:11,240 as it is in our example that we were working on 2633 01:42:11,340 --> 01:42:13,240 then this is 32 numbers 2634 01:42:13,340 --> 01:42:15,240 and 32 arrows going back to mu 2635 01:42:15,340 --> 01:42:17,240 and then mu going to sigma square 2636 01:42:17,340 --> 01:42:19,240 is just a single arrow 2637 01:42:19,340 --> 01:42:21,240 because sigma square is a scalar 2638 01:42:21,340 --> 01:42:23,240 so in total there are 33 arrows 2639 01:42:23,340 --> 01:42:25,240 emanating from mu 2640 01:42:25,340 --> 01:42:27,240 and then all of them have gradients coming into mu 2641 01:42:27,340 --> 01:42:29,240 and they all need to be summed up 2642 01:42:29,340 --> 01:42:31,240 and so that's why 2643 01:42:31,340 --> 01:42:33,240 when we look at the expression for dl by d mu 2644 01:42:33,340 --> 01:42:35,240 I'm summing up over all the gradients 2645 01:42:35,340 --> 01:42:37,240 of dl by dx i hat 2646 01:42:37,340 --> 01:42:39,240 times dx i hat by d mu 2647 01:42:39,340 --> 01:42:41,240 so that's this arrow 2648 01:42:41,340 --> 01:42:43,240 and that's 32 arrows here 2649 01:42:43,340 --> 01:42:45,240 and then plus the one arrow from here 2650 01:42:45,340 --> 01:42:47,240 which is dl by d sigma square 2651 01:42:47,340 --> 01:42:49,240 times d sigma square by d mu 2652 01:42:49,340 --> 01:42:51,240 so now we have to work out 2653 01:42:51,340 --> 01:42:53,240 that expression 2654 01:42:53,340 --> 01:42:55,240 and let me just reveal the rest of it 2655 01:42:55,340 --> 01:42:57,240 simplifying here is not complicated 2656 01:42:57,340 --> 01:42:59,240 the first term 2657 01:42:59,240 --> 01:43:01,140 and you just get an expression here 2658 01:43:01,240 --> 01:43:03,140 for the second term though 2659 01:43:03,240 --> 01:43:05,140 there's something really interesting that happens 2660 01:43:05,240 --> 01:43:07,140 when we look at d sigma square by d mu 2661 01:43:07,240 --> 01:43:09,140 and we simplify 2662 01:43:09,240 --> 01:43:11,140 at one point if we assume 2663 01:43:11,240 --> 01:43:13,140 that in a special case where mu is actually 2664 01:43:13,240 --> 01:43:15,140 the average of xi's 2665 01:43:15,240 --> 01:43:17,140 as it is in this case 2666 01:43:17,240 --> 01:43:19,140 then if we plug that in 2667 01:43:19,240 --> 01:43:21,140 then actually the gradient vanishes 2668 01:43:21,240 --> 01:43:23,140 and becomes exactly zero 2669 01:43:23,240 --> 01:43:25,140 and that makes the entire second term cancel 2670 01:43:25,240 --> 01:43:27,140 and so 2671 01:43:27,240 --> 01:43:29,140 these 2672 01:43:29,140 --> 01:43:31,040 if you have a mathematical expression like this 2673 01:43:31,140 --> 01:43:33,040 and you look at d sigma square by d mu 2674 01:43:33,140 --> 01:43:35,040 you would get some mathematical formula 2675 01:43:35,140 --> 01:43:37,040 for how mu impacts sigma square 2676 01:43:37,140 --> 01:43:39,040 but if it is the special case 2677 01:43:39,140 --> 01:43:41,040 that mu is actually equal to the average 2678 01:43:41,140 --> 01:43:43,040 as it is in the case of batch normalization 2679 01:43:43,140 --> 01:43:45,040 that gradient will actually vanish 2680 01:43:45,140 --> 01:43:47,040 and become zero 2681 01:43:47,140 --> 01:43:49,040 so the whole term cancels 2682 01:43:49,140 --> 01:43:51,040 and we just get a fairly straightforward expression here 2683 01:43:51,140 --> 01:43:53,040 for dl by d mu 2684 01:43:53,140 --> 01:43:55,040 okay and now we get to the craziest part 2685 01:43:55,140 --> 01:43:57,040 which is deriving dl by d xi 2686 01:43:57,140 --> 01:43:59,040 which is ultimately what we're after 2687 01:43:59,140 --> 01:44:01,040 now let's count 2688 01:44:01,140 --> 01:44:03,040 first of all how many numbers are there inside x 2689 01:44:03,140 --> 01:44:05,040 as I mentioned there are 32 numbers 2690 01:44:05,140 --> 01:44:07,040 there are 32 little xi's 2691 01:44:07,140 --> 01:44:09,040 and let's count the number of arrows 2692 01:44:09,140 --> 01:44:11,040 emanating from each xi 2693 01:44:11,140 --> 01:44:13,040 there's an arrow going to mu 2694 01:44:13,140 --> 01:44:15,040 an arrow going to sigma square 2695 01:44:15,140 --> 01:44:17,040 and then there's an arrow going to x hat 2696 01:44:17,140 --> 01:44:19,040 but this arrow here 2697 01:44:19,140 --> 01:44:21,040 let's scrutinize that a little bit 2698 01:44:21,140 --> 01:44:23,040 each xi hat is just a function of xi 2699 01:44:23,140 --> 01:44:25,040 and all the other scalars 2700 01:44:25,140 --> 01:44:27,040 so xi hat 2701 01:44:27,140 --> 01:44:29,040 only depends on xi 2702 01:44:29,040 --> 01:44:30,940 and all the other x's 2703 01:44:31,040 --> 01:44:32,940 and so therefore 2704 01:44:33,040 --> 01:44:34,940 there are actually in this single arrow 2705 01:44:35,040 --> 01:44:36,940 there are 32 arrows 2706 01:44:37,040 --> 01:44:38,940 but those 32 arrows are going exactly parallel 2707 01:44:39,040 --> 01:44:40,940 they don't interfere 2708 01:44:41,040 --> 01:44:42,940 they're just going parallel between x and x hat 2709 01:44:43,040 --> 01:44:44,940 you can look at it that way 2710 01:44:45,040 --> 01:44:46,940 and so how many arrows are emanating from each xi 2711 01:44:47,040 --> 01:44:48,940 there are three arrows 2712 01:44:49,040 --> 01:44:50,940 mu sigma square 2713 01:44:51,040 --> 01:44:52,940 and the associated x hat 2714 01:44:53,040 --> 01:44:54,940 and so in back propagation 2715 01:44:55,040 --> 01:44:56,940 we now need to apply the chain rule 2716 01:44:57,040 --> 01:44:58,940 and we need to add up those three contributions 2717 01:44:58,940 --> 01:45:00,840 like if I just write that out 2718 01:45:00,940 --> 01:45:02,840 we have 2719 01:45:02,940 --> 01:45:04,840 we're going through 2720 01:45:04,940 --> 01:45:06,840 we're chaining through mu sigma square 2721 01:45:06,940 --> 01:45:08,840 and through x hat 2722 01:45:08,940 --> 01:45:10,840 and those three terms are just here 2723 01:45:10,940 --> 01:45:12,840 now we already have three of these 2724 01:45:12,940 --> 01:45:14,840 we have dl by d xi hat 2725 01:45:14,940 --> 01:45:16,840 we have dl by d mu 2726 01:45:16,940 --> 01:45:18,840 which we derived here 2727 01:45:18,940 --> 01:45:20,840 and we have dl by d sigma square 2728 01:45:20,940 --> 01:45:22,840 which we derived here 2729 01:45:22,940 --> 01:45:24,840 but we need three other terms here 2730 01:45:24,940 --> 01:45:26,840 this one, this one, and this one 2731 01:45:26,940 --> 01:45:28,840 so I invite you to try to derive them 2732 01:45:28,840 --> 01:45:30,740 if you find it complicated 2733 01:45:30,840 --> 01:45:32,740 you're just looking at these expressions here 2734 01:45:32,840 --> 01:45:34,740 and differentiating with respect to xi 2735 01:45:34,840 --> 01:45:36,740 so give it a shot 2736 01:45:36,840 --> 01:45:38,740 but here's the result 2737 01:45:38,840 --> 01:45:40,740 or at least what I got 2738 01:45:40,840 --> 01:45:42,740 I'm just differentiating with respect to xi 2739 01:45:42,840 --> 01:45:44,740 for all of these expressions 2740 01:45:44,840 --> 01:45:46,740 and honestly I don't think there's anything too tricky here 2741 01:45:46,840 --> 01:45:48,740 it's basic calculus 2742 01:45:48,840 --> 01:45:50,740 now what gets a little bit more tricky 2743 01:45:50,840 --> 01:45:52,740 is we are now going to plug everything together 2744 01:45:52,840 --> 01:45:54,740 so all of these terms 2745 01:45:54,840 --> 01:45:56,740 multiplied with all of these terms 2746 01:45:56,840 --> 01:45:58,740 and added up according to this formula 2747 01:45:58,740 --> 01:46:00,640 and that gets a little bit hairy 2748 01:46:00,740 --> 01:46:02,640 so what ends up happening is 2749 01:46:04,740 --> 01:46:06,640 you get a large expression 2750 01:46:06,740 --> 01:46:08,640 and the thing to be very careful with here 2751 01:46:08,740 --> 01:46:10,640 of course is 2752 01:46:10,740 --> 01:46:12,640 we are working with a dl by d xi 2753 01:46:12,740 --> 01:46:14,640 for a specific i here 2754 01:46:14,740 --> 01:46:16,640 but when we are plugging in some of these terms 2755 01:46:16,740 --> 01:46:18,640 like say 2756 01:46:18,740 --> 01:46:20,640 this term here 2757 01:46:20,740 --> 01:46:22,640 dl by d sigma squared 2758 01:46:22,740 --> 01:46:24,640 you see how dl by d sigma squared 2759 01:46:24,740 --> 01:46:26,640 I end up with an expression 2760 01:46:26,740 --> 01:46:28,640 and I'm iterating over little i's here 2761 01:46:28,740 --> 01:46:30,640 but I can't use i as the variable 2762 01:46:30,740 --> 01:46:32,640 when I plug in here 2763 01:46:32,740 --> 01:46:34,640 because this is a different i from this i 2764 01:46:34,740 --> 01:46:36,640 this i here is just a placeholder 2765 01:46:36,740 --> 01:46:38,640 like a local variable for a for loop 2766 01:46:38,740 --> 01:46:40,640 in here 2767 01:46:40,740 --> 01:46:42,640 so here when I plug that in 2768 01:46:42,740 --> 01:46:44,640 you notice that I rename the i to a j 2769 01:46:44,740 --> 01:46:46,640 because I need to make sure that this j 2770 01:46:46,740 --> 01:46:48,640 is not this i 2771 01:46:48,740 --> 01:46:50,640 this j is like a little local iterator 2772 01:46:50,740 --> 01:46:52,640 over 32 terms 2773 01:46:52,740 --> 01:46:54,640 and so you have to be careful with that 2774 01:46:54,740 --> 01:46:56,640 when you are plugging in the expressions from here to here 2775 01:46:56,740 --> 01:46:58,640 you may have to rename i's into j's 2776 01:46:58,640 --> 01:47:00,540 but you have to be very careful 2777 01:47:00,640 --> 01:47:02,540 what is actually an i 2778 01:47:02,640 --> 01:47:04,540 with respect to dl by d xi 2779 01:47:04,640 --> 01:47:06,540 so some of these are j's 2780 01:47:06,640 --> 01:47:08,540 some of these are i's 2781 01:47:08,640 --> 01:47:10,540 and then we simplify this expression 2782 01:47:10,640 --> 01:47:12,540 and I guess like 2783 01:47:12,640 --> 01:47:14,540 the big thing to notice here is 2784 01:47:14,640 --> 01:47:16,540 a bunch of terms are just going to come out to the front 2785 01:47:16,640 --> 01:47:18,540 and you can refactor them 2786 01:47:18,640 --> 01:47:20,540 there is a sigma squared plus epsilon 2787 01:47:20,640 --> 01:47:22,540 raised to the power of negative 3 over 2 2788 01:47:22,640 --> 01:47:24,540 this sigma squared plus epsilon 2789 01:47:24,640 --> 01:47:26,540 can be actually separated out into 3 terms 2790 01:47:26,640 --> 01:47:28,540 each of them are sigma squared plus epsilon 2791 01:47:28,540 --> 01:47:30,440 raised to the power of negative 1 over 2 2792 01:47:30,540 --> 01:47:32,440 so the 3 of them multiplied 2793 01:47:32,540 --> 01:47:34,440 is equal to this 2794 01:47:34,540 --> 01:47:36,440 and then those 3 terms can go different places 2795 01:47:36,540 --> 01:47:38,440 because of the multiplication 2796 01:47:38,540 --> 01:47:40,440 so one of them actually comes out to the front 2797 01:47:40,540 --> 01:47:42,440 and will end up here outside 2798 01:47:42,540 --> 01:47:44,440 one of them joins up with this term 2799 01:47:44,540 --> 01:47:46,440 and one of them joins up with this other term 2800 01:47:46,540 --> 01:47:48,440 and then when you simplify the expression 2801 01:47:48,540 --> 01:47:50,440 you will notice that 2802 01:47:50,540 --> 01:47:52,440 some of these terms that are coming out 2803 01:47:52,540 --> 01:47:54,440 are just the xi hats 2804 01:47:54,540 --> 01:47:56,440 so you can simplify just by rewriting that 2805 01:47:56,540 --> 01:47:58,440 and what we end up with at the end 2806 01:47:58,440 --> 01:48:00,340 is a fairly simple mathematical expression 2807 01:48:00,440 --> 01:48:02,340 over here that I cannot simplify further 2808 01:48:02,440 --> 01:48:04,340 but basically you'll notice that 2809 01:48:04,440 --> 01:48:06,340 it only uses the stuff we have 2810 01:48:06,440 --> 01:48:08,340 and it derives the thing we need 2811 01:48:08,440 --> 01:48:10,340 so we have dl by dy 2812 01:48:10,440 --> 01:48:12,340 for all the i's 2813 01:48:12,440 --> 01:48:14,340 and those are used plenty of times here 2814 01:48:14,440 --> 01:48:16,340 and also in addition what we're using 2815 01:48:16,440 --> 01:48:18,340 is these xi hats and xj hats 2816 01:48:18,440 --> 01:48:20,340 and they just come from the forward pass 2817 01:48:20,440 --> 01:48:22,340 and otherwise this is a 2818 01:48:22,440 --> 01:48:24,340 simple expression and it gives us 2819 01:48:24,440 --> 01:48:26,340 dl by d xi for all the i's 2820 01:48:26,440 --> 01:48:28,340 and that's ultimately what we're interested in 2821 01:48:28,440 --> 01:48:30,340 so that's the end of 2822 01:48:30,440 --> 01:48:32,340 batch norm 2823 01:48:32,440 --> 01:48:34,340 backward pass analytically 2824 01:48:34,440 --> 01:48:36,340 let's now implement this final result 2825 01:48:36,440 --> 01:48:38,340 okay so I implemented the expression 2826 01:48:38,440 --> 01:48:40,340 into a single line of code here 2827 01:48:40,440 --> 01:48:42,340 and you can see that the max diff 2828 01:48:42,440 --> 01:48:44,340 is tiny so this is the correct implementation 2829 01:48:44,440 --> 01:48:46,340 of this formula 2830 01:48:46,440 --> 01:48:48,340 now I'll just 2831 01:48:48,440 --> 01:48:50,340 basically tell you that getting this 2832 01:48:50,440 --> 01:48:52,340 formula here from this mathematical expression 2833 01:48:52,440 --> 01:48:54,340 was not trivial and there's a lot 2834 01:48:54,440 --> 01:48:56,340 going on packed into this one formula 2835 01:48:56,440 --> 01:48:58,340 and this is a whole exercise by itself 2836 01:48:58,440 --> 01:49:00,340 because you have to consider 2837 01:49:00,440 --> 01:49:02,340 the fact that this formula here 2838 01:49:02,440 --> 01:49:04,340 is just for a single neuron 2839 01:49:04,440 --> 01:49:06,340 and a batch of 32 examples 2840 01:49:06,440 --> 01:49:08,340 but what I'm doing here is I'm actually 2841 01:49:08,440 --> 01:49:10,340 we actually have 64 neurons 2842 01:49:10,440 --> 01:49:12,340 and so this expression has to in parallel 2843 01:49:12,440 --> 01:49:14,340 evaluate the batch norm backward pass 2844 01:49:14,440 --> 01:49:16,340 for all of those 64 neurons 2845 01:49:16,440 --> 01:49:18,340 in parallel and independently 2846 01:49:18,440 --> 01:49:20,340 so this has to happen basically in every single 2847 01:49:20,440 --> 01:49:22,340 column of 2848 01:49:22,440 --> 01:49:24,340 the inputs here 2849 01:49:24,440 --> 01:49:26,340 and in addition to that 2850 01:49:26,440 --> 01:49:28,340 you see how there are a bunch of sums here 2851 01:49:28,340 --> 01:49:30,240 and I want to make sure that when I do those sums 2852 01:49:30,340 --> 01:49:32,240 that they broadcast correctly onto everything else 2853 01:49:32,340 --> 01:49:34,240 that's here and so getting this expression 2854 01:49:34,340 --> 01:49:36,240 is just like highly non-trivial 2855 01:49:36,340 --> 01:49:38,240 and I invite you to basically look through it 2856 01:49:38,340 --> 01:49:40,240 and step through it and it's a whole exercise 2857 01:49:40,340 --> 01:49:42,240 to make sure that this checks out 2858 01:49:42,340 --> 01:49:44,240 but once all the shapes agree 2859 01:49:44,340 --> 01:49:46,240 and once you convince yourself that it's correct 2860 01:49:46,340 --> 01:49:48,240 you can also verify that PyTorch 2861 01:49:48,340 --> 01:49:50,240 gets the exact same answer as well 2862 01:49:50,340 --> 01:49:52,240 and so that gives you a lot of peace of mind 2863 01:49:52,340 --> 01:49:54,240 that this mathematical formula is correctly 2864 01:49:54,340 --> 01:49:56,240 implemented here and broadcasted correctly 2865 01:49:56,340 --> 01:49:58,240 and replicated in parallel 2866 01:49:58,240 --> 01:50:00,140 for all of the 64 neurons 2867 01:50:00,240 --> 01:50:02,140 inside this batch norm layer 2868 01:50:02,240 --> 01:50:04,140 okay and finally exercise number 4 2869 01:50:04,240 --> 01:50:06,140 asks you to put it all together 2870 01:50:06,240 --> 01:50:08,140 and here we have a redefinition 2871 01:50:08,240 --> 01:50:10,140 of the entire problem 2872 01:50:10,240 --> 01:50:12,140 so you see that we re-initialized the neural net from scratch 2873 01:50:12,240 --> 01:50:14,140 and everything and then here 2874 01:50:14,240 --> 01:50:16,140 instead of calling loss that backward 2875 01:50:16,240 --> 01:50:18,140 we want to have the manual back propagation 2876 01:50:18,240 --> 01:50:20,140 here as we derived it up above 2877 01:50:20,240 --> 01:50:22,140 so go up copy paste 2878 01:50:22,240 --> 01:50:24,140 all the chunks of code that we've already derived 2879 01:50:24,240 --> 01:50:26,140 put them here and derive your own gradients 2880 01:50:26,240 --> 01:50:28,140 and then optimize this model 2881 01:50:28,140 --> 01:50:30,040 using this neural net 2882 01:50:30,140 --> 01:50:32,040 basically using your own gradients 2883 01:50:32,140 --> 01:50:34,040 all the way to the calibration of the batch norm 2884 01:50:34,140 --> 01:50:36,040 and the evaluation of the loss 2885 01:50:36,140 --> 01:50:38,040 and I was able to achieve quite a good loss 2886 01:50:38,140 --> 01:50:40,040 basically the same loss you would achieve before 2887 01:50:40,140 --> 01:50:42,040 and that shouldn't be surprising 2888 01:50:42,140 --> 01:50:44,040 because all we've done is we've 2889 01:50:44,140 --> 01:50:46,040 really got into loss that backward 2890 01:50:46,140 --> 01:50:48,040 and we've pulled out all the code 2891 01:50:48,140 --> 01:50:50,040 and inserted it here 2892 01:50:50,140 --> 01:50:52,040 but those gradients are identical 2893 01:50:52,140 --> 01:50:54,040 and everything is identical 2894 01:50:54,140 --> 01:50:56,040 and the results are identical 2895 01:50:56,140 --> 01:50:58,040 it's just that we have full visibility 2896 01:50:58,040 --> 01:50:59,940 in this specific case 2897 01:51:00,040 --> 01:51:01,940 okay and this is all of our code 2898 01:51:02,040 --> 01:51:03,940 this is the full backward pass 2899 01:51:04,040 --> 01:51:05,940 using basically the simplified backward pass 2900 01:51:06,040 --> 01:51:07,940 for the cross entropy loss 2901 01:51:08,040 --> 01:51:09,940 and the batch normalization 2902 01:51:10,040 --> 01:51:11,940 so back propagating through cross entropy 2903 01:51:12,040 --> 01:51:13,940 the second layer 2904 01:51:14,040 --> 01:51:15,940 the 10H null linearity 2905 01:51:16,040 --> 01:51:17,940 the batch normalization 2906 01:51:18,040 --> 01:51:19,940 through the first layer 2907 01:51:20,040 --> 01:51:21,940 and through the embedding 2908 01:51:22,040 --> 01:51:23,940 and so you see that this is only maybe 2909 01:51:24,040 --> 01:51:25,940 what is this 20 lines of code or something like that 2910 01:51:26,040 --> 01:51:27,940 and that's what gives us gradients 2911 01:51:27,940 --> 01:51:29,840 in this case loss that backward 2912 01:51:29,940 --> 01:51:31,840 so the way I have the code set up is 2913 01:51:31,940 --> 01:51:33,840 you should be able to run this entire cell 2914 01:51:33,940 --> 01:51:35,840 once you fill this in 2915 01:51:35,940 --> 01:51:37,840 and this will run for only 100 iterations 2916 01:51:37,940 --> 01:51:39,840 and then break 2917 01:51:39,940 --> 01:51:41,840 and it breaks because it gives you an opportunity 2918 01:51:41,940 --> 01:51:43,840 to check your gradients against PyTorch 2919 01:51:43,940 --> 01:51:45,840 so here our gradients we see 2920 01:51:45,940 --> 01:51:47,840 are not exactly equal 2921 01:51:47,940 --> 01:51:49,840 they are approximately equal 2922 01:51:49,940 --> 01:51:51,840 and the differences are tiny 2923 01:51:51,940 --> 01:51:53,840 one in negative nine or so 2924 01:51:53,940 --> 01:51:55,840 and I don't exactly know where they're coming from 2925 01:51:55,940 --> 01:51:57,840 to be honest 2926 01:51:57,840 --> 01:51:59,740 but if I'm basically correct 2927 01:51:59,840 --> 01:52:01,740 we can take out the gradient checking 2928 01:52:01,840 --> 01:52:05,740 we can disable this breaking statement 2929 01:52:05,840 --> 01:52:07,740 and then we can 2930 01:52:07,840 --> 01:52:09,740 basically disable loss that backward 2931 01:52:09,840 --> 01:52:11,740 we don't need it anymore 2932 01:52:11,840 --> 01:52:13,740 feels amazing to say that 2933 01:52:13,840 --> 01:52:15,740 and then here 2934 01:52:15,840 --> 01:52:17,740 when we are doing the update 2935 01:52:17,840 --> 01:52:19,740 we're not going to use p.grad 2936 01:52:19,840 --> 01:52:21,740 this is the old way of PyTorch 2937 01:52:21,840 --> 01:52:23,740 we don't have that anymore 2938 01:52:23,840 --> 01:52:25,740 because we're not doing backward 2939 01:52:25,840 --> 01:52:27,740 we are going to use this update 2940 01:52:27,740 --> 01:52:29,640 I'm grading over 2941 01:52:29,740 --> 01:52:31,640 I've arranged the grads to be in the same order 2942 01:52:31,740 --> 01:52:33,640 as the parameters 2943 01:52:33,740 --> 01:52:35,640 and I'm zipping them up 2944 01:52:35,740 --> 01:52:37,640 the gradients and the parameters 2945 01:52:37,740 --> 01:52:39,640 into p and grad 2946 01:52:39,740 --> 01:52:41,640 and then here I'm going to step with 2947 01:52:41,740 --> 01:52:43,640 just the grad that we derived manually 2948 01:52:43,740 --> 01:52:45,640 so the last piece 2949 01:52:45,740 --> 01:52:47,640 is that none of this now requires 2950 01:52:47,740 --> 01:52:49,640 gradients from PyTorch 2951 01:52:49,740 --> 01:52:51,640 and so one thing you can do here 2952 01:52:51,740 --> 01:52:53,640 is you can do 2953 01:52:53,740 --> 01:52:55,640 with torch.nograd 2954 01:52:55,740 --> 01:52:57,640 and offset this whole code block 2955 01:52:57,740 --> 01:52:59,640 and really what you're saying is 2956 01:52:59,740 --> 01:53:01,640 you're telling PyTorch that hey 2957 01:53:01,740 --> 01:53:03,640 I'm not going to call backward on any of this 2958 01:53:03,740 --> 01:53:05,640 and this allows PyTorch to be 2959 01:53:05,740 --> 01:53:07,640 a bit more efficient with all of it 2960 01:53:07,740 --> 01:53:09,640 and then we should be able to just run this 2961 01:53:09,740 --> 01:53:11,640 and 2962 01:53:11,740 --> 01:53:13,640 it's running 2963 01:53:13,740 --> 01:53:15,640 and you see that 2964 01:53:15,740 --> 01:53:17,640 loss that backward is commented out 2965 01:53:17,740 --> 01:53:19,640 and we're optimizing 2966 01:53:19,740 --> 01:53:21,640 so we're going to leave this run 2967 01:53:21,740 --> 01:53:23,640 and hopefully 2968 01:53:23,740 --> 01:53:25,640 we get a good result 2969 01:53:25,740 --> 01:53:27,640 okay so I allowed the neural net 2970 01:53:27,640 --> 01:53:29,540 optimization then here 2971 01:53:29,640 --> 01:53:31,540 I calibrate the BatchNorm parameters 2972 01:53:31,640 --> 01:53:33,540 because I did not keep track of the running 2973 01:53:33,640 --> 01:53:35,540 mean and variance 2974 01:53:35,640 --> 01:53:37,540 in the training loop 2975 01:53:37,640 --> 01:53:39,540 then here I ran the loss 2976 01:53:39,640 --> 01:53:41,540 and you see that we actually obtained a pretty good loss 2977 01:53:41,640 --> 01:53:43,540 very similar to what we've achieved before 2978 01:53:43,640 --> 01:53:45,540 and then here I'm sampling from the model 2979 01:53:45,640 --> 01:53:47,540 and we see some of the name-like gibberish 2980 01:53:47,640 --> 01:53:49,540 that we're sort of used to 2981 01:53:49,640 --> 01:53:51,540 so basically the model worked and samples 2982 01:53:51,640 --> 01:53:53,540 pretty decent results 2983 01:53:53,640 --> 01:53:55,540 compared to what we were used to 2984 01:53:55,640 --> 01:53:57,540 so everything is the same but of course 2985 01:53:57,540 --> 01:53:59,440 the big deal is that we did not use lots of backward 2986 01:53:59,540 --> 01:54:01,440 we did not use PyTorch AutoGrad 2987 01:54:01,540 --> 01:54:03,440 and we estimated our gradients ourselves 2988 01:54:03,540 --> 01:54:05,440 by hand 2989 01:54:05,540 --> 01:54:07,440 and so hopefully you're looking at this 2990 01:54:07,540 --> 01:54:09,440 the backward pass of this neural net 2991 01:54:09,540 --> 01:54:11,440 and you're thinking to yourself 2992 01:54:11,540 --> 01:54:13,440 actually that's not too complicated 2993 01:54:13,540 --> 01:54:15,440 each one of these layers is like three lines of code 2994 01:54:15,540 --> 01:54:17,440 or something like that 2995 01:54:17,540 --> 01:54:19,440 and most of it is fairly straightforward 2996 01:54:19,540 --> 01:54:21,440 potentially with the notable exception 2997 01:54:21,540 --> 01:54:23,440 of the BatchNormalization backward pass 2998 01:54:23,540 --> 01:54:25,440 otherwise it's pretty good 2999 01:54:25,540 --> 01:54:27,440 okay and that's everything I wanted to cover 3000 01:54:27,540 --> 01:54:29,440 so hopefully you found this interesting 3001 01:54:29,540 --> 01:54:31,440 and what I liked about it honestly is that 3002 01:54:31,540 --> 01:54:33,440 it gave us a very nice diversity of layers 3003 01:54:33,540 --> 01:54:35,440 to backpropagate through 3004 01:54:35,540 --> 01:54:37,440 and I think it gives a pretty nice 3005 01:54:37,540 --> 01:54:39,440 and comprehensive sense of how these 3006 01:54:39,540 --> 01:54:41,440 backward passes are implemented 3007 01:54:41,540 --> 01:54:43,440 and how they work 3008 01:54:43,540 --> 01:54:45,440 and you'd be able to derive them yourself 3009 01:54:45,540 --> 01:54:47,440 but of course in practice you probably don't want to 3010 01:54:47,540 --> 01:54:49,440 and you want to use the PyTorch AutoGrad 3011 01:54:49,540 --> 01:54:51,440 but hopefully you have some intuition about 3012 01:54:51,540 --> 01:54:53,440 how gradients flow backwards through the neural net 3013 01:54:53,540 --> 01:54:55,440 starting at the loss 3014 01:54:55,540 --> 01:54:57,440 and how they flow through all the variables 3015 01:54:57,540 --> 01:54:59,440 and if you understood a good chunk of it 3016 01:54:59,540 --> 01:55:01,440 and if you have a sense of that 3017 01:55:01,540 --> 01:55:03,440 then you can count yourself as one of these 3018 01:55:03,540 --> 01:55:05,440 buff dojis on the left 3019 01:55:05,540 --> 01:55:07,440 instead of the dojis on the right here 3020 01:55:07,540 --> 01:55:09,440 now in the next lecture 3021 01:55:09,540 --> 01:55:11,440 we're actually going to go to recurrent neural nets 3022 01:55:11,540 --> 01:55:13,440 LSTMs and all the other variants 3023 01:55:13,540 --> 01:55:15,440 of RNNs 3024 01:55:15,540 --> 01:55:17,440 and we're going to start to complexify the architecture 3025 01:55:17,540 --> 01:55:19,440 and start to achieve better log likelihoods 3026 01:55:19,540 --> 01:55:21,440 and so I'm really looking forward to that 3027 01:55:21,540 --> 01:55:23,440 and I'll see you then