Created
January 31, 2025 20:46
-
-
Save TomAugspurger/69896a50c852897486b799d1023638be to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# dali + Zarr (GPU) example.\n", | |
"\n", | |
"This script adapts the GPU example from\n", | |
"https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/external_input.html\n", | |
"to use Zarr for storage.\n", | |
"\n", | |
"To run it, you'll currently need to use my fork of zarr-python:\n", | |
"\n", | |
" pip install git+https://github.com/TomAugspurger/zarr-python/@tom/fix/gpu\n", | |
"\n", | |
"That should be in zarr `main` soon. You'll also need the data.\n", | |
"\n", | |
"```\n", | |
"mkdir -p data/images\n", | |
"cd data/images\n", | |
"curl -O https://docs.nvidia.com/deeplearning/dali/user-guide/docs/_images/examples_general_data_loading_external_input_12_2.png\n", | |
"curl -O curl -O https://docs.nvidia.com/deeplearning/dali/user-guide/docs/_images/examples_general_data_loading_external_input_19_2.png\n", | |
"\n", | |
"```\n", | |
"\n", | |
"And a `file_list.txt` like\n", | |
"\n", | |
"```\n", | |
"examples_general_data_loading_external_input_12_2.png 0\n", | |
"examples_general_data_loading_external_input_19_2.png 1\n", | |
"```\n", | |
"\n", | |
"Then run `make_data()` to create the zarr store." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from random import shuffle\n", | |
"from nvidia.dali.pipeline import Pipeline\n", | |
"import nvidia.dali.fn as fn\n", | |
"import zarr\n", | |
"import zarr.storage\n", | |
"from PIL import Image\n", | |
"from nvidia.dali import types" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def make_data():\n", | |
" # TODO: figure out the shape here.\n", | |
" # goes from 4 -> 3 somewhere.\n", | |
" store = zarr.storage.LocalStore(root=\"data/example.zarr\")\n", | |
" group = zarr.create_group(store, overwrite=True)\n", | |
"\n", | |
" TOTAL_SAMPLES = 100\n", | |
"\n", | |
" # note: the images from the docs vary in size while Zarr requires\n", | |
" # uniform chunk sizes. I've truncated the images to 231 x 300\n", | |
"\n", | |
" arr = group.create_array(\n", | |
" name=\"images\",\n", | |
" shape=(TOTAL_SAMPLES, 231, 300, 3),\n", | |
" chunks=(1, 231, 300, 3),\n", | |
" dtype=\"uint8\",\n", | |
" overwrite=True,\n", | |
" )\n", | |
"\n", | |
" labels = group.create_array(\n", | |
" name=\"labels\",\n", | |
" shape=(TOTAL_SAMPLES,),\n", | |
" chunks=(1,),\n", | |
" dtype=\"uint8\",\n", | |
" overwrite=True,\n", | |
" )\n", | |
"\n", | |
" # TODO: use file list\n", | |
" # assuming you've downloaded these two\n", | |
" img = Image.open(\n", | |
" \"data/images/examples_general_data_loading_external_input_12_2.png\"\n", | |
" )\n", | |
" arr[0] = img\n", | |
" labels[0] = 0\n", | |
" img = Image.open(\n", | |
" \"data/images/examples_general_data_loading_external_input_19_2.png\"\n", | |
" )\n", | |
" arr[1] = img\n", | |
" labels[1] = 1\n", | |
"\n", | |
"make_data()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 16\n", | |
"\n", | |
"class ExternalInputIterator:\n", | |
" def __init__(self, batch_size: int):\n", | |
" self.root = \"data/example.zarr/\"\n", | |
" self.variable = \"images\"\n", | |
" self.batch_size = batch_size\n", | |
"\n", | |
" # Does this class get serialized? Is it safe to store\n", | |
" # references to zarr arrays here?\n", | |
" # self.images = zarr.open_array(self.root, path=self.variable)\n", | |
" # self.labels = zarr.open_array(self.root, path=\"labels\")\n", | |
"\n", | |
" self.indices = list(\n", | |
" range(zarr.open_array(self.root, path=self.variable).shape[0])\n", | |
" )\n", | |
" shuffle(self.indices)\n", | |
" self.i = 0\n", | |
" self.n = len(self.indices)\n", | |
"\n", | |
" def __iter__(self):\n", | |
" self.i = 0\n", | |
" self.n = len(self.indices)\n", | |
" return self\n", | |
"\n", | |
" def __next__(self):\n", | |
" batch = []\n", | |
" labels = []\n", | |
"\n", | |
" arr = zarr.open(self.root, path=self.variable)\n", | |
" arr_labels = zarr.open(self.root, path=\"labels\")\n", | |
"\n", | |
" for _ in range(self.batch_size):\n", | |
" batch.append(arr[self.i])\n", | |
" labels.append(arr_labels[self.i])\n", | |
" self.i = (self.i + 1) % self.n\n", | |
" return (batch, labels)\n", | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(231, 300, 3)\n", | |
"0\n" | |
] | |
} | |
], | |
"source": [ | |
"# Need's my branch currently\n", | |
"zarr.config.enable_gpu()\n", | |
"\n", | |
"\n", | |
"eii = ExternalInputIterator(batch_size)\n", | |
"pipe = Pipeline(batch_size=batch_size, num_threads=2, device_id=0)\n", | |
"# note: using the `device=\"gpu\"` variant from https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/external_input.html\n", | |
"with pipe:\n", | |
" images, labels = fn.external_source(source=eii, num_outputs=2, dtype=types.UINT8, device=\"gpu\")\n", | |
" enhance = fn.brightness_contrast(images, contrast=2)\n", | |
" pipe.set_outputs(enhance, labels)\n", | |
"\n", | |
"pipe.build()\n", | |
"pipe_out = pipe.run()\n", | |
"\n", | |
"batch_gpu = pipe_out[0].as_cpu()\n", | |
"labels_gpu = pipe_out[1].as_cpu()\n", | |
"\n", | |
"print(batch_gpu.at(0).shape)\n", | |
"print(labels_gpu.at(0))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(batch_gpu.at(1));" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "python-3.12", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment