Skip to content

Instantly share code, notes, and snippets.

@brunaw
Created May 9, 2022 16:16
Show Gist options
  • Save brunaw/af39c1de1d20d2935f38dd74081009b9 to your computer and use it in GitHub Desktop.
Save brunaw/af39c1de1d20d2935f38dd74081009b9 to your computer and use it in GitHub Desktop.
# Loading packages
library(tidyverse)
library(tidymodels)
library(hebart)
library(lme4)
library(dbarts)
library(patchwork)
# Loading Andrew's friedman data
df <- read.table("https://raw.githubusercontent.com/andrewcparnell/rBART/master/friedman.txt")
#df$y <- rnorm(nrow(df))
df$y <- df$V1
df$group <- sample(1:5, nrow(df), replace = TRUE)
# Train/test split
set.seed(2022)
split <- initial_split(df)
train <- training(split)
test <- testing(split)
# Setting parameters
group_variable = "group"
formula <- y ~ V2 + V3 + V4 + V5 + V6
alpha = 0.5; beta = 1; mu_mu = 0;
pars <- list(
k1 = 0.001, k2 = 1, alpha = alpha, beta = beta, mu_mu = 0
)
# Running HEBART
heb_model <- hebart(formula,
dataset = train,
# number of iterations and burn_in
iter = 100, burn_in = 10,
# number of trees
num.trees = 10,
# should we sample k1?
sample_k1 = FALSE,
# other parameters
group_variable, pars,
scale = FALSE)
pred_hebart <- predict_hebart(heb_model, test,
formula, group_variable)
pred_hebart_train <- predict_hebart(heb_model, train,
formula, group_variable)
# Diagnostics
diagnostics(heb_model)
# LME3 ------------------------------------------------------------------------
lm3_model <- lmer(y ~ V2 + V3 + V4 + V5 + V6 + (1 |group), data = train)
pred_lm3 <- predict(lm3_model, test)
pred_lm3_train <- predict(lm3_model, train)
# BART ------------------------------------------------------------------------
# Note that BART uses many more trees and iterations by
# default;
bart_0 <- dbarts::bart2(y ~ V2 + V3 + V4 + V5 + V6,
data = train,
test = test,
n.trees = 5,
n.samples = 100,
keepTrees = TRUE)
pred_bart_train <- bart_0$yhat.train.mean
pred_bart <- bart_0$yhat.test.mean
# Comparing -----------------------------------------
all_preds <- data.frame(
y = c(test$y, train$y),
pred_hebart = c(pred_hebart$pred, pred_hebart_train$pred),
pred_lme = c(pred_lm3, pred_lm3_train),
pred_bart = c(pred_bart, pred_bart_train),
source = rep(c("Test (25%)", "Train (75%)"), c(nrow(test), nrow(train)))
) %>%
mutate(
res_hebart = y - pred_hebart,
res_lme = y - pred_lme,
res_bart = y - pred_bart
) %>%
dplyr::select(5:8)
all_preds %>%
group_by(source) %>%
summarise_all(~round(sqrt(mean(.x^2)), 3))
#------------------------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment