Create Predictive Classification Models in R with Caret
An example project of training and testing machine learning (ML) models with caret
in R to predict contraceptive use
This post focuses on the technical aspect of this project. If you’re interested in viewing the study paper to learn about the overall workflow, including analysis strategy, model comparison, and limitations, please view:
1. Overview
“An Introduction to Statistical Learning: With Applications in R” or ISLR was my first book on predictive analytics, and I strongly recommend everyone interesting in machine learning to read the book. I learnt how to programme in R and use various statistical packages, such as glm
and randomForest
, but it felt inefficient, having so many different packages. Gladly, there are several libraries available that attempt to streamline the process for building predictive models. Here I focus on the caret
package (short for Classification And REgression Training) created by Max Kuhn.
The project's goal is to build classification models to predict the use of contraception in Thailand, Mongolia, and Laos. The data are from the Multiple Indicator Cluster Surveys (MICS) published by the United Nations Children’s Fund (UNICEF). Because they are microdata for individual response data in surveys and censuses, users must register to access the database. After a quick registration, I downloaded the 6th MICS datasets for three countries and performed data cleaning and merging. As a result, the dataset contains 58,356 cases and 9 variables:
- Currently using contraception,
use
: Binary (Yes/No) - Age,
age
: Numerical - Highest education attainment,
edu
: Categorical (Less than primary/Primary/Lower secondary/Upper secondary/Higher education) - Country-specific wealth percentile,
wealth
: Numerical - Marital status,
mstat
: Categorical (Never/Former/Current) - Residence,
urban
: Categorical (Urban/Rural) - Country,
country
: Categorical (Thailand/Mongolia/Laos) - Ever given birth,
given_birth
: Binary (Yes/No) - Ever had a child or children who later died,
child_died
: Binary (Yes/No)
Import Data and Packages
library(tidyverse) # data manipulation
library(caret) # predictive modelling
library(rpart.plot) # decision tree visualisation
mics <- read_dta("MICS.dta")
All categorical variables are text strings or characters, which machine learning models cannot really work with. Therefore, I encoded or factorised categorical features.
# Factorise variables
mics$mstat <- factor(mics$mstat)
mics$edu <- factor(mics$edu)
mics$country <- factor(mics$country)# Factor recode for clarity
mics$use <- factor(mics$use,
levels = c(1, 0),
labels = c("Using", "Not Using"))mics$residence <- factor(mics$residence,
levels = c(1, 0),
labels = c("Urban", "Rural"))mics$given_birth <- factor(mics$given_birth,
levels = c(1, 0),
labels = c("Yes", "No"))mics$child_died <- factor(mics$child_died,
levels = c(1, 0),
labels = c("Yes", "No"))mics$edu <- factor(mics$edu,
levels = c("PRE-PRIMARY OR NONE", "PRIMARY",
"LOWER SECONDARY", "UPPER SECONDARY",
"HIGHER"),
labels = c("Less than Primary", "Primary",
"Lower Secondary", "Upper Secondary",
"Higher Education"))mics$country <- factor(mics$country,
levels = c("THAILAND", "LAO", "MONGOLIA"),
labels = c("Thailand", "Laos", "Mongolia"))
The structure of the data frame after recoding is:
tibble [58,356 × 10] (S3: tbl_df/tbl/data.frame)
$ age : num [1:58356] 29 35 36 24 34 38 24 16 15 36 ...
$ edu : Factor w/ 5 levels "Less than Primary",..: 5 5 5 5 5
$ mstat : Factor w/ 3 levels "Current","Former",..: 1 1 3 3 1
$ wealth : num [1:58356] 9 9 5 8 10 4 10 7 7 8 ...
$ residence : Factor w/ 2 levels "Urban","Rural": 1 1 1 1 1 1 1 1
$ country : Factor w/ 3 levels "Thailand","Laos",..: 1 1 1 1 1 1
$ given_birth: Factor w/ 2 levels "Yes","No": 1 1 2 2 1 1 2 2 2 1
$ child_died : Factor w/ 2 levels "Yes","No": 2 2 2 2 2 2 2 2 2 2
$ use : Factor w/ 2 levels "Using","Not Using": 1 1 2 2 1 1
Split Training and Testing Set
I set aside 15% of observations for the testing set, which is reserved for the final testing once the models are trained and optimised. The rest of 85% is used to develop classification models. Because I will experiment with different parameters, I also use 10-fold cross-validation on the training set to evaluate performance.
The first training/testing split is done with createDataPartition
command, which creates balanced splits of the data according to the outcome.
# Split data into testing and training
train_index <- createDataPartition(mics$use,
# 85% for training
p = .85,
times = 1,
list = FALSE)micsTrain <- mics[ train_index, ] # Training
micsTest <- mics[-train_index, ] # Testing
10-fold cross-validation is set up with createFolds
and trainControl
. The first one splits the training set into ten folds, and the second one specifies cross-validation using the folds. Typically, a simple trainControl (method="cv", k=10)
would suffice, but the result may be different every time the command is executed. While trainControl
provides a seed
parameter for reproducibility, I had trouble setting it up and decided to use createFolds
.
# 10-folds
fold_index <- createFolds(micsTrain$use,
# number of folds
k = 10,
# return as list
list = T,
# return numbers corresponding positions
returnTrain = T)# Cross validation
ctrl <- trainControl(method="cv", index = fold_index)
The data preparation is complete, and I am ready to train the models. Please note that data preparation involves much more sophisticated steps in real-life application, such as dealing with missing data, selecting variables, engineering features, scaling and centring variables, etc. However, since this was my first time building models from the beginning to the end, I did not perform said steps. Moreover, most data preparation works with numerical features. As the majority of the variables are categorical, there are not many things I can do.
Train Model and Cross-Validation
The train
function streamlines the model building and evaluation process. My first model is k-nearest neighbours (KNN). If I used the knn
package following the ISLR instruction, I would have to run knn.cv
multiple times and compare the result to find the best k. With caret
I only need one command. I first look up the value for the method
and available parameters on the documentation. For knn
, there is only one tuning parameter, k. I have three options for tuning:
- Doing nothing: in this case,
train
tries 3 random numbers for k. While I say random, it’s actually not. But it’s outside of the scope of my project. tuneGrid
: I specify the numbers to try,seq(2, 20, 1)
, a sequence of number from 2 to 20 with an interval of 1.tuneLength
: Instead of providing numbers, I specify the function to try 10 different numbers. It can be 1–10, 101–110, or ten even numbers.
The form
parameter tells the model the target variable and the predictors. It looks like outcome ~ var1 + var2
. Here I used .
as wide card, allowing the model to select input variables. Lastly, as mentioned earlier, I used 10-fold cross-validation to evaluate the model, which is done with trControl
.
# Option 1: No specification on tuning parameter
m_knn <- train(form = use~.,
data = micsTrain,
method = 'knn',
trControl = ctrl)# Option 2: Try all specified parameters
m_knn <- train(form = use~.,
data = micsTrain,
method = 'knn',
trControl = ctrl, # Cross-validation
tuneGrid = data.frame(k = seq(2, 20, 1)))# Option 3: Try 10 random parameters
m_knn <- train(form = use~.,
data = micsTrain,
method = 'knn',
trControl = ctrl, # Cross-validation
tuneLength = 10)
Using option 3 as an example, the train
function generates ten numbers for the tuning parameter, k, and for each number, 10-fold cross-validation is performed to calculate the average accuracy. The process can take a while. One done, it compares the average accuracy across ten numbers and returns the number with the highest performance as the final parameter. As shown below, print(m_knn)
lists the average accuracy for each k and selects the best one, 9.
k-Nearest Neighbors49604 samples
9 predictor
2 classes: 'Using', 'Not Using'No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 44644, 44644, 44644, 44644, 44643, 44643, ...
Resampling results across tuning parameters:k Accuracy Kappa
5 0.7392345 0.4790542
7 0.7409078 0.4824688
9 0.7422585 0.4852000
11 0.7404442 0.4816136
13 0.7418755 0.4844997
15 0.7408474 0.4824703
17 0.7404241 0.4816458
19 0.7401217 0.4810566
21 0.7404038 0.4816320
23 0.7398998 0.4806354Accuracy was used to select the optimal model using the largest value.
The final value used for the model was k = 9.
caret
also makes it easy to visualise the cross-validation result by calling plot(m_knn, main = “KNN 10-fold Cross-Validation")
Test the Model
Testing the model is straightforward: predicting the target variable and evaluating the result.
pred_knn <- predict(m_knn, newdata = micsTest)
Since this is a classification model, I can use the confusion matrix to examine other performance metrics by comparing the predicted classes to the actual classes.
tbl_knn <- confusionMatrix(pred_knn, micsTest$use)
tbl_knn
In addition to accuracy, the output includes other common measures, such as specificity and sensitivity. The model has 74.81% accuracy, which is not impressive. It can be seen that Type I Error (False Positive) occurs much more frequently than Type II Error (False Negative), suggesting that the models are generally better at identifying women currently using contraception than labelling those who are not.
Confusion Matrix and StatisticsReference
Prediction Using Not Using
Using 3800 1652
Not Using 553 2747
Accuracy : 0.7481
95% CI : (0.7388, 0.7571)
No Information Rate : 0.5026
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.4968
Mcnemar's Test P-Value : < 2.2e-16
Sensitivity : 0.8730
Specificity : 0.6245
Pos Pred Value : 0.6970
Neg Pred Value : 0.8324
Prevalence : 0.4974
Detection Rate : 0.4342
Detection Prevalence : 0.6229
Balanced Accuracy : 0.7487
'Positive' Class : Using
Conclusion
This post provides an example of building classification models with caret
using R. caret
is an incredible package for machine learning, yet it may be complicated to navigate at the beginning. I hope this example can help whoever is new to caret
. Good luck with modelling and have fun!