Chapter 8 K-Nearest Neighbors

8.1 TL;DR

What it does
K-Nearest Neighbors classifies observations by judging their proximity to (\(k\)) other nearby classified neighbors.
When to do it
When you’re working through all of the various classification methods to see which one works the best for your data
How to do it
With the knn() function from the class library
How to assess it
Check the prediction accuracy, as with the other classification methods

8.2 What it does

K-Nearest Neighbors starts with known categories of clustered data, and assigns new data to one of these categories by grouping it with its (\(k\)) nearest neighbor(s). This is judged in whatever dimensional space is appropriate to the number of predictor variables; if only using one predictor, as in the example here, then the points are simply arranged on a line; 2 predictors puts them on a plane, and so forth.

Once the new observation is located against the known categories, a sort of “lasso” extends out from the point in all available directions until it contains \(k\) other neighbors. Whichever category has the most points is assigned to the new observation.

8.3 When to do it

When you’re still looking for the best classification method for your data. Maybe this will be it this time!

8.4 How to do it

We will continue with the Boston data set and the same training split.

data(Boston)
boston <- Boston %>%
  mutate(
    # Create the categorical crim_above_med response variable
    crim_above_med = as.factor(ifelse(crim > median(crim), "Yes", "No")),
  )

We again split the data into training and test sets:

set.seed(1235)
boston.training <- rep(FALSE, nrow(boston))
boston.training[sample(nrow(boston), nrow(boston) * 0.8)] <- TRUE
boston.test <- !boston.training

Unlike the other classification methods used so far, the knn() function needs the data columns in a simple matrix. Also, we supply the training and test data together in a single call to knn().

So first, create the matrices from the data columns. We only have one column in our model here, so we need to convert it to a data.frame because knn() won’t accept a plain vector:

train.X <- as.data.frame(boston[boston.training, ]$nox)
test.X <- as.data.frame(boston[boston.test, ]$nox)

If we had more than one column, we could instead bind them with cbind():

train.X <- cbind(boston$nox, boston$something_else)[boston.training, ]
test.X <- cbind(boston$nox, boston$something_else)[boston.test, ]

Note that it is highly recommended to standardize all columns to the same scale when using more than one, unless a weighted scale is desired, and even then it would likely be better to standardize first and then scale from there.

train.X <- cbind(scale(boston$nox), scale(boston$something_else))[boston.training, ]
test.X <- cbind(scale(boston$nox), scale(boston$something_else))[boston.test, ]

We also extract the outcome variable into a vector.

# Also extract the outcome variable.
train.crim_above_med <- boston$crim_above_med[boston.training]

Now set a seed and run the knn() function. Select a value for \(k\), the number of nearest neighbors to use in the classifier; here we will start with 1, which means each new observation will be classified according to its single closest neighbor. (Higher values will decide based on the majority of the \(k\) nearest neighbors; keeping \(k\) prime, if more than 1, will avoid ties; alternately, ties can be randomly assigned.)

set.seed(1235)
knn.pred <- knn(train.X, test.X, train.crim_above_med, k = 1)

8.5 How to assess it

Check the table of predictions against the test data and see how it did.

table(knn.pred, boston[boston.test,]$crim_above_med)
##         
## knn.pred No Yes
##      No  40   2
##      Yes  1  59
mean(knn.pred == boston[boston.test,]$crim_above_med)
## [1] 0.9705882

In this case, we achieved 97% accuracy with knn() and \(k\) = 1. If we didn’t get a good enough result, we could run knn() again with different values for \(k\) to see if we could improve the results. The best result here can then be compared against the results with other classification methods.

(It’s unusual for \(k\) = 1 to be the best number, as it will tend to overfit; some exploration will be needed to find the best results, and it’s simple enough to try a range in a loop.)

8.6 Where to learn more

References

James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. 2021. An Introduction to Statistical Learning. 2nd ed. Springer. https://link.springer.com/book/10.1007/978-1-0716-1418-1.