Created
December 16, 2022 21:00
-
-
Save cemoody/f8457345c941da693b04675593d13a3c to your computer and use it in GitHub Desktop.
A multiprocess Parquet DataLoader for PyTorch. Great for loading large sequential access datasets. Easy to install, modify, and use.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
import queue | |
from loguru import logger | |
import pandas as pd | |
def chunks(df, chunk_size=1000): | |
for i in range(0, len(df), chunk_size): | |
yield df[i : i + chunk_size] | |
def parquet_reader(path): | |
return pd.read_parquet(path) | |
def worker_fn( | |
input_queue, | |
output_queue, | |
func=None, | |
batch_size=128, | |
worker_id=0, | |
verbose=False, | |
file_reader=parquet_reader, | |
): | |
logger.info(f"starting worker {worker_id}") | |
output_queue.put(worker_id) | |
while True: | |
# Worker function, simply reads indices from index_queue, and adds the | |
# dataset element to the output_queue | |
try: | |
path = input_queue.get(timeout=0) | |
except queue.Empty: | |
continue | |
if path is None: | |
logger.info(f"worker {worker_id} got stop signal") | |
break | |
logger.info(f"worker {worker_id} reading {path}") | |
fh = file_reader(path) | |
if verbose: | |
logger.info(f"worker {worker_id} did read {path}") | |
for j, chunk in enumerate(chunks(fh, batch_size)): | |
if verbose: | |
logger.info(f"worker {worker_id} putting batch {j}") | |
if func: | |
output_queue.put(func(chunk)) | |
else: | |
output_queue.put(chunk) | |
logger.info(f"worker {worker_id} finished {path}") | |
output_queue.put(worker_id) | |
class FileDataLoader: | |
finished_workers = set() | |
def __init__( | |
self, | |
paths, | |
batch_size=64, | |
num_workers=1, | |
prefetch_batches=100, | |
transform=None, | |
): | |
self.num_workers = num_workers | |
self.prefetch_batches = prefetch_batches | |
self.input_queue = multiprocessing.Queue() | |
self.output_queue = multiprocessing.Queue(maxsize=prefetch_batches) | |
self.transform = transform | |
# Start workers | |
self.workers = [] | |
for i in range(num_workers): | |
worker = multiprocessing.Process( | |
target=worker_fn, | |
args=(self.input_queue, self.output_queue, transform, batch_size, i), | |
) | |
worker.daemon = True | |
worker.start() | |
self.workers.append(worker) | |
logger.debug(f"started {num_workers} workers") | |
# Wait for workers to start and verify they started | |
started_worker_ids = set() | |
for _ in range(num_workers): | |
started_worker_ids.add(self.output_queue.get(timeout=3)) | |
assert started_worker_ids == set(range(num_workers)) | |
logger.debug(f"verified {num_workers} workers") | |
# Load up input queue | |
self.paths = paths | |
for path in paths: | |
self.input_queue.put(path) | |
logger.debug(f"loaded {len(paths)} files into queue") | |
while self.output_queue.empty(): | |
time.sleep(1) | |
logger.debug(f"discovered first batch of data") | |
# Load None into queue to signal workers to stop | |
for _ in range(num_workers): | |
self.input_queue.put(None) | |
logger.debug(f"queued stop signals to workers") | |
def __iter__(self): | |
return self | |
def __next__(self): | |
batch = self.get() | |
if batch is None: | |
raise StopIteration | |
return batch | |
def is_done(self): | |
# Data can be in the input, in a worker, or in the output queue | |
# If the input queue is empty, and all workers are done, and the output | |
# queue is empty, we're done | |
return ( | |
(len(self.finished_workers) == self.num_workers) | |
and self.input_queue.empty() | |
and self.output_queue.empty() | |
) | |
def get(self): | |
timer = time.time() | |
while True: | |
try: | |
batch = self.output_queue.get(timeout=0) | |
if isinstance(batch, int): | |
worker_id = batch | |
logger.debug(f"dataloader acks worker {worker_id} finished") | |
self.finished_workers.add(worker_id) | |
continue | |
timer = time.time() | |
return batch | |
except queue.Empty: # output queue empty, keep trying | |
pass | |
if self.is_done(): | |
logger.debug("Loader is done with data") | |
return None | |
if time.time() - timer > 10: | |
logger.debug("Loader has been waiting for data for 10 sec") | |
logger.debug(f"input_queue: {self.input_queue.empty()}") | |
logger.debug(f"output_queue: {self.output_queue.empty()}") | |
logger.debug(f"finished_workers: {self.finished_workers}") | |
time.sleep(1) | |
def __del__(self): | |
try: | |
for _ in self.workers: | |
self.input_queue.put(None) | |
for w in self.workers: | |
w.join(timeout=5.0) | |
self.input_queue.cancel_join_thread() | |
self.input_queue.close() | |
self.output_queue.cancel_join_thread() | |
self.output_queue.close() | |
logger.debug("closed queues") | |
finally: | |
for w in self.workers: | |
if w.is_alive(): | |
w.terminate() | |
logger.debug("terminated workers") | |
if __name__ == "__main__": | |
import time | |
fns = [ | |
"/Users/chris/Downloads/temp2/4ec3fd44-d1ec-4610-82a6-5b409a796abf.wds_img_vectors.parquet_80a55350-2927-4716-94b9-996a56606bf3.wds_img_vectors.parquet.parquet", | |
"/Users/chris/Downloads/temp2/3fd02503-663c-4048-830e-eab3f349ff26.wds_img_vectors.parquet_5ae4cc08-5c07-4ed2-aea3-27b2e13dd84b.wds_img_vectors.parquet.parquet", | |
"/Users/chris/Downloads/temp2/1fcbf028-49b2-4c84-87c1-d0c4622cdaf3.wds_img_vectors.parquet_5ec6a0c3-6771-4b11-8d32-fcde1d6af428.wds_img_vectors.parquet.parquet", | |
"/Users/chris/Downloads/temp2/4ec3fd44-d1ec-4610-82a6-5b409a796abf.wds_img_vectors.parquet_80a55350-2927-4716-94b9-996a56606bf3.wds_img_vectors.parquet.parquet", | |
# "/Users/chris/Downloads/temp2/3fd02503-663c-4048-830e-eab3f349ff26.wds_img_vectors.parquet_5ae4cc08-5c07-4ed2-aea3-27b2e13dd84b.wds_img_vectors.parquet.parquet", | |
# "/Users/chris/Downloads/temp2/1fcbf028-49b2-4c84-87c1-d0c4622cdaf3.wds_img_vectors.parquet_5ec6a0c3-6771-4b11-8d32-fcde1d6af428.wds_img_vectors.parquet.parquet", | |
# "/Users/chris/Downloads/temp2/4ec3fd44-d1ec-4610-82a6-5b409a796abf.wds_img_vectors.parquet_80a55350-2927-4716-94b9-996a56606bf3.wds_img_vectors.parquet.parquet", | |
# "/Users/chris/Downloads/temp2/3fd02503-663c-4048-830e-eab3f349ff26.wds_img_vectors.parquet_5ae4cc08-5c07-4ed2-aea3-27b2e13dd84b.wds_img_vectors.parquet.parquet", | |
# "/Users/chris/Downloads/temp2/1fcbf028-49b2-4c84-87c1-d0c4622cdaf3.wds_img_vectors.parquet_5ec6a0c3-6771-4b11-8d32-fcde1d6af428.wds_img_vectors.parquet.parquet", | |
] | |
dl = FileDataLoader(fns, batch_size=2048, num_workers=2) | |
log_i = 0 | |
for i, batch in enumerate(dl): | |
if i % 2**log_i == 0: | |
logger.debug(f"batch {i}") | |
log_i += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment