XGBoost

XGBoost

XGBoost which stands for eXtreme Gradient Boosting is an efficent implementation of gradient boosting. Gradient boosting is an ensemble technique in machine learning. Unlike traditional models that learn from the data independently, boosting combines the predictions of multiple weak learners to create a single, more accurate strong learner.

An XGBoost model is based on trees, so we don’t need to do much preprocessing for our data; we don’t need to worry about the factors or centering or scaling our data.

Available R packages

There are multiple packages that can be used to to implement xgboost in R.

{tidymodels} and {caret} easy ways to access xgboost easily. This example will use {tidymodels} because of the functionality included in {tidymodels} and is being heavily supported by Posit.

Data used

Data used for this example is birthwt which is part of the {MASS} package. This data-set considers a number of risk factors associated with birth weight in infants.

library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.1
✔ ggplot2   3.5.1     ✔ tibble    3.2.1
✔ lubridate 1.9.3     ✔ tidyr     1.3.1
✔ purrr     1.0.2     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(MASS)

Attaching package: 'MASS'

The following object is masked from 'package:dplyr':

    select
library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
✔ broom        1.0.7     ✔ rsample      1.2.1
✔ dials        1.3.0     ✔ tune         1.2.1
✔ infer        1.0.7     ✔ workflows    1.1.4
✔ modeldata    1.4.0     ✔ workflowsets 1.1.0
✔ parsnip      1.2.1     ✔ yardstick    1.3.1
✔ recipes      1.1.0     
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard() masks purrr::discard()
✖ dplyr::filter()   masks stats::filter()
✖ recipes::fixed()  masks stringr::fixed()
✖ dplyr::lag()      masks stats::lag()
✖ MASS::select()    masks dplyr::select()
✖ yardstick::spec() masks readr::spec()
✖ recipes::step()   masks stats::step()
• Use suppressPackageStartupMessages() to eliminate package startup messages
library(xgboost)

Attaching package: 'xgboost'

The following object is masked from 'package:dplyr':

    slice
head(birthwt)
   low age lwt race smoke ptl ht ui ftv  bwt
85   0  19 182    2     0   0  0  1   0 2523
86   0  33 155    3     0   0  0  0   3 2551
87   0  20 105    1     1   0  0  0   1 2557
88   0  21 108    1     1   0  0  1   2 2594
89   0  18 107    1     1   0  0  1   0 2600
91   0  21 124    3     0   0  0  0   0 2622

Our modeling goal using the birthwt dataset is to predict whether the birth weight is low or not low based on factors such as mother’s age, smoking status, and history of hypertension.

Example Code

Use {tidymodels} metadata package to split the data into training and testing data. For classification, we need to change the Low variable into a factor, since currently coded as an integer (0,1).

birthwt <- 
  birthwt %>% 
  mutate(
    low_f = lvls_revalue(factor(low), c("Not Low", "Low")),
    smoke_f = lvls_revalue(factor(smoke), c("Non-smoker", "Smoker"))
  )


brthwt_split <- initial_split(birthwt, strata = low)
brthwt_train <- training(brthwt_split)
brthwt_test <- testing(brthwt_split)

Classification

After creating the data split, we setup the params of the model.

xgboost_spec <- 
  boost_tree(trees = 15) %>% 
  # This model can be used for classification or regression, so set mode
  set_mode("classification") %>% 
  set_engine("xgboost")

xgboost_spec
Boosted Tree Model Specification (classification)

Main Arguments:
  trees = 15

Computational engine: xgboost 
xgboost_cls_fit <- xgboost_spec %>% fit(low_f ~ ., data = brthwt_train)
xgboost_cls_fit
parsnip model object

##### xgb.Booster
raw: 15.2 Kb 
call:
  xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
    colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
    subsample = 1), data = x$data, nrounds = 15, watchlist = x$watchlist, 
    verbose = 0, nthread = 1, objective = "binary:logistic")
params (as set within xgb.train):
  eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", nthread = "1", objective = "binary:logistic", validate_parameters = "TRUE"
xgb.attributes:
  niter
callbacks:
  cb.evaluation.log()
# of features: 12 
niter: 15
nfeatures : 12 
evaluation_log:
  iter training_logloss
 <num>            <num>
     1       0.44894346
     2       0.31033704
   ---              ---
    14       0.01571681
    15       0.01396153
bind_cols(
  predict(xgboost_cls_fit, brthwt_test),
  predict(xgboost_cls_fit, brthwt_test, type = "prob")
)
# A tibble: 48 × 3
   .pred_class `.pred_Not Low` .pred_Low
   <fct>                 <dbl>     <dbl>
 1 Not Low               0.985    0.0151
 2 Not Low               0.985    0.0151
 3 Not Low               0.988    0.0116
 4 Not Low               0.988    0.0116
 5 Not Low               0.988    0.0116
 6 Not Low               0.988    0.0116
 7 Not Low               0.988    0.0116
 8 Not Low               0.988    0.0116
 9 Not Low               0.988    0.0116
10 Not Low               0.988    0.0116
# ℹ 38 more rows

Regression

To perform xgboost with regression, when setting up the parameter of the model, set the mode of xgboost to regression. After that switch and then changing the variable of interest back to an integer, the rest of the code is the same.

xgboost_reg_spec <- 
  boost_tree(trees = 15) %>% 
  # This model can be used for classification or regression, so set mode
  set_mode("regression") %>% 
  set_engine("xgboost")

xgboost_reg_spec
Boosted Tree Model Specification (regression)

Main Arguments:
  trees = 15

Computational engine: xgboost 
# For a regression model, the outcome should be `numeric`, not a `factor`.
xgboost_reg_fit <- xgboost_reg_spec %>% fit(low~ ., data = brthwt_train)
xgboost_reg_fit 
parsnip model object

##### xgb.Booster
raw: 15.2 Kb 
call:
  xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
    colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
    subsample = 1), data = x$data, nrounds = 15, watchlist = x$watchlist, 
    verbose = 0, nthread = 1, objective = "reg:squarederror")
params (as set within xgb.train):
  eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", nthread = "1", objective = "reg:squarederror", validate_parameters = "TRUE"
xgb.attributes:
  niter
callbacks:
  cb.evaluation.log()
# of features: 13 
niter: 15
nfeatures : 13 
evaluation_log:
  iter training_rmse
 <num>         <num>
     1   0.352094163
     2   0.247943366
   ---           ---
    14   0.003690328
    15   0.002599106
predict(xgboost_reg_fit , brthwt_test)
# A tibble: 48 × 1
     .pred
     <dbl>
 1 0.00253
 2 0.00253
 3 0.00253
 4 0.00253
 5 0.00253
 6 0.00253
 7 0.00253
 8 0.00253
 9 0.00253
10 0.00253
# ℹ 38 more rows

Reference

─ Session info ───────────────────────────────────────────────────────────────
 setting  value
 version  R version 4.4.0 (2024-04-24)
 os       Ubuntu 22.04.5 LTS
 system   x86_64, linux-gnu
 ui       X11
 language (EN)
 collate  C.UTF-8
 ctype    C.UTF-8
 tz       UTC
 date     2024-10-25
 pandoc   3.2 @ /opt/quarto/bin/tools/ (via rmarkdown)

─ Packages ───────────────────────────────────────────────────────────────────
 ! package      * version    date (UTC) lib source
 P backports      1.5.0      2024-05-23 [?] RSPM (R 4.4.0)
 P broom        * 1.0.7      2024-09-26 [?] RSPM (R 4.4.0)
   class          7.3-22     2023-05-03 [2] CRAN (R 4.4.0)
 P cli            3.6.3      2024-06-21 [?] RSPM (R 4.4.0)
   codetools      0.2-20     2024-03-31 [2] CRAN (R 4.4.0)
 P colorspace     2.1-1      2024-07-26 [?] RSPM (R 4.4.0)
 P data.table     1.16.0     2024-08-27 [?] RSPM (R 4.4.0)
 P dials        * 1.3.0      2024-07-30 [?] RSPM (R 4.4.0)
 P DiceDesign     1.10       2023-12-07 [?] RSPM (R 4.4.0)
 P digest         0.6.37     2024-08-19 [?] RSPM (R 4.4.0)
 P dplyr        * 1.1.4      2023-11-17 [?] RSPM (R 4.4.0)
 P evaluate       1.0.0      2024-09-17 [?] RSPM (R 4.4.0)
 P fansi          1.0.6      2023-12-08 [?] RSPM (R 4.4.0)
 P fastmap        1.2.0      2024-05-15 [?] RSPM (R 4.4.0)
 P forcats      * 1.0.0      2023-01-29 [?] RSPM (R 4.4.0)
 P foreach        1.5.2      2022-02-02 [?] RSPM (R 4.4.0)
 P furrr          0.3.1      2022-08-15 [?] RSPM (R 4.4.0)
 P future         1.34.0     2024-07-29 [?] RSPM (R 4.4.0)
 P future.apply   1.11.2     2024-03-28 [?] RSPM (R 4.4.0)
 P generics       0.1.3      2022-07-05 [?] RSPM (R 4.4.0)
 P ggplot2      * 3.5.1      2024-04-23 [?] RSPM (R 4.4.0)
 P globals        0.16.3     2024-03-08 [?] RSPM (R 4.4.0)
 P glue           1.8.0      2024-09-30 [?] RSPM (R 4.4.0)
 P gower          1.0.1      2022-12-22 [?] RSPM (R 4.4.0)
 P GPfit          1.0-8      2019-02-08 [?] RSPM (R 4.4.0)
 P gtable         0.3.5      2024-04-22 [?] RSPM (R 4.4.0)
 P hardhat        1.4.0      2024-06-02 [?] RSPM (R 4.4.0)
 P hms            1.1.3      2023-03-21 [?] RSPM (R 4.4.0)
 P htmltools      0.5.8.1    2024-04-04 [?] RSPM (R 4.4.0)
 P htmlwidgets    1.6.4      2023-12-06 [?] RSPM (R 4.4.0)
 P infer        * 1.0.7      2024-03-25 [?] RSPM (R 4.4.0)
 P ipred          0.9-15     2024-07-18 [?] RSPM (R 4.4.0)
 P iterators      1.0.14     2022-02-05 [?] RSPM (R 4.4.0)
 P jsonlite       1.8.9      2024-09-20 [?] RSPM (R 4.4.0)
 P knitr          1.48       2024-07-07 [?] RSPM (R 4.4.0)
   lattice        0.22-6     2024-03-20 [2] CRAN (R 4.4.0)
 P lava           1.8.0      2024-03-05 [?] RSPM (R 4.4.0)
 P lhs            1.2.0      2024-06-30 [?] RSPM (R 4.4.0)
 P lifecycle      1.0.4      2023-11-07 [?] RSPM (R 4.4.0)
 P listenv        0.9.1      2024-01-29 [?] RSPM (R 4.4.0)
 P lubridate    * 1.9.3      2023-09-27 [?] RSPM (R 4.4.0)
 P magrittr       2.0.3      2022-03-30 [?] RSPM (R 4.4.0)
   MASS         * 7.3-60.2   2024-08-22 [2] local
   Matrix         1.7-0      2024-03-22 [2] CRAN (R 4.4.0)
 P modeldata    * 1.4.0      2024-06-19 [?] RSPM (R 4.4.0)
 P munsell        0.5.1      2024-04-01 [?] RSPM (R 4.4.0)
   nnet           7.3-19     2023-05-03 [2] CRAN (R 4.4.0)
 P parallelly     1.38.0     2024-07-27 [?] RSPM (R 4.4.0)
 P parsnip      * 1.2.1      2024-03-22 [?] RSPM (R 4.4.0)
 P pillar         1.9.0      2023-03-22 [?] RSPM (R 4.4.0)
 P pkgconfig      2.0.3      2019-09-22 [?] RSPM (R 4.4.0)
 P prodlim        2024.06.25 2024-06-24 [?] RSPM (R 4.4.0)
 P purrr        * 1.0.2      2023-08-10 [?] RSPM (R 4.4.0)
 P R6             2.5.1      2021-08-19 [?] RSPM (R 4.4.0)
 P Rcpp           1.0.13     2024-07-17 [?] RSPM (R 4.4.0)
 P readr        * 2.1.5      2024-01-10 [?] RSPM (R 4.4.0)
 P recipes      * 1.1.0      2024-07-04 [?] RSPM (R 4.4.0)
   renv           1.0.10     2024-10-05 [1] RSPM (R 4.4.0)
 P rlang          1.1.4      2024-06-04 [?] RSPM (R 4.4.0)
 P rmarkdown      2.28       2024-08-17 [?] RSPM (R 4.4.0)
   rpart          4.1.23     2023-12-05 [2] CRAN (R 4.4.0)
 P rsample      * 1.2.1      2024-03-25 [?] RSPM (R 4.4.0)
 P rstudioapi     0.16.0     2024-03-24 [?] RSPM (R 4.4.0)
 P scales       * 1.3.0      2023-11-28 [?] RSPM (R 4.4.0)
 P sessioninfo    1.2.2      2021-12-06 [?] RSPM (R 4.4.0)
 P stringi        1.8.4      2024-05-06 [?] RSPM (R 4.4.0)
 P stringr      * 1.5.1      2023-11-14 [?] RSPM (R 4.4.0)
 P survival       3.7-0      2024-06-05 [?] RSPM (R 4.4.0)
 P tibble       * 3.2.1      2023-03-20 [?] RSPM (R 4.4.0)
 P tidymodels   * 1.2.0      2024-03-25 [?] RSPM (R 4.4.0)
 P tidyr        * 1.3.1      2024-01-24 [?] RSPM (R 4.4.0)
 P tidyselect     1.2.1      2024-03-11 [?] RSPM (R 4.4.0)
 P tidyverse    * 2.0.0      2023-02-22 [?] RSPM (R 4.4.0)
 P timechange     0.3.0      2024-01-18 [?] RSPM (R 4.4.0)
 P timeDate       4041.110   2024-09-22 [?] RSPM (R 4.4.0)
 P tune         * 1.2.1      2024-04-18 [?] RSPM (R 4.4.0)
 P tzdb           0.4.0      2023-05-12 [?] RSPM (R 4.4.0)
 P utf8           1.2.4      2023-10-22 [?] RSPM (R 4.4.0)
 P vctrs          0.6.5      2023-12-01 [?] RSPM (R 4.4.0)
 P withr          3.0.1      2024-07-31 [?] RSPM (R 4.4.0)
 P workflows    * 1.1.4      2024-02-19 [?] RSPM (R 4.4.0)
 P workflowsets * 1.1.0      2024-03-21 [?] RSPM (R 4.4.0)
 P xfun           0.48       2024-10-03 [?] RSPM (R 4.4.0)
 P xgboost      * 1.7.8.1    2024-07-24 [?] RSPM (R 4.4.0)
 P yaml           2.3.10     2024-07-26 [?] RSPM (R 4.4.0)
 P yardstick    * 1.3.1      2024-03-21 [?] RSPM (R 4.4.0)

 [1] /home/runner/work/CAMIS/CAMIS/renv/library/linux-ubuntu-jammy/R-4.4/x86_64-pc-linux-gnu
 [2] /opt/R/4.4.0/lib/R/library

 P ── Loaded and on-disk path mismatch.

──────────────────────────────────────────────────────────────────────────────