Create Predictive Classification Models in R with Caret

An example project of training and testing machine learning (ML) models with caret in R to predict contraceptive use

Yu-En Hsu

--

Photo by Christopher Gower on Unsplash

This post focuses on the technical aspect of this project. If you’re interested in viewing the study paper to learn about the overall workflow, including analysis strategy, model comparison, and limitations, please view:

1. Overview

An Introduction to Statistical Learning: With Applications in R” or ISLR was my first book on predictive analytics, and I strongly recommend everyone interesting in machine learning to read the book. I learnt how to programme in R and use various statistical packages, such as glm and randomForest, but it felt inefficient, having so many different packages. Gladly, there are several libraries available that attempt to streamline the process for building predictive models. Here I focus on the caret package (short for Classification And REgression Training) created by Max Kuhn.

The project's goal is to build classification models to predict the use of contraception in Thailand, Mongolia, and Laos. The data are from the Multiple Indicator Cluster Surveys (MICS) published by the United Nations Children’s Fund (UNICEF). Because they are microdata for individual response data in surveys and censuses, users must register to access the database. After a quick registration, I downloaded the 6th MICS datasets for three countries and performed data cleaning and merging. As a result, the dataset contains 58,356 cases and 9 variables:

  1. Currently using contraception, use: Binary (Yes/No)
  2. Age, age: Numerical
  3. Highest education attainment, edu: Categorical (Less than primary/Primary/Lower secondary/Upper secondary/Higher education)
  4. Country-specific wealth percentile, wealth: Numerical
  5. Marital status, mstat: Categorical (Never/Former/Current)
  6. Residence, urban: Categorical (Urban/Rural)
  7. Country, country: Categorical (Thailand/Mongolia/Laos)
  8. Ever given birth, given_birth: Binary (Yes/No)
  9. Ever had a child or children who later died, child_died: Binary (Yes/No)

Import Data and Packages

library(tidyverse) # data manipulation
library(caret) # predictive modelling
library(rpart.plot) # decision tree visualisation

mics <- read_dta("MICS.dta")

All categorical variables are text strings or characters, which machine learning models cannot really work with. Therefore, I encoded or factorised categorical features.

# Factorise variables
mics$mstat <- factor(mics$mstat)
mics$edu <- factor(mics$edu)
mics$country <- factor(mics$country)
# Factor recode for clarity
mics$use <- factor(mics$use,
levels = c(1, 0),
labels = c("Using", "Not Using"))
mics$residence <- factor(mics$residence,
levels = c(1, 0),
labels = c("Urban", "Rural"))
mics$given_birth <- factor(mics$given_birth,
levels = c(1, 0),
labels = c("Yes", "No"))
mics$child_died <- factor(mics$child_died,
levels = c(1, 0),
labels = c("Yes", "No"))
mics$edu <- factor(mics$edu,
levels = c("PRE-PRIMARY OR NONE", "PRIMARY",
"LOWER SECONDARY", "UPPER SECONDARY",
"HIGHER"),
labels = c("Less than Primary", "Primary",
"Lower Secondary", "Upper Secondary",
"Higher Education"))
mics$country <- factor(mics$country,
levels = c("THAILAND", "LAO", "MONGOLIA"),
labels = c("Thailand", "Laos", "Mongolia"))

The structure of the data frame after recoding is:

tibble [58,356 × 10] (S3: tbl_df/tbl/data.frame)
$ age : num [1:58356] 29 35 36 24 34 38 24 16 15 36 ...
$ edu : Factor w/ 5 levels "Less than Primary",..: 5 5 5 5 5
$ mstat : Factor w/ 3 levels "Current","Former",..: 1 1 3 3 1
$ wealth : num [1:58356] 9 9 5 8 10 4 10 7 7 8 ...
$ residence : Factor w/ 2 levels "Urban","Rural": 1 1 1 1 1 1 1 1
$ country : Factor w/ 3 levels "Thailand","Laos",..: 1 1 1 1 1 1
$ given_birth: Factor w/ 2 levels "Yes","No": 1 1 2 2 1 1 2 2 2 1
$ child_died : Factor w/ 2 levels "Yes","No": 2 2 2 2 2 2 2 2 2 2
$ use : Factor w/ 2 levels "Using","Not Using": 1 1 2 2 1 1

Split Training and Testing Set

I set aside 15% of observations for the testing set, which is reserved for the final testing once the models are trained and optimised. The rest of 85% is used to develop classification models. Because I will experiment with different parameters, I also use 10-fold cross-validation on the training set to evaluate performance.

The first training/testing split is done with createDataPartition command, which creates balanced splits of the data according to the outcome.

# Split data into testing and training
train_index <- createDataPartition(mics$use,
# 85% for training
p = .85,
times = 1,
list = FALSE)
micsTrain <- mics[ train_index, ] # Training
micsTest <- mics[-train_index, ] # Testing

10-fold cross-validation is set up with createFolds and trainControl. The first one splits the training set into ten folds, and the second one specifies cross-validation using the folds. Typically, a simple trainControl (method="cv", k=10) would suffice, but the result may be different every time the command is executed. While trainControl provides a seed parameter for reproducibility, I had trouble setting it up and decided to use createFolds.

# 10-folds
fold_index <- createFolds(micsTrain$use,
# number of folds
k = 10,
# return as list
list = T,
# return numbers corresponding positions
returnTrain = T)
# Cross validation
ctrl <- trainControl(method="cv", index = fold_index)

The data preparation is complete, and I am ready to train the models. Please note that data preparation involves much more sophisticated steps in real-life application, such as dealing with missing data, selecting variables, engineering features, scaling and centring variables, etc. However, since this was my first time building models from the beginning to the end, I did not perform said steps. Moreover, most data preparation works with numerical features. As the majority of the variables are categorical, there are not many things I can do.

Train Model and Cross-Validation

The train function streamlines the model building and evaluation process. My first model is k-nearest neighbours (KNN). If I used the knn package following the ISLR instruction, I would have to run knn.cv multiple times and compare the result to find the best k. With caretI only need one command. I first look up the value for the method and available parameters on the documentation. For knn, there is only one tuning parameter, k. I have three options for tuning:

  1. Doing nothing: in this case, train tries 3 random numbers for k. While I say random, it’s actually not. But it’s outside of the scope of my project.
  2. tuneGrid: I specify the numbers to try, seq(2, 20, 1), a sequence of number from 2 to 20 with an interval of 1.
  3. tuneLength: Instead of providing numbers, I specify the function to try 10 different numbers. It can be 1–10, 101–110, or ten even numbers.

The form parameter tells the model the target variable and the predictors. It looks like outcome ~ var1 + var2. Here I used . as wide card, allowing the model to select input variables. Lastly, as mentioned earlier, I used 10-fold cross-validation to evaluate the model, which is done with trControl.

# Option 1: No specification on tuning parameter
m_knn <- train(form = use~.,
data = micsTrain,
method = 'knn',
trControl = ctrl)
# Option 2: Try all specified parameters
m_knn <- train(form = use~.,
data = micsTrain,
method = 'knn',
trControl = ctrl, # Cross-validation
tuneGrid = data.frame(k = seq(2, 20, 1)))
# Option 3: Try 10 random parameters
m_knn <- train(form = use~.,
data = micsTrain,
method = 'knn',
trControl = ctrl, # Cross-validation
tuneLength = 10)

Using option 3 as an example, the train function generates ten numbers for the tuning parameter, k, and for each number, 10-fold cross-validation is performed to calculate the average accuracy. The process can take a while. One done, it compares the average accuracy across ten numbers and returns the number with the highest performance as the final parameter. As shown below, print(m_knn) lists the average accuracy for each k and selects the best one, 9.

k-Nearest Neighbors49604 samples
9 predictor
2 classes: 'Using', 'Not Using'
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 44644, 44644, 44644, 44644, 44643, 44643, ...
Resampling results across tuning parameters:
k Accuracy Kappa
5 0.7392345 0.4790542
7 0.7409078 0.4824688
9 0.7422585 0.4852000
11 0.7404442 0.4816136
13 0.7418755 0.4844997
15 0.7408474 0.4824703
17 0.7404241 0.4816458
19 0.7401217 0.4810566
21 0.7404038 0.4816320
23 0.7398998 0.4806354
Accuracy was used to select the optimal model using the largest value.
The final value used for the model was k = 9.

caret also makes it easy to visualise the cross-validation result by calling plot(m_knn, main = “KNN 10-fold Cross-Validation")

10-fold cross-validation accuracy for KNN

Test the Model

Testing the model is straightforward: predicting the target variable and evaluating the result.

pred_knn <- predict(m_knn, newdata = micsTest)

Since this is a classification model, I can use the confusion matrix to examine other performance metrics by comparing the predicted classes to the actual classes.

tbl_knn  <- confusionMatrix(pred_knn, micsTest$use)
tbl_knn

In addition to accuracy, the output includes other common measures, such as specificity and sensitivity. The model has 74.81% accuracy, which is not impressive. It can be seen that Type I Error (False Positive) occurs much more frequently than Type II Error (False Negative), suggesting that the models are generally better at identifying women currently using contraception than labelling those who are not.

Confusion Matrix and StatisticsReference
Prediction Using Not Using
Using 3800 1652
Not Using 553 2747

Accuracy : 0.7481
95% CI : (0.7388, 0.7571)
No Information Rate : 0.5026
P-Value [Acc > NIR] : < 2.2e-16

Kappa : 0.4968

Mcnemar's Test P-Value : < 2.2e-16

Sensitivity : 0.8730
Specificity : 0.6245
Pos Pred Value : 0.6970
Neg Pred Value : 0.8324
Prevalence : 0.4974
Detection Rate : 0.4342
Detection Prevalence : 0.6229
Balanced Accuracy : 0.7487

'Positive' Class : Using

Conclusion

This post provides an example of building classification models with caret using R. caret is an incredible package for machine learning, yet it may be complicated to navigate at the beginning. I hope this example can help whoever is new to caret. Good luck with modelling and have fun!

--

--

Yu-En Hsu

I am passionate about using data to make the world a better place, and I write about data science, visualisation, and machine learning.