library(hBayesDMregression)
# Tidyverse packages
library(tibble)
library(dplyr)
library(ggplot2)
library(tidyr)
library(purrr)
library(magrittr)
# For MCMC diagnostics
library(rstan)
data("igt_example")
head(igt_example)
## subjID trial choice gain loss age x1 x2 x3 x4 sex
## 1 1001 1 3 50 0 31 1.22201 -0.6302797 0.401264 0.7221759 0
## 2 1001 2 2 100 0 31 1.22201 -0.6302797 0.401264 0.7221759 0
## 3 1001 3 3 50 0 31 1.22201 -0.6302797 0.401264 0.7221759 0
## 4 1001 4 4 50 0 31 1.22201 -0.6302797 0.401264 0.7221759 0
## 5 1001 5 4 50 0 31 1.22201 -0.6302797 0.401264 0.7221759 0
## 6 1001 6 4 50 0 31 1.22201 -0.6302797 0.401264 0.7221759 0
## cond1 cond2 cond3
## 1 0 1 0
## 2 0 1 0
## 3 0 1 0
## 4 0 1 0
## 5 0 1 0
## 6 0 1 0
Note that the included example data set has a trial
column that records the trial number for each subject. This column
should be excluded when building the model.
## NOT RUN
orl_model <- igt_orl_regression(
igt_example,
exclude_cols = "trial",
regression_pars = c("K", "betaF"),
nchain = 3,
niter = 3000,
nwarmup = 1000,
ncore = parallel::detectCores() / 2
)
orl_model <- readRDS("orl_model.rds")
The posterior samples are contained in the parVals
index
of the model output. We will extract and assign it to a variable in our
workspace since we will be calling it often.
Note: the above is equivalent to calling either
orl_model$parVals
or
orl_model[["parVals"]]
.
str(posterior_samples)
## List of 15
## $ mu_Arew : num [1:6000(1d)] 0.165 0.192 0.14 0.221 0.163 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ mu_Apun : num [1:6000(1d)] 0.129 0.137 0.159 0.157 0.163 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ mu_K : num [1:6000(1d)] 2.73 3.2 4.78 4.63 1.27 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ mu_betaF : num [1:6000(1d)] 0.536 0.54 -0.508 -0.536 1.257 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ mu_betaP : num [1:6000(1d)] -0.363 0.243 0.343 -1.058 1.733 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ sigma : num [1:6000, 1:5] 0.2881 0.0352 0.1675 0.0414 0.2144 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ Arew : num [1:6000, 1:4] 0.114 0.186 0.143 0.218 0.153 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ Apun : num [1:6000, 1:4] 0.114 0.148 0.118 0.153 0.224 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ K : num [1:6000, 1:4] 5 0 5 5 0.159 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ betaF : num [1:6000, 1:4] 2.907 3.209 2.202 2.853 0.843 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ betaP : num [1:6000, 1:4] 0.2994 0.0233 8.915 1.1277 1.755 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ log_lik : num [1:6000, 1:4] -88.9 -89.6 -89.9 -89.2 -90.3 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ iterations: NULL
## .. ..$ : NULL
## $ beta : num [1:6000, 1:18] 3.0608 -6.5937 2.2449 8.8583 -0.0384 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ : NULL
## .. ..$ : chr [1:18] "K_age" "K_x1" "K_x2" "K_x3" ...
## $ sigma_beta: num [1:6000, 1:18] 2.779 10.783 2.325 13.629 0.703 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ : NULL
## .. ..$ : chr [1:18] "sigma_K_age" "sigma_K_x1" "sigma_K_x2" "sigma_K_x3" ...
## $ lp__ : num [1:6000(1d)] -429 -439 -422 -422 -424 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
mu <- posterior_samples %>% # This is a list
keep_at(~ grepl("^mu_", .x)) %>% # Keep list elements whose names begin with `mu_`
do.call(cbind, .) %>% # Column-bind the list elements
as_tibble() %>% # Matrix -> Tibble
pivot_longer(everything(), names_to="parameter") # Reshape to long format
mu
## # A tibble: 30,000 x 2
## parameter value
## <chr> <dbl>
## 1 mu_Arew 0.165
## 2 mu_Apun 0.129
## 3 mu_K 2.73
## 4 mu_betaF 0.536
## 5 mu_betaP -0.363
## 6 mu_Arew 0.192
## 7 mu_Apun 0.137
## 8 mu_K 3.20
## 9 mu_betaF 0.540
## 10 mu_betaP 0.243
## # i 29,990 more rows
ggplot(mu) +
geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
facet_wrap(~parameter, scales="free", ncol=2) +
theme_bw() +
labs(x="Value", y="Density")
sigma <- posterior_samples %>% # This is a list
pluck("sigma") %>% # This is a matrix
set_colnames(paste0("sigma_", orl_model$modelPars)) %>% # Set column names
as_tibble() %>% # Matrix -> Tibble
pivot_longer(everything(), names_to="parameter") # Reshape to long format
sigma
## # A tibble: 30,000 x 2
## parameter value
## <chr> <dbl>
## 1 sigma_Arew 0.288
## 2 sigma_Apun 0.0665
## 3 sigma_K 0.407
## 4 sigma_betaF 0.215
## 5 sigma_betaP 1.12
## 6 sigma_Arew 0.0352
## 7 sigma_Apun 0.0689
## 8 sigma_K 0.0739
## 9 sigma_betaF 1.92
## 10 sigma_betaP 0.336
## # i 29,990 more rows
ggplot(sigma) +
geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
facet_wrap(~parameter, scales="free", ncol=2) +
theme_bw() +
labs(x="Value", y="Density")
model_parameters <- posterior_samples %>% # This is a list
keep_at(orl_model$modelPars) %>% # This is a list of matrices
map(~ set_colnames(.x, paste0(orl_model$subjID))) %>% # Set column names for each matrix in the list
map(as_tibble) %>% # Matrix -> Tibble for each matrix in the list
map(~ pivot_longer(.x, everything(), names_to="subjID")) # Reshape each matrix within list to long format
model_parameters
## $Arew
## # A tibble: 24,000 x 2
## subjID value
## <chr> <dbl>
## 1 1001 0.114
## 2 1002 0.142
## 3 1003 0.163
## 4 1004 0.0651
## 5 1001 0.186
## 6 1002 0.202
## 7 1003 0.185
## 8 1004 0.187
## 9 1001 0.143
## 10 1002 0.161
## # i 23,990 more rows
##
## $Apun
## # A tibble: 24,000 x 2
## subjID value
## <chr> <dbl>
## 1 1001 0.114
## 2 1002 0.154
## 3 1003 0.124
## 4 1004 0.108
## 5 1001 0.148
## 6 1002 0.122
## 7 1003 0.151
## 8 1004 0.158
## 9 1001 0.118
## 10 1002 0.136
## # i 23,990 more rows
##
## $K
## # A tibble: 24,000 x 2
## subjID value
## <chr> <dbl>
## 1 1001 5
## 2 1002 5
## 3 1003 5
## 4 1004 5
## 5 1001 0
## 6 1002 0
## 7 1003 0
## 8 1004 0
## 9 1001 5
## 10 1002 5
## # i 23,990 more rows
##
## $betaF
## # A tibble: 24,000 x 2
## subjID value
## <chr> <dbl>
## 1 1001 2.91
## 2 1002 3.62
## 3 1003 2.19
## 4 1004 2.28
## 5 1001 3.21
## 6 1002 4.99
## 7 1003 1.42
## 8 1004 0.976
## 9 1001 2.20
## 10 1002 3.91
## # i 23,990 more rows
##
## $betaP
## # A tibble: 24,000 x 2
## subjID value
## <chr> <dbl>
## 1 1001 0.299
## 2 1002 1.06
## 3 1003 -1.91
## 4 1004 -2.10
## 5 1001 0.0233
## 6 1002 -0.631
## 7 1003 0.0240
## 8 1004 0.572
## 9 1001 8.91
## 10 1002 9.49
## # i 23,990 more rows
ggplot(model_parameters$Arew) +
geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
facet_wrap(~subjID, scales="free", labeller="label_both") +
theme_bw() +
labs(x="Value", y="Density", caption="Model parameter Arew by subject")
beta <- posterior_samples %>%
pluck("beta") %>% # This is a matrix and already has column names
as_tibble() # Matrix -> Tibble
K_betas <- beta %>%
select(starts_with("K")) %>% # Select betas corresponding to model parameter K
pivot_longer( # Reshape to long format
everything(),
names_to = "parameter",
names_transform = ~ sub("K_", replacement="", .x)
)
K_betas
## # A tibble: 54,000 x 2
## parameter value
## <chr> <dbl>
## 1 age 3.06
## 2 x1 0.122
## 3 x2 -0.000365
## 4 x3 0.185
## 5 x4 -0.0606
## 6 sex 3.35
## 7 cond1 -0.697
## 8 cond2 0.557
## 9 cond3 1.42
## 10 age -6.59
## # i 53,990 more rows
ggplot(K_betas) +
geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
facet_wrap(~parameter, scales="free") +
theme_bw() +
labs(x="Value", y="Density", caption="Betas for model parameter K")
sigma_beta <- posterior_samples %>%
pluck("sigma_beta") %>% # This is a matrix and already has column names
as_tibble() # Matrix -> Tibble
K_sigma_betas <- sigma_beta %>%
select(starts_with("sigma_K")) %>% # Select sigma betas corresponding to model parameter K
pivot_longer( # Reshape to long format
everything(),
names_to = "parameter",
names_transform = ~ sub("sigma_K_", replacement="", .x)
)
K_sigma_betas
## # A tibble: 54,000 x 2
## parameter value
## <chr> <dbl>
## 1 age 2.78
## 2 x1 1.18
## 3 x2 0.362
## 4 x3 0.475
## 5 x4 0.0617
## 6 sex 2.97
## 7 cond1 2.08
## 8 cond2 0.602
## 9 cond3 1.07
## 10 age 10.8
## # i 53,990 more rows
ggplot(K_sigma_betas) +
geom_histogram(aes(x=value, y=after_stat(density)), colour="#56B4E9", fill="#0072B2", alpha=0.7) +
facet_wrap(~parameter, scales="free") +
theme_bw() +
labs(x="Value", y="Density", caption="Sigma betas for model parameter K")
Since the original Stan fit is included in the
hBayesDMregression
model output, we can pass it directly to
rstan::traceplot()
to create traceplots. The parameters
whose traceplots can be plotted can be found by getting the names of the
Stan fit:
names(orl_model$fit)
## [1] "mu_Arew" "mu_Apun" "mu_K"
## [4] "mu_betaF" "mu_betaP" "sigma[1]"
## [7] "sigma[2]" "sigma[3]" "sigma[4]"
## [10] "sigma[5]" "Arew[1]" "Arew[2]"
## [13] "Arew[3]" "Arew[4]" "Apun[1]"
## [16] "Apun[2]" "Apun[3]" "Apun[4]"
## [19] "K[1]" "K[2]" "K[3]"
## [22] "K[4]" "betaF[1]" "betaF[2]"
## [25] "betaF[3]" "betaF[4]" "betaP[1]"
## [28] "betaP[2]" "betaP[3]" "betaP[4]"
## [31] "log_lik[1]" "log_lik[2]" "log_lik[3]"
## [34] "log_lik[4]" "K_age" "betaF_age"
## [37] "K_x1" "betaF_x1" "K_x2"
## [40] "betaF_x2" "K_x3" "betaF_x3"
## [43] "K_x4" "betaF_x4" "K_sex"
## [46] "betaF_sex" "K_cond1" "betaF_cond1"
## [49] "K_cond2" "betaF_cond2" "K_cond3"
## [52] "betaF_cond3" "sigma_K_age" "sigma_betaF_age"
## [55] "sigma_K_x1" "sigma_betaF_x1" "sigma_K_x2"
## [58] "sigma_betaF_x2" "sigma_K_x3" "sigma_betaF_x3"
## [61] "sigma_K_x4" "sigma_betaF_x4" "sigma_K_sex"
## [64] "sigma_betaF_sex" "sigma_K_cond1" "sigma_betaF_cond1"
## [67] "sigma_K_cond2" "sigma_betaF_cond2" "sigma_K_cond3"
## [70] "sigma_betaF_cond3" "lp__"
Suppose we want the traceplots for all parameters related to
covariate x4
:
pars_x4 <- grep("x4$", names(orl_model$fit), value=TRUE)
pars_x4
## [1] "K_x4" "betaF_x4" "sigma_K_x4" "sigma_betaF_x4"
traceplot(orl_model$fit, pars=pars_x4)
Note that rstan::traceplot()
uses ggplot2
for plotting. As such, additional customizations can be made via the
ggplot2 package.
Intervals plots are constructed similarly.
pars_age <- grep("age$", names(orl_model$fit), value=TRUE)
pars_age
## [1] "K_age" "betaF_age" "sigma_K_age" "sigma_betaF_age"
stan_plot(orl_model$fit, pars=pars_age)
Note that rstan::stan_plot()
uses ggplot2
for plotting. As such, additional customizations can be made via the
ggplot2 package.