Created
January 5, 2021 05:41
-
-
Save mike-lawrence/e3794cdedf91e98859d3232b8020aa1f to your computer and use it in GitHub Desktop.
hierarchical within-subjects model with Gaussian outcome & reduced-redundant-computation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Installs any packages not already installed | |
#' @examples | |
#' \dontrun{ | |
#' install_if_missing(c('tidyverse','github.com/stan-dev/cmdstanr')) | |
#' } | |
install_if_missing = function(pkgs){ | |
missing_pkgs = NULL | |
for(this_pkg in pkgs){ | |
path = NULL | |
try( | |
path <- find.package(basename(this_pkg),quiet=T,verbose=F) | |
, silent = T | |
) | |
if(is.null(path)){ | |
missing_pkgs = c(missing_pkgs,this_pkg) | |
} | |
} | |
cran_missing = missing_pkgs[!grepl('github.com/',fixed=T,missing_pkgs)] | |
if(length(cran_missing)>0){ | |
message('The following required but uninstalled CRAN packages will now be installed:\n',paste(cran_missing,collapse='\n')) | |
install.packages(cran_missing) | |
} | |
github_missing = missing_pkgs[grepl('github.com/',fixed=T,missing_pkgs)] | |
github_missing = gsub('github.com/','',github_missing) | |
if(length(github_missing)>0){ | |
message('The following required but uninstalled Github packages will now be installed:\n',paste(this_pkg,collapse='\n')) | |
remotes::install_github(github_missing) | |
} | |
invisible() | |
} | |
#define a function that adds diagnostics as a metadata attribute | |
add_diagnostic_bools = function(x,fit){ | |
sink('/dev/null') | |
diagnostics = fit$cmdstan_diagnose()$stdout #annoyingly not quiet-able | |
sink(NULL) | |
diagnostic_bools = list( | |
treedepth_maxed = stringr::str_detect(diagnostics,'transitions hit the maximum') | |
, ebfmi_low = stringr::str_detect(diagnostics,' is below the nominal threshold') | |
, essp_low = stringr::str_detect(diagnostics,'The following parameters had fewer than 0.001 effective draws per transition:') | |
, rhat_high = stringr::str_detect(diagnostics,'The following parameters had split R-hat greater than 1.05') | |
) | |
attr(x,'meta') = list(diagnostic_bools=diagnostic_bools) | |
return(x) | |
} | |
#define a custom print method that shows the diagnostics in the metadata attribute | |
print.stan_summary_tbl = function(x,...) { | |
meta = attr(x,'meta') | |
if(any(unlist(meta$diagnostic_bools))){ | |
cat(crayon::bgRed('WARNING:\n')) | |
} | |
if(meta$diagnostic_bools$treedepth_maxed){ | |
cat(crayon::bgRed('Treedepth maxed\n')) | |
} | |
if(meta$diagnostic_bools$ebfmi_low){ | |
cat(crayon::bgRed('E-BMFI low\n')) | |
} | |
if(meta$diagnostic_bools$essp_low){ | |
cat(crayon::bgRed('ESS% low for one or more parameters\n')) | |
} | |
if(meta$diagnostic_bools$essp_low){ | |
cat(crayon::bgRed('R-hat high for one or more parameters\n')) | |
} | |
NextMethod(x,...) | |
invisible(x) | |
} | |
#create a new S3 class for custom-printing stanfit summary tables | |
add_stan_summary_tbl_class = function(x){ | |
class(x) <- c("stan_summary_tbl",class(x)) | |
return(x) | |
} | |
#function to detect whether a variable name indicates that it's on the diagonal | |
# of a correlation parameter matrix | |
has_underscore_suffix = function(x){ | |
bare_has_underscore_suffix = stringr::str_ends(x,'_') | |
has_index_suffix = stringr::str_ends(x,']') | |
indexed_has_underscore_suffix = | |
( | |
tibble::tibble(x = x[has_index_suffix]) | |
%>% tidyr::separate(x,sep='\\[',into='x',extra='drop') | |
%>% dplyr::mutate( out=stringr::str_ends(x,'_') ) | |
%>% dplyr::pull(out) | |
) | |
bare_has_underscore_suffix[has_index_suffix] = indexed_has_underscore_suffix | |
return(bare_has_underscore_suffix) | |
} | |
#function to detect whether a variable name indicates that it's on the diagonal | |
# or lower-tri element of a correlation parameter matrix | |
is_cor_diag_or_lower_tri = function(x,prefix){ | |
has_prefix = stringr::str_starts(x,prefix) | |
x = x[has_prefix] | |
del = function(x,to_del){gsub(to_del,'',x,fixed=T)} | |
to_toss = | |
( | |
x | |
%>% del(prefix) | |
%>% del('[') | |
%>% del(']') | |
%>% tibble::tibble(x = .) | |
%>% tidyr::separate(x,into=c('i','j')) | |
%>% dplyr::mutate( to_toss = (i==j) | (i>j) ) | |
%>% dplyr::pull(to_toss) | |
) | |
has_prefix[has_prefix] = to_toss | |
return(has_prefix) | |
} | |
#function to sort a stan summary table by size of variables | |
sort_by_variable_size = function(x){ | |
x2 = | |
( | |
x | |
%>% tidyr::separate( | |
variable | |
, sep = '\\[' | |
, into = 'var' | |
, extra = 'drop' | |
, remove = F | |
) | |
) | |
( | |
x2 | |
%>% dplyr::group_by(var) | |
%>% dplyr::summarise(count = dplyr::n(),.groups = 'drop') | |
%>% dplyr::full_join(x2,by='var') | |
%>% dplyr::arrange(count,var,variable) | |
%>% dplyr::select(-count,-var) | |
) | |
} | |
halfsum_contrasts = function(...){ | |
contr.sum(...)*.5 | |
} | |
get_contrast_matrix = function( | |
data | |
, formula | |
, contrast_kind = NULL | |
){ | |
if (inherits(data, "tbl_df")) { | |
data = as.data.frame(data) | |
} | |
vars = attr(terms(formula),'term.labels') | |
vars = vars[!grepl(':',vars)] | |
if(length(vars)==1){ | |
data = data.frame(data[,vars]) | |
names(data) = vars | |
}else{ | |
data = data[,vars] | |
} | |
vars_to_rename = NULL | |
for(i in vars){ | |
if(is.character(data[,i])){ | |
data[,i] = factor(data[,i]) | |
} | |
if( is.factor(data[,i])){ | |
if(length(levels(data[,i]))==2){ | |
vars_to_rename = c(vars_to_rename,i) | |
} | |
if(!is.null(contrast_kind) ){ | |
contrasts(data[,i]) = contrast_kind | |
} | |
} | |
} | |
mm = model.matrix(data=data,object=formula) | |
dimnames(mm)[[2]][dimnames(mm)[[2]]=='(Intercept)'] = '(I)' | |
for(i in vars_to_rename){ | |
dimnames(mm)[[2]] = gsub(paste0(i,1),i,dimnames(mm)[[2]]) | |
} | |
attr(mm,'formula') = formula | |
attr(mm,'data') = data | |
return(mm) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#preamble (options, installs, imports & custom functions) ---- | |
options(warn=1) #really should be default in R | |
`%!in%` = Negate(`%in%`) #should be in base R! | |
# specify the packages used: | |
required_packages = c( | |
'rethinking' #for rlkjcorr & rmvrnom2 | |
, 'crayon' #for coloring terminal output | |
, 'bayesplot' #for convenient posterior plots | |
, 'github.com/stan-dev/cmdstanr' #for Stan stuff | |
, 'tidyverse' #for all that is good and holy | |
) | |
#load the helper functions: | |
source('helper_functions.r') | |
#helper_functions.r defines: | |
# - install_if_missing() | |
# - add_diagnostic_bools() | |
# - print.stan_summary_tbl() | |
# - add_stan_summary_tbl_class() | |
# - is_cor_diag() | |
# - halfsum_contrasts() | |
# - get_contrast_matrix() | |
#install any required packages not already present | |
install_if_missing(required_packages) | |
# define a shorthand for the pipe operator | |
`%>%` = magrittr::`%>%` | |
#simulate data ---- | |
set.seed(1) #change this to make different data | |
#setting the data simulation parameters | |
sim_pars = tibble::lst( | |
#parameters you can play with | |
num_subj = 100 #number of subjects, must be an integer >1 | |
, num_vars = 3 #number of 2-level variables manipulated as crossed and within each subject, must be an integer >0 | |
, num_trials = 100 #number of trials per subject/condition combo, must be an integer >1 | |
#the rest of these you shouldn't touch | |
, num_coef = 2^(num_vars) | |
, coef_means = rnorm(num_coef) | |
, coef_sds = rweibull(num_coef,2,1) | |
, cor_mat = rethinking::rlkjcorr(1,num_coef,eta=1) | |
, noise = rweibull(1,2,1) | |
) | |
#compute the contrast matrix | |
contrast_matrix = | |
( | |
1:sim_pars$num_vars | |
%>% purrr::map(.f=function(x){ | |
factor(c('lo','hi')) | |
}) | |
%>% (function(x){ | |
names(x) = paste0('v',1:sim_pars$num_vars) | |
return(x) | |
}) | |
%>% purrr::cross_df() | |
%>% (function(x){ | |
get_contrast_matrix( | |
data = x | |
, formula = as.formula(paste('~',paste0('v',1:sim_pars$num_vars,collapse='*'))) | |
, contrast_kind = halfsum_contrasts | |
) | |
}) | |
) | |
#get coefficients for each subject | |
subj_coef = | |
( | |
#subj coefs as mvn | |
rethinking::rmvnorm2( | |
n = sim_pars$num_subj | |
, Mu = sim_pars$coef_means | |
, sigma = sim_pars$coef_sds | |
, Rho = sim_pars$cor_mat | |
) | |
#add names to columns | |
%>% (function(x){ | |
dimnames(x)=list(NULL,paste0('X',1:ncol(x))) | |
return(x) | |
}) | |
#make a tibble | |
%>% tibble::as_tibble(.name_repair='unique') | |
#add subject identifier column | |
%>% dplyr::mutate( | |
subj = 1:sim_pars$num_subj | |
) | |
) | |
# get condition means implied by subject coefficients and contrast matrix | |
subj_cond = | |
( | |
subj_coef | |
%>% dplyr::group_by(subj) | |
%>% dplyr::summarise( | |
(function(x){ | |
out = attr(contrast_matrix,'data') | |
out$cond_mean = as.vector(contrast_matrix %*% t(x)) | |
return(out) | |
})(dplyr::cur_data()) | |
, .groups = 'drop' | |
) | |
) | |
# get noisy measurements in each condition for each subject | |
dat = | |
( | |
subj_cond | |
%>% tidyr::expand_grid(trial = 1:sim_pars$num_trials) | |
%>% dplyr::mutate( | |
obs = rnorm(dplyr::n(),cond_mean,sim_pars$noise) | |
) | |
) | |
# Compute inputs to model ---- | |
#W: the full trial-by-trial contrast matrix | |
W = | |
( | |
dat | |
#get the contrast matrix (wrapper on stats::model.matrix) | |
%>% get_contrast_matrix( | |
# the following compilcated specification of the formula is a by-product of making this example | |
# work for any value for sim_pars$num_vars; normally you would do something like this | |
# (for 2 variables for example): | |
# formula = ~ v1*v2 | |
formula = as.formula(paste0('~',paste0('v',1:sim_pars$num_vars,collapse='*'))) | |
# half-sum contrasts are nice for 2-level variables bc they yield parameters whose value | |
# is the difference between conditions | |
, contrast_kind = halfsum_contrasts | |
) | |
#convert to tibble | |
%>% tibble::as_tibble(.name_repair='unique') | |
) | |
#quick glimpse; lots of rows | |
print(W) | |
# get the unique entries in W | |
uW = dplyr::distinct(W) | |
print(uW) | |
#far fewer rows! | |
#for each unique condition specified by uW, the stan model will | |
# work out values for that condition for each subject, and we'll need to index | |
# into the resulting subject-by-condition matrix. So we need to create our own | |
# subject-by-condition matrix and get the indices of the observed data into a | |
# the array produced when that matrix is flattened. | |
obs_index = | |
( | |
uW | |
#first repeat the matrix so there's a copy for each subject | |
%>% dplyr::slice( | |
rep( | |
dplyr::row_number() | |
, length(unique(dat$subj)) | |
) | |
) | |
#now add the subject labels | |
%>% dplyr::mutate( | |
subj = rep(sort(unique(dat$subj)),each=nrow(uW)) | |
) | |
#add row identifier | |
%>% dplyr::mutate( | |
row = 1:dplyr::n() | |
) | |
# join to the full contrast matrix W | |
%>% dplyr::right_join( | |
#add the subject column | |
dplyr::mutate(W,subj=dat$subj) | |
, by = c(names(uW),'subj') | |
) | |
#pull the row label | |
%>% dplyr::pull(row) | |
) | |
# package for stan & sample ---- | |
data_for_stan = tibble::lst( #lst permits later entries to refer to earlier entries | |
#### | |
# Entries we need to specify ourselves | |
#### | |
# W: within predictor matrix | |
uW = as.matrix(uW) | |
# sim_pars$num_subj: number of subjects | |
, num_subj = length(unique(dat$subj)) | |
# outcome: outcome on each trial | |
, obs = dat$obs | |
# obs_index: index of each trial in flattened version of subject-by-condition value matrix | |
, obs_index = obs_index | |
#### | |
# Entries computable from the above | |
#### | |
# num_obs_total: total number of observations | |
, num_obs_total = length(obs) | |
# num_rows_W: num rows in within predictor matrix W | |
, num_rows_uW = nrow(uW) | |
# num_cols_W: num cols in within predictor matrix W | |
, num_cols_uW = ncol(uW) | |
) | |
#double-check: | |
tibble::glimpse(data_for_stan) | |
#compile the model | |
mod = cmdstanr::cmdstan_model('hwg_fast.stan') | |
#how many chains to run in parallel | |
phys_cores_minus_one = parallel::detectCores()/2-1 | |
#we want at least 4 chains. Most CPUs have >=4 cores these days, but parallel::detectCores() | |
# typically returns twice the physical core count thanks to most systems being able to | |
# "hyperthread", treating a single physical core as if it were two. However, only certain | |
# workloads benefit from hyperthreading, and Stan generally doesn't (indeed, it can hurt) | |
# so best to run only as many chains as there are physical cores. Additionally, probably a | |
# good idea to leave one core unused for other processes (inc. monitoring the Stan progress) | |
num_samples_to_obtain = 1e3 | |
#this is the number of samples to run on each chain. If the model samples well, 1e3 should | |
# be plenty (especially since it'll be 1e3*phys_cores_minus_one) for stable inference on | |
# even tail quantities of the posterior | |
sampling_seed = 1 | |
#setting the sampling seed explicitly helps ensure reproducibility | |
#sample the model | |
fit = mod$sample( | |
data = data_for_stan | |
, chains = phys_cores_minus_one | |
, parallel_chains = phys_cores_minus_one | |
, seed = sampling_seed | |
, iter_warmup = num_samples_to_obtain | |
, iter_sampling = num_samples_to_obtain | |
# update every 10% | |
, refresh = (num_samples_to_obtain*2)/10 | |
) | |
#gather summary (inc. diagnostics) | |
fit_summary = | |
( | |
fit$summary() | |
%>% dplyr::select(variable,mean,q5,q95,rhat,contains('ess')) | |
%>% dplyr::filter( | |
!stringr::str_starts(variable,'chol_corr') | |
, !stringr::str_detect(variable,'helper') | |
, !has_underscore_suffix(variable) | |
, !is_cor_diag_or_lower_tri(variable,prefix='cor') | |
) | |
%>% sort_by_variable_size() | |
%>% add_stan_summary_tbl_class() | |
%>% add_diagnostic_bools(fit) | |
) | |
print(fit_summary,n=nrow(fit_summary)) | |
( | |
bayesplot::mcmc_intervals(fit$draws(variables='noise')) | |
+ ggplot2::geom_point( | |
data = tibble::tibble(par='noise',value=sim_pars$noise) | |
, mapping = ggplot2::aes(y=par,x=value) | |
, colour = 'red' | |
) | |
) | |
( | |
bayesplot::mcmc_intervals(fit$draws(variables='mean_coef')) | |
+ ggplot2::geom_point( | |
data = | |
( | |
tibble::tibble(value=sim_pars$coef_means) | |
%>% dplyr::mutate( | |
y = paste0('mean_coef[',1:dplyr::n(),']') | |
) | |
) | |
, mapping = ggplot2::aes( | |
y = y | |
, x = value | |
) | |
, colour = 'red' | |
) | |
) | |
( | |
bayesplot::mcmc_intervals(fit$draws(variables='sd_coef')) | |
+ ggplot2::geom_point( | |
data = | |
( | |
tibble::tibble(value=sim_pars$coef_sds) | |
%>% dplyr::mutate( | |
y = paste0('sd_coef[',1:dplyr::n(),']') | |
) | |
) | |
, mapping = ggplot2::aes( | |
y = y | |
, x = value | |
) | |
, colour = 'red' | |
) | |
) | |
( | |
bayesplot::mcmc_intervals(fit$draws(variables='cor')) | |
+ ggplot2::geom_point( | |
data = | |
( | |
tibble::as_tibble( | |
sim_pars$cor_mat | |
) | |
%>% dplyr::mutate(var1 = 1:dplyr::n()) | |
%>% tidyr::pivot_longer( | |
names_to = 'var2' | |
, names_prefix = 'V' | |
, cols = c(-var1) | |
) | |
%>% dplyr::mutate( | |
y = paste0('cor[',var1,',',var2,']') | |
) | |
) | |
, mapping = ggplot2::aes( | |
y = y | |
, x = value | |
) | |
, colour = 'red' | |
) | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data{ | |
// num_obs_total: number of trials | |
int<lower=1> num_obs_total ; | |
// obs: observation on each trial | |
vector[num_obs_total] obs ; | |
// num_subj: number of subj | |
int<lower=1> num_subj ; | |
// num_rows_uW: num rows in uW | |
int<lower=1> num_rows_uW ; | |
// num_cols_uW: num cols in uW | |
int<lower=1> num_cols_uW ; | |
// uW: unique entries in the within predictor matrix | |
matrix[num_rows_uW,num_cols_uW] uW ; | |
// index: index of each trial in flattened subject-by-condition value matrix | |
int obs_index[num_obs_total] ; | |
} | |
transformed data{ | |
// obs_mean: mean obs value | |
real obs_mean = mean(obs) ; | |
// obs_sd: sd of obss | |
real obs_sd = sd(obs) ; | |
// obs_: observations scaled to have zero mean and unit variance | |
vector[num_obs_total] obs_ = (obs-obs_mean)/obs_sd ; | |
} | |
parameters{ | |
// chol_corr: population-level correlations (on cholesky factor scale) amongst within-subject predictors | |
cholesky_factor_corr[num_cols_uW] chol_corr ; | |
//for parameters below, trailing underscore denotes that they need to be un-scaled in generated quantities | |
// coef_mean_: mean (across subj) for each coefficient | |
row_vector[num_cols_uW] mean_coef_ ; | |
// coef_sd_: sd (across subj) for each coefficient | |
vector<lower=0>[num_cols_uW] sd_coef_ ; | |
// multi_normal_helper: a helper variable for implementing non-centered parameterization | |
matrix[num_cols_uW,num_subj] multi_normal_helper ; | |
// noise_: measurement noise | |
real<lower=0> noise_ ; | |
} | |
model{ | |
//// | |
// Priors | |
//// | |
// multi_normal_helper must have normal(0,1) prior for non-centered parameterization | |
to_vector(multi_normal_helper) ~ std_normal() ; | |
// relatively flat prior on correlations | |
chol_corr ~ lkj_corr_cholesky(2) ; | |
// normal(0,1) priors on all coef_sd | |
sd_coef_ ~ std_normal() ; | |
// normal(0,1) priors on all coefficients | |
mean_coef_ ~ std_normal() ; | |
// low-near-zero prior on measurement noise | |
noise_ ~ weibull(2,1) ; // weibull(2,1) is peaked around .8 | |
// compute coefficients for each subject/condition | |
matrix[num_subj,num_cols_uW] subj_coef_ = ( | |
rep_matrix(mean_coef_,num_subj) | |
+ transpose( | |
diag_pre_multiply(sd_coef_,chol_corr) | |
* multi_normal_helper | |
) | |
) ; | |
// Loop over subj and conditions to compute unique entries in design matrix | |
matrix[num_rows_uW,num_subj] value_for_subj_cond ; | |
for(this_subj in 1:num_subj){ | |
for(this_condition in 1:num_rows_uW){ | |
value_for_subj_cond[this_condition,this_subj] = dot_product( | |
subj_coef_[this_subj] | |
, uW[this_condition] | |
) ; | |
} | |
// // slightly less explicit but equally fast: | |
// value_for_subj_cond[,this_subj] = rows_dot_product( | |
// rep_matrix( | |
// subj_coef_[this_subj] | |
// , num_rows_uW | |
// ) | |
// , W | |
// ) ; | |
} | |
// Likelihood | |
obs_ ~ normal( | |
to_vector(value_for_subj_cond)[obs_index] | |
, noise_ | |
) ; | |
} | |
generated quantities{ | |
// cor: correlation matrix for the full set of within-subject predictors | |
corr_matrix[num_cols_uW] cor = multiply_lower_tri_self_transpose(chol_corr) ; | |
// coef_sd_: sd (across subj) for each coefficient | |
vector[num_cols_uW] sd_coef = sd_coef_ * obs_sd ; | |
// coef_mean: mean (across subj) for each coefficient | |
row_vector[num_cols_uW] mean_coef = mean_coef_ * obs_sd ; | |
mean_coef[1] = mean_coef[1] + obs_mean ; //adding the intercept (assumes contrast matrix had an intercept column) | |
// noise: measurement noise | |
real noise = noise_ * obs_sd ; | |
// tweak cor to avoid rhat false-alarm | |
for(i in 1:num_cols_uW){ | |
cor[i,i] += uniform_rng(1e-16, 1e-15) ; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment