I needed a way to easily sample rows from a data.frame with different fractions per class (as opposed to dplyr::sample_frac()
that only supports a single fraction). So I wrote this little function below.
Last active
June 17, 2019 09:27
-
-
Save moredatapls/f9a180a85443b4dd1d5249f8b5d69400 to your computer and use it in GitHub Desktop.
R: sample rows from a data.table with different fractions per class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Samples random rows from a data.frame based on per-class fractions. | |
#' `dplyr::sample_frac` only supports a single fraction for the entire dataset. | |
#' This function works best if there are little classes because the fractions | |
#' have to be specified for all the classes. | |
#' | |
#' The notation is consistent with the one used for `dplyr::sample_frac`. | |
#' | |
#' @param tbl The data.frame to sample from | |
#' @param classCol The column containing the class labels | |
#' @param sizes The per-class fractions as a `list("class name" = fraction)` | |
#' @param ... Other parameters to pass to `dplyr::sample_frac`, such as `replace = TRUE` | |
#' | |
#' @return A subset of `tbl` | |
#' | |
sample_frac_class <- function(tbl, classCol, sizes, ...) { | |
class_ <- as.factor(dplyr::pull(tbl, !!dplyr::enquo(classCol))) | |
stopifnot( | |
is.list(sizes), | |
all(levels(class_) %in% names(sizes)), | |
all(sapply(sizes, function(size) size >= 0 & size <= 1)) | |
) | |
sample_ <- function(clazz) { | |
dplyr::sample_frac(tbl[which(tbl$class == clazz),], size = sizes[[clazz]], ...) | |
} | |
do.call(rbind, lapply(names(sizes), sample_)) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
context("sampling") | |
testthat::test_that("per-class sampling works", { | |
set.seed(42) | |
# define input and output | |
data <- data.frame( | |
id = c(1, 2, 3, 4, 5), | |
class = c(1, 2, 2, 2, 3), | |
val = c("abc", "def", "geh", "ijk", "lmn") | |
) | |
sizes <- list("1" = 1, "2" = 2/3, "3" = 0) | |
expected <- data.frame( | |
id = c(1, 2, 4), | |
class = c(1, 2, 2), | |
val = factor(c("abc", "def", "ijk"), levels = levels(data$val)) | |
) | |
expected_replace <- data.frame( | |
id = c(1, 2, 2), | |
class = c(1, 2, 2), | |
val = factor(c("abc", "def", "def"), levels = levels(data$val)) | |
) | |
# run it | |
actual <- sample_frac_class(data, class, sizes) | |
actual_replace <- sample_frac_class(data, class, sizes, replace = TRUE) | |
# check it | |
expect(isTRUE(dplyr::all_equal(expected, actual)), "sampling without replacement") | |
expect(isTRUE(dplyr::all_equal(expected_replace, actual_replace)), "sampling without replacement") | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment