Last active
July 31, 2018 13:37
-
-
Save jhofman/0bb6e0705083001cdf2f865c246d633e to your computer and use it in GitHub Desktop.
a more efficient way to filter a grouped data frame?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
library(digest) | |
# create a dummy dataframe with 100,000 groups and 1,000,000 rows | |
# where group ids are md5 hash of integers from 1 to 100,000 | |
set.seed(42) | |
md5 <- Vectorize(function(x) digest(x, algo="md5")) | |
df <- data.frame(group_id=sample(md5(1:1e4), 1e6, replace=T), | |
val=sample(1:100, 1e6, replace=T)) | |
# group observations by group_id, creating an index on group_id in the background | |
df <- df %>% | |
group_by(group_id) | |
######################################## | |
# bad: filter by group id, the naive way | |
######################################## | |
# this is slow for two reasons | |
# the first is that it's a linear scan over all rows | |
# and the second is that there's overhead created by the grouping | |
system.time( df1 <- df %>% filter(group_id == "4b5630ee914e848e8d07221556b0a2fb") ) | |
# user system elapsed | |
# 1.416 0.485 1.957 | |
######################################## | |
# better: filter by group id, the smart way | |
######################################## | |
# this is faster than the above because it uses the group index created by dplyr | |
# as a result it's linear in the total number of groups + the length of the requested group | |
# create a function that uses the group indices to filter more efficiently | |
filter_groups <- function(df, filter_formula) { | |
# quosure magic for tidy evaluation | |
filter_formula <- enquo(filter_formula) | |
# make sure we're given a grouped data frame | |
if(!("grouped_df" %in% class(df))) { | |
return(data.frame()) | |
} | |
# find the group index for this group label | |
labels <- attr(df, "labels") %>% | |
rowid_to_column() %>% | |
filter(!!filter_formula) | |
# find the indices of all rows in this group | |
ndx <- unlist(attr(df, "indices")[labels$rowid]) | |
# return the rows for this group, adjusting for 0-based indexing | |
df[ndx + 1, ] | |
} | |
system.time( df2 <- filter_groups(df, group_id == "4b5630ee914e848e8d07221556b0a2fb") ) | |
# user system elapsed | |
# 0.002 0.000 0.001 | |
# check that results are the same | |
all(df1 == df2) | |
######################################## | |
# much cleaner, slightly slower: created a nested data frame, then filter | |
######################################## | |
# h/t to @hadleywickham for this solution | |
system.time( df_nested <- df %>% nest() ) | |
# user system elapsed | |
# 0.607 0.017 0.630 | |
system.time( | |
df3 <- df_nested %>% | |
filter(group_id == "4b5630ee914e848e8d07221556b0a2fb") %>% | |
unnest() | |
) | |
# user system elapsed | |
# 0.005 0.000 0.005 | |
all(df1 == df3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
example with two columns:
df <- data.frame(first_id=sample(md5(1:1e2), 1e6, replace=T), second_id=sample(md5(1:1e2), 1e6, replace=T), val=sample(1:100, 1e6, replace=T))
df %>% filter_by_group_id(first_id == "06cd248dd1409b804444bd9ad5533d1d" & second_id == "11946e7a3ed5e1776e81c0f0ecd383d0")