Add direct dispatch for Bitonic Sort when using small input buffer

This commit is contained in:
Wojtek Figat
2025-08-08 17:03:39 +02:00
parent 0ea555b041
commit 519a9c0a14
5 changed files with 66 additions and 52 deletions

BIN
Content/Shaders/BitonicSort.flax (Stored with Git LFS)

Binary file not shown.

View File

@@ -940,8 +940,7 @@ void DrawEmittersGPU(RenderContextBatch& renderContextBatch)
const auto sortMode = (ParticleSortMode)module->Values[2].AsInt;
bool sortAscending = sortMode == ParticleSortMode::CustomAscending;
BitonicSort::Instance()->Sort(context, draw.Buffer->GPU.SortingKeysBuffer, draw.Buffer->GPU.Buffer, draw.Buffer->GPU.ParticleCounterOffset, sortAscending, draw.Buffer->GPU.SortedIndices, draw.Buffer->GPU.ParticlesCountMax);
// TODO: split sorted keys copy with another loop to give time for UAV transition
// TODO: use args buffer from GPUIndirectArgsBuffer instead of internal from BitonicSort to get rid of UAV barrier
// TODO: use args buffer from GPUIndirectArgsBuffer instead of internal from BitonicSort to get rid of UAV barrier (run all sorting in parallel)
}
}
}

View File

@@ -56,7 +56,7 @@ bool BitonicSort::setupResources()
// Cache compute shaders
_indirectArgsCS = shader->GetCS("CS_IndirectArgs");
_preSortCS = shader->GetCS("CS_PreSort");
_preSortCS.Get(shader, "CS_PreSort");
_innerSortCS = shader->GetCS("CS_InnerSort");
_outerSortCS = shader->GetCS("CS_OuterSort");
_copyIndicesCS = shader->GetCS("CS_CopyIndices");
@@ -73,7 +73,7 @@ void BitonicSort::Dispose()
SAFE_DELETE_GPU_RESOURCE(_dispatchArgsBuffer);
_cb = nullptr;
_indirectArgsCS = nullptr;
_preSortCS = nullptr;
_preSortCS.Clear();
_innerSortCS = nullptr;
_outerSortCS = nullptr;
_copyIndicesCS = nullptr;
@@ -86,8 +86,9 @@ void BitonicSort::Sort(GPUContext* context, GPUBuffer* sortingKeysBuffer, GPUBuf
if (checkIfSkipPass())
return;
PROFILE_GPU_CPU("Bitonic Sort");
const uint32 elementSizeBytes = sizeof(uint64);
const uint32 maxNumElements = maxElements != 0 ? maxElements : sortingKeysBuffer->GetSize() / elementSizeBytes;
uint32 maxNumElements = sortingKeysBuffer->GetSize() / sizeof(uint64);
if (maxElements > 0 && maxElements < maxNumElements)
maxNumElements = maxElements;
const uint32 alignedMaxNumElements = Math::RoundUpToPowerOf2(maxNumElements);
const uint32 maxIterations = (uint32)Math::Log2((float)Math::Max(2048u, alignedMaxNumElements)) - 10;
@@ -102,33 +103,44 @@ void BitonicSort::Sort(GPUContext* context, GPUBuffer* sortingKeysBuffer, GPUBuf
data.LoopJ = 0;
context->UpdateCB(_cb, &data);
context->BindCB(0, _cb);
// Generate execute indirect arguments
context->BindSR(0, countBuffer->View());
context->BindUA(0, _dispatchArgsBuffer->View());
context->Dispatch(_indirectArgsCS, 1, 1, 1);
// Pre-Sort the buffer up to k = 2048 (this also pads the list with invalid indices that will drift to the end of the sorted list)
context->BindUA(0, sortingKeysBuffer->View());
context->DispatchIndirect(_preSortCS, _dispatchArgsBuffer, 0);
// We have already pre-sorted up through k = 2048 when first writing our list, so we continue sorting with k = 4096
// For really large values of k, these indirect dispatches will be skipped over with thread counts of 0
uint32 indirectArgsOffset = sizeof(GPUDispatchIndirectArgs);
for (uint32 k = 4096; k <= alignedMaxNumElements; k *= 2)
// If item count is small we can do only presorting within a single dispatch thread group
if (maxNumElements <= 2048)
{
for (uint32 j = k / 2; j >= 2048; j /= 2)
{
data.LoopK = k;
data.LoopJ = j;
context->UpdateCB(_cb, &data);
// Use pre-sort with smaller thread group size (eg. for small particle emitters sorting)
const int32 permutation = maxNumElements < 128 ? 1 : 0;
context->BindUA(0, sortingKeysBuffer->View());
context->Dispatch(_preSortCS.Get(permutation), 1, 1, 1);
}
else
{
// Generate execute indirect arguments
context->BindUA(0, _dispatchArgsBuffer->View());
context->Dispatch(_indirectArgsCS, 1, 1, 1);
context->DispatchIndirect(_outerSortCS, _dispatchArgsBuffer, indirectArgsOffset);
// Pre-Sort the buffer up to k = 2048 (this also pads the list with invalid indices that will drift to the end of the sorted list)
context->BindUA(0, sortingKeysBuffer->View());
context->DispatchIndirect(_preSortCS.Get(0), _dispatchArgsBuffer, 0);
// We have already pre-sorted up through k = 2048 when first writing our list, so we continue sorting with k = 4096
// For really large values of k, these indirect dispatches will be skipped over with thread counts of 0
uint32 indirectArgsOffset = sizeof(GPUDispatchIndirectArgs);
for (uint32 k = 4096; k <= alignedMaxNumElements; k *= 2)
{
for (uint32 j = k / 2; j >= 2048; j /= 2)
{
data.LoopK = k;
data.LoopJ = j;
context->UpdateCB(_cb, &data);
context->DispatchIndirect(_outerSortCS, _dispatchArgsBuffer, indirectArgsOffset);
indirectArgsOffset += sizeof(GPUDispatchIndirectArgs);
}
context->DispatchIndirect(_innerSortCS, _dispatchArgsBuffer, indirectArgsOffset);
indirectArgsOffset += sizeof(GPUDispatchIndirectArgs);
}
context->DispatchIndirect(_innerSortCS, _dispatchArgsBuffer, indirectArgsOffset);
indirectArgsOffset += sizeof(GPUDispatchIndirectArgs);
}
context->ResetUA();

View File

@@ -18,7 +18,7 @@ private:
GPUBuffer* _dispatchArgsBuffer = nullptr;
GPUConstantBuffer* _cb;
GPUShaderProgramCS* _indirectArgsCS;
GPUShaderProgramCS* _preSortCS;
ComputeShaderPermutation<2> _preSortCS;
GPUShaderProgramCS* _innerSortCS;
GPUShaderProgramCS* _outerSortCS;
GPUShaderProgramCS* _copyIndicesCS;
@@ -46,7 +46,7 @@ public:
#if COMPILE_WITH_DEV_ENV
void OnShaderReloading(Asset* obj)
{
_preSortCS = nullptr;
_preSortCS.Clear();
_innerSortCS = nullptr;
_outerSortCS = nullptr;
invalidateResources();

View File

@@ -3,6 +3,10 @@
#include "./Flax/Common.hlsl"
#include "./Flax/Math.hlsl"
#ifndef THREAD_GROUP_SIZE
#define THREAD_GROUP_SIZE 1024
#endif
struct Item
{
float Key;
@@ -36,14 +40,14 @@ uint InsertOneBit(uint value, uint oneBitMask)
// (effectively a negation) or leave the value alone. When the KeySign is
// 1, we are sorting descending, so when A < B, they should swap. For an
// ascending sort, -A < -B should swap.
bool ShouldSwap(Item a, Item b, float keySign)
bool ShouldSwap(Item a, Item b)
{
//return (a ^ NullItem) < (b ^ NullItem);
//return (a.Key) < (b.Key);
return (a.Key * keySign) < (b.Key * keySign);
return (a.Key * KeySign) < (b.Key * KeySign);
//return asfloat(a) < asfloat(b);
//return (asfloat(a) * keySign) < (asfloat(b) * keySign);
//return (asfloat(a) * KeySign) < (asfloat(b) * KeySign);
}
#ifdef _CS_IndirectArgs
@@ -91,7 +95,7 @@ void CS_IndirectArgs(uint groupIndex : SV_GroupIndex)
RWStructuredBuffer<Item> SortBuffer : register(u0);
groupshared Item SortData[2048];
groupshared Item SortData[THREAD_GROUP_SIZE * 2];
void LoadItem(uint element, uint count)
{
@@ -106,7 +110,7 @@ void LoadItem(uint element, uint count)
item.Key = NullItemKey;
item.Value = NullItemValue;
}
SortData[element & 2047] = item;
SortData[element & (THREAD_GROUP_SIZE * 2 - 1)] = item;
}
void StoreItem(uint element, uint count)
@@ -122,23 +126,24 @@ void StoreItem(uint element, uint count)
#ifdef _CS_PreSort
META_CS(true, FEATURE_LEVEL_SM5)
[numthreads(1024, 1, 1)]
META_PERMUTATION_1(THREAD_GROUP_SIZE=1024)
META_PERMUTATION_1(THREAD_GROUP_SIZE=64)
[numthreads(THREAD_GROUP_SIZE, 1, 1)]
void CS_PreSort(uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex)
{
// Item index of the start of this group
const uint groupStart = groupID.x * 2048;
const uint groupStart = groupID.x * (THREAD_GROUP_SIZE * 2);
// Actual number of items that need sorting
const uint count = CounterBuffer.Load(CounterOffset);
LoadItem(groupStart + groupIndex, count);
LoadItem(groupStart + groupIndex + 1024, count);
LoadItem(groupStart + groupIndex + THREAD_GROUP_SIZE, count);
GroupMemoryBarrierWithGroupSync();
float keySign = KeySign;
UNROLL
for (uint k = 2; k <= 2048; k <<= 1)
for (uint k = 2; k <= THREAD_GROUP_SIZE * 2; k <<= 1)
{
for (uint j = k / 2; j > 0; j /= 2)
{
@@ -148,7 +153,7 @@ void CS_PreSort(uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex)
Item a = SortData[index1];
Item b = SortData[index2];
if (ShouldSwap(a, b, keySign))
if (ShouldSwap(a, b))
{
// Swap the items
SortData[index1] = b;
@@ -161,7 +166,7 @@ void CS_PreSort(uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex)
// Write sorted results to memory
StoreItem(groupStart + groupIndex, count);
StoreItem(groupStart + groupIndex + 1024, count);
StoreItem(groupStart + groupIndex + THREAD_GROUP_SIZE, count);
}
#endif
@@ -169,23 +174,22 @@ void CS_PreSort(uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex)
#ifdef _CS_InnerSort
META_CS(true, FEATURE_LEVEL_SM5)
[numthreads(1024, 1, 1)]
[numthreads(THREAD_GROUP_SIZE, 1, 1)]
void CS_InnerSort(uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex)
{
const uint count = CounterBuffer.Load(CounterOffset);
// Item index of the start of this group
const uint groupStart = groupID.x * 2048;
const uint groupStart = groupID.x * (THREAD_GROUP_SIZE * 2);
// Load from memory into LDS to prepare sort
LoadItem(groupStart + groupIndex, count);
LoadItem(groupStart + groupIndex + 1024, count);
LoadItem(groupStart + groupIndex + THREAD_GROUP_SIZE, count);
GroupMemoryBarrierWithGroupSync();
float keySign = KeySign;
UNROLL
for (uint j = 1024; j > 0; j /= 2)
for (uint j = THREAD_GROUP_SIZE; j > 0; j /= 2)
{
uint index2 = InsertOneBit(groupIndex, j);
uint index1 = index2 ^ j;
@@ -193,7 +197,7 @@ void CS_InnerSort(uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex)
Item a = SortData[index1];
Item b = SortData[index2];
if (ShouldSwap(a, b, keySign))
if (ShouldSwap(a, b))
{
// Swap the items
SortData[index1] = b;
@@ -204,7 +208,7 @@ void CS_InnerSort(uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex)
}
StoreItem(groupStart + groupIndex, count);
StoreItem(groupStart + groupIndex + 1024, count);
StoreItem(groupStart + groupIndex + THREAD_GROUP_SIZE, count);
}
#endif
@@ -229,8 +233,7 @@ void CS_OuterSort(uint3 dispatchThreadId : SV_DispatchThreadID)
Item a = SortBuffer[index1];
Item b = SortBuffer[index2];
float keySign = KeySign;
if (ShouldSwap(a, b, keySign))
if (ShouldSwap(a, b))
{
// Swap the items
SortBuffer[index1] = b;