In a GPU-driven renderer, "work expansion" is a commonly occurring problem. "Work Expansion" means that a single item of work spawns N following work items. Typically one work item will be executed by one shader thread/invocation.
An example for work expansion is gpu driven meshlet culling following mesh culling.
In this example a "work item" is culling a mesh, where each mesh cull work item
spawns N following meshlet cull work items
.
There are many diverse cases of this problem and many solutions. Some are trivial to solve, for example, when N (how many work items are spawned) is fixed. Simply dispatch N times more threads for the second pass than for the first pass, easy!
But those cases are not particularly interesting. In this post, I want to address more complex cases where N (how many work items get spawned) can vary a lot.
pass 1
: the first pass. Each work item in this pass spawns one or moredst work items
forpass 2
. In the culling example this would be the mesh cullingpass 2
: the second pass. Eachdst work item
inpass 2
is spawned fromsrc work items
in thepass 1
. In the culling example this would be the meshlet culling.P1
: number ofsrc work items
inpass 1
P2
: number ofdst work items
inpass 2
src work item
: work item inpass 1
that spawns new work items forpass 2
.dst work item
: work item inpass 2
, spawned from asrc work item
N
: number ofdst work items
spawned by asrc work item
MAX_N
: maximum work items that each work item inpass 1
can spawnMAX_DST_WORK_ITEMS
: maximum number of TOTALdst work items
MAX_SRC_WORK_ITEMS
: maximum number of TOTALsrc work items
local/dst work item index
: asrc work item
spawns Ndst work items
, theirlocal work item indices
will be 0,1,2,...N-1.
The latest Graphics APIs actually have a set of solutions to these problems. For example DX12 Workgraphs. With "Work Graphs", each thread can immediately start new threads within the same dispatch! This directly solves most work expansion problems. With work graphs, we don't even need two passes, we could do it all in one!
Another example is the Mesh Shader Dispatch API. In vulkan and dx12, the mesh cull threads can simply append to a buffer of dispatch indirect structs using an atomic counter and then call vkCmdDrawMeshTasksIndirectCountEXT. These allow for a gpu driven count to be written by shaders. But the real power of mesh shaders are the task/amplification shader stage. It allows the task/amplification shader stage to directly launch N
runtime determined mesh shader work groups. Effectively this is a very limited form of work graphs. This is a great solution for meshlet culling and LOD selection.
Lastly there is Device Generated Commands. With this extension, we can simply insert new dispatch indirects for spawned work items directly into the command buffer!
The OG is of course the dispatch indirect command. This command allows shaders to write a dispatch command to a buffer that is then read by the gpu to determine the size of a following dispatch. This solution is great and is used extensively in gpu driven renderers. With some extra work in shaders we can get a lot of efficient work expansion done with this command. But the command alone is very barebones and does not deal with most expansion problem directly (and thats totally fine!).
Ideally we would also have something like "dispatch indirect count" for other work expansion problems, sadly no gfx api has this feature to my knowledge.
All these api solutions are great, but they have downsides:
- no "dispatch indirect count", its possible to abuse task/amplification -> mesh shader expansion but that is very awkward.
- work graphs only run on very recent gpus and have very poor performance on nvidia.
- device generated commands put a lot of pressure on the frontend of the gpu. Inserting many thousands of dispatches can be slow.
- thread granularity and dispatch overhead for small
N
s. Dispatches have a non-insignificant fixed minimal performance overhead on gpus. Inserting many many small dispatches via DGC can be very slow as the gpus frontend wont be able to process them fast enough for many small workloads. - thread granularity within dispatches. GPU dispatches are made to work on many work items, workgroup sizes should typically be at least 64 threads and optimally there should also be many workgroups. Having small dispatches, that use far fewer than, say, 64 threads, will leave a lot of the gpus hardware wasted. This is the case with vkCmdDrawMeshTasksIndirectCountEXT and DGC.
By far the best solution out of these is work graphs. But as its so recent, most gpus cant run them at all and nvidia only poorly. So they are not yet a generally applicable solution. So lets look over software solutions.
Dispatch P1
threads in pass 1
and P1
* MAX_N
threads for pass 2
. Each thread working on a single item.
Each thread in pass 1
will know its src work item
by its global dispatch index. Each thread in pass 2
will know its src work item
by dispatchThreadIdx / MAX_N
and its local work item index by dispatchThreadIdx % MAX_N
Now, not all work items in pass 1
will spawn MAX_N
work items for pass 2
, yet we spawn MAX_N
work threads for each unconditionally. To solve this, we can can write to a buffer for each item in pass 1
, describing how many work items are desired. The threads in pass 2
would read that buffer and early out if their relative thread index is greater then the work item count.
This is a good solution when most src work items
spawn a number of dst work items
with a small MAX_N
.
It is a very poor solution, when a lot or most dst work items
spawn lot lot less dst work items
then MAX_N
.
In most cases, such as the mesh and meshlet culling example, N
varies a lot for each src work item
.
Dispatch P1
threads in pass 1
and perform one dispatch indirect threads for pass 2
. Each thread working on a single item.
We have a buffer containing a dst work item
counter, a dispatch indirect struct and and array of dst work item
infos:
struct DstWorkItemInfo
{
uint src_work_item_idx;
uint local_work_item_idx;
}
struct Buffer
{
uint3 pass2_dispatch_indirect;
uint dst_work_item_count;
DstWorkItemInfo dst_work_item_infos[MAX_DST_WORK_ITEMS];
}
Each thread in pass 1
will perform atomic operations to update pass2_dispatch_indirect
and dst_work_item_count
, using the returned value from the atomic add on dst_work_item_count
as an offset to write its DstWorkItemInfo
s to. Each thread will then loop N
times, writing N DstWorkItemInfo
s.
Each thread in pass 2
will use its dispatch thread index to index the buffers dst_work_item_infos
, to know what to work on.
This solution is very good when N
is mostly very small (Lets say smaller then ~16) and MAX_N
is also very small.
This solution is very bad when N
varies a lot and gets very big. This is because the threads in pass 1
will have to perform a a write in a loop N
times. For large N
s this is very slow and can become very divergent, hurting performance.
Note: Each dst work item could also represent a whole workgroup of threads instead. Its common that the expansion ratio is mostly quite large, making it much more efficient to map a dst work item to whole work-groups even if that means that some threads may be wasted (For example 1 dst work item = 64 threads but there is only work for 24 threads).
Dispatch P1
threads in pass 1
and perform one dispatch indirect threads for pass 2
. Each thread working on a single item.
Now it gets a little complicated. Like in Solution 1
we have a Buffer with a value, counting the number of needed dst work items
and a dispatch indirect struct. But we have no array of dst work item
infos.
Instead of writing a single dst work item
info for every dst work item
, we write a list of multi dst work item
infos and make the threads in pass 2
search for their src work item
.
struct MultiDstWorkItemInfo
{
uint src_work_item_idx;
uint work_item_count; // Count instead of single index!
}
struct Buffer
{
uint3 pass2_dispatch_indirect;
uint multi_dst_work_item_count;
MultiDstWorkItemInfo multi_dst_work_item_infos[MAX_SRC_WORK_ITEMS];
}
Writing only one multi dst work item
instead of many dst work item
s can be much faster for the src work item
threads.
Wait, but how do the threads in pass 2
know what dst and src work item to work on? They only have their dispatch thread id and there is no direct mapping from that to a dst work item
like there is in Solution 1
!
Each thread in pass 2
can easily find its MultiDstWorkItemInfo
index by iterating over all MultiDstWorkItemInfo
s, subtracting the work_item_count
from its dispatch thread index each time, until it gets negative. The last SrcWorkItemInfo
, for which the subtracted thread index is still >= 0
is the threads SrcWorkItemInfo
index.
If this is not fully clear, pick a thread in the following example and perform the search described above.
SrcWorkItemInfo src_work_item_infos[MAX_SRC_WORK_ITEMS] = {
{ 0, 3 },
{ 1, 1 },
{ 2, 2 }
}
pass 1 threads:
dispatch thread index: 0: src work item 0, local work item index: 0
dispatch thread index: 1: src work item 0, local work item index: 1
dispatch thread index: 2: src work item 0, local work item index: 2
dispatch thread index: 3: src work item 1, local work item index: 0
dispatch thread index: 4: src work item 2, local work item index: 0
dispatch thread index: 4: src work item 2, local work item index: 1
Linear search is very coherent for the pass 2
threads but also scales super poorly with more multi dst work item
s.
The better approach is to do the following:
- every
pass 1
thread writes out theirN
for theirsrc work item
into an array - build a prefix sum array over the
N
- every
pass 2
thread performs a binary search or theirdispatch thread index
in the prefix sum array, the found index is theMultiDstWorkItem
for that thread
The idea behind this is that the prefix sum value for each MultiDstWorkItem
in the prefix sum array represents the exclusive upper bound of the dispatch thread indices associated with that source work item
. So the prefix sum maps pass 2
thread indices to the MultiDstWorkItem
s. Binary search is also a very good choice here as its very data and execution coherent for similar thread indices.
This solution is very good for most cases. It has "constant" time for writing expansion info for threads in pass 1
, a few low memory footprint of only MAX_SRC_WORK_ITEMS
and the search is relatively fast due to its very high data and execution coherence in pass 2
threads.
Generally this is the most commonly used solution for work expansion, as it is so simple to implement and its high performance.
Now, building a prefix sum is typically done in many passes and can require significant gpu time due to the barriers and work that has to be performed.
Luckily there is a trick to speed the prefix sum part up by a lot!
Instead of only atomically counting the number of needed MultiDstWorkItem
s in pass 1
threads, we can also build the prefix sum at the same time!
The 32 bit counter variable multi_dst_work_item_count
is bumped to a 64bit variable named multi_dst_work_item_count_sum
. The upper 32 bit of the new 64 bit variable still contains the count of current multi dst work item
s but we put a payload in the lower 32 bit of the variable. In this payload we put the sum of all needed dst work items
(NOT multi dst work items
but the actual total amount of threads needed in pass 2
).
Using 64 bit atomics, we can perform an atomic compare exchange loop:
- read current value for
multi_dst_work_item_count_sum
- unpack
multi_dst_work_item_count
and sum of alldst work item
s - add 1 to
multi_dst_work_item_count
andN
to the sum of neededdst work item
s - pack both values into a 64 bit value
- try to write the packed value BUT ONLY if it didnt change in the meantime, using the atomic compare exchange
After performing this loop, the src work item
threads will write their MultiDstWorkItem
out. We also need the prefix sum value for the needed dst work items
inside the MultiDstWorkItem
s now:
uint src_work_item_idx;
uint work_item_count;
uint work_item_prefix_sum;
Now, doing this, we can build the prefix sum entirely within pass 1
with minimal extra performance overhead. But we can completely get rid of the multi-pass prefix sum generation, making the surrounding code simpler and faster.
Implementation Example Shader File
This solution is an optimized alternative to Solution 1. It aims to reduce the cost of storing and accessing work items by using a bucketed power-of-two grouping strategy. While slightly more expensive in pass 1
, it completely eliminates the need for a binary search in pass 2
, making it highly efficient for large-scale work expansion.
Instead of each src work item
writing N individual DstWorkItemInfo
entries, like in Sol.1, it writes log₂(N) entries, one for each set bit in the binary representation of N.
Each entry is written to a dedicated bucket, where bucket i corresponds to a block of 2^i work items.
Each bucket thus stores fewer, coarser-grained entries, reducing overall memory operations and simplifying lookup logic for pass 2
.
struct DstWorkItemInfo
{
uint src_work_item_idx;
uint work_item_idx_offset;
// The work_item_count is now implicitly known by the bucket index
}
Each bucket contains an array of DstWorkItemInfo. There are MAX_BUCKETS (typically 32) such arrays, where:
- Bucket 0: each entry represents 1 work item
- Bucket 1: each entry represents 2 work items
- Bucket 2: each entry represents 4 work items ...
- Bucket i: each entry represents 2^i work items
In pass 1
, each thread analyzes its N (number of work items to spawn) and decomposes it into powers of two. For each bit set in the binary representation of N, it appends one DstWorkItemInfo to the corresponding bucket.
Example
Let’s say a thread wants to spawn N = 11 work items.
Binary representation: 1011 → bits 0, 1, and 3 are set
The thread will append entries to buckets 0, 1, and 3, corresponding to work group sizes 1, 2, and 8, respectively.
Each pass 2
dispatch reads from a specific bucket. The mapping of a thread index to its DstWorkItemInfo is simple and direct:
uint numItemsPerInfo = 1 << bucketIndex;
uint infoIndex = dispatchThreadIndex / numItemsPerInfo;
DstWorkItemInfo info = bucket[infoIndex];
uint localIndex = info.work_item_idx_offset + (dispatchThreadIndex % numItemsPerInfo);
Instead of writing a DstWorkItemInfo
for each dst work item
, we bundle multiple infos into one.
In order for this to work, we need multiple DstWorkItemInfo
arrays and one dispatch for each. Each of these arrays is called a "bucket".
The trick here is that for each of the buckets, a DstWorkItemInfo
represents a different amount of work items. The larger the bucket index, the more dst work items
each DstWorkItemInfo
represents: numOfWorkItemsPerInfo = 2^bucketIndex
.
So for bucket 0
a DstWorkItemInfo
represents 1 dst work item
, for bucket 1
its 2, for bucket 2
its 4 and so on.
Now the mapping of dispatch thread index of a dst work item
thread is trivial again:
uint numItemsPerInfo = 1 << bucketIndex;
uint infoIndex = dispatchThreadIndex / numItemsPerInfo;
DstWorkItemInfo info = bucket[infoIndex];
uint localIndex = info.work_item_idx_offset + (dispatchThreadIndex % numItemsPerInfo);
If N can be loaded from a buffer indexed by src_work_item_idx, then work_item_idx_offset can even be computed on-the-fly rather than stored explicitly.
- No binary search required in
pass 2
- Coarse granularity in memory writes from
pass 1
threads - High performance, especially when N is diverse and potentially large
- Increased memory usage: requires up to 32 bucket arrays, each needing
sizeof(uint) * (MAX_DST_WORK_ITEMS/(2^bucket_i))
memory - Higher implementation complexity
Note: In practice, all 32 dispatches can be merged into a single indirect dispatch to reduce overhead. This optimization is included in the provided example code. The overhead of compute dispatches is quite significant, so this can save a lot fixed overhead of gpu time (>70us on a RTX4080).
Solution 2.5 has low minimal overhead, but it scales in overhead with dst item count because of the binary search.
Solution 3 needs a lot more memory then 2.5 (depending on the max expansion count 8-16x more) but it consistently outperforms it as the work expansion lookup is much faster with po2 then with the binary search in 2.5 .
In the end the decision has to be made based a performance/memory trandeoff.