Created
March 12, 2024 07:03
-
-
Save pszemraj/8f10d362bdb56329532bf31c4df821a5 to your computer and use it in GitHub Desktop.
hf datasets train_test_split with stratify_by_column for any type (by tricking it)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
from datasets import ClassLabel, Dataset, DatasetDict | |
def split_dataset( | |
dataset: Dataset, | |
test_size=0.025, | |
validation_size=0.025, | |
stratify_by_column: str = None, | |
): | |
""" | |
Splits a dataset into training, testing, and validation sets with optional stratification. | |
Parameters: | |
- dataset: The dataset to split, assumed to be a Hugging Face dataset object. | |
- test_size: The proportion of the dataset to allocate to the test set. | |
- validation_size: The proportion of the dataset to allocate to the validation set. | |
- stratify_by_column: The column name to stratify by. | |
Returns: | |
- A DatasetDict with keys 'train', 'test', and 'validation', each corresponding to the respective dataset split. | |
""" | |
if ( | |
stratify_by_column | |
and dataset.features[stratify_by_column].dtype != "ClassLabel" | |
): | |
# Convert the stratify column to integer labels if not already ClassLabel | |
unique_values = sorted(set(dataset[stratify_by_column])) | |
value_to_int = {v: i for i, v in enumerate(unique_values)} | |
tmp_stratify_col = f"{stratify_by_column}-ClassLabel" | |
dataset = dataset.map( | |
lambda examples: {tmp_stratify_col: value_to_int[examples[stratify_by_column]]}, | |
load_from_cache_file=False, | |
num_proc=os.cpu_count(), | |
) | |
dataset = dataset.cast_column( | |
tmp_stratify_col, ClassLabel(num_classes=len(unique_values), names=unique_values) | |
) | |
else: | |
tmp_stratify_col = None | |
nontrain_size = test_size + validation_size | |
if nontrain_size >= 1: | |
raise ValueError( | |
"The combined size of test and validation sets must be less than 1." | |
) | |
train_test_split = dataset.train_test_split( | |
test_size=nontrain_size, | |
stratify_by_column=tmp_stratify_col, | |
) | |
train_set = train_test_split["train"] | |
non_train_set = train_test_split["test"] | |
temp_test_proportion = test_size / nontrain_size | |
test_validation_split = non_train_set.train_test_split( | |
test_size=temp_test_proportion, | |
stratify_by_column=tmp_stratify_col, | |
) | |
test_set = test_validation_split["test"] | |
validation_set = test_validation_split["train"] | |
split_ds = DatasetDict( | |
{ | |
"train": train_set, | |
"test": test_set, | |
"validation": validation_set, | |
} | |
) | |
if tmp_stratify_col: | |
split_ds = split_ds.remove_columns(tmp_stratify_col) | |
return split_ds | |
ds_u_split = split_dataset( | |
ds_unique, test_size=0.025, validation_size=0.025, stratify_by_column="year" | |
) | |
ds_u_split |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment