Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save wlinInspire/48ce2ebe0d4265c16dbf10de79933ced to your computer and use it in GitHub Desktop.
Save wlinInspire/48ce2ebe0d4265c16dbf10de79933ced to your computer and use it in GitHub Desktop.
Calculating Life Time Value for Subscription Products
library(survival)
library(data.table)
library(ggplot2)
data = read.csv('https://raw.githubusercontent.com/IBM/invoke-wml-using-cognos-custom-control/master/data/Telco-Customer-Churn.csv')
setDT(data)
churn_data <- data[, churn_flag := ifelse(Churn == 'Yes', 1, 0)]
km_curve <- survfit(Surv(tenure, churn_flag) ~ 1, data=churn_data)
# Calculate KM Survival and Churn Rate
km_curve_df = summary(km_curve)
curve = data.table(cycle = km_curve_df$time,
survival_rate = km_curve_df$surv,
churn_rate = km_curve_df$n.event / km_curve_df$n.risk)
# Expand dataset
max_tenure <- data.table(cycle = 0:max(curve$cycle))
max_tenure[, cycle_join := cycle]
data[, cycle_join := tenure]
churn_data_total <- max_tenure[data, on = .(cycle_join <= cycle_join), allow.cartesian = TRUE]
churn_data_total[cycle < tenure, churn_flag := 0]
churn_data_total <- churn_data_total[cycle >= 1]
churn_data_total[, special_cycle := 1]
churn_data_total[cycle > 1, special_cycle := 0]
h2o::h2o.init()
churn_data_train_h2o <- h2o::as.h2o(churn_data_total)
fit <- h2o::h2o.glm(y = 'churn_flag',
x = c('cycle', 'special_cycle'),
training_frame = churn_data_train_h2o,
family = 'binomial',
lambda = 0)
# Predict
curve_lr <- data.table(cycle = 1:500)
curve_lr[, special_cycle := 1]
curve_lr[cycle > 1, special_cycle := 0]
curve_lr_h2o <- h2o::as.h2o(curve_lr)
pred <- h2o::h2o.predict(fit, curve_lr_h2o) %>% as.data.table()
pred <- pred[, p1]
# Attrition Curve Comparison
curve_lr <- cbind(curve_lr, churn_rate = pred)
curve_lr <- curve_lr[order(cycle)]
curve_lr[, survival_rate := cumprod(1 - churn_rate)]
# Compare Attrition Curve of Raw and Logisitic Regression
ggplot() +
geom_line(aes(curve$cycle, curve$churn_rate, col = 'raw')) +
geom_line(aes(x = curve_lr$cycle, y = curve_lr$churn_rate, col = 'fit')) +
scale_y_continuous(labels = scales::percent) +
xlab('cycle') +
ylab('churn_rate') +
theme_minimal()
# Compare Survival Curve of Raw and Logisitic Regression
ggplot() +
geom_line(aes(x = curve$cycle, y = curve$survival_rate, col = 'raw')) +
geom_line(aes(x = curve_lr$cycle, y = curve_lr$survival_rate, col = 'fit')) +
scale_x_continuous(name = 'cycle', limits = c(0, 500)) +
scale_y_continuous(labels = scales::percent, name = 'survival_rate', limits = c(0,1)) +
theme_minimal()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment