In a gpu-driven renderer "work expansion" is a commonly occuring problem. "Work Expansion" means that a single item of work spawns N following work items. Typically one work item will be executed by one shader thread/invocation.
An example for work expansion is gpu driven meshlet culling following mesh culling.
In this example a "work item" is culling a mesh, where each mesh cull work item
spawns N following meshlet cull work items
.
There are many diverse cases of this problem and many solutions. Some are trivial to solve, for example, when N (how many work items are spawned) is fixed. Simply dispatch N times more threads for the second pass than for the first pass, easy!
But those cases are not intresting. In this gist i want to address complexer cases where N (how many work items get spawned) can vary a lot.
pass 1
: the first pass. Each work item in this pass spawnsdst work items
forpass 2
. In the culling example this would be the mesh cullingpass 2
: the second pass. Eachdst work item
inpass 2
is spawned fromsrc work items
in thepass 1
. In the culling example this would be the meshlet culling.P1
: number ofsrc work items
inpass 1
P2
: number ofdst work items
inpass 2
src work item
: work item inpass 1
that spawns new work items forpass 2
.dst work item
: work item inpass 2
, spawned from asrc work item
N
: number ofdst work items
spawned by asrc work item
MAX_N
: maximum work items that each work item inpass 1
can spawnMAX_DST_WORK_ITEMS
: maximum number of TOTALdst work items
MAX_SRC_WORK_ITEMS
: maximum number of TOTALsrc work items
local/dst work item index
: asrc work item
spawns Ndst work items
, theirlocal work item indices
will be 0,1,2,...N-1.
The latest Graphics APIs actually have a set of solutions to these problems. For example DX12 Workgraphs. With "Work Graphs", each thread can immediately start new threads within the same dispatch! This directly solves most work expansion problems. With work graphs, we dont even need two passes, we could do it all in one!
Another example is the Mesh Shader Dispatch API. In vulkan and dx12, the mesh cull threads can simple append to a buffer of dispatch indirect structs using an atomic counter and then call vkCmdDrawMeshTasksIndirectCountEXT. These allow for a gpu driven count to be written by shaders. But the real power of mesh shaders are the task/amplification shader stage. It allows the task/amplification shader stage to direcly launch N
runtime determined mesh shader work groups. Effectively this is a very limited form of work graphs. This is a great solution for meshlet culling and lod selection.
Lastly there is Device Generated Commands. With this extension, we can simply insert new dispatch indirects for spawned work items directly into the command buffer!
The OG is of course the dispatch indirect command. This command allows shaders to write a dispatch command to a buffer that is then read by the gpu to determine the size of a following dispatch. This solution is great and is used extensively in gpu driven renderers. With some extra work in shaders we can get a lot of efficient work expansion done with this command. But the command alone is very barebones and does not deal with most expansion problem directly (and thats totally fine!).
Ideally we would also have something like "dispatch indirect count" for other work expansion problems, sadly no gfx api has this feature to my knowledge.
All these api solutions are great, but they have downsides:
- no "dispatch indirect count", its possible to abuse task/amplification -> mesh shader expansion but that is very akward.
- work graphs only run on very recent gpus and have very poor performance on nvidia.
- device generated commands put a lot of pressure on the frontend of the gpu. Inserting many thousands of dispatches can be slow.
- thread granularity and dispatch overhead for small
N
s. Dispatches have a non-insignificant fixed minimal performance overhead on gpus. Inserting many many small dispatches via DGC can be very slow as the gpus frontend wont be able to process them fast enough for many small workloads. - thread granulariy within dispatches. GPU dispatches are made to work on many work items, workgroup sizes should typically be at least 64 threads and optimally there should also be many workgroups. Having small dispatches, that use way less then say 64 threads, will leave a lot of the gpus hardware wasted. This is the case with vkCmdDrawMeshTasksIndirectCountEXT and DGC.
By far the best solution out of these is work graphs. But as its so recent, most gpus cant run them at all and nvidia only poorly. So they are not yet a generally applicable solution. So lets look over software solutions.
Dispatch P1
threads in pass 1
and P1
* MAX_N
threads for pass 2
. Each thread working on a single item.
Each thread in pass 1
will know its src work item
by its global dispatch index. Each thread in pass 2
will know its src work item
by dispatchThreadIdx / MAX_N
and its local work item index by dispatchThreadIdx % MAX_N
Now, not all all work items in pass 1
will spawn MAX_N
work items for pass 2
, yet we spawn MAX_N
work threads for each unconditionally. To solve this, we can can write a buffer for each item in pass 1
, describing how many work items are desired. The threads in pass 2
would read that buffer and early out if their relative thread index is greater then the work item count.
This is a good solution when most src work items
want to spawn close to MAX_N
dst work items
.
It is a very poor solution, when a lot or most dst work items
spawn lot lot less dst work items
then MAX_N
.
In most cases, such as the mesh and meshlet culling example, N
varies a lot for each src work item
.
Dispatch P1
threads in pass 1
and perform one dispatch indirect threads for pass 2
. Each thread working on a single item.
We have a buffer containing a dst work item
counter, a dispatch indirect struct and and array of dst work item
infos:
struct DstWorkItemInfo
{
uint src_work_item_idx;
uint local_work_item_idx;
}
struct Buffer
{
uint3 pass2_dispatch_indirect;
uint dst_work_item_count;
DstWorkItemInfo dst_work_item_infos[MAX_DST_WORK_ITEMS];
}
Each thread in pass 1
will perform atomic operations to update pass2_dispatch_indirect
and dst_work_item_count
, using the returned value from the atomic add on dst_work_item_count
as an offset to write its DstWorkItemInfo
s to. Each thread will then loop N
times, writing N DstWorkItemInfo
s.
Each thread in pass 2
will use its dispatch thread index to index the buffers dst_work_item_infos
, to know what to work on.
This solution is very good when N
is mostly very small (Lets say smaller then ~16) and MAX_N
is also very small.
This solution is very bad when N
varies a lot and gets very big. This is because the threads in pass 1
will have to perform a a write in a loop N
times. For large N
s this is very slow and can become very divergent, hurting performance.
Dispatch P1
threads in pass 1
and perform one dispatch indirect threads for pass 2
. Each thread working on a single item.
Now it gets a little complicated. Like in Solution 1
we have a Buffer with a value, counting the number of needed dst work items
and a dispatch indirect struct. But we have no array of dst work item
infos.
Instead of writing a single dst work item
info for every dst work item
, we write a list of multi dst work item
infos and make the threads in pass 2
search for their src work item
.
struct MultiDstWorkItemInfo
{
uint src_work_item_idx;
uint work_item_count; // Count instead of single index!
}
struct Buffer
{
uint3 pass2_dispatch_indirect;
uint multi_dst_work_item_count;
MultiDstWorkItemInfo multi_dst_work_item_infos[MAX_SRC_WORK_ITEMS];
}
Writing only one multi dst work item
instead of many dst work item
s can be much faster for the src work item
threads.
Wait, but how do the threads in pass 2
know what dst and src work item to work on? They only have their dispatch thread id and there is no direct mapping from that to a dst work item
like there is in Solution 1
!
Each thread in pass 2
can easily find its MultiDstWorkItemInfo
index by iterating over all MultiDstWorkItemInfo
s, subtracting the work_item_count
from its dispatch thread index each time, until it gets negative. The last SrcWorkItemInfo
, for which the subtracted thread index is still >= 0
is the threads SrcWorkItemInfo
index.
If this is not fully clear, pick a thread in the following example and perform the search described above.
SrcWorkItemInfo src_work_item_infos[MAX_SRC_WORK_ITEMS] = {
{ 0, 3 },
{ 1, 1 },
{ 2, 2 }
}
pass 1 threads:
dispatch thread index: 0: src work item 0, local work item index: 0
dispatch thread index: 1: src work item 0, local work item index: 1
dispatch thread index: 2: src work item 0, local work item index: 2
dispatch thread index: 3: src work item 1, local work item index: 0
dispatch thread index: 4: src work item 2, local work item index: 0
dispatch thread index: 4: src work item 2, local work item index: 1
Linear search is very coherent for the pass 2
threads but also scales super poorly with more multi dst work item
s.
The better approach is to do the following:
- every
pass 1
thread writes out theirN
for theirsrc work item
into an array - build a prefix sum array over the
N
- every
pass 2
thread performs a binary search or theirdispatch thread index
in the prefix sum array, the found index is theMultiDstWorkItem
for that thread
The idea behind this is that the prefix sum value for each MultiDstWorkItem
represents the last dispatch thread index
for pass 2
. So the prefix sum maps pass 2
thread indices to the MultiDstWorkItem
s. Binary search is also a very good choice here as its very data and execution coherent for similar thread indices.
This solution is very good for most cases. It has "constant" time for writing expansion info for threads in pass 1
, a few low memory footprint of only MAX_SRC_WORK_ITEMS
and the search is relatively fast due to its very high data and execution coherence in pass 2
threads.
Generally this is the most commonly used solution for work expansion, as it is so simple to implement and its high performance.
Now, building a prefix sum is typically done in many passes and can require significant gpu time due to the barriers and work that has to be performed.
Luckily there is a trick to speed the prefix sum part up by a lot!
Instead of only atomically counting the number of needed MultiDstWorkItem
s in pass 1
threads, we can also build the prefix sum at the same time!
The 32 bit counter variable multi_dst_work_item_count
is bumped to a 64bit variable named multi_dst_work_item_count_sum
. The upper 32 bit of the new 64 bit variable sill contains the count of current multi dst work item
s but we put a payload in the lower 32 bit of the variable. In this payload we put the sum of all needed dst work items
(NOT multi dst work items
but the actual total amount of threads needed in pass 2
).
Using 64 bit atomics, we can perform an atomic compare exchange loop:
- read current value for
multi_dst_work_item_count_sum
- unpack
multi_dst_work_item_count
and sum of alldst work item
s - add 1 to
multi_dst_work_item_count
andN
to the sum of neededdst work item
s - pack both values into a 64 bit value
- try to write the packed value BUT ONLY if it didnt change in the meantime, using the atomic compare exchange
After performing this loop, the src work item
threads will write their MultiDstWorkItem
out. We also need the prefix sum value for the needed dst work items
inside the MultiDstWorkItem
s now:
uint src_work_item_idx;
uint work_item_count;
uint work_item_prefix_sum;
Now, doing this, we can build the prefix sum entirely within pass 1
with minimal extra performance overhead. But we can completely get rid of the multi-pass prefix sum generation, making the surrounding code simpler and faster.
Implementation Example Shader File
Now there is one more solution i want to present. It has slightly more cost in the pass 1
threads but does not need the binary search in the pass 2
threads.
Essentially, the this is an optimized version of Solution 1
. The threads in pass 1
write DstWorkItemInfo
s based on the src work items
N
. But each thread does not write N
DstWorkItemInfo
s, each thread writes log2(N
) infos.
struct DstWorkItemInfo
{
uint src_work_item_idx;
uint work_item_idx_offset;
// The work_item_count is now implicitly known by the bucket index
}
Instead of writing a DstWorkItemInfo
for each dst work item
, we bundle multiple infos into one.
In order for this to work, we need multiple DstWorkItemInfo
arrays and one dispatch for each. Each of these arrays is called a "bucket".
The trick here is that for each of the buckets, a DstWorkItemInfo
represents a different amounf of work items. The larger the bucket index, the more dst work items
a single DstWorkItemInfo
represents: numOfWorkItemsPerInfo = 2^bucketIndex
.
So for bucket 0
a DstWorkItemInfo
represents 1 dst work item
, for bucket 1
its 2, for bucket 2
its 4 and so on.
Now the mapping of dispatch thread index of a dst work item
thread is trivial again:
numOfDstWorkItemsPerInfo = 2^bucketIndex
dstWorkItemInfoIndex = dispatchThreadIndex / numOfDstWorkItemsPerInfo
DstWorkItemInfo info = bucket[dstWorkItemInfoIndex]
localWorkItemIndex = work_item_idx_offset + dispatchThreadIndex % numOfWorkItemsPerInfo
The work_item_idx_offset can also be inferred, if N can be loaded from a buffer with the
src work item
index.
An example:
We are looking at a thread in pass 1
that wants to spawn N
= 11 dst work items
.
Now we want to distribute the 11 dst work items
among the 32 buckets.
With binary numbers this is very trivial, each bit of N
represents a bucket!
So for N
= 11 (0x1011 in binary), the thread will append a DstWorkItemInfo
to the buckets 0, 2 and 3.
This solution has very nice scaling properties. For low N
, there will be few bits set in N
, so the overhead for pass 1
threads will be low. And for larger N
s the overhead becomes less important overall, considering that the pass 2
overall workload will be much greater in comparison to pass 1
.
The big downside of this solution is implementation complexity and to some extent also the memory requirements. It needs more memory then Solution 2
and is harder to get correct. The performance is better in most scenarios, as the search in the pass 2
threads is entirely eliminated.
Due to the higher memory use and the increase of dispatch indirects (x32 !!), this solution works well if few of these expansions are needed. But if many many expansions are needed at the same time, the memory use and frontend load from dispatches can become a problem.
NOTE: It is possible to merge all 32 dispatches to avoid dispatch overhead. This is done in the example code. It saves quite some fixed dispatch overhead (70mics on my RTX4080).
Solution 2.5 has low minimal overhead, but it scales in overhead with dst item count because of the binary search.
Solution 3 needs a lot more memory then 2.5 (depending on the max expansion count 8-16x more) but it consistently outperforms it as the work expansion lookup is much faster with po2 then with the binary search in 2.5 .
In the end the decision has to be made based a performance/memory trandeof.