# Chapter 4 Classification Models for Computer Vision

## 4.1 Hand-Written Digit Classifiation

In this module, we will examine the MNIST dataset, which is a set of 70,000 images of digits handwritten by high school students and employees of the US Census Bureau.

MNIST is considered the “hello-world” of the machine-learning world, and is often a good place to start for understanding classification algorithms.

``````library(MicrosoftML)
library(tidyverse)
library(magrittr)``````
``````##
## Attaching package: 'magrittr'``````
``````## The following object is masked from 'package:purrr':
##
##     set_names``````
``````## The following object is masked from 'package:tidyr':
##
##     extract``````
``````library(dplyrXdf)
theme_set(theme_minimal())

mnist_xdf <- file.path("data", "MNIST.xdf")
mnist_xdf <- RxXdfData(mnist_xdf)``````

Let’s take a look at the data:

``rxGetInfo(mnist_xdf)``
``````## File name: /home/alizaidi/bookdown-demo/data/MNIST.xdf
## Number of observations: 70000
## Number of variables: 786
## Number of blocks: 7
## Compression type: zlib``````

Our dataset contains 70K records, and 786 columns. There are actually 784 features, because each image in the dataset is a 28x28 pixel image. The two additional columns are for the label, and a column with a pre-sampled train and test split.

## 4.2 Visualizing Digits

Let’s make some visualizations to examine the MNIST data and see what we can use for a classifier to classify the digits.

``````mnist_df <- rxDataStep(inData = mnist_xdf, outFile = NULL,
maxRowsByCols = nrow(mnist_xdf)*ncol(mnist_xdf)) %>% tbl_df``````

Let’s see the average for each digit:

``````mnist_df %>%
keep(is.numeric) %>%
rowMeans() %>% data.frame(intensity = .) %>%
tbl_df %>%
bind_cols(mnist_df) %T>% print -> mnist_df``````
``````## # A tibble: 70,000 x 787
##    intensity  Label    V2    V3    V4    V5    V6    V7    V8    V9   V10
##        <dbl> <fctr> <int> <int> <int> <int> <int> <int> <int> <int> <int>
##  1  35.10842      5     0     0     0     0     0     0     0     0     0
##  2  39.66199      0     0     0     0     0     0     0     0     0     0
##  3  24.79974      4     0     0     0     0     0     0     0     0     0
##  4  21.85587      1     0     0     0     0     0     0     0     0     0
##  5  29.60969      9     0     0     0     0     0     0     0     0     0
##  6  37.75638      2     0     0     0     0     0     0     0     0     0
##  7  22.50765      1     0     0     0     0     0     0     0     0     0
##  8  45.74872      3     0     0     0     0     0     0     0     0     0
##  9  13.86990      1     0     0     0     0     0     0     0     0     0
## 10  27.93878      4     0     0     0     0     0     0     0     0     0
## # ... with 69,990 more rows, and 776 more variables: V11 <int>, V12 <int>,
## #   V13 <int>, V14 <int>, V15 <int>, V16 <int>, V17 <int>, V18 <int>,
## #   V19 <int>, V20 <int>, V21 <int>, V22 <int>, V23 <int>, V24 <int>,
## #   V25 <int>, V26 <int>, V27 <int>, V28 <int>, V29 <int>, V30 <int>,
## #   V31 <int>, V32 <int>, V33 <int>, V34 <int>, V35 <int>, V36 <int>,
## #   V37 <int>, V38 <int>, V39 <int>, V40 <int>, V41 <int>, V42 <int>,
## #   V43 <int>, V44 <int>, V45 <int>, V46 <int>, V47 <int>, V48 <int>,
## #   V49 <int>, V50 <int>, V51 <int>, V52 <int>, V53 <int>, V54 <int>,
## #   V55 <int>, V56 <int>, V57 <int>, V58 <int>, V59 <int>, V60 <int>,
## #   V61 <int>, V62 <int>, V63 <int>, V64 <int>, V65 <int>, V66 <int>,
## #   V67 <int>, V68 <int>, V69 <int>, V70 <int>, V71 <int>, V72 <int>,
## #   V73 <int>, V74 <int>, V75 <int>, V76 <int>, V77 <int>, V78 <int>,
## #   V79 <int>, V80 <int>, V81 <int>, V82 <int>, V83 <int>, V84 <int>,
## #   V85 <int>, V86 <int>, V87 <int>, V88 <int>, V89 <int>, V90 <int>,
## #   V91 <int>, V92 <int>, V93 <int>, V94 <int>, V95 <int>, V96 <int>,
## #   V97 <int>, V98 <int>, V99 <int>, V100 <int>, V101 <int>, V102 <int>,
## #   V103 <int>, V104 <int>, V105 <int>, V106 <int>, V107 <int>,
## #   V108 <int>, V109 <int>, V110 <int>, ...``````

Visualize average intensity by label:

``````ggplot(mnist_df, aes(x = intensity, y = ..density..)) +
geom_density(aes(fill = Label), alpha = 0.3)``````

Let’s try a boxplot:

``````ggplot(mnist_df, aes(x = Label, y = intensity)) +
geom_boxplot(aes(fill = Label), alpha = 0.3)``````

## 4.3 Visualize Digits

Let’s plot a sample set of digits:

``````flip <- function(matrix) {

apply(matrix, 2, rev)
}

plot_digit <- function(samp) {

digit <- unlist(samp)
m <- flip(matrix(rev(as.numeric(digit)), nrow = 28))
image(m, col = grey.colors(255))

}

mnist_df[11, ] %>%
select(-Label, -intensity, -splitVar) %>%
sample_n(1) %>%
rowwise() %>% plot_digit``````

## 4.4 Split the Data into Train and Test Sets

``````splits <- rxSplit(mnist_xdf,
splitByFactor = "splitVar",
overwrite = TRUE)
names(splits) <- c("train", "test")``````

Let’s first train a softmax classifier using the `rxLogisticRegression`:

``````softmax <- estimate_model(xdf_data = splits\$train,
form = make_form(splits\$train,
resp_var = "Label",
vars_to_skip = c("splitVar")),
model = rxLogisticRegression,
type = "multiClass")``````
``````## Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.
## LBFGS multi-threading will attempt to load dataset into memory. In case of out-of-memory issues, turn off multi-threading by setting trainThreads to 1.
## Beginning optimization
## num vars: 7850
## improvement criterion: Mean Improvement
## L1 regularization selected 3699 of 7850 weights.
## Not training a calibrator because it is not needed.
## Elapsed time: 00:00:18.6072601
## Elapsed time: 00:00:00.0334469``````

Let’s see how we did. Let’s examine our results on the train set:

``````softmax_scores <- rxPredict(modelObject = softmax,
data = splits\$test,
outData = tempfile(fileext = ".xdf"),
overwrite = TRUE,
extraVarsToWrite = "Label")``````
``## Elapsed time: 00:00:01.0520198``

We can make a confusion matrix of all our results:

``````rxCube( ~ Label : PredictedLabel , data = softmax_scores,
returnDataFrame = TRUE) -> softmax_scores_df

softmax_scores_df %>% ggplot(aes(x = Label, y = PredictedLabel,
fill = Counts)) +
geom_raster() +
scale_fill_continuous(low = "steelblue2", high = "mediumblue")``````

Here we are plotting the raw counts. This might unfairly represent the more populated classes. Let’s weight each count by the total number of samples in that class:

``````label_rates <- softmax_scores_df %>%
tbl_df %>%
group_by(Label) %>%
mutate(rate = Counts/sum(Counts))

label_rates %>% ggplot(aes(x = Label, y = PredictedLabel, fill = rate)) +
geom_raster() +
scale_fill_continuous(low = "steelblue2", high = "mediumblue")``````

Let’s fill out all the correct scores with zeros so we can see the errors more clearly:

``````label_rates %>%
mutate(error_rate = ifelse(Label == PredictedLabel,
0, rate)) %>%
ggplot(aes(x = Label, y = PredictedLabel, fill = error_rate)) +
geom_raster() +
scale_fill_continuous(low = "steelblue2", high = "mediumblue",
labels = scales::percent)``````

## 4.5 Exercises

1. Take a look at David Robinson’s tweet on using a single pixel to distinguish between pairs of digits.
2. You can find his gist saved in the Rscripts directory.