Hi all, here’s the third on my series on neural networks / machine learning / AI from scratch. In the previous articles (please read them first!), I explained how a single neuron works, and how to calculate the gradient of its weight and bias. In this article, I’ll explain how you can use those gradients to train the neuron.
I recommend opening this spreadsheet in a separate tab, and viewing it as you read this post which explains the maths: Single neuron training.
In case the linked spreadsheet is lost to posterity, here it is in slightly less well-formatted form (note: for brevity’s sake, I’ve shortened references such as B2 to simply ‘B’ when referring to a column in the same row):
A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | Learning rate | Training | Neuron | Outputs | |||||||||||||
2 | 0.1 | In | Out | Input | Weight | Weight gradient | Bias | Bias gradient | Net | Output | Target | Attempt | Error | Loss | |||
3 | 0.01 | 0.1 (C*10) | 0.01 (C) | 0.5 | J * F | 0.5 | P * (1-L²) | F*G+I | Tanh(K) | 0.1 (D) | 1 | L-N | P² / 2 | ||||
4 | 0.01 | 0.1 (C*10) | 0.01 (C) | G3 - H3 * LEARNING_RATE | J * F | I3 - J3 * LEARNING_RATE | P * (1-L²) | F*G+I | Tanh(K) | 0.1 (D) | 2 | L-N | P² / 2 | ||||
5 | 0.01 | 0.1 (C*10) | 0.01 (C) | G4 - H4 * LEARNING_RATE | J * F | I4 - J4 * LEARNING_RATE | P * (1-L²) | F*G+I | Tanh(K) | 0.1 (D) | 3 | L-N | P² / 2 |
Note: “Parameters” is the umbrella term for “weights and biases”.
A2 is the ‘learning rate’. This governs how much we ‘nudge’ our weight/bias each iteration. In this example it’s higher than a more common 0.1% - 1%.
Columns C-D are the ‘training data’. In this example we want to train the neuron to multiply by 10.
Columns F-L are the neuron maths, as covered by my earlier articles. The two gradients in particular are tricky and important: They dictate which direction the bias/weight should respectively be ‘nudged’ to decrease the error.
Columns N-Q are the outputs, and useful for producing the neat graph you’ll hopefully see in the actual spreadsheet, which demonstrates how the error decreases over the iterations.
Row 3 is the initial data. At this point in a real implementation we would typically choose random values for the initial bias and weight, however I’ve chosen 0.5 to start with because it’s a nice round number.
🧨💣💥 Rows 4+ are the same as row 3, except that the parameters have some of their gradient subtracted each time. (this is the important bit)
Incidentally, this might help explain why training a NN uses a lot more computation than using it: Because of all the gradient calculations and iterations over training data.
And there you have it, that’s how to use the gradients to train a single neuron. Next I’ll explain how to calculate the gradients for a network of them!
Because I’m a Rust tragic, here’s a demo:
const LEARNING_RATE: f64 = 0.01;
const TRAINING_INPUT: f64 = 0.01;
const TRAINING_OUTPUT: f64 = 0.1;
fn main() {
// Initial parameters.
let mut weight: f64 = 0.5;
let mut bias: f64 = 0.5;
// Train.
for _ in 0..100_000 {
let net = TRAINING_INPUT * weight + bias;
let output = net.tanh();
let error = output - TRAINING_OUTPUT;
let loss = error * error / 2.;
let bias_gradient = error * (1. - output * output);
let weight_gradient = bias_gradient * TRAINING_INPUT;
weight -= weight_gradient * LEARNING_RATE;
bias -= bias_gradient * LEARNING_RATE;
}
// Use the trained parameters:
let trained_net = TRAINING_INPUT * weight + bias;
let trained_output = trained_net.tanh();
println!("Trained output: {}", trained_output);
}
Which outputs:
Trained output: 0.1000000000000007
Which matches the training output nicely!
Thanks for reading, hope you found this helpful, at least a tiny bit, God bless!
Photo by Eugene Golovesov on Unsplash
Thanks for reading! And if you want to get in touch, I'd love to hear from you: chris.hulbert at gmail.
(Comp Sci, Hons - UTS)
Software Developer (Freelancer / Contractor) in Australia.
I have worked at places such as Google, Cochlear, Assembly Payments, News Corp, Fox Sports, NineMSN, FetchTV, Coles, Woolworths, Trust Bank, and Westpac, among others. If you're looking for help developing an iOS app, drop me a line!
Get in touch:
[email protected]
github.com/chrishulbert
linkedin