Not your grandfather's colonel
Continuing my series, I wanted to introduce a very useful tool for people working in Machine Learning (ML). ML developed as a subset of Artificial Intelligence in the 60's, and has been surging in popularity in the past 20 years.
Machine Learning is the science of getting computers to act without being explicitly programmed. It's the stuff powering web search, self-driving cars, image recognition software and recommendation engines (not to mention those wonderful ads that follow you around the web).
One of the most utilized techniques involves the concept of a kernel. The kernel is also one of the most confused, black box-type tools in ML literature. I hope this post can help unravel some of the mystery. This might get a little technical but trudge on, I have faith.
Classification
First, the motivation: Our goal is to correctly classify data. You can think of our data as two groups of users - those who subscribe to your service and those who don't. Correctly classifying them into their respective homogeneous groups is a good thing.
In the perfect case, we imagine that these data points, when plotted on a 2 dimensional plane (think scatter plot) can be easily separated by a straight line
Generally speaking, working with straight objects is easier than curvy ones, so we prefer to keep things linear. In this case, our boundary separating the two classes is linear.
One thing to think about is where to put this separating line.
Intuitively, you might want to put it exactly in the middle of the two groups of data because then it's farthest away from the closest point in each group. If you thought that way, then you basically described something called a Support Vector Machine. Though if you googled that, I bet your results wouldn't appear so simple.
In this example, splitting our data is easy because its linearly separable. As I said, we like linear stuff. But what if we couldn't just draw a line through it?
One option is to draw a circle separating red from blue. But as I said above, we prefer lines to curves. So what do we do?
Higher Dimensions != Higher Complexity
The basic idea is that we can raise the dimensionality of our dataset, such that it becomes easier to linearly separate. I'll show you what this means at the end (the big reveal). For now, you're probably wondering why we'd rather increase dimensionality to make solving this easier. After all, generally speaking, the more dimensions you have, the more complex your problem. Operating in 2 dimensions is easy to compute, visualize, and the math is simple. Operating in 1000 dimensions is hard. Many algorithms don't work, many basic assumptions fail, and you can't visualize it. So the question becomes - how can we get to a higher dimensional place without incurring the costs of higher dimensionality?
Obscure the Complexity
We want to hide the messiness associated with higher dimensions. And we do this using something called the kernel trick. A kernel is a function that maps inputs in some dimension into outputs in a higher dimension. The math behind this can be pretty gnarly for the uninitiated, but all you have to know is that it's a neat way to solve problems of high dimensionality without ever having to operate in that dimension - exactly what we're looking to do. To get some intuition, here's an example:
Returning to our example, the goal is to map a 2 dimensional object into 3 dimensional space. Using the kernel trick, we can do this at a much lower cost - more dimensions does not mean more problems.
The basic math underlying the kernel trick is something called an inner product. You can think of inner products as multiplying stuff together to get one number. But I can't just use any kernel that I whip up. The kernel has to satisfy certain properties (symmetric, positive semi-definite) and if it does, then I am guaranteed to be able to map my data into a much higher dimensional space (called a Hilbert space), using only the inner product. Like above, I can add a + b and raise it to the 3rd power, without knowing exactly how to open up the brackets.
With that out of the way, let's recall why we even care about mapping into higher dimensional space - we want to linearly separate messy data.
I can use a linear kernel to map my 2 dimensional data into 3 dimensions, and voila:
And now it's much easier for me to classify the data, because I can just draw a plane in between the red and blue points, like so:
And because of the kernel trick, I got there relatively easily. You can imagine how useful this is when I need to go from 10 dimensions to 1000 dimensions. Or infinite dimensions.
So to recap:
- Separating data using straight lines and planes is generally easier than curves
- Solving a problem in a higher dimensional space allows you to use linear methods for non-linear problems
- While more dimensions usually results in more complexity, we can use the kernel trick to reduce the complexity involved by obscuring the higher dimensional space
- Rotating, 3D images are cool
I have more to say about the kernel and its trickery, but I'll save that for a future post.
Cheers!