library(tidyverse)
library(MASS)
library(tidymodels)
library(xgboost)
head(birthwt)
XGBoost
XGBoost
XGBoost which stands for eXtreme Gradient Boosting is an efficent implementation of gradient boosting. Gradient boosting is an ensemble technique in machine learning. Unlike traditional models that learn from the data independently, boosting combines the predictions of multiple weak learners to create a single, more accurate strong learner.
An XGBoost model is based on trees, so we don’t need to do much preprocessing for our data; we don’t need to worry about the factors or centering or scaling our data.
Available R packages
There are multiple packages that can be used to to implement xgboost in R.
{tidymodels} and {caret} easy ways to access xgboost easily. This example will use {tidymodels} because of the functionality included in {tidymodels} and is being heavily supported by Posit. {caret} was the precursor to {tidymodels} and it is recommended that you use {tidymodels} over {caret} as no new features are being added.
Data used
Data used for this example is birthwt
which is part of the {MASS} package. This data-set considers a number of risk factors associated with birth weight in infants.
Our modeling goal using the birthwt
dataset is to predict whether the birth weight is low or not low based on factors such as mother’s age, smoking status, and history of hypertension.
Example Code
Use {tidymodels} metadata package to split the data into training and testing data. For classification, we need to change the Low variable into a factor, since currently coded as an integer (0,1).
<-
birthwt %>%
birthwt mutate(
low_f = lvls_revalue(factor(low), c("Not Low", "Low")),
smoke_f = lvls_revalue(factor(smoke), c("Non-smoker", "Smoker"))
)
<- initial_split(birthwt, strata = low)
brthwt_split <- training(brthwt_split)
brthwt_train <- testing(brthwt_split) brthwt_test
Classification
After creating the data split, we setup the params of the model.
<-
xgboost_spec boost_tree(trees = 15) %>%
# This model can be used for classification or regression, so set mode
set_mode("classification") %>%
set_engine("xgboost")
xgboost_spec
Boosted Tree Model Specification (classification)
Main Arguments:
trees = 15
Computational engine: xgboost
<- xgboost_spec %>% fit(low_f ~ ., data = brthwt_train)
xgboost_cls_fit xgboost_cls_fit
parsnip model object
##### xgb.Booster
raw: 15.2 Kb
call:
xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0,
colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1,
subsample = 1), data = x$data, nrounds = 15, watchlist = x$watchlist,
verbose = 0, nthread = 1, objective = "binary:logistic")
params (as set within xgb.train):
eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", nthread = "1", objective = "binary:logistic", validate_parameters = "TRUE"
xgb.attributes:
niter
callbacks:
cb.evaluation.log()
# of features: 12
niter: 15
nfeatures : 12
evaluation_log:
iter training_logloss
<num> <num>
1 0.44894346
2 0.31033704
--- ---
14 0.01571681
15 0.01396153
bind_cols(
predict(xgboost_cls_fit, brthwt_test),
predict(xgboost_cls_fit, brthwt_test, type = "prob")
)
# A tibble: 48 × 3
.pred_class `.pred_Not Low` .pred_Low
<fct> <dbl> <dbl>
1 Not Low 0.985 0.0151
2 Not Low 0.985 0.0151
3 Not Low 0.985 0.0151
4 Not Low 0.988 0.0116
5 Not Low 0.988 0.0116
6 Not Low 0.988 0.0116
7 Not Low 0.988 0.0116
8 Not Low 0.988 0.0116
9 Not Low 0.988 0.0116
10 Not Low 0.988 0.0116
# ℹ 38 more rows
Regression
To perform xgboost with regression, when setting up the parameter of the model, set the mode of xgboost to regression. After that switch and then changing the variable of interest back to an integer, the rest of the code is the same.
<-
xgboost_reg_spec boost_tree(trees = 15) %>%
# This model can be used for classification or regression, so set mode
set_mode("regression") %>%
set_engine("xgboost")
xgboost_reg_spec
Boosted Tree Model Specification (regression)
Main Arguments:
trees = 15
Computational engine: xgboost
# For a regression model, the outcome should be `numeric`, not a `factor`.
<- xgboost_reg_spec %>% fit(low~ ., data = brthwt_train)
xgboost_reg_fit xgboost_reg_fit
parsnip model object
##### xgb.Booster
raw: 15.2 Kb
call:
xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0,
colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1,
subsample = 1), data = x$data, nrounds = 15, watchlist = x$watchlist,
verbose = 0, nthread = 1, objective = "reg:squarederror")
params (as set within xgb.train):
eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", nthread = "1", objective = "reg:squarederror", validate_parameters = "TRUE"
xgb.attributes:
niter
callbacks:
cb.evaluation.log()
# of features: 13
niter: 15
nfeatures : 13
evaluation_log:
iter training_rmse
<num> <num>
1 0.352094163
2 0.247943366
--- ---
14 0.003690328
15 0.002599106
predict(xgboost_reg_fit , brthwt_test)
# A tibble: 48 × 1
.pred
<dbl>
1 0.00253
2 0.00253
3 0.00253
4 0.00253
5 0.00253
6 0.00253
7 0.00253
8 0.00253
9 0.00253
10 0.00253
# ℹ 38 more rows
Reference
─ Session info ───────────────────────────────────────────────────────────────
setting value
version R version 4.4.2 (2024-10-31)
os Ubuntu 24.04.2 LTS
system x86_64, linux-gnu
ui X11
language (EN)
collate C.UTF-8
ctype C.UTF-8
tz UTC
date 2025-03-27
pandoc 3.4 @ /opt/quarto/bin/tools/ (via rmarkdown)
─ Packages ───────────────────────────────────────────────────────────────────
! package * version date (UTC) lib source
P backports 1.5.0 2024-05-23 [?] RSPM (R 4.4.0)
P broom * 1.0.7 2024-09-26 [?] RSPM (R 4.4.0)
class 7.3-22 2023-05-03 [2] CRAN (R 4.4.2)
P cli 3.6.3 2024-06-21 [?] RSPM (R 4.4.0)
codetools 0.2-20 2024-03-31 [2] CRAN (R 4.4.2)
P colorspace 2.1-1 2024-07-26 [?] RSPM (R 4.4.0)
P data.table 1.16.0 2024-08-27 [?] RSPM (R 4.4.0)
P dials * 1.3.0 2024-07-30 [?] RSPM (R 4.4.0)
P DiceDesign 1.10 2023-12-07 [?] RSPM (R 4.4.0)
P digest 0.6.37 2024-08-19 [?] RSPM (R 4.4.0)
P dplyr * 1.1.4 2023-11-17 [?] RSPM (R 4.4.0)
P evaluate 1.0.0 2024-09-17 [?] RSPM (R 4.4.0)
P fansi 1.0.6 2023-12-08 [?] RSPM (R 4.4.0)
P fastmap 1.2.0 2024-05-15 [?] RSPM (R 4.4.0)
P forcats * 1.0.0 2023-01-29 [?] RSPM (R 4.4.0)
P foreach 1.5.2 2022-02-02 [?] RSPM (R 4.4.0)
P furrr 0.3.1 2022-08-15 [?] RSPM (R 4.4.0)
P future 1.34.0 2024-07-29 [?] RSPM (R 4.4.0)
P future.apply 1.11.2 2024-03-28 [?] RSPM (R 4.4.0)
P generics 0.1.3 2022-07-05 [?] RSPM (R 4.4.0)
P ggplot2 * 3.5.1 2024-04-23 [?] RSPM (R 4.4.0)
P globals 0.16.3 2024-03-08 [?] RSPM (R 4.4.0)
P glue 1.8.0 2024-09-30 [?] RSPM (R 4.4.0)
P gower 1.0.1 2022-12-22 [?] RSPM (R 4.4.0)
P GPfit 1.0-8 2019-02-08 [?] RSPM (R 4.4.0)
P gtable 0.3.5 2024-04-22 [?] RSPM (R 4.4.0)
P hardhat 1.4.0 2024-06-02 [?] RSPM (R 4.4.0)
P hms 1.1.3 2023-03-21 [?] RSPM (R 4.4.0)
P htmltools 0.5.8.1 2024-04-04 [?] RSPM (R 4.4.0)
P htmlwidgets 1.6.4 2023-12-06 [?] RSPM (R 4.4.0)
P infer * 1.0.7 2024-03-25 [?] RSPM (R 4.4.0)
P ipred 0.9-15 2024-07-18 [?] RSPM (R 4.4.0)
P iterators 1.0.14 2022-02-05 [?] RSPM (R 4.4.0)
P jsonlite 1.8.9 2024-09-20 [?] RSPM (R 4.4.0)
P knitr 1.48 2024-07-07 [?] RSPM (R 4.4.0)
lattice 0.22-6 2024-03-20 [2] CRAN (R 4.4.2)
P lava 1.8.0 2024-03-05 [?] RSPM (R 4.4.0)
P lhs 1.2.0 2024-06-30 [?] RSPM (R 4.4.0)
P lifecycle 1.0.4 2023-11-07 [?] RSPM (R 4.4.0)
P listenv 0.9.1 2024-01-29 [?] RSPM (R 4.4.0)
P lubridate * 1.9.3 2023-09-27 [?] RSPM (R 4.4.0)
P magrittr 2.0.3 2022-03-30 [?] RSPM (R 4.4.0)
MASS * 7.3-61 2024-06-13 [2] CRAN (R 4.4.2)
Matrix 1.7-1 2024-10-18 [2] CRAN (R 4.4.2)
P modeldata * 1.4.0 2024-06-19 [?] RSPM (R 4.4.0)
P munsell 0.5.1 2024-04-01 [?] RSPM (R 4.4.0)
nnet 7.3-19 2023-05-03 [2] CRAN (R 4.4.2)
P parallelly 1.38.0 2024-07-27 [?] RSPM (R 4.4.0)
P parsnip * 1.2.1 2024-03-22 [?] RSPM (R 4.4.0)
P pillar 1.9.0 2023-03-22 [?] RSPM (R 4.4.0)
P pkgconfig 2.0.3 2019-09-22 [?] RSPM (R 4.4.0)
P prodlim 2024.06.25 2024-06-24 [?] RSPM (R 4.4.0)
P purrr * 1.0.2 2023-08-10 [?] RSPM (R 4.4.0)
P R6 2.5.1 2021-08-19 [?] RSPM (R 4.4.0)
P Rcpp 1.0.13 2024-07-17 [?] RSPM (R 4.4.0)
P readr * 2.1.5 2024-01-10 [?] RSPM (R 4.4.0)
P recipes * 1.1.0 2024-07-04 [?] RSPM (R 4.4.0)
renv 1.0.10 2024-10-05 [1] RSPM (R 4.4.0)
P rlang 1.1.4 2024-06-04 [?] RSPM (R 4.4.0)
P rmarkdown 2.28 2024-08-17 [?] RSPM (R 4.4.0)
rpart 4.1.23 2023-12-05 [2] CRAN (R 4.4.2)
P rsample * 1.2.1 2024-03-25 [?] RSPM (R 4.4.0)
P rstudioapi 0.16.0 2024-03-24 [?] RSPM (R 4.4.0)
P scales * 1.3.0 2023-11-28 [?] RSPM (R 4.4.0)
P sessioninfo 1.2.2 2021-12-06 [?] RSPM (R 4.4.0)
P stringi 1.8.4 2024-05-06 [?] RSPM (R 4.4.0)
P stringr * 1.5.1 2023-11-14 [?] RSPM (R 4.4.0)
survival 3.7-0 2024-06-05 [2] CRAN (R 4.4.2)
P tibble * 3.2.1 2023-03-20 [?] RSPM (R 4.4.0)
P tidymodels * 1.2.0 2024-03-25 [?] RSPM (R 4.4.0)
P tidyr * 1.3.1 2024-01-24 [?] RSPM (R 4.4.0)
P tidyselect 1.2.1 2024-03-11 [?] RSPM (R 4.4.0)
P tidyverse * 2.0.0 2023-02-22 [?] RSPM (R 4.4.0)
P timechange 0.3.0 2024-01-18 [?] RSPM (R 4.4.0)
P timeDate 4041.110 2024-09-22 [?] RSPM (R 4.4.0)
P tune * 1.2.1 2024-04-18 [?] RSPM (R 4.4.0)
P tzdb 0.4.0 2023-05-12 [?] RSPM (R 4.4.0)
P utf8 1.2.4 2023-10-22 [?] RSPM (R 4.4.0)
P vctrs 0.6.5 2023-12-01 [?] RSPM (R 4.4.0)
P withr 3.0.1 2024-07-31 [?] RSPM (R 4.4.0)
P workflows * 1.1.4 2024-02-19 [?] RSPM (R 4.4.0)
P workflowsets * 1.1.0 2024-03-21 [?] RSPM (R 4.4.0)
P xfun 0.48 2024-10-03 [?] RSPM (R 4.4.0)
P xgboost * 1.7.8.1 2024-07-24 [?] RSPM (R 4.4.0)
P yaml 2.3.10 2024-07-26 [?] RSPM (R 4.4.0)
P yardstick * 1.3.1 2024-03-21 [?] RSPM (R 4.4.0)
[1] /home/runner/work/CAMIS/CAMIS/renv/library/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu
[2] /opt/R/4.4.2/lib/R/library
P ── Loaded and on-disk path mismatch.
──────────────────────────────────────────────────────────────────────────────