Last active
May 5, 2022 13:21
-
-
Save brunaw/a194086be958b73dd1c25a93b84730b8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Loading packages | |
library(tidyverse) | |
library(hebart) | |
df <- read.table("https://raw.githubusercontent.com/andrewcparnell/rBART/master/friedman.txt") | |
df$y <- rnorm(nrow(df)) | |
df$group <- sample(1:5, nrow(df), replace = TRUE) | |
group_variable = "group" | |
formula <- y ~ V2 + V3 + V4 + V5 + V6 | |
alpha = 0.5; beta = 1; mu_mu = 0; | |
pars <- list( | |
k1 = 0.001, k2 = 5, alpha = alpha, beta = beta, mu_mu = 0 | |
) | |
#------------------------------------------------------------------------------ | |
dataset = df | |
# HEBART parameters | |
min_u = 0 | |
max_u = 20 | |
prior_k1 = TRUE | |
num.trees = 5 | |
sample_k1 = TRUE | |
#burn_in = 50 | |
alpha_grow = 0.98 | |
beta_grow = 0.05 | |
action_taken = "grow" | |
#------------------------------------------------------------------------------ | |
# HEBART Function | |
#--------------------------------------------------------------------- | |
# Handling initial dataset | |
#--------------------------------------------------------------------- | |
results_data <- data_handler(formula, data = dataset, group_variable) | |
data <- results_data$data | |
depara_names <- results_data$names | |
names(data)[names(data) == group_variable] <- "group" | |
N <- n <- nrow(data) | |
group <- results_data$group | |
mf <- stats::model.frame(formula, dataset) | |
y <- stats::model.extract(mf, "response") | |
num.variables <- ncol(mf) - 1 | |
name_y <- names(mf)[1] | |
names(data)[names(data) == name_y] <- "y" | |
P <- num.trees | |
#--------------------------------------------------------------------- | |
# Defining current distribution parameters | |
#--------------------------------------------------------------------- | |
# Prior hyperparameters ------------- | |
J <- dplyr::n_distinct(group) | |
beta <- pars$beta | |
alpha <- pars$alpha | |
mu_mu <- pars$mu_mu | |
#k1 <- pars$k1 | |
k1 <- 0.0001 | |
k2 <- pars$k2 | |
# Minimum batch for each node | |
keep_node <- 0.05 * nrow(data) | |
x_vars <- all.vars(formula[[3]]) | |
p_vars <- length(x_vars) # number of vars | |
to_do <- vector() # Actions that can be taken | |
#--------------------------------------------------------------------- | |
# Initializing useful vectors | |
#--------------------------------------------------------------------- | |
# For grow or prune | |
selec_var = vector() # To save the selected variable when growing | |
rule = vector() # To save the selected splitting rule when growing | |
drawn_node = vector() # To save the selected node to grow or prune | |
r <- r_k <- vector() # To save the ratios of grow or prune | |
u <- u_k <- vector() # To save the sampled uniform values | |
sampled_k1 <- vector() # To save the sampled values for k1 | |
sampled_k1[1] <- pars$k1 | |
# For the trees ------------ | |
my_trees_l <- list() # To save each new tree | |
my_trees_l[[1]] <- data # Initializing the first tree as the | |
# data will be the 'root' tree, with no nodes | |
# For the sampling of posterior values ------------- | |
tau_post <- vector() # To save posterior values of | |
tau_post[1] <- stats::rgamma(n = 1, 1/alpha, beta) | |
parent_action <- vector() | |
results <- data.frame(node = NA, var = NA, rule = NA, action = NA) | |
my_trees <- dplyr::tibble(est_tree = my_trees_l, parent_action = NA) | |
# One results and one tree_data per tree | |
all_trees <- dplyr::tibble(tree_index = 1:P, | |
tree_data = list(my_trees), | |
results = list(results)) | |
#--------------------------------------------------------------------- | |
i = 1 | |
# Grow only one tree | |
p = 1 | |
# Growing, pruning or staying on the same tree --------------------- | |
# Unnesting trees data by i and p | |
current_tree <- tidyr::unnest(all_trees[p, ], tree_data) %>% | |
dplyr::slice(i) | |
my_tree <- tidyr::unnest(current_tree, est_tree) %>% | |
dplyr::select(dplyr::starts_with("X"), y, node, d, group, parent, node_index) | |
results_current <- tidyr::unnest( | |
dplyr::select(current_tree, results), results) | |
#---------------------------------------------------------------------- | |
# Sampling details ---------------------------------------------------- | |
dn <- dplyr::n_distinct(my_tree$d) | |
p_grow <- alpha_grow*(1 + dn)^(-beta_grow) | |
u_grow <- stats::runif(1) | |
depth <- dplyr::n_distinct(my_tree$node) | |
# Selecting the node to grow, uniformly | |
drawn_node <- sample(unique(my_tree$node), size = 1) | |
# Selecting the variable and splitting rule, uniformly | |
selec_var <- depara_names$new[sample(1:p_vars, size = 1)] | |
rule <- p_rule(variable_index = selec_var, | |
data = my_tree, sel_node = drawn_node) | |
# Grow the tree | |
sample_tree <- hebart::grow_tree( | |
current_tree = my_tree, selec_var = selec_var, | |
drawn_node = drawn_node, rule = rule | |
) | |
parent_action <- sample_tree %>% | |
dplyr::filter(parent == drawn_node) %>% | |
dplyr::pull(parent) %>% | |
unique() | |
results_new <- suppressWarnings( | |
dplyr::bind_rows(results_current, | |
data.frame( | |
node = drawn_node, | |
var = selec_var, | |
rule = rule, | |
action = action_taken))) | |
#---------------------------------------------------------------------- | |
lk <- lk_ratio_grow(sample_tree, parent_action, pars) | |
new_tree_prior <- tree_prior(sample_tree, alpha_grow, beta_grow) | |
old_tree_prior <- tree_prior(my_tree, alpha_grow, beta_grow) | |
pr_ratio <- new_tree_prior - old_tree_prior | |
r <- min(1, exp(lk + pr_ratio)) | |
#---------------------------------------------------------------------- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment