Last active
June 27, 2025 02:56
-
-
Save pszemraj/502cf50767e8c363947f5e93cb95cc07 to your computer and use it in GitHub Desktop.
aggregate and push an hf dataset from text files
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Create & save an hf dataset with train/test/val splits from dir w/ text files | |
Ideal structure: | |
root / section_name_1 / file 1 | |
root / section_name_1 / file 2 | |
root / section_name_1 / file YYY | |
root / section_name_2 / file 1 | |
root / section_name_2 / file ZZZ | |
... | |
but will work for just a single folder. | |
Usage: | |
python push_dataset_from_text.py --help | |
""" | |
import datetime | |
import logging | |
from pathlib import Path | |
import datasets | |
import fire | |
import pandas as pd | |
from datasets import DatasetDict | |
from tqdm import tqdm | |
def get_filename_safe_timestamp(): | |
timestamp = datetime.datetime.now().strftime("%Y-%b-%d_%H-%M") | |
return timestamp | |
def setup_logging(log_level="INFO", log_file=None): | |
""" | |
Set up logging configurations. | |
Args: | |
log_level (str): Minimum log level for messages to handle. | |
log_file (str): Path to the log file. If not specified, logs won't be saved to a file. | |
""" | |
LOG_LEVELS = { | |
"DEBUG": logging.DEBUG, | |
"INFO": logging.INFO, | |
"WARNING": logging.WARNING, | |
"ERROR": logging.ERROR, | |
"CRITICAL": logging.CRITICAL, | |
} | |
minimalist_formatter = logging.Formatter("%(levelname)s: %(message)s") | |
logger = logging.getLogger() | |
logger.setLevel(LOG_LEVELS.get(log_level.upper(), logging.INFO)) | |
if logger.hasHandlers(): | |
logger.handlers.clear() | |
console_handler = logging.StreamHandler() | |
console_handler.setFormatter(minimalist_formatter) # set formatter for console | |
logger.addHandler(console_handler) | |
if log_file: | |
file_handler = logging.handlers.RotatingFileHandler( | |
log_file, | |
maxBytes=1 * 1024 * 1024, | |
backupCount=2, | |
) | |
file_handler.setFormatter(minimalist_formatter) # set formatter for file | |
logger.addHandler(file_handler) | |
logger.debug("Logging is set up.") | |
def split_dataset( | |
data_dir: Path, | |
extensions: str = "txt", | |
val_test_size: float = 0.0, | |
repo_id: str = None, | |
config_name: str = None, | |
commit_message: str = None, | |
private: bool = True, | |
save_dir: Path = None, | |
log_level="INFO", | |
): | |
""" | |
Discover all text files, read content into a Pandas DataFrame, convert into a dataset, and create train/test/validation splits. | |
:param Path data_dir: root directory to search for text files. | |
:param str extensions: extensions to look for in the dataset as one string separated by commas. | |
:param float val_test_size: size of the validation and test sets. Default: 0.05 (0.025 for each) | |
:param str repo_id: repository ID for pushing to Huggingface hub. | |
:param str config_name: config name for pushing to Huggingface hub. | |
:param str commit_message: commit message for pushing to Huggingface hub. | |
:param bool private: if the dataset should be pushed as private | |
:param Path save_dir: directory where to save the dataset | |
:param str log_level: logging level | |
:raises ValueError: if data_dir is not a directory | |
""" | |
setup_logging(log_level) | |
logger = logging.getLogger(__name__) | |
data_dir = Path(data_dir) | |
if not data_dir.is_dir(): | |
raise ValueError(f"Data directory not found: {data_dir}") | |
if repo_id is None and save_dir is None: | |
logger.warning("No repo_id or save_dir specified, 'dry-run' mode.") | |
if isinstance(extensions, tuple): | |
extensions = ",".join(extensions) | |
extensions = [extension.strip(".") for extension in extensions.split(",")] | |
logger.info(f"Looking for files with extensions: {extensions}") | |
data = [] | |
for extension in extensions: | |
txt_files = list(data_dir.rglob(f"*.{extension}")) | |
for file_path in tqdm(txt_files, desc=f"Processing {extension} files"): | |
with file_path.open("r", encoding="utf-8", errors="ignore") as f: | |
content = f.read() | |
if len(content.strip()) == 0: | |
logger.warning(f"Empty file: {file_path}, skipping...") | |
continue | |
channel_name = file_path.parts[-2] | |
filename = file_path.stem | |
rel_path = file_path.relative_to(data_dir).parent | |
data.append( | |
{ | |
"relative_path": str(rel_path), | |
"section": channel_name, | |
"filename": filename, | |
"text": content, | |
} | |
) | |
# TODO: this is inefficient, look up the actual from list of dicts API | |
ds = datasets.Dataset.from_pandas(pd.DataFrame(data)) | |
logger.info(f"Found {len(data)} text files - {get_filename_safe_timestamp()}") | |
del data | |
# Check for columns with only one unique value and remove them | |
columns_to_remove = [] | |
for column in ["relative_path", "section"]: | |
if column in ds.column_names: | |
unique_values = set(ds[column]) | |
if len(unique_values) == 1: | |
columns_to_remove.append(column) | |
logger.info( | |
f"Removing column '{column}' as it only contains one unique value: {list(unique_values)[0]}" | |
) | |
if columns_to_remove: | |
ds = ds.remove_columns(columns_to_remove) | |
logger.info(f"Removed columns: {columns_to_remove}") | |
if val_test_size > 0: | |
logger.debug("train-test split...") | |
train_temp_split = ds.train_test_split( | |
test_size=val_test_size, shuffle=True, seed=80085 | |
) | |
train_dataset = train_temp_split["train"] | |
temp_dataset = train_temp_split["test"] | |
logger.debug("validation-test split...") | |
val_test_split = temp_dataset.train_test_split( | |
test_size=0.5, shuffle=True, seed=80085 | |
) | |
val_dataset = val_test_split["train"] | |
test_dataset = val_test_split["test"] | |
logger.info( | |
f"Train size: {len(train_dataset)}\n" | |
f"Validation size: {len(val_dataset)}\n" | |
f"Test size: {len(test_dataset)}" | |
) | |
aggregated_dataset = DatasetDict( | |
{"train": train_dataset, "validation": val_dataset, "test": test_dataset} | |
) | |
else: | |
logger.info(f"Only creating train split (val_test_size={val_test_size})") | |
train_dataset = ds | |
logger.info(f"Train size: {len(train_dataset)}") | |
aggregated_dataset = DatasetDict({"train": train_dataset}) | |
if repo_id: | |
logger.info(f"Pushing dataset to Huggingface hub ({repo_id})...") | |
commit_message = ( | |
commit_message | |
if commit_message is not None | |
else f"update dataset {get_filename_safe_timestamp()}" | |
) | |
logger.info(f"Using repo id:\t{repo_id}, config name:\t{config_name}") | |
if config_name: | |
aggregated_dataset.push_to_hub( | |
repo_id=repo_id, | |
config_name=config_name, | |
private=private, | |
max_shard_size="800MB", | |
commit_message=commit_message, | |
) | |
else: | |
aggregated_dataset.push_to_hub( | |
repo_id=repo_id, | |
private=private, | |
max_shard_size="800MB", | |
commit_message=commit_message, | |
) | |
logger.info(f"Dataset pushed to {repo_id}") | |
if save_dir: | |
logger.info(f"Saving dataset to {save_dir}...") | |
save_dir = Path(save_dir) | |
save_dir.mkdir(parents=True, exist_ok=True) | |
aggregated_dataset.save_to_disk(save_dir) | |
logger.info(f"Dataset saved to {save_dir}") | |
logger.info("Done!") | |
if __name__ == "__main__": | |
fire.Fire(split_dataset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment