10  Optimization

For the purposes of machine learning, by far the most important application of differentiation and calculus in general is to optimization. Optimization is the problem of finding the “best” values with respect to some function. Usually in machine learning, by “best” we mean finding the minimum value of a loss function, which is a function that measures agreement between a model’s prediction and the data it sees. Finding the minimum value of the loss function essentially means we’ve found the best weights for our model, the ones that give the highest accuracy on the data.

An interesting fact is that for a reasonably smooth function, its minimum value will always be at a point where the derivative is zero. To see why, consider our tangent line plot of \(y=x^2\) from before. What happens if we set our point of interest to be \(x_0=0\)? Clearly that’s the minimum of this function. At this point, the tangent line hugs the parabola horizontally, which means it’s a point where the slope is zero.

f = lambda x: x ** 2
dfdx = lambda x: 2 * x
x0 = 0
y0 = f(x0)
x = np.arange(-10, 10, 0.1)

f_tangent = lambda x: y0 + dfdx(x0) * (x - x0) plot_function(x, (f, f_tangent), (-5, 5), (-2, 10), title=f’Tangent of \(y=x^2\) at \({(x0,y0)}\)’)

This same fact also holds for the maximum of a function as well. Not just the maximum, but any other point where the function is flat, called saddle points. As an example, the origin is a saddle-point of the function \(y=x^3\). These general points where the derivative is zero (min, max, or saddle point) are called stationary points.

In machine learning we usually care most about the minimum. I’ll just mention that we can formulate any maximum problem as a minimum problem by just multiplying the function by -1, which flips the function upside down, turning any maxima into minima.

Now, suppose we have a univariate function \(y=f(x)\). The problem of (unconstrained) optimization is to find a point \(x^*\) such that \(y^* = f(x^*)\) is the minimum value of \(f(x)\), i.e.  \[y^* = \min f(x) \leq f(x) \text{ for all } x.\] The special point \(x^*\) that minimizes the function is called the argmin, written \[x^* = \text{argmin } f(x).\]

I need to mention a subtle point. What do I mean when I say “the minimum”? When I say \(y^* \leq f(x)\) for all \(x\), which \(x\) values am I talking about? This means we’re really only talking about the minimum over some range of \(x\) values. We have to specify what that range is. If the range is the whole real line, it really is the minimum, usually called the global minimum. If it’s over some subset of the real line it may not be the global minimum since we’re not looking at every \(x\). It’s only the minimum in our region of interest. This sort of region-specific minimum is called a local minimum.

While this seems like a subtle point, it is an important one in machine learning. Some algorithms, like deep learning algorithms, can only reliably find a local minimum. Finding the global minimum can be harder unless there’s only one minimum to begin with. These simple functions are called convex functions. Our above example of \(y=x^2\) is a convex function. It only has one minimum, and the function just slopes up around it on both sides in a bowl shape. Deep learning loss functions on the other hand are nasty, wiggly things with lots of bumps and valleys. Such functions are called non-convex functions. In general they’ll have lots of local minima.

So back to the fact about the derivative being zero at the minimum, what we “proved” by example is that at the point \(x^*\) we should have \[\frac{d}{dx}f(x^*)=0.\] Another useful way to state the same fact is to think in terms of infinitesimals: At \(x^*\), any infinitesimal perturbation \(dx\) won’t change the value of the function at all, \(f(x^*+dx) = f(x^*)\). This is just another way of stating that \(dy=0\) at \(x^*\). The fact that small perturbations don’t change the function’s value is unique to minima and other stationary points.

Let’s verify this fact with the same example \(y=x^2\) by looking at small perturbations around \(x=0\). Since \(f(0)=0\) is a minimum, any perturbation should just give \(0\) as well. Choosing a \(dx\) of 1e-5, we can see that the function’s perturbed value \(f(0+dx)\) is only about 1e-10, essentially negligible since \(dx^2 \approx 0\) for infinitesimals. This won’t be true for any other value of \(x\), e.g. \(x=1\), which has a much larger change of 2e-5, which is on the order of \(dx\), as expected.

dx = 1e-5
f(0 + dx) - f(0)
f(1 + dx) - f(1)

Pretty much everything I’ve said on optimization extends naturally to higher dimensions. That’s why I went into so much detail on the simple univariate case. It’s easier to explain and visualize. To extend to \(n\) dimensions we basically just need to convert inputs into vectors and derivatives into gradients. Other than this the formulas all look basically the same.

Suppose we have now a scalar-valued multivariate function \(z=f(\mathbf{x})=f(x_1,\cdots,x_n)\). The problem of (unconstrained) optimization is to find a vector \(\mathbf{x}^* \in \mathbb{R}^n\) such that \(z^* = f(\mathbf{x}^*)\) is the minimum value of \(f(\mathbf{x})\), i.e.  \[z^* = \min f(\mathbf{x}) \leq f(\mathbf{x}) \text{ for all } \mathbf{x} \in \mathbb{R}^n.\] The vector \(\mathbf{x}^*\) that minimizes the function is called the argmin, written \[\mathbf{x}^* = \text{argmin } f(\mathbf{x}).\]

Just as the derivative is zero at the minimum in the univariate case, the gradient is the zero vector at the minimum in the multivariate case, \[\frac{d}{d\mathbf{x}}f(\mathbf{x^*})=\mathbf{0}.\] Another way of stating the same fact is that at the minimum \(f(\mathbf{x^*} + d\mathbf{x}) = f(\mathbf{x^*})\) for any infinitesimal perturbation vector \(d\mathbf{x}\). Equivalently, \(dz=0\).

10.0.1 Gradient Descent

So if the minimum is so important how do we actually find the thing? For simple functions like \(y=x^2\) we can do it just by plotting the function, or by trial and error. We can also do it analytically by solving the equation \(\frac{dy}{dx}\big|_{x^*}=0\) for \(x^*\). But for complicated functions, or functions we can’t exactly write down, this isn’t feasible. We need an algorithmic way to do it.

Let’s try something simple. Since the derivative at \(x\) tells us the slope of the function at \(x\), it’s in some sense telling us how far we are away from the minimum. Suppose we perturb \(x\) to \(dx\). Then \(y=f(x)\) gets perturbed to \(y+dy=f(x+dx)\). Now, observe the almost trivial fact that \[dy = \frac{dy}{dx}dx.\] So if \(\frac{dy}{dx}\) is large, small changes in \(x\) will result in large changes in \(y\). Similarly, if \(\frac{dy}{dx}\) is small, then small changes in \(x\) will result in small changes in \(y\). But we demonstrated above that if we’re near the minimum we know that changes in \(y\) will be tiny if \(dx\) is small. Thus, the derivative serves as a kind of “how close are we to the minimum” metric.

But that’s not all the derivative tells us. Since the sign of the derivative indicates which way the slope is slanting, it also tells us which direction the minimum is in. If you’re at a point on the function, the minimum will always be in the direction that’s sloping downward from you. Since the slope slants upward in the direction of the sign of the derivative, and we want to move downward the other way, the minimum will be in the direction of the negative of the derivative.

More formally, suppose we want to find the minimum of \(y=f(x)\). To start, we’ll pick a point \(x_0\) at random. Doesn’t matter too much how. Pick a step size, we’ll call it \(\alpha\). This will multiply the derivative and tell us how big of a step to take towards the minimum (more on why this is important in a second). Now, we’ll take a step towards the minimum \[x_1 = x_0 - \alpha \frac{dy}{dx}\bigg|_{x_0}.\] This puts us at a new point \(x_1\), which will be closer to the argmin \(x^*\) if our step size is small enough. Now do it again, \[x_2 = x_1 - \alpha \frac{dy}{dx}\bigg|_{x_1}.\] And again, \[x_3 = x_2 - \alpha \frac{dy}{dx}\bigg|_{x_2}.\] Keep doing this over and over. Stop when the points aren’t changing much anymore, i.e. when \(|x_{n+1}-x_n|<\varepsilon\) for some small tolerance \(\varepsilon\). Then we can say that the argmin is \(x^* \approx x_n\), and the minimum is \(y^* \approx f(x_n)\). Done.

This simple algorithm to find the (local) minimum by starting at a random point and steadily marching in the direction of the derivative is called gradient descent. With some relatively minor modifications here and there, gradient descent is how many machine learning algorithms are trained, including essentially all deep learning algorithms. It’s very possibly the most important algorithm in machine learning.

In machine learning, running an optimizer like gradient descent is usually called training. You can kind of imagine optimization as trying to teach something to a model. The condition of being at the minimum is analogous to the model learning whatever task it is you’re trying to teach it. The thing we’re minimizing in this case is the loss function, which is hand-picked essentially to measure how well the model is learning the given task.

The step size \(\alpha\) is so important in machine learning that it’s given a special name, the learning rate. It in essence controls how quickly a model learns, or trains. I’ll use this terminology for \(\alpha\) going forward.

Here’s what the algorithm looks like as a python function gradient_descent. It will take as arguments the function f we’re trying to minimize, the function for its derivative or gradient grad_fn, the initial point x0, the learning rate alpha. I’ll also pass in two optional arguments, max_iter and eps, where max_iter is how many iterations to run gradient descent in the worst case, and eps is the tolerance parameter to indicate when to stop.

def gradient_descent(f, grad_fn, x0, alpha, max_iter=1000, eps=1e-5):
    x_prev = x0  # initialize the algorithm
    for i in range(max_iter):
        x_curr = x_prev - alpha * grad_fn(x_prev)  # gradient descent step
        if np.abs(x_curr - x_prev) < eps:  # if changes are smaller than eps we're done, return x*
            print(f'converged after {i} iterations')
            return x_curr
        x_prev = x_curr
    print(f'failed to converge in {max_iter} iterations')  # else warn and return x* anyway
    return x_curr

Let’s run this algorithm on our simple example \(y=x^2\). Recall its derivative function is \(\frac{dy}{dx}=2x\). I’ll choose an initial point \(x_0=5\) and a learning rate of \(\alpha=0.8\). The optional arguments won’t change.

We can see that gradient descent in this case converges (i.e. finishes) after only 27 iterations. It predicts an argmin of about \(x^* \approx 3 \cdot 10^{-6}\) and a minimum of about \(y^* \approx 9 \cdot 10^{12}\). Since both are basically \(0\) (the true value for both) to within one part in \(10^{-5}\) we seem to have done pretty well here.

Feel free to play around with different choices of the learning rate alpha to see how that affects training time and convergence. Getting a good feel for gradient descent is essential for a machine learning practitioner.

f = lambda x: x ** 2
grad_fn = lambda x: 2 * x
x0 = 5
alpha = 0.8
x_min = gradient_descent(f, grad_fn, x0, alpha)
y_min = f(x_min)
print(f'estimated argmin: {x_min}')
print(f'estimated min: {y_min}')

While I’ve shown the math and code for gradient descent, we’ve still yet to get a good intuition for what the algorithm is doing. For this I’ll turn to a visualization. What I’m going to do is plot the function curve in black, and on top of it show each step of gradient descent. Each red dot on the curve of the function will indicate the point \((x_n,y_n)\) at step \(n\) of the algorithm. Successive steps will be connected by a red line. Each red line will show which points the algorithm jumps from and to at each step. Starting and ending points will be annotated as well.

To do this I’ll use a helper function plot_gradient_descent, which takes in the same arguments as gradient_descent as well as a few more arguments that do some styling of the plot. Internally, all this function is doing is running gradient descent on the given arguments, then plotting the functions, dots, and line segments described.

I’ll start by showing what gradient descent is doing on the exact same example as above. The curve of course is just a parabola sloping upward from the origin. The starting point is just \((x_0,f(x_0))=(5,25)\). After running for \(N=30\) iterations the algorithm basically settles down to \((x_N,f(x_N)) \approx (0,0)\). Notice what’s happening in between though. Imagine you dropped a marble into a bowl at the starting point. After landing, the marble bounces across the bowl several times as it settles down around the origin, where it rolls around less and less until it eventually dissipates all its kinetic energy and settles down at the bottom of the bowl.

plot_gradient_descent(f=f, grad_fn=grad_fn, x0=x0, alpha=alpha, n_iters=30, 
                      title=f'$y=x^2$,  $\\alpha={alpha}$,  $N={30}$,  $x_0={x0}$')

To illustrate what the learning rate is doing, and how important it is to tune it well, let’s try the same problem in two other cases: a really high learning rate, and a really low learning rate. I’ll start with a high learning rate of \(\alpha=1.1\). I’ll run the algorithm this time for \(N=20\) iterations.

Pay particular attention in this case to the start and end labels. Evidently choosing a high learning rate caused the algorithm not to spiral down towards the minimum, but to spiral up away from the minimum! This is the hallmark of choosing too large a learning rate. The algorithm won’t converge at all. It’ll just keep shooting further and further away from the minimum.

alpha = 1.1
N = 20
plot_gradient_descent(f=f, grad_fn=grad_fn, x0=x0, alpha=alpha, n_iters=10, 
                      title=f'$y=x^2$,  $\\alpha={alpha}$,  $N={30}$,  $x_0={x0}$')

Let’s now look at a low learning rate of \(\alpha=0.01\). I’ll run this one for \(N=150\) iterations. Notice now that the algorithm is indeed converging towards the minimum, but it’s doing it really, really slowly. It’s not bouncing around the bowl at all, but rather slowly crawling down in small steps. This is the hallmark of using too low a learning rate. The algorithm will converge, but it’ll do so really, really slowly, and you’ll need to train for a lot of iterations.

alpha = 0.01
N = 150
plot_gradient_descent(f=f, grad_fn=grad_fn, x0=x0, alpha=alpha, n_iters=N,
                      title=f'$y=x^2$,  $\\alpha={alpha}$,  $N={N}$,  $x_0={x0}$')

Things may seem all fine and good. We have an algorithm that seems like it can reliably find the minimum of whatever function we give it, at least in the univariate case. Unfortunately, there are a few subtleties involved that I’ve yet to mention. It turns out that the function I picked, \(y=x^2\) is a particularly easy function to minimize. It’s a convex function. Not all functions behave that nicely. Practically no loss function in deep learning does.

If a function is non-convex (i.e. not bowl-shaped) it can have multiple minima. This means that you can’t be sure gradient descent will pick out the global minimum if you run it. Which minimum it settles in will depend on your choice of initial point \(x_0\), the learning rate \(\alpha\), and perhaps even the number of iterations \(N\) you run the algorithm.

This isn’t the only problem, or even the worst problem. Perhaps the worst problem is saddle points. If there are saddle points in the function, gradient descent may well settle down on one of those instead of any of the minima. Here’s an example of this. Let’s look at the function \(y=x^3 + (x+1)^4\). Its derivative function turns out to be \(\frac{dy}{dx}=3x^2 + 4(x+1)^3\). Check WolframAlpha if you don’t believe me.

Now, suppose we want to find the minimum of this function. Not knowing any better, we pick an initial point \(x_0=3\), and just to be safe we pick a small learning rate \(\alpha=0.001\). Let’s run gradient descent now for \(N=500\) iterations. Surely that’s enough to find the minimum, right?

Evidently not. The true minimum seems to be somewhere around the point \((-2.8, -12)\). The algorithm didn’t settle down anywhere near this point. It settled around the origin \((0,0)\). So what happened? If you look closely, you’ll see it got stuck in a flat spot, i.e. a saddle point. No matter how many iterations you run gradient descent with this learning rate, it will never leave this flat spot. It’s stuck.

f = lambda x: x ** 3 + (x + 1) ** 4
grad_fn = lambda x: 3 * x ** 2 + 4 * (x + 1) ** 3
x0 = 3
N = 500
alpha=0.001
plot_gradient_descent(f, grad_fn, x0, alpha=alpha, n_iters=N, xlim=(-4, 2), ylim=(-15, 50), 
                      title=f'$y=x^3 + (x-1)^4$,  $\\alpha={alpha}$,  $N={N}$,  $x_0={x0}$')

All isn’t necessarily lost. What happens if we pick a higher learning rate to let the algorithm bounce around the function a little bit before slowing down? Let’s pick \(\alpha=0.03\) now and run for the same number of iterations. Now it looks like we’re doing just fine. Gradient descent was able to bounce across the flat spot and settle down at the other side.

alpha=0.03
N = 100
plot_gradient_descent(f, grad_fn, x0, alpha=alpha, n_iters=N, xlim=(-6, 4), ylim=(-15, f(3) + 20), 
                      annotate_start_end=True,
                      title=f'$y=x^3 + (x-1)^4$,  $\\alpha={alpha}$,  $N={N}$,  $x_0={x0}$')

This example was meant to show that saddle points can be a real issue. Gradient descent will not tell you if the point it found is a minimum or a saddle point, it’ll just stop running and spit out a value. You thus need to be careful about things like this when running gradient descent on real-life functions. It’s even worse in higher dimensions, where it turns out that almost all stationary points will be saddle points, and very few will be minima or maxima.

For these reasons, it’s common in machine learning to not use a tolerance condition like \(|x_{n}-x_{n-1}| < \varepsilon\). Instead we just specify some number of iterations \(N\) and run the algorithm \(N\) times. Basically, we want to give the algorithm a chance to get out of a flat spot if it gets stuck in one for some reason. Said differently, if a function is not convex, and most in machine learning are not convex, the notion of convergence doesn’t necessarily mean that much since we don’t even know if we’re at a minimum or not.

The gradient descent algorithm works exactly the same as in the univariate case, except we now use the gradient vector instead of the derivative at each step. Here’s the algorithm in steps: 1. Initialize a starting vector \(\mathbf{x}_0\). 2. For \(N\) iterations, perform the gradient descent update \[\mathbf{x}_n = \mathbf{x}_{n-1} - \alpha \frac{dz}{d\mathbf{x}}\bigg|_{\mathbf{x}=\mathbf{x}_{n-1}}.\] 3. Converge either when some convergence criterion is satisfied, \(||\mathbf{x}_n-\mathbf{x}_{n-1}||_2 \leq \varepsilon\), or when some maximum number of iterations \(N\) is reached. 4. Return \(\mathbf{x}_N\). The best guess for the argmin is \(\mathbf{x}^* \approx \mathbf{x}_N\), and for the minimum is \(z^* \approx f(\mathbf{x}_N)\).

Aside: I’ll quickly note that gradient descent isn’t the only minimization algorithm. Some other algorithms worth noting use not just the first derivative in their updates, but also the second derivative. Examples include algorithms like Newton’s Method and LBFGS. The second derivative provides information about the curvature of the function, which can speed up convergence by making the learning rate adaptive. While these second-order algorithms are useful in some areas of machine learning, it usually turns out to be far too computationally expensive to calculate the second derivative (also called the Hessian) of a function in high dimensions. Perhaps the main reason gradient descent is used in machine learning is because it provides a good tradeoff between its speed of convergence and computational performance.

This pretty much covers everything I wanted to talk about regarding optimization, the most important application of calculus to machine learning. In future lessons we’ll spend more time talking about gradient descent as well as its more modern variants like SGD and Adam.