Skip to content

Instantly share code, notes, and snippets.

@brunaw
Last active May 5, 2022 13:21
Show Gist options
  • Save brunaw/a194086be958b73dd1c25a93b84730b8 to your computer and use it in GitHub Desktop.
Save brunaw/a194086be958b73dd1c25a93b84730b8 to your computer and use it in GitHub Desktop.
# Loading packages
library(tidyverse)
library(hebart)
df <- read.table("https://raw.githubusercontent.com/andrewcparnell/rBART/master/friedman.txt")
df$y <- rnorm(nrow(df))
df$group <- sample(1:5, nrow(df), replace = TRUE)
group_variable = "group"
formula <- y ~ V2 + V3 + V4 + V5 + V6
alpha = 0.5; beta = 1; mu_mu = 0;
pars <- list(
k1 = 0.001, k2 = 5, alpha = alpha, beta = beta, mu_mu = 0
)
#------------------------------------------------------------------------------
dataset = df
# HEBART parameters
min_u = 0
max_u = 20
prior_k1 = TRUE
num.trees = 5
sample_k1 = TRUE
#burn_in = 50
alpha_grow = 0.98
beta_grow = 0.05
action_taken = "grow"
#------------------------------------------------------------------------------
# HEBART Function
#---------------------------------------------------------------------
# Handling initial dataset
#---------------------------------------------------------------------
results_data <- data_handler(formula, data = dataset, group_variable)
data <- results_data$data
depara_names <- results_data$names
names(data)[names(data) == group_variable] <- "group"
N <- n <- nrow(data)
group <- results_data$group
mf <- stats::model.frame(formula, dataset)
y <- stats::model.extract(mf, "response")
num.variables <- ncol(mf) - 1
name_y <- names(mf)[1]
names(data)[names(data) == name_y] <- "y"
P <- num.trees
#---------------------------------------------------------------------
# Defining current distribution parameters
#---------------------------------------------------------------------
# Prior hyperparameters -------------
J <- dplyr::n_distinct(group)
beta <- pars$beta
alpha <- pars$alpha
mu_mu <- pars$mu_mu
#k1 <- pars$k1
k1 <- 0.0001
k2 <- pars$k2
# Minimum batch for each node
keep_node <- 0.05 * nrow(data)
x_vars <- all.vars(formula[[3]])
p_vars <- length(x_vars) # number of vars
to_do <- vector() # Actions that can be taken
#---------------------------------------------------------------------
# Initializing useful vectors
#---------------------------------------------------------------------
# For grow or prune
selec_var = vector() # To save the selected variable when growing
rule = vector() # To save the selected splitting rule when growing
drawn_node = vector() # To save the selected node to grow or prune
r <- r_k <- vector() # To save the ratios of grow or prune
u <- u_k <- vector() # To save the sampled uniform values
sampled_k1 <- vector() # To save the sampled values for k1
sampled_k1[1] <- pars$k1
# For the trees ------------
my_trees_l <- list() # To save each new tree
my_trees_l[[1]] <- data # Initializing the first tree as the
# data will be the 'root' tree, with no nodes
# For the sampling of posterior values -------------
tau_post <- vector() # To save posterior values of
tau_post[1] <- stats::rgamma(n = 1, 1/alpha, beta)
parent_action <- vector()
results <- data.frame(node = NA, var = NA, rule = NA, action = NA)
my_trees <- dplyr::tibble(est_tree = my_trees_l, parent_action = NA)
# One results and one tree_data per tree
all_trees <- dplyr::tibble(tree_index = 1:P,
tree_data = list(my_trees),
results = list(results))
#---------------------------------------------------------------------
i = 1
# Grow only one tree
p = 1
# Growing, pruning or staying on the same tree ---------------------
# Unnesting trees data by i and p
current_tree <- tidyr::unnest(all_trees[p, ], tree_data) %>%
dplyr::slice(i)
my_tree <- tidyr::unnest(current_tree, est_tree) %>%
dplyr::select(dplyr::starts_with("X"), y, node, d, group, parent, node_index)
results_current <- tidyr::unnest(
dplyr::select(current_tree, results), results)
#----------------------------------------------------------------------
# Sampling details ----------------------------------------------------
dn <- dplyr::n_distinct(my_tree$d)
p_grow <- alpha_grow*(1 + dn)^(-beta_grow)
u_grow <- stats::runif(1)
depth <- dplyr::n_distinct(my_tree$node)
# Selecting the node to grow, uniformly
drawn_node <- sample(unique(my_tree$node), size = 1)
# Selecting the variable and splitting rule, uniformly
selec_var <- depara_names$new[sample(1:p_vars, size = 1)]
rule <- p_rule(variable_index = selec_var,
data = my_tree, sel_node = drawn_node)
# Grow the tree
sample_tree <- hebart::grow_tree(
current_tree = my_tree, selec_var = selec_var,
drawn_node = drawn_node, rule = rule
)
parent_action <- sample_tree %>%
dplyr::filter(parent == drawn_node) %>%
dplyr::pull(parent) %>%
unique()
results_new <- suppressWarnings(
dplyr::bind_rows(results_current,
data.frame(
node = drawn_node,
var = selec_var,
rule = rule,
action = action_taken)))
#----------------------------------------------------------------------
lk <- lk_ratio_grow(sample_tree, parent_action, pars)
new_tree_prior <- tree_prior(sample_tree, alpha_grow, beta_grow)
old_tree_prior <- tree_prior(my_tree, alpha_grow, beta_grow)
pr_ratio <- new_tree_prior - old_tree_prior
r <- min(1, exp(lk + pr_ratio))
#----------------------------------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment