Getting started

library(hBayesDMregression)

# Tidyverse packages
library(tibble)
library(dplyr)
library(ggplot2)
library(tidyr)
library(purrr)
library(magrittr)

# For MCMC diagnostics
library(rstan)

An example model

data("igt_example")

head(igt_example)
##   subjID trial choice gain loss age      x1         x2       x3        x4 sex
## 1   1001     1      3   50    0  31 1.22201 -0.6302797 0.401264 0.7221759   0
## 2   1001     2      2  100    0  31 1.22201 -0.6302797 0.401264 0.7221759   0
## 3   1001     3      3   50    0  31 1.22201 -0.6302797 0.401264 0.7221759   0
## 4   1001     4      4   50    0  31 1.22201 -0.6302797 0.401264 0.7221759   0
## 5   1001     5      4   50    0  31 1.22201 -0.6302797 0.401264 0.7221759   0
## 6   1001     6      4   50    0  31 1.22201 -0.6302797 0.401264 0.7221759   0
##   cond1 cond2 cond3
## 1     0     1     0
## 2     0     1     0
## 3     0     1     0
## 4     0     1     0
## 5     0     1     0
## 6     0     1     0

Note that the included example data set has a trial column that records the trial number for each subject. This column should be excluded when building the model.

## NOT RUN
orl_model <- igt_orl_regression(
  igt_example,
  exclude_cols = "trial",
  regression_pars = c("K", "betaF"),
  nchain = 3,
  niter = 3000,
  nwarmup = 1000,
  ncore = parallel::detectCores() / 2
)
orl_model <- readRDS("orl_model.rds")

Distributions of posterior samples

The posterior samples are contained in the parVals index of the model output. We will extract and assign it to a variable in our workspace since we will be calling it often.

posterior_samples <- orl_model %>%
  pluck("parVals")

Note: the above is equivalent to calling either orl_model$parVals or orl_model[["parVals"]].

str(posterior_samples)
## List of 15
##  $ mu_Arew   : num [1:6000(1d)] 0.165 0.192 0.14 0.221 0.163 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL
##  $ mu_Apun   : num [1:6000(1d)] 0.129 0.137 0.159 0.157 0.163 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL
##  $ mu_K      : num [1:6000(1d)] 2.73 3.2 4.78 4.63 1.27 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL
##  $ mu_betaF  : num [1:6000(1d)] 0.536 0.54 -0.508 -0.536 1.257 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL
##  $ mu_betaP  : num [1:6000(1d)] -0.363 0.243 0.343 -1.058 1.733 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL
##  $ sigma     : num [1:6000, 1:5] 0.2881 0.0352 0.1675 0.0414 0.2144 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ iterations: NULL
##   .. ..$           : NULL
##  $ Arew      : num [1:6000, 1:4] 0.114 0.186 0.143 0.218 0.153 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ iterations: NULL
##   .. ..$           : NULL
##  $ Apun      : num [1:6000, 1:4] 0.114 0.148 0.118 0.153 0.224 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ iterations: NULL
##   .. ..$           : NULL
##  $ K         : num [1:6000, 1:4] 5 0 5 5 0.159 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ iterations: NULL
##   .. ..$           : NULL
##  $ betaF     : num [1:6000, 1:4] 2.907 3.209 2.202 2.853 0.843 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ iterations: NULL
##   .. ..$           : NULL
##  $ betaP     : num [1:6000, 1:4] 0.2994 0.0233 8.915 1.1277 1.755 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ iterations: NULL
##   .. ..$           : NULL
##  $ log_lik   : num [1:6000, 1:4] -88.9 -89.6 -89.9 -89.2 -90.3 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ iterations: NULL
##   .. ..$           : NULL
##  $ beta      : num [1:6000, 1:18] 3.0608 -6.5937 2.2449 8.8583 -0.0384 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : NULL
##   .. ..$ : chr [1:18] "K_age" "K_x1" "K_x2" "K_x3" ...
##  $ sigma_beta: num [1:6000, 1:18] 2.779 10.783 2.325 13.629 0.703 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : NULL
##   .. ..$ : chr [1:18] "sigma_K_age" "sigma_K_x1" "sigma_K_x2" "sigma_K_x3" ...
##  $ lp__      : num [1:6000(1d)] -429 -439 -422 -422 -424 ...
##   ..- attr(*, "dimnames")=List of 1
##   .. ..$ iterations: NULL

Hyperparameters

Mu

mu <- posterior_samples %>% # This is a list
  keep_at(~ grepl("^mu_", .x)) %>% # Keep list elements whose names begin with `mu_`
  do.call(cbind, .) %>% # Column-bind the list elements
  as_tibble() %>% # Matrix -> Tibble
  pivot_longer(everything(), names_to="parameter") # Reshape to long format

mu
## # A tibble: 30,000 x 2
##    parameter  value
##    <chr>      <dbl>
##  1 mu_Arew    0.165
##  2 mu_Apun    0.129
##  3 mu_K       2.73 
##  4 mu_betaF   0.536
##  5 mu_betaP  -0.363
##  6 mu_Arew    0.192
##  7 mu_Apun    0.137
##  8 mu_K       3.20 
##  9 mu_betaF   0.540
## 10 mu_betaP   0.243
## # i 29,990 more rows
ggplot(mu) +
  geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
  facet_wrap(~parameter, scales="free", ncol=2) +
  theme_bw() +
  labs(x="Value", y="Density")

Sigma

sigma <- posterior_samples %>% # This is a list
  pluck("sigma") %>% # This is a matrix
  set_colnames(paste0("sigma_", orl_model$modelPars)) %>% # Set column names
  as_tibble() %>% # Matrix -> Tibble
  pivot_longer(everything(), names_to="parameter") # Reshape to long format

sigma
## # A tibble: 30,000 x 2
##    parameter    value
##    <chr>        <dbl>
##  1 sigma_Arew  0.288 
##  2 sigma_Apun  0.0665
##  3 sigma_K     0.407 
##  4 sigma_betaF 0.215 
##  5 sigma_betaP 1.12  
##  6 sigma_Arew  0.0352
##  7 sigma_Apun  0.0689
##  8 sigma_K     0.0739
##  9 sigma_betaF 1.92  
## 10 sigma_betaP 0.336 
## # i 29,990 more rows
ggplot(sigma) +
  geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
  facet_wrap(~parameter, scales="free", ncol=2) +
  theme_bw() +
  labs(x="Value", y="Density")

Model parameters

model_parameters <- posterior_samples %>% # This is a list
  keep_at(orl_model$modelPars) %>% # This is a list of matrices
  map(~ set_colnames(.x, paste0(orl_model$subjID))) %>% # Set column names for each matrix in the list
  map(as_tibble) %>% # Matrix -> Tibble for each matrix in the list
  map(~ pivot_longer(.x, everything(), names_to="subjID")) # Reshape each matrix within list to long format

model_parameters
## $Arew
## # A tibble: 24,000 x 2
##    subjID  value
##    <chr>   <dbl>
##  1 1001   0.114 
##  2 1002   0.142 
##  3 1003   0.163 
##  4 1004   0.0651
##  5 1001   0.186 
##  6 1002   0.202 
##  7 1003   0.185 
##  8 1004   0.187 
##  9 1001   0.143 
## 10 1002   0.161 
## # i 23,990 more rows
## 
## $Apun
## # A tibble: 24,000 x 2
##    subjID value
##    <chr>  <dbl>
##  1 1001   0.114
##  2 1002   0.154
##  3 1003   0.124
##  4 1004   0.108
##  5 1001   0.148
##  6 1002   0.122
##  7 1003   0.151
##  8 1004   0.158
##  9 1001   0.118
## 10 1002   0.136
## # i 23,990 more rows
## 
## $K
## # A tibble: 24,000 x 2
##    subjID value
##    <chr>  <dbl>
##  1 1001       5
##  2 1002       5
##  3 1003       5
##  4 1004       5
##  5 1001       0
##  6 1002       0
##  7 1003       0
##  8 1004       0
##  9 1001       5
## 10 1002       5
## # i 23,990 more rows
## 
## $betaF
## # A tibble: 24,000 x 2
##    subjID value
##    <chr>  <dbl>
##  1 1001   2.91 
##  2 1002   3.62 
##  3 1003   2.19 
##  4 1004   2.28 
##  5 1001   3.21 
##  6 1002   4.99 
##  7 1003   1.42 
##  8 1004   0.976
##  9 1001   2.20 
## 10 1002   3.91 
## # i 23,990 more rows
## 
## $betaP
## # A tibble: 24,000 x 2
##    subjID   value
##    <chr>    <dbl>
##  1 1001    0.299 
##  2 1002    1.06  
##  3 1003   -1.91  
##  4 1004   -2.10  
##  5 1001    0.0233
##  6 1002   -0.631 
##  7 1003    0.0240
##  8 1004    0.572 
##  9 1001    8.91  
## 10 1002    9.49  
## # i 23,990 more rows
ggplot(model_parameters$Arew) +
  geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
  facet_wrap(~subjID, scales="free", labeller="label_both") +
  theme_bw() +
  labs(x="Value", y="Density", caption="Model parameter Arew by subject")

Regression parameters

Beta

beta <- posterior_samples %>%
  pluck("beta") %>% # This is a matrix and already has column names
  as_tibble() # Matrix -> Tibble

K_betas <- beta %>%
  select(starts_with("K")) %>% # Select betas corresponding to model parameter K
  pivot_longer( # Reshape to long format
    everything(),
    names_to = "parameter",
    names_transform = ~ sub("K_", replacement="", .x)
  )

K_betas
## # A tibble: 54,000 x 2
##    parameter     value
##    <chr>         <dbl>
##  1 age        3.06    
##  2 x1         0.122   
##  3 x2        -0.000365
##  4 x3         0.185   
##  5 x4        -0.0606  
##  6 sex        3.35    
##  7 cond1     -0.697   
##  8 cond2      0.557   
##  9 cond3      1.42    
## 10 age       -6.59    
## # i 53,990 more rows
ggplot(K_betas) +
  geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
  facet_wrap(~parameter, scales="free") +
  theme_bw() +
  labs(x="Value", y="Density", caption="Betas for model parameter K")

Sigma beta

sigma_beta <- posterior_samples %>%
  pluck("sigma_beta") %>% # This is a matrix and already has column names
  as_tibble() # Matrix -> Tibble

K_sigma_betas <- sigma_beta %>%
  select(starts_with("sigma_K")) %>% # Select sigma betas corresponding to model parameter K
  pivot_longer( # Reshape to long format
    everything(),
    names_to = "parameter",
    names_transform = ~ sub("sigma_K_", replacement="", .x)
  )

K_sigma_betas
## # A tibble: 54,000 x 2
##    parameter   value
##    <chr>       <dbl>
##  1 age        2.78  
##  2 x1         1.18  
##  3 x2         0.362 
##  4 x3         0.475 
##  5 x4         0.0617
##  6 sex        2.97  
##  7 cond1      2.08  
##  8 cond2      0.602 
##  9 cond3      1.07  
## 10 age       10.8   
## # i 53,990 more rows
ggplot(K_sigma_betas) +
  geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
  facet_wrap(~parameter, scales="free") +
  theme_bw() +
  labs(x="Value", y="Density", caption="Sigma betas for model parameter K")

Traceplots

Since the original Stan fit is included in the hBayesDMregression model output, we can pass it directly to rstan::traceplot() to create traceplots. The parameters whose traceplots can be plotted can be found by getting the names of the Stan fit:

names(orl_model$fit)
##  [1] "mu_Arew"           "mu_Apun"           "mu_K"             
##  [4] "mu_betaF"          "mu_betaP"          "sigma[1]"         
##  [7] "sigma[2]"          "sigma[3]"          "sigma[4]"         
## [10] "sigma[5]"          "Arew[1]"           "Arew[2]"          
## [13] "Arew[3]"           "Arew[4]"           "Apun[1]"          
## [16] "Apun[2]"           "Apun[3]"           "Apun[4]"          
## [19] "K[1]"              "K[2]"              "K[3]"             
## [22] "K[4]"              "betaF[1]"          "betaF[2]"         
## [25] "betaF[3]"          "betaF[4]"          "betaP[1]"         
## [28] "betaP[2]"          "betaP[3]"          "betaP[4]"         
## [31] "log_lik[1]"        "log_lik[2]"        "log_lik[3]"       
## [34] "log_lik[4]"        "K_age"             "betaF_age"        
## [37] "K_x1"              "betaF_x1"          "K_x2"             
## [40] "betaF_x2"          "K_x3"              "betaF_x3"         
## [43] "K_x4"              "betaF_x4"          "K_sex"            
## [46] "betaF_sex"         "K_cond1"           "betaF_cond1"      
## [49] "K_cond2"           "betaF_cond2"       "K_cond3"          
## [52] "betaF_cond3"       "sigma_K_age"       "sigma_betaF_age"  
## [55] "sigma_K_x1"        "sigma_betaF_x1"    "sigma_K_x2"       
## [58] "sigma_betaF_x2"    "sigma_K_x3"        "sigma_betaF_x3"   
## [61] "sigma_K_x4"        "sigma_betaF_x4"    "sigma_K_sex"      
## [64] "sigma_betaF_sex"   "sigma_K_cond1"     "sigma_betaF_cond1"
## [67] "sigma_K_cond2"     "sigma_betaF_cond2" "sigma_K_cond3"    
## [70] "sigma_betaF_cond3" "lp__"

Suppose we want the traceplots for all parameters related to covariate x4:

pars_x4 <- grep("x4$", names(orl_model$fit), value=TRUE)

pars_x4
## [1] "K_x4"           "betaF_x4"       "sigma_K_x4"     "sigma_betaF_x4"
traceplot(orl_model$fit, pars=pars_x4)

Note that rstan::traceplot() uses ggplot2 for plotting. As such, additional customizations can be made via the ggplot2 package.

Interval plots

Intervals plots are constructed similarly.

pars_age <- grep("age$", names(orl_model$fit), value=TRUE)

pars_age
## [1] "K_age"           "betaF_age"       "sigma_K_age"     "sigma_betaF_age"
stan_plot(orl_model$fit, pars=pars_age)

Note that rstan::stan_plot() uses ggplot2 for plotting. As such, additional customizations can be made via the ggplot2 package.