Created
June 28, 2021 17:53
-
-
Save eriknw/9a334b74809232c1f1bc2e97141ebb16 to your computer and use it in GitHub Desktop.
GraphBLAS prefix scan
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "c0cd52f4", | |
"metadata": {}, | |
"source": [ | |
"## Goal: demonstrate calculating cumsum with matrix multiplications\n", | |
"\n", | |
"We use the Blelloch parallel prefix scan algorithm to perform cumsum.\n", | |
"Numpy is used for illustration, but the goal is to support a prefix scan in GraphBLAS with any Monoid.\n", | |
"\n", | |
"https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "513b64bf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n", | |
" [ 8, 9, 10, 11, 12, 13, 14, 15]])" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"A = np.arange(16).reshape((2, 8))\n", | |
"A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "699a8553", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 3, 6, 10, 15, 21, 28],\n", | |
" [ 8, 17, 27, 38, 50, 63, 77, 92]])" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# End result\n", | |
"A.cumsum(axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "7a5b1756", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# We use `S` names for \"Scan\" matrices, which are patterns to perform prefix scans.\n", | |
"S1 = np.array(\n", | |
" [\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [1, 1, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 1, 1, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 1, 1, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 1, 1],\n", | |
" ]\n", | |
").T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "dfb42902", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 0, 5, 0, 9, 0, 13],\n", | |
" [ 0, 17, 0, 21, 0, 25, 0, 29]])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# B holds intermediate sums.\n", | |
"#\n", | |
"# Observe that we are using the standard plus-times semiring.\n", | |
"# In GraphBLAS, the Semiring would use FIRST for the BinaryOp and any Monoid.\n", | |
"# Hence, prefix scans only work with Monoids using this method.\n", | |
"B = A @ S1\n", | |
"B" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "6e51243b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"S2 = np.array(\n", | |
" [\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 1, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 1, 0, 0],\n", | |
" ]\n", | |
").T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "e934d0b8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 0, 6, 0, 9, 0, 22],\n", | |
" [ 0, 17, 0, 38, 0, 25, 0, 54]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"B += B @ S2\n", | |
"B" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "90c33300", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"S3 = np.array(\n", | |
" [\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 1, 0, 0, 0, 0],\n", | |
" ]\n", | |
").T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "b7dc1205", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 0, 6, 0, 9, 0, 28],\n", | |
" [ 0, 17, 0, 38, 0, 25, 0, 92]])" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"B += B @ S3\n", | |
"B" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "6ebcef49", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# downsweep\n", | |
"S4 = np.array(\n", | |
" [\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 1, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" ]\n", | |
").T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "5fd544d9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 0, 6, 0, 15, 0, 28],\n", | |
" [ 0, 17, 0, 38, 0, 63, 0, 92]])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"B += B @ S4\n", | |
"B" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "6aeabcc3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"S5 = np.array(\n", | |
" [\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 1, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 1, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 1, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" ]\n", | |
").T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "433e7981", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 3, 6, 10, 15, 21, 28],\n", | |
" [ 8, 17, 27, 38, 50, 63, 77, 92]])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Finish 1 (copy A)\n", | |
"rv = A.copy()\n", | |
"rv += B @ S5 # can be before or after the `np.where` statement\n", | |
"rv = np.where(B != 0, B, rv)\n", | |
"rv" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "31ff6eb9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 3, 6, 10, 15, 21, 28],\n", | |
" [ 8, 17, 27, 38, 50, 63, 77, 92]])" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Finish 2 (copy B)\n", | |
"rv = B.copy()\n", | |
"rv = np.where(rv == 0, A, rv)\n", | |
"rv += B @ S5\n", | |
"rv" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "726ad8e5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 3, 6, 10, 15, 21, 28],\n", | |
" [ 8, 17, 27, 38, 50, 63, 77, 92]])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Finish 3 (use B)\n", | |
"C = B @ S5\n", | |
"B = np.where(B == 0, A, B)\n", | |
"B += C\n", | |
"B" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "1320903b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 3, 6, 10, 15, 21, 28],\n", | |
" [ 8, 17, 27, 38, 50, 63, 77, 92]])" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Sanity check\n", | |
"A.cumsum(axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bc6927a3", | |
"metadata": {}, | |
"source": [ | |
"### Comments\n", | |
"- We should limit the entries in the S matrices according to which columns in A have values\n", | |
"- For hypersparse matrices, we should \"compress\" or \"project\" the matrix into a matrix with fewer columns so that all columns have values\n", | |
" - Should we always do this?\n", | |
"- Cumsum for a diagonal matrix is a pathologically bad case. How can we make it better? What does this teach us?\n", | |
" - For example, the first element in a diagonal matrix with `N` elements will be duplicated `log2(N)` times in the `B` matrix, but none of the duplicated values are used\n", | |
"- This method requires roughly `2 * log2(ncols)` matrix multiplies\n", | |
"- This would be straightforward to implement in GraphBLAS\n", | |
"- It's desirable that prefix scans are performed within GraphBLAS instead of exporting to C arrays\n", | |
"- I don't have a good understanding how well this would perform in practice" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment