Published on

Machine Learning Part 2: Loss Functions & Convexity

Authors
  • avatar
    Name
    Jyotir Sai
    Twitter
    Engineering Student

Loss functions tell us how far our predicted value is from the ground truth. What are some important properties that loss functions should have? Why do we use squared error loss functions? The purpose of this blog is to answer those questions. Before diving into loss functions and convexity, I want to talk about the math notation we'll be using.

Vector Representation

The input matrix, X\mathbf{X}, can be written as:

X=(x(1)x(2)...x(M))ϵRM\mathbf{X} = \begin{pmatrix} - \mathbf{x^{(1)}} -\\ - \mathbf{x^{(2)}} -\\ ...\\ - \mathbf{x^{(M)}} -\\ \end{pmatrix} \epsilon \mathbb{R}^{M}

In the above representation, each row vector corresponds to a sample and we have MM samples/rows. The values in this vector are an element of the RM\mathbb{R}^{M} space because if we have MM samples then we have MM dimensions. The values in a single row vector correspond to each feature i:

x(1)=[x1(1),x2(1),...,xi(1)]\mathbf{x^{(1)}} = [ x_{1}^{(1)}, x_{2}^{(1)}, ... , x_{i}^{(1)}]
x(2)=[x1(2),x2(2),...,xi(2)]\mathbf{x^{(2)}} = [ x_{1}^{(2)}, x_{2}^{(2)}, ... , x_{i}^{(2)}]
......
x(M)=[x1(M),x2(M),...,xi(M)]\mathbf{x^{(M)}} = [ x_{1}^{(M)}, x_{2}^{(M)}, ... , x_{i}^{(M)}]

It is also common to write the matrix with the rows as features and columns as samples:

X=(x(1)x(2)...x(M))\mathbf{X} = \begin{pmatrix} | & | & & | \\ \mathbf{x^{(1)}} & \mathbf{x^{(2)}} & ... & \mathbf{x^{(M)}} \\ | & | & & | \end{pmatrix}

In machine learning, we usually use matrices / vectors to represent our variables so we can take advantage of vectorization.

Let's talk about how to use linear regression when dealing with vectors. The equation for linear regression with a matrix of input values is written below.

h(X)=wTX+bh(\mathbf{X}) = \mathbf{w}^{T}\mathbf{X}+b

Instead of a single slope value, we have a vector of values, w\mathbf{w}, known as the weights.

w=(w1w2...wi)ϵRi\mathbf{w} = \begin{pmatrix} w_{1}\\ w_{2}\\ ...\\ w_{i}\\ \end{pmatrix} \epsilon \mathbb{R}^{i}

Each corresponding input feature ximx_{i}^{m}, has a corresponding weight wiw_{i}. The y-intercept, bb, is still represented with bb except we call this term the bias.

For aesthetic reasons, the bias term is sometimes absorbed into the weight vector so we can write:

h(X)=wTXh(\mathbf{X}) = \mathbf{w}^{T}\mathbf{X}

where X\mathbf{X} and w\mathbf{w} are now:

X=(1x1...xd)ϵRd+1,w=(bw1...wd)ϵRd+1\mathbf{X} = \begin{pmatrix} 1\\ x_{1}\\ ...\\ x_{d}\\ \end{pmatrix} \epsilon \mathbb{R}^{d+1}, \mathbf{w} = \begin{pmatrix} b\\ w_{1}\\ ...\\ w_{d}\\ \end{pmatrix} \epsilon \mathbb{R}^{d+1}

The above vector representation is the same as writing:

h(x)=b+w1x1+w2x2+...+xdwdh(x) = b+w_{1}x_{1}+w_{2}x_{2}+...+x_{d}w_{d}

Loss Function

We previously discussed the squared error loss function:

J=12m=1M(ymh(xm))2J = \frac{1}{2}\sum_{m=1}^{M}\left ( y_{m}-h(x_{m}) \right )^2

Another loss function we can use is the mean squared error:

J=1Mm=1M(ymh(xm))2J = \frac{1}{M}\sum_{m=1}^{M}\left ( y_{m}-h(x_{m}) \right )^2

Now, let's express the loss function in a matrix-vector form. The L2 norm of a vector is defined as:

ν2=i=1dνi2\left \| \nu \right \|^{2} = \sum_{i=1}^{d}\nu _{i}^{2}

The L2 norm can be used to express the summation term in the loss function in vector form:

m=1M(ymh(xm))2yh2\sum_{m=1}^{M}\left ( y_{m}-h(x_{m}) \right )^2 \rightarrow \left \| \mathbf{y}-\mathbf{h} \right \|^{2}

As seen above, hh is simply equal to wTX\mathbf{w}^{T}\mathbf{X}.

yh2ywTX2\left \| \mathbf{y}-\mathbf{h} \right \|^{2} \rightarrow \left \| \mathbf{y}-\mathbf{w}^{T}\mathbf{X} \right \|^{2}

The matrix-vector version of the mean squared error is therefore:

J=1MywTX2J = \frac{1}{M} \left \| \mathbf{y}-\mathbf{w}^{T}\mathbf{X} \right \|^{2}

or

J=1MXwy2J = \frac{1}{M} \left \| \mathbf{X}\mathbf{w}-\mathbf{y} \right \|^{2}

Why do we use these loss functions in particular? Why do we square the error instead of cubing it or raising it to a higher power? Since we want the minimum of a loss function, we want to differentiate the loss function and find where its derivative is 0. Therefore, we want to choose a loss function that is differentiable. We also want a loss function where the point at which the derivative is 0 corresponds to the global minimum. More formally, we want a loss function that is both smooth and convex.

Convex Sets

Before talking about convex functions, we'll first have to cover convex sets. A set, SS, is convex if and only if

x,yϵS,λϵ(0,1)\forall \: x, y \: \epsilon \: S \: \: , \: \: \forall \: \lambda \: \epsilon \: (0,1)
λx+(1λ)yϵS\lambda x + (1-\lambda)y \: \epsilon \: S

In plain english, the above means that for all x,yx, y that are an element of SS, and for all λ\lambda values between 0 and 1, the equation on the second line yields a value that is also apart of the set. Let's look at some visual examples.

The example on the left is a convex set because for all points xx and yy, the line that connects them will always be inside the set. The example on the right is not a convex set since the line between xx and yy goes outside of the set. Our value of λ\lambda picks a point on this line. For example, λ=0.5\lambda = 0.5 results in the red point in the middle of the line.

λx+(1λ)y=0.5x+(10.5)y=12(x+y)\lambda x + (1-\lambda)y = 0.5x+(1-0.5)y = \frac{1}{2}(x+y)

Simple sets like the empty set, lines, and hyperplanes are all considered convex. Discontinuous sets are not convex.

Convex Functions

A real-valued function, ff, is convex if the domain of ff is a convex set. For all xx and yy in the domain of ff, and for all λϵ(0,1)\lambda \: \epsilon \: (0, 1), we have the following relation

λf(x)+(1λ)f(y)f(λx+(1λ)y)\lambda f(x) + (1-\lambda)f(y) \geq f(\lambda x + (1-\lambda)y)

The above relation holds true for a convex function.

A quadratic is a convex function, so the green point will always lie above the red point. The line between the points f(x)f(x) and f(y)f(y) will always lie above the function ff.

The cubic function shown above is not a convex function since the green point lies below the red point, violating the above relation.

There are two more properties of convex functions that you should know. Let's start with the 1st-order condition which states that for all x,yx,y in the domain of ff

f(y)f(x)+xf(x)T(yx)f(y) \geq f(x) + \nabla_{x} f(x)^{T} (y-x)

The term on the right is the 1st-order Taylor polynomial expansion where

xf(x)=(x1f(x)...xif(x))\nabla_{x} f(x) = \begin{pmatrix} \frac{\partial}{\partial x_{1}}f(x)\\ ...\\ \frac{\partial}{\partial x_{i}}f(x) \end{pmatrix}

Let's again look at a visual example.

The blue line is the 1st-order Taylor expansion (tangent line). According to the 1st-order condition, this line will never cross "inside" the function. It will always be less than or equal to f(y)f(y). All the points on the graph are "above" the tangent line.

The 2nd-order condition states that

x2f(x)0\nabla_{x}^{2}f(x) \geq 0

where x2f(x)\nabla_{x}^{2}f(x) is the Hessian matrix. The 2nd-order condition says that if the 2nd-order derivative of a function is greater than or equal to 0 (positive semi-definite), then the function is convex. A matrix is positive semi-definite if its eigenvalues are greater than or equal to 0.

To summarize, if a function, ff, is twice differentiable, then the following conditions are equivalent.

  1. λf(x)+(1λ)f(y)f(λx+(1λ)y)\lambda f(x) + (1-\lambda)f(y) \geq f(\lambda x + (1-\lambda)y) (ff is convex)

  2. f(y)f(x)+xf(x)T(yx)f(y) \geq f(x) + \nabla_{x} f(x)^{T} (y-x)

  3. x2f(x)0\nabla_{x}^{2}f(x) \geq 0

Revisiting Loss Functions

We previously looked at the mean squared error loss function:

J=1MXwy2J = \frac{1}{M} \left \| \mathbf{X}\mathbf{w}-\mathbf{y} \right \|^{2}

We add a factor of 1/2 to the above function so that when we take the derivative it cancels out the 2 from the exponent. It is done purely for aesthetic reasons.

J=12MXwy2J = \frac{1}{2M} \left \| \mathbf{X}\mathbf{w}-\mathbf{y} \right \|^{2}

Now, let's work on finding the derivative for this function. The L2-norm can be rewritten as

J=12M(Xwy)T(Xwy)J = \frac{1}{2M} \left ( \mathbf{X}\mathbf{w}-\mathbf{y} \right )^{T}\left ( \mathbf{X}\mathbf{w}-\mathbf{y} \right )

Remember that (AB)T=BTAT(AB)^{T}=B^{T}A^{T}

J=12M(wTXTyT)(Xwy)J = \frac{1}{2M} \left ( \mathbf{w}^{T} \mathbf{X}^{T}-\mathbf{y}^{T} \right )\left ( \mathbf{X}\mathbf{w}-\mathbf{y} \right )
J=12M[wTXTXwwTXTyyTXw+yTy]J = \frac{1}{2M} \left [ \mathbf{w}^{T} \mathbf{X}^{T} \mathbf{X}\mathbf{w}-\mathbf{w}^{T} \mathbf{X}^{T} \mathbf{y}-\mathbf{y}^{T}\mathbf{X}\mathbf{w}+\mathbf{y}^{T}\mathbf{y} \right ]

Inside the brackets, the two terms in the middle are actually the same since

yTXw=(yTXw)T=wTXTy\mathbf{y}^{T}\mathbf{X}\mathbf{w}=\left (\mathbf{y}^{T} \mathbf{X} \mathbf{w} \right)^{T}=\mathbf{w}^{T}\mathbf{X}^{T}\mathbf{y}

The above term actually results in a scalar, that's why it's equal to its transpose.

J=12M[wTXTXw2yTXw+yTy]J = \frac{1}{2M} \left [ \mathbf{w}^{T} \mathbf{X}^{T} \mathbf{X}\mathbf{w}-2\mathbf{y}^{T}\mathbf{X}\mathbf{w}+\mathbf{y}^{T}\mathbf{y} \right ]

Taking the derivative

J(w)w=12M[2XTXw2XTy]\frac{\partial J(\mathbf{w})}{\partial \mathbf{w}} = \frac{1}{2M}\left [2 \mathbf{X}^T \mathbf{X} \mathbf{w} - 2 \mathbf{X}^T \mathbf{y} \right]

The above derivative uses the following identities

xxTSx=2Sx\frac{\partial }{\partial x} \mathbf{x}^T \mathbf{S} \mathbf{x} = 2\mathbf{S}\mathbf{x}
xAx=AT\frac{\partial }{\partial x} \mathbf{A}\mathbf{x} = \mathbf{A}^{T}

Where S=XTX\mathbf{S}=\mathbf{X}^T \mathbf{X} and A=yTX\mathbf{A}=\mathbf{y^{T}} \mathbf{X}. The yTy\mathbf{y}^T \mathbf{y} term has no w\mathbf{w} term so its derivative is just 0. The derivative can be further simplified to

J(w)w=1MXT(Xwy)\frac{\partial J(\mathbf{w})}{\partial \mathbf{w}} = \frac{1}{M} \mathbf{X}^{T} \left (\mathbf{X}\mathbf{w}-\mathbf{y} \right)

Taking the 2nd derivative

2J(w)w2=w2J(w)=1MXTX\frac{\partial^{2} J(\mathbf{w})}{\partial \mathbf{w}^{2}} = \nabla_{w}^{2} J(\mathbf{w}) = \frac{1}{M} \mathbf{X}^{T} \mathbf{X}

We want to know whether the above equation satisfies the 2nd-order condition (w2f(w)0\nabla_{w}^{2}f(w) \geq 0). How do we tell whether the matrix XTX\mathbf{X}^T\mathbf{X} is positive semi-definite? A matrix M\mathbf{M} is positive semi-definite if the number produced by zTMz\mathbf{z}^{T}\mathbf{M}\mathbf{z} is non-negative where z\mathbf{z} is a nonzero column vector.

zT(XTX)z\mathbf{z}^T \left (\mathbf{X}^T \mathbf{X} \right ) \mathbf{z}
(Xz)T(Xz)=Xz20\left (\mathbf{Xz} \right )^{T} \left (\mathbf{Xz} \right ) = \left \| \mathbf{Xz} \right \|^{2} \geq 0

The above shows that the mean-squared error function is indeed convex.