A | B | C | D | E | F | G | H | I | J | K | L | M | N | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|

1 | ||||||||||||||

2 | Gradient Descent | |||||||||||||

3 | We know by now that neural networks learn by making guesses about the parameters of a function (filters in a convolutional layer, weights in a dense layer), and updating those guesses based on how closely the network's outputs match the known labels in our data. We might also know that networks somehow do this using derivatives, and, as I learned not too long ago, that a derivative is just the rate of change of a thing. | |||||||||||||

4 | ||||||||||||||

5 | The Gradient | |||||||||||||

6 | In order to do this, we need a way to collectively compare the outputs from our network to the labels in our data. There are multiple loss functions we can use here, but I like RMSE (root mean squared error) because it's pretty straightforward mathematically, and I'm a simple man. | |||||||||||||

7 | ||||||||||||||

8 | To understand RMSE, let's imagine that we have a following set of training data where x and y are linearly related. A linear relationship can be represented in the form y = mx + b. In this case, m = 2 and b = 30. | x | y | |||||||||||

9 | 7 | 44 | ||||||||||||

10 | 9 | 48 | ||||||||||||

11 | 28 | 86 | ||||||||||||

12 | 30 | 90 | ||||||||||||

13 | 50 | 130 | ||||||||||||

14 | m_actual | 2 | ||||||||||||

15 | b_actual | 30 | ||||||||||||

16 | ||||||||||||||

17 | We know that the relationship between x and y is linear, but we don't know the value of the parameters m and b. What we can do is make a guess that m=1 and b=1 and see what values of y we'd predict using those values of m and b. | x | y_pred | |||||||||||

18 | 7 | 8 | ||||||||||||

19 | 9 | 10 | ||||||||||||

20 | 28 | 29 | ||||||||||||

21 | 30 | 31 | ||||||||||||

22 | 50 | 51 | ||||||||||||

23 | m_guess | 1 | ||||||||||||

24 | b_guess | 1 | ||||||||||||

25 | ||||||||||||||

26 | We can compare our predicted y values to our actual y values by subtracting them from each other and squaring the result. The square root of the sum of (y_pred - y)^2 gives us RMSE - a single number that tells us how close our guesses are to our actual labels. The lower our RMSE, the more accurate our guesses. | y | y_pred | (y_pred - y) ^2 | ||||||||||

27 | 44 | 8 | 1,296 | |||||||||||

28 | 48 | 10 | 1,444 | |||||||||||

29 | 86 | 29 | 3,249 | |||||||||||

30 | 90 | 31 | 3,481 | |||||||||||

31 | 130 | 51 | 6,241 | |||||||||||

32 | ||||||||||||||

33 | rmse | 125 | ||||||||||||

34 | ||||||||||||||

35 | We can see that our guesses are not that accurate right now! That's fine though - we haven't done any optimization yet. Let's start by visualizing our loss function, or our gradient, which is just a loss function in multiple dimensions (I'm like 80% sure of this). Unfortunately, Google Sheets doesn't support 3D area charts, so we're going to have to make do with conditional formatting. Our columns represent different guesses for the value of m, and our rows different guesses for the value of b. RMSE decreases as we get closer to actual values of m and b: | |||||||||||||

36 | ||||||||||||||

37 | ||||||||||||||

38 | ||||||||||||||

39 | ||||||||||||||

40 | ||||||||||||||

41 | ||||||||||||||

42 | ||||||||||||||

43 | ||||||||||||||

44 | m | |||||||||||||

45 | b | 0 | 3 | 6 | 9 | 12 | 15 | 18 | 21 | 24 | 27 | 30 | ||

46 | 1 | 127 | 121 | 115 | 108 | 102 | 96 | 89 | 83 | 77 | 71 | 66 | ||

47 | 2 | 67 | 60 | 54 | 47 | 40 | 34 | 27 | 20 | 13 | 7 | 0 | ||

48 | 3 | 37 | 36 | 35 | 36 | 38 | 41 | 45 | 50 | 55 | 60 | 66 | ||

49 | 4 | 83 | 87 | 91 | 95 | 100 | 105 | 110 | 115 | 120 | 126 | 131 | ||

50 | 5 | 145 | 150 | 154 | 159 | 164 | 170 | 175 | 180 | 186 | 191 | 197 | ||

51 | 6 | 209 | 214 | 219 | 224 | 230 | 235 | 240 | 246 | 251 | 257 | 263 | ||

52 | 7 | 274 | 279 | 285 | 290 | 295 | 301 | 306 | 312 | 317 | 323 | 328 | ||

53 | 8 | 339 | 345 | 350 | 355 | 361 | 366 | 372 | 377 | 383 | 388 | 394 | ||

54 | ||||||||||||||

55 | The Descent | |||||||||||||

56 | Imagine the above gradient as a landscape with peaks and valleys. If we were blindfolded and dropped onto this landscape with the goal of getting to the its lowest possible point, one way to accomplish our objective would be to test the ground around us with a foot, take a step wherever the descent feels steepest, and repeat. This is what derivatives allow us to do. In the following table, we have 20 values of x along with their actual y values. In the first row, we make guesses about the values of m and b, make a prediction about the value of y, and calculate squared error using our m and b values. Next, we calculate what squared error would be if we added 0.01 to our guess of m. We find the derivative of our error with respect to m (how much our error changed when we made that change to m), and use that information to pick a new value of m, essentially taking a step in the direction of steepest descent. The learn parameter decides how large or small a step we take. We do the same thing with b, copy our values of m and b to the next row, and do it all over again. When we've completed this process against every pair of x and y values in our dataset, we've completed one epoch. The "Run Epoch" button below completes as many epochs as we specify (5 by default) and records the results in a table, showing how RMSE changes after epoch. Hitting "Reset" removes the recorded values and sets our guesses for both parameters back to 1. Try messing around with the parameters and running a few epochs to see if you can build an intuition around gradient descent. You're probably going to want your own copy of this one - the buttons don't work in view-only mode. | |||||||||||||

57 | ||||||||||||||

58 | ||||||||||||||

59 | ||||||||||||||

60 | ||||||||||||||

61 | ||||||||||||||

62 | ||||||||||||||

63 | ||||||||||||||

64 | ||||||||||||||

65 | ||||||||||||||

66 | ||||||||||||||

67 | ||||||||||||||

68 | ||||||||||||||

69 | ||||||||||||||

70 | ||||||||||||||

71 | ||||||||||||||

72 | ||||||||||||||

73 | ||||||||||||||

74 | ||||||||||||||

75 | ||||||||||||||

76 | ||||||||||||||

77 | ||||||||||||||

78 | ||||||||||||||

79 | ||||||||||||||

80 | And if it turns out that making a copy of the spreadsheet doesn't copy the scripts, I made a gist of the scripts attached to this spreadsheet here. | |||||||||||||

81 | ||||||||||||||

82 | ||||||||||||||

83 | actual | guess | ||||||||||||

84 | epochs | 5 | 1 | m | 2 | 2.051 | rmse | 17 | ||||||

85 | learn | 0.0001 | 1 | b | 30 | 25.075 | ||||||||

86 | ||||||||||||||

87 | ||||||||||||||

88 | (mx+b) | (mx+b-y)^2 | 2(mx+b-y) | 2x(mx+b-y) | ||||||||||

89 | x | y | m | b | y_pred | error_sq | err_m1 | de/dm | new_m | err_b1 | de/db | new_b | ||

90 | 7 | 44 | 2.051 | 25.075 | 39.4 | 21 | 20 | -9 | 2.052 | 21 | -64 | 25.081 | ||

91 | 9 | 48 | 2.052 | 25.081 | 43.5 | 20 | 19 | -9 | 2.053 | 20 | -80 | 25.089 | ||

92 | 28 | 86 | 2.053 | 25.089 | 82.6 | 12 | 10 | -7 | 2.053 | 12 | -192 | 25.109 | ||

93 | 30 | 90 | 2.053 | 25.109 | 86.7 | 11 | 9 | -7 | 2.054 | 11 | -197 | 25.128 | ||

94 | 50 | 130 | 2.054 | 25.128 | 127.8 | 5 | 3 | -4 | 2.055 | 5 | -216 | 25.150 | ||

95 | 22 | 74 | 2.055 | 25.150 | 70.4 | 13 | 12 | -7 | 2.055 | 13 | -161 | 25.166 | ||

96 | 3 | 36 | 2.055 | 25.166 | 31.3 | 22 | 22 | -9 | 2.056 | 22 | -28 | 25.169 | ||

97 | 9 | 48 | 2.056 | 25.169 | 43.7 | 19 | 18 | -9 | 2.057 | 19 | -78 | 25.177 | ||

98 | 30 | 90 | 2.057 | 25.177 | 86.9 | 10 | 8 | -6 | 2.058 | 10 | -187 | 25.195 | ||

99 | 23 | 76 | 2.058 | 25.195 | 72.5 | 12 | 11 | -7 | 2.058 | 12 | -160 | 25.211 | ||

100 | 40 | 110 | 2.058 | 25.211 | 107.5 | 6 | 4 | -5 | 2.059 | 6 | -196 | 25.231 |

Loading...