Skip to content

Instantly share code, notes, and snippets.

@peterwmwong
Last active October 13, 2024 21:51
Show Gist options
  • Save peterwmwong/03773c9a3ff3a13000807a866a5eb56c to your computer and use it in GitHub Desktop.
Save peterwmwong/03773c9a3ff3a13000807a866a5eb56c to your computer and use it in GitHub Desktop.
Ridiculous atomic-less, simd reduction mesh shader for itty-bitty 32v/32t meshlets...
[[mesh, max_total_threads_per_threadgroup(32)]]
void select_triangles(
object_data SelectedMeshlets const & selected_meshlets [[payload]],
ushort const t_in_sg [[thread_index_in_simdgroup]],
ushort const tg_in_g [[threadgroup_position_in_grid]],
constant Model const & model [[buffer(0)]],
Mesh out
) {
Meshlet const meshlet = selected_meshlets.meshlets[tg_in_g];
half2 cached_pos_2d;
if (t_in_sg < meshlet.vertices_count()) {
uint const index = model.meshlet_vertices[meshlet.vertices_start() + t_in_sg];
float3 const position = float4((constant rgba16snorm<float4> &) model.encoded_vertex_positions[index]).xyz;
float4 const pos = selected_meshlets.model_to_camera_projection * float4(position, 1);
out.set_vertex(t_in_sg, { .position = pos });
cached_pos_2d = half2(pos.xy / pos.w);
}
constant packed_uchar3 const * tris = &model.meshlet_triangles[meshlet.triangles_start()];
uchar3 const tri = t_in_sg < meshlet.triangles_count() ? tris[t_in_sg] : uchar3(0);
__builtin_assume(tri.x < 32 && tri.y < 32 && tri.z < 32);
half2 const a = simd_shuffle(cached_pos_2d, tri.x);
half2 const b = simd_shuffle(cached_pos_2d, tri.y);
half2 const c = simd_shuffle(cached_pos_2d, tri.z);
if (!(t_in_sg < meshlet.triangles_count())) return;
if (!mimeo::is_tri_front_facing_2d(a, b, c)) return;
if (!mimeo::is_tri_within_frustum_ndc_2d(a, b, c)) return;
// TODO(0): Change back to `ushort`, workaround for FB15482904
short const primitive = simd_prefix_exclusive_sum(1);
out.set_primitive(primitive, MeshPrimitive {
.meshlet_id = selected_meshlets.meshlet_ids[tg_in_g],
});
ushort3 const indices = ushort3(primitive * 3u) + ushort3(0u, 1u, 2u);
out.set_index(indices.x, tri.x);
out.set_index(indices.y, tri.y);
out.set_index(indices.z, tri.z);
ushort const num_tris = simd_sum(1u);
if (simd_is_first()) out.set_primitive_count(num_tris);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment