Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save wiseConst/58ff820b3cb1b5569fb695a539f7a0e4 to your computer and use it in GitHub Desktop.
Save wiseConst/58ff820b3cb1b5569fb695a539f7a0e4 to your computer and use it in GitHub Desktop.
Efficient GPU Work Expansion

What is "Work Expansion"

In a gpu-driven renderer "work expansion" is a commonly occuring problem. "Work Expansion" means that a single item of work spawns N following work items. Typically one work item will be executed by one shader thread/invocation.

An example for work expansion is gpu driven meshlet culling following mesh culling. In this example a "work item" is culling a mesh, where each mesh cull work item spawns N following meshlet cull work items.

There are many diverse cases of this problem and many solutions. Some are trivial to solve, for example, when N (how many work items are spawned) is fixed. Simply dispatch N times more threads for the second pass than for the first pass, easy!

But those cases are not intresting. In this gist i want to address complexer cases where N (how many work items get spawned) can vary a lot.

Terminology

  • pass 1: the first pass. Each work item in this pass spawns dst work items for pass 2. In the culling example this would be the mesh culling
  • pass 2: the second pass. Each dst work item in pass 2 is spawned from src work items in the pass 1. In the culling example this would be the meshlet culling.
  • P1: number of src work items in pass 1
  • P2: number of dst work items in pass 2
  • src work item: work item in pass 1 that spawns new work items for pass 2.
  • dst work item: work item in pass 2, spawned from a src work item
  • N: number of dst work items spawned by a src work item
  • MAX_N: maximum work items that each work item in pass 1 can spawn
  • MAX_DST_WORK_ITEMS: maximum number of TOTAL dst work items
  • MAX_SRC_WORK_ITEMS: maximum number of TOTAL src work items
  • local/dst work item index: a src work item spawns N dst work items, their local work item indices will be 0,1,2,...N-1.

Gfx API Solutions

The latest Graphics APIs actually have a set of solutions to these problems. For example DX12 Workgraphs. With "Work Graphs", each thread can immediately start new threads within the same dispatch! This directly solves most work expansion problems. With work graphs, we dont even need two passes, we could do it all in one!

Another example is the Mesh Shader Dispatch API. In vulkan and dx12, the mesh cull threads can simple append to a buffer of dispatch indirect structs using an atomic counter and then call vkCmdDrawMeshTasksIndirectCountEXT. These allow for a gpu driven count to be written by shaders. But the real power of mesh shaders are the task/amplification shader stage. It allows the task/amplification shader stage to direcly launch N runtime determined mesh shader work groups. Effectively this is a very limited form of work graphs. This is a great solution for meshlet culling and lod selection.

Lastly there is Device Generated Commands. With this extension, we can simply insert new dispatch indirects for spawned work items directly into the command buffer!

The OG is of course the dispatch indirect command. This command allows shaders to write a dispatch command to a buffer that is then read by the gpu to determine the size of a following dispatch. This solution is great and is used extensively in gpu driven renderers. With some extra work in shaders we can get a lot of efficient work expansion done with this command. But the command alone is very barebones and does not deal with most expansion problem directly (and thats totally fine!).

Ideally we would also have something like "dispatch indirect count" for other work expansion problems, sadly no gfx api has this feature to my knowledge.

All these api solutions are great, but they have downsides:

  • no "dispatch indirect count", its possible to abuse task/amplification -> mesh shader expansion but that is very akward.
  • work graphs only run on very recent gpus and have very poor performance on nvidia.
  • device generated commands put a lot of pressure on the frontend of the gpu. Inserting many thousands of dispatches can be slow.
  • thread granularity and dispatch overhead for small Ns. Dispatches have a non-insignificant fixed minimal performance overhead on gpus. Inserting many many small dispatches via DGC can be very slow as the gpus frontend wont be able to process them fast enough for many small workloads.
  • thread granulariy within dispatches. GPU dispatches are made to work on many work items, workgroup sizes should typically be at least 64 threads and optimally there should also be many workgroups. Having small dispatches, that use way less then say 64 threads, will leave a lot of the gpus hardware wasted. This is the case with vkCmdDrawMeshTasksIndirectCountEXT and DGC.

By far the best solution out of these is work graphs. But as its so recent, most gpus cant run them at all and nvidia only poorly. So they are not yet a generally applicable solution. So lets look over software solutions.

Software Solutions

- Solution 0 (flat expansion):

Dispatch P1 threads in pass 1 and P1 * MAX_N threads for pass 2. Each thread working on a single item.

Each thread in pass 1 will know its src work item by its global dispatch index. Each thread in pass 2 will know its src work item by dispatchThreadIdx / MAX_N and its local work item index by dispatchThreadIdx % MAX_N

Now, not all all work items in pass 1 will spawn MAX_N work items for pass 2, yet we spawn MAX_N work threads for each unconditionally. To solve this, we can can write a buffer for each item in pass 1, describing how many work items are desired. The threads in pass 2 would read that buffer and early out if their relative thread index is greater then the work item count.

This is a good solution when most src work items want to spawn close to MAX_N dst work items.

It is a very poor solution, when a lot or most dst work items spawn lot lot less dst work items then MAX_N. In most cases, such as the mesh and meshlet culling example, N varies a lot for each src work item.

- Solution 1 (simple dispatch indirect):

Dispatch P1 threads in pass 1 and perform one dispatch indirect threads for pass 2. Each thread working on a single item.

We have a buffer containing a dst work item counter, a dispatch indirect struct and and array of dst work item infos:

struct DstWorkItemInfo
{
  uint src_work_item_idx;
  uint local_work_item_idx;
}

struct Buffer
{
  uint3 pass2_dispatch_indirect;
  uint dst_work_item_count;
  DstWorkItemInfo dst_work_item_infos[MAX_DST_WORK_ITEMS];
}

Each thread in pass 1 will perform atomic operations to update pass2_dispatch_indirect and dst_work_item_count, using the returned value from the atomic add on dst_work_item_count as an offset to write its DstWorkItemInfos to. Each thread will then loop N times, writing N DstWorkItemInfos.

Each thread in pass 2 will use its dispatch thread index to index the buffers dst_work_item_infos, to know what to work on.

This solution is very good when N is mostly very small (Lets say smaller then ~16) and MAX_N is also very small.

This solution is very bad when N varies a lot and gets very big. This is because the threads in pass 1 will have to perform a a write in a loop N times. For large Ns this is very slow and can become very divergent, hurting performance.

- Solution 2 (prefix sum and binary search):

Dispatch P1 threads in pass 1 and perform one dispatch indirect threads for pass 2. Each thread working on a single item.

Now it gets a little complicated. Like in Solution 1 we have a Buffer with a value, counting the number of needed dst work items and a dispatch indirect struct. But we have no array of dst work item infos.

Instead of writing a single dst work item info for every dst work item, we write a list of multi dst work item infos and make the threads in pass 2 search for their src work item.

struct MultiDstWorkItemInfo
{
  uint src_work_item_idx;
  uint work_item_count; // Count instead of single index!
}

struct Buffer
{
  uint3 pass2_dispatch_indirect;
  uint multi_dst_work_item_count;
  MultiDstWorkItemInfo multi_dst_work_item_infos[MAX_SRC_WORK_ITEMS];
}

Writing only one multi dst work item instead of many dst work items can be much faster for the src work item threads.

Wait, but how do the threads in pass 2 know what dst and src work item to work on? They only have their dispatch thread id and there is no direct mapping from that to a dst work item like there is in Solution 1!

Each thread in pass 2 can easily find its MultiDstWorkItemInfo index by iterating over all MultiDstWorkItemInfos, subtracting the work_item_count from its dispatch thread index each time, until it gets negative. The last SrcWorkItemInfo, for which the subtracted thread index is still >= 0 is the threads SrcWorkItemInfo index.

If this is not fully clear, pick a thread in the following example and perform the search described above.

SrcWorkItemInfo src_work_item_infos[MAX_SRC_WORK_ITEMS] = {
  { 0, 3 },
  { 1, 1 },
  { 2, 2 }
}

pass 1 threads:
dispatch thread index: 0: src work item 0, local work item index: 0
dispatch thread index: 1: src work item 0, local work item index: 1
dispatch thread index: 2: src work item 0, local work item index: 2
dispatch thread index: 3: src work item 1, local work item index: 0
dispatch thread index: 4: src work item 2, local work item index: 0
dispatch thread index: 4: src work item 2, local work item index: 1

Linear search is very coherent for the pass 2 threads but also scales super poorly with more multi dst work items.

The better approach is to do the following:

  • every pass 1 thread writes out their N for their src work item into an array
  • build a prefix sum array over the N
  • every pass 2 thread performs a binary search or their dispatch thread index in the prefix sum array, the found index is the MultiDstWorkItem for that thread

The idea behind this is that the prefix sum value for each MultiDstWorkItem represents the last dispatch thread index for pass 2. So the prefix sum maps pass 2 thread indices to the MultiDstWorkItems. Binary search is also a very good choice here as its very data and execution coherent for similar thread indices.

This solution is very good for most cases. It has "constant" time for writing expansion info for threads in pass 1, a few low memory footprint of only MAX_SRC_WORK_ITEMS and the search is relatively fast due to its very high data and execution coherence in pass 2 threads.

Generally this is the most commonly used solution for work expansion, as it is so simple to implement and its high performance.

Solution 2.5 (Turbo Charging Solution 2 With 64-BIT Atomics)

Now, building a prefix sum is typically done in many passes and can require significant gpu time due to the barriers and work that has to be performed.

Luckily there is a trick to speed the prefix sum part up by a lot!

Instead of only atomically counting the number of needed MultiDstWorkItems in pass 1 threads, we can also build the prefix sum at the same time!

The 32 bit counter variable multi_dst_work_item_count is bumped to a 64bit variable named multi_dst_work_item_count_sum. The upper 32 bit of the new 64 bit variable sill contains the count of current multi dst work items but we put a payload in the lower 32 bit of the variable. In this payload we put the sum of all needed dst work items (NOT multi dst work items but the actual total amount of threads needed in pass 2).

Using 64 bit atomics, we can perform an atomic compare exchange loop:

  • read current value for multi_dst_work_item_count_sum
  • unpack multi_dst_work_item_count and sum of all dst work items
  • add 1 to multi_dst_work_item_count and N to the sum of needed dst work items
  • pack both values into a 64 bit value
  • try to write the packed value BUT ONLY if it didnt change in the meantime, using the atomic compare exchange

After performing this loop, the src work item threads will write their MultiDstWorkItem out. We also need the prefix sum value for the needed dst work items inside the MultiDstWorkItems now:

  uint src_work_item_idx;
  uint work_item_count;
  uint work_item_prefix_sum;

Now, doing this, we can build the prefix sum entirely within pass 1 with minimal extra performance overhead. But we can completely get rid of the multi-pass prefix sum generation, making the surrounding code simpler and faster.

- Solution 3 (power of two expansion, aka. Total Bikeshed)

Implementation Example Shader File

Now there is one more solution i want to present. It has slightly more cost in the pass 1 threads but does not need the binary search in the pass 2 threads.

Essentially, the this is an optimized version of Solution 1. The threads in pass 1 write DstWorkItemInfos based on the src work items N. But each thread does not write N DstWorkItemInfos, each thread writes log2(N) infos.

struct DstWorkItemInfo
{
  uint src_work_item_idx;
  uint work_item_idx_offset;
  // The work_item_count is now implicitly known by the bucket index
}

Instead of writing a DstWorkItemInfo for each dst work item, we bundle multiple infos into one.

In order for this to work, we need multiple DstWorkItemInfo arrays and one dispatch for each. Each of these arrays is called a "bucket".

The trick here is that for each of the buckets, a DstWorkItemInfo represents a different amounf of work items. The larger the bucket index, the more dst work items a single DstWorkItemInfo represents: numOfWorkItemsPerInfo = 2^bucketIndex.

So for bucket 0 a DstWorkItemInfo represents 1 dst work item, for bucket 1 its 2, for bucket 2 its 4 and so on.

Now the mapping of dispatch thread index of a dst work item thread is trivial again:

numOfDstWorkItemsPerInfo = 2^bucketIndex
dstWorkItemInfoIndex = dispatchThreadIndex / numOfDstWorkItemsPerInfo
DstWorkItemInfo info = bucket[dstWorkItemInfoIndex]
localWorkItemIndex = work_item_idx_offset + dispatchThreadIndex % numOfWorkItemsPerInfo

The work_item_idx_offset can also be inferred, if N can be loaded from a buffer with the src work item index.

An example: We are looking at a thread in pass 1 that wants to spawn N = 11 dst work items. Now we want to distribute the 11 dst work items among the 32 buckets. With binary numbers this is very trivial, each bit of N represents a bucket!

So for N = 11 (0x1011 in binary), the thread will append a DstWorkItemInfo to the buckets 0, 2 and 3.

This solution has very nice scaling properties. For low N, there will be few bits set in N, so the overhead for pass 1 threads will be low. And for larger Ns the overhead becomes less important overall, considering that the pass 2 overall workload will be much greater in comparison to pass 1.

The big downside of this solution is implementation complexity and to some extent also the memory requirements. It needs more memory then Solution 2 and is harder to get correct. The performance is better in most scenarios, as the search in the pass 2 threads is entirely eliminated.

Due to the higher memory use and the increase of dispatch indirects (x32 !!), this solution works well if few of these expansions are needed. But if many many expansions are needed at the same time, the memory use and frontend load from dispatches can become a problem.

NOTE: It is possible to merge all 32 dispatches to avoid dispatch overhead. This is done in the example code. It saves quite some fixed dispatch overhead (70mics on my RTX4080).

Conclusion and comparison

Solution 2.5 has low minimal overhead, but it scales in overhead with dst item count because of the binary search.

Solution 3 needs a lot more memory then 2.5 (depending on the max expansion count 8-16x more) but it consistently outperforms it as the work expansion lookup is much faster with po2 then with the binary search in 2.5 .

In the end the decision has to be made based a performance/memory trandeof.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment