Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active June 27, 2025 02:56
Show Gist options
  • Save pszemraj/502cf50767e8c363947f5e93cb95cc07 to your computer and use it in GitHub Desktop.
Save pszemraj/502cf50767e8c363947f5e93cb95cc07 to your computer and use it in GitHub Desktop.
aggregate and push an hf dataset from text files
"""
Create & save an hf dataset with train/test/val splits from dir w/ text files
Ideal structure:
root / section_name_1 / file 1
root / section_name_1 / file 2
root / section_name_1 / file YYY
root / section_name_2 / file 1
root / section_name_2 / file ZZZ
...
but will work for just a single folder.
Usage:
python push_dataset_from_text.py --help
"""
import datetime
import logging
from pathlib import Path
import datasets
import fire
import pandas as pd
from datasets import DatasetDict
from tqdm import tqdm
def get_filename_safe_timestamp():
timestamp = datetime.datetime.now().strftime("%Y-%b-%d_%H-%M")
return timestamp
def setup_logging(log_level="INFO", log_file=None):
"""
Set up logging configurations.
Args:
log_level (str): Minimum log level for messages to handle.
log_file (str): Path to the log file. If not specified, logs won't be saved to a file.
"""
LOG_LEVELS = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
minimalist_formatter = logging.Formatter("%(levelname)s: %(message)s")
logger = logging.getLogger()
logger.setLevel(LOG_LEVELS.get(log_level.upper(), logging.INFO))
if logger.hasHandlers():
logger.handlers.clear()
console_handler = logging.StreamHandler()
console_handler.setFormatter(minimalist_formatter) # set formatter for console
logger.addHandler(console_handler)
if log_file:
file_handler = logging.handlers.RotatingFileHandler(
log_file,
maxBytes=1 * 1024 * 1024,
backupCount=2,
)
file_handler.setFormatter(minimalist_formatter) # set formatter for file
logger.addHandler(file_handler)
logger.debug("Logging is set up.")
def split_dataset(
data_dir: Path,
extensions: str = "txt",
val_test_size: float = 0.0,
repo_id: str = None,
config_name: str = None,
commit_message: str = None,
private: bool = True,
save_dir: Path = None,
log_level="INFO",
):
"""
Discover all text files, read content into a Pandas DataFrame, convert into a dataset, and create train/test/validation splits.
:param Path data_dir: root directory to search for text files.
:param str extensions: extensions to look for in the dataset as one string separated by commas.
:param float val_test_size: size of the validation and test sets. Default: 0.05 (0.025 for each)
:param str repo_id: repository ID for pushing to Huggingface hub.
:param str config_name: config name for pushing to Huggingface hub.
:param str commit_message: commit message for pushing to Huggingface hub.
:param bool private: if the dataset should be pushed as private
:param Path save_dir: directory where to save the dataset
:param str log_level: logging level
:raises ValueError: if data_dir is not a directory
"""
setup_logging(log_level)
logger = logging.getLogger(__name__)
data_dir = Path(data_dir)
if not data_dir.is_dir():
raise ValueError(f"Data directory not found: {data_dir}")
if repo_id is None and save_dir is None:
logger.warning("No repo_id or save_dir specified, 'dry-run' mode.")
if isinstance(extensions, tuple):
extensions = ",".join(extensions)
extensions = [extension.strip(".") for extension in extensions.split(",")]
logger.info(f"Looking for files with extensions: {extensions}")
data = []
for extension in extensions:
txt_files = list(data_dir.rglob(f"*.{extension}"))
for file_path in tqdm(txt_files, desc=f"Processing {extension} files"):
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
content = f.read()
if len(content.strip()) == 0:
logger.warning(f"Empty file: {file_path}, skipping...")
continue
channel_name = file_path.parts[-2]
filename = file_path.stem
rel_path = file_path.relative_to(data_dir).parent
data.append(
{
"relative_path": str(rel_path),
"section": channel_name,
"filename": filename,
"text": content,
}
)
# TODO: this is inefficient, look up the actual from list of dicts API
ds = datasets.Dataset.from_pandas(pd.DataFrame(data))
logger.info(f"Found {len(data)} text files - {get_filename_safe_timestamp()}")
del data
# Check for columns with only one unique value and remove them
columns_to_remove = []
for column in ["relative_path", "section"]:
if column in ds.column_names:
unique_values = set(ds[column])
if len(unique_values) == 1:
columns_to_remove.append(column)
logger.info(
f"Removing column '{column}' as it only contains one unique value: {list(unique_values)[0]}"
)
if columns_to_remove:
ds = ds.remove_columns(columns_to_remove)
logger.info(f"Removed columns: {columns_to_remove}")
if val_test_size > 0:
logger.debug("train-test split...")
train_temp_split = ds.train_test_split(
test_size=val_test_size, shuffle=True, seed=80085
)
train_dataset = train_temp_split["train"]
temp_dataset = train_temp_split["test"]
logger.debug("validation-test split...")
val_test_split = temp_dataset.train_test_split(
test_size=0.5, shuffle=True, seed=80085
)
val_dataset = val_test_split["train"]
test_dataset = val_test_split["test"]
logger.info(
f"Train size: {len(train_dataset)}\n"
f"Validation size: {len(val_dataset)}\n"
f"Test size: {len(test_dataset)}"
)
aggregated_dataset = DatasetDict(
{"train": train_dataset, "validation": val_dataset, "test": test_dataset}
)
else:
logger.info(f"Only creating train split (val_test_size={val_test_size})")
train_dataset = ds
logger.info(f"Train size: {len(train_dataset)}")
aggregated_dataset = DatasetDict({"train": train_dataset})
if repo_id:
logger.info(f"Pushing dataset to Huggingface hub ({repo_id})...")
commit_message = (
commit_message
if commit_message is not None
else f"update dataset {get_filename_safe_timestamp()}"
)
logger.info(f"Using repo id:\t{repo_id}, config name:\t{config_name}")
if config_name:
aggregated_dataset.push_to_hub(
repo_id=repo_id,
config_name=config_name,
private=private,
max_shard_size="800MB",
commit_message=commit_message,
)
else:
aggregated_dataset.push_to_hub(
repo_id=repo_id,
private=private,
max_shard_size="800MB",
commit_message=commit_message,
)
logger.info(f"Dataset pushed to {repo_id}")
if save_dir:
logger.info(f"Saving dataset to {save_dir}...")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
aggregated_dataset.save_to_disk(save_dir)
logger.info(f"Dataset saved to {save_dir}")
logger.info("Done!")
if __name__ == "__main__":
fire.Fire(split_dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment