Optimize GPU particles Bitonic sort to use separate buffers for indices and keys to avoid additional buffer copy

This commit is contained in:
Wojtek Figat
2025-08-08 18:24:44 +02:00
parent 519a9c0a14
commit 854f3acd4c
9 changed files with 115 additions and 169 deletions

View File

@@ -8,7 +8,7 @@
GPU_CB_STRUCT(Data {
float NullItemKey;
uint32 NullItemValue;
uint32 NullItemIndex;
uint32 CounterOffset;
uint32 MaxIterations;
uint32 LoopK;
@@ -47,7 +47,6 @@ bool BitonicSort::Init()
bool BitonicSort::setupResources()
{
// Check if shader has not been loaded
if (!_shader->IsLoaded())
return true;
const auto shader = _shader->GetShader();
@@ -59,14 +58,12 @@ bool BitonicSort::setupResources()
_preSortCS.Get(shader, "CS_PreSort");
_innerSortCS = shader->GetCS("CS_InnerSort");
_outerSortCS = shader->GetCS("CS_OuterSort");
_copyIndicesCS = shader->GetCS("CS_CopyIndices");
return false;
}
void BitonicSort::Dispose()
{
// Base
RendererPass::Dispose();
// Cleanup
@@ -76,17 +73,16 @@ void BitonicSort::Dispose()
_preSortCS.Clear();
_innerSortCS = nullptr;
_outerSortCS = nullptr;
_copyIndicesCS = nullptr;
_shader = nullptr;
}
void BitonicSort::Sort(GPUContext* context, GPUBuffer* sortingKeysBuffer, GPUBuffer* countBuffer, uint32 counterOffset, bool sortAscending, GPUBuffer* sortedIndicesBuffer, uint32 maxElements)
void BitonicSort::Sort(GPUContext* context, GPUBuffer* indicesBuffer, GPUBuffer* keysBuffer, GPUBuffer* countBuffer, uint32 counterOffset, bool sortAscending, int32 maxElements)
{
ASSERT(context && sortingKeysBuffer && countBuffer);
ASSERT(context && indicesBuffer && keysBuffer && countBuffer);
if (checkIfSkipPass())
return;
PROFILE_GPU_CPU("Bitonic Sort");
uint32 maxNumElements = sortingKeysBuffer->GetSize() / sizeof(uint64);
uint32 maxNumElements = indicesBuffer->GetElementsCount();
if (maxElements > 0 && maxElements < maxNumElements)
maxNumElements = maxElements;
const uint32 alignedMaxNumElements = Math::RoundUpToPowerOf2(maxNumElements);
@@ -96,7 +92,7 @@ void BitonicSort::Sort(GPUContext* context, GPUBuffer* sortingKeysBuffer, GPUBuf
Data data;
data.CounterOffset = counterOffset;
data.NullItemKey = sortAscending ? MAX_float : -MAX_float;
data.NullItemValue = 0;
data.NullItemIndex = 0;
data.KeySign = sortAscending ? -1.0f : 1.0f;
data.MaxIterations = maxIterations;
data.LoopK = 0;
@@ -110,7 +106,8 @@ void BitonicSort::Sort(GPUContext* context, GPUBuffer* sortingKeysBuffer, GPUBuf
{
// Use pre-sort with smaller thread group size (eg. for small particle emitters sorting)
const int32 permutation = maxNumElements < 128 ? 1 : 0;
context->BindUA(0, sortingKeysBuffer->View());
context->BindUA(0, indicesBuffer->View());
context->BindUA(1, keysBuffer->View());
context->Dispatch(_preSortCS.Get(permutation), 1, 1, 1);
}
else
@@ -120,7 +117,8 @@ void BitonicSort::Sort(GPUContext* context, GPUBuffer* sortingKeysBuffer, GPUBuf
context->Dispatch(_indirectArgsCS, 1, 1, 1);
// Pre-Sort the buffer up to k = 2048 (this also pads the list with invalid indices that will drift to the end of the sorted list)
context->BindUA(0, sortingKeysBuffer->View());
context->BindUA(0, indicesBuffer->View());
context->BindUA(1, keysBuffer->View());
context->DispatchIndirect(_preSortCS.Get(0), _dispatchArgsBuffer, 0);
// We have already pre-sorted up through k = 2048 when first writing our list, so we continue sorting with k = 4096
@@ -144,27 +142,4 @@ void BitonicSort::Sort(GPUContext* context, GPUBuffer* sortingKeysBuffer, GPUBuf
}
context->ResetUA();
if (sortedIndicesBuffer)
{
// Copy indices to another buffer
#if !BUILD_RELEASE
switch (sortedIndicesBuffer->GetDescription().Format)
{
case PixelFormat::R32_UInt:
case PixelFormat::R16_UInt:
case PixelFormat::R8_UInt:
break;
default:
LOG(Warning, "Invalid format {0} of sortedIndicesBuffer for BitonicSort. It needs to be UInt type.", (int32)sortedIndicesBuffer->GetDescription().Format);
}
#endif
context->BindSR(1, sortingKeysBuffer->View());
context->BindUA(0, sortedIndicesBuffer->View());
// TODO: use indirect dispatch to match the items count for copy
context->Dispatch(_copyIndicesCS, (alignedMaxNumElements + 1023) / 1024, 1, 1);
}
context->ResetUA();
context->ResetSR();
}