Optimize ProbesTrace texture in DDGI to use batched probes update

This commit is contained in:
Wojciech Figat
2022-06-10 10:39:46 +02:00
parent 3a8e5e0bbe
commit 3b27ae5fa9
2 changed files with 80 additions and 71 deletions

View File

@@ -31,6 +31,7 @@
// "Dynamic Diffuse Global Illumination with Ray-Traced Irradiance Fields", https://jcgt.org/published/0008/02/01/
// This must match HLSL
#define DDGI_TRACE_RAYS_PROBES_COUNT_LIMIT 4096 // Maximum amount of probes to update at once during rays tracing and blending
#define DDGI_TRACE_RAYS_GROUP_SIZE_X 32
#define DDGI_TRACE_RAYS_LIMIT 512 // Limit of rays per-probe (runtime value can be smaller)
#define DDGI_PROBE_RESOLUTION_IRRADIANCE 6 // Resolution (in texels) for probe irradiance data (excluding 1px padding on each side)
@@ -52,8 +53,10 @@ PACK_STRUCT(struct Data0
PACK_STRUCT(struct Data1
{
Vector3 Padding1;
uint32 CascadeIndex; // TODO: use push constants on Vulkan or root signature data on DX12 to reduce overhead of changing single DWORD
// TODO: use push constants on Vulkan or root signature data on DX12 to reduce overhead of changing single DWORD
Vector2 Padding1;
uint32 CascadeIndex;
uint32 ProbeIndexOffset;
});
class DDGICustomBuffer : public RenderBuffers::CustomBuffer
@@ -335,7 +338,7 @@ bool DynamicDiffuseGlobalIlluminationPass::Render(RenderContext& renderContext,
// TODO rethink probes data placement in memory -> what if we get [50x50x30] resolution? That's 75000 probes! Use sparse storage with active-only probes
#define INIT_TEXTURE(texture, format, width, height) desc.Format = format; desc.Width = width; desc.Height = height; ddgiData.texture = RenderTargetPool::Get(desc); if (!ddgiData.texture) return true; memUsage += ddgiData.texture->GetMemoryUsage()
desc.Flags = GPUTextureFlags::ShaderResource | GPUTextureFlags::UnorderedAccess;
INIT_TEXTURE(ProbesTrace, PixelFormat::R16G16B16A16_Float, probeRaysCount, probesCountTotal); // TODO: limit to 4k probes for a single batch to trace
INIT_TEXTURE(ProbesTrace, PixelFormat::R16G16B16A16_Float, probeRaysCount, Math::Min(probesCountCascade, DDGI_TRACE_RAYS_PROBES_COUNT_LIMIT));
INIT_TEXTURE(ProbesState, PixelFormat::R16G16B16A16_Float, probesCountTotalX, probesCountTotalY); // TODO: optimize to a RGBA32 (pos offset can be normalized to [0-0.5] range of ProbesSpacing and packed with state flag)
INIT_TEXTURE(ProbesIrradiance, PixelFormat::R11G11B10_Float, probesCountTotalX * (DDGI_PROBE_RESOLUTION_IRRADIANCE + 2), probesCountTotalY * (DDGI_PROBE_RESOLUTION_IRRADIANCE + 2));
INIT_TEXTURE(ProbesDistance, PixelFormat::R16G16_Float, probesCountTotalX * (DDGI_PROBE_RESOLUTION_DISTANCE + 2), probesCountTotalY * (DDGI_PROBE_RESOLUTION_DISTANCE + 2));
@@ -481,59 +484,57 @@ bool DynamicDiffuseGlobalIlluminationPass::Render(RenderContext& renderContext,
if (cascadeSkipUpdate[cascadeIndex])
continue;
anyDirty = true;
Data1 data;
data.CascadeIndex = cascadeIndex;
context->UpdateCB(_cb1, &data);
context->BindCB(1, _cb1);
// TODO: run probes tracing+update in 4k batches
// Trace rays from probes
// Update probes in batches so ProbesTrace texture can be smaller
for (int32 probesOffset = 0; probesOffset < probesCountCascade; probesOffset += DDGI_TRACE_RAYS_PROBES_COUNT_LIMIT)
{
PROFILE_GPU_CPU("Trace Rays");
uint32 probesBatchSize = Math::Min(probesCountCascade - probesOffset, DDGI_TRACE_RAYS_PROBES_COUNT_LIMIT);
Data1 data;
data.CascadeIndex = cascadeIndex;
data.ProbeIndexOffset = probesOffset;
context->UpdateCB(_cb1, &data);
context->BindCB(1, _cb1);
// Global SDF with Global Surface Atlas software raytracing (thread X - per probe ray, thread Y - per probe)
ASSERT_LOW_LAYER((probeRaysCount % DDGI_TRACE_RAYS_GROUP_SIZE_X) == 0);
for (int32 i = 0; i < 4; i++)
// Trace rays from probes
{
context->BindSR(i, bindingDataSDF.Cascades[i]->ViewVolume());
context->BindSR(i + 4, bindingDataSDF.CascadeMips[i]->ViewVolume());
}
context->BindSR(8, bindingDataSurfaceAtlas.Chunks ? bindingDataSurfaceAtlas.Chunks->View() : nullptr);
context->BindSR(9, bindingDataSurfaceAtlas.CulledObjects ? bindingDataSurfaceAtlas.CulledObjects->View() : nullptr);
context->BindSR(10, bindingDataSurfaceAtlas.AtlasDepth->View());
context->BindSR(11, bindingDataSurfaceAtlas.AtlasLighting->View());
context->BindSR(12, ddgiData.Result.ProbesState);
context->BindSR(13, skybox);
context->BindUA(0, ddgiData.ProbesTrace->View());
context->Dispatch(_csTraceRays, probeRaysCount / DDGI_TRACE_RAYS_GROUP_SIZE_X, probesCountCascade, 1);
context->ResetUA();
context->ResetSR();
PROFILE_GPU_CPU("Trace Rays");
// Global SDF with Global Surface Atlas software raytracing (thread X - per probe ray, thread Y - per probe)
ASSERT_LOW_LAYER((probeRaysCount % DDGI_TRACE_RAYS_GROUP_SIZE_X) == 0);
for (int32 i = 0; i < 4; i++)
{
context->BindSR(i, bindingDataSDF.Cascades[i]->ViewVolume());
context->BindSR(i + 4, bindingDataSDF.CascadeMips[i]->ViewVolume());
}
context->BindSR(8, bindingDataSurfaceAtlas.Chunks ? bindingDataSurfaceAtlas.Chunks->View() : nullptr);
context->BindSR(9, bindingDataSurfaceAtlas.CulledObjects ? bindingDataSurfaceAtlas.CulledObjects->View() : nullptr);
context->BindSR(10, bindingDataSurfaceAtlas.AtlasDepth->View());
context->BindSR(11, bindingDataSurfaceAtlas.AtlasLighting->View());
context->BindSR(12, ddgiData.Result.ProbesState);
context->BindSR(13, skybox);
context->BindUA(0, ddgiData.ProbesTrace->View());
context->Dispatch(_csTraceRays, probeRaysCount / DDGI_TRACE_RAYS_GROUP_SIZE_X, probesBatchSize, 1);
context->ResetUA();
context->ResetSR();
#if 0
// Probes trace debug preview
context->SetViewportAndScissors(renderContext.View.ScreenSize.X, renderContext.View.ScreenSize.Y);
context->SetRenderTarget(lightBuffer);
context->Draw(ddgiData.ProbesTrace);
return false;
// Probes trace debug preview
context->SetViewportAndScissors(renderContext.View.ScreenSize.X, renderContext.View.ScreenSize.Y);
context->SetRenderTarget(lightBuffer);
context->Draw(ddgiData.ProbesTrace);
return false;
#endif
}
}
context->BindSR(0, ddgiData.Result.ProbesState);
context->BindSR(1, ddgiData.ProbesTrace->View());
// Update probes irradiance texture
{
PROFILE_GPU_CPU("Update Irradiance");
context->BindUA(0, ddgiData.Result.ProbesIrradiance);
context->Dispatch(_csUpdateProbesIrradiance, probesCountCascadeX, probesCountCascadeY, 1);
}
// Update probes distance texture
{
PROFILE_GPU_CPU("Update Distance");
context->BindUA(0, ddgiData.Result.ProbesDistance);
context->Dispatch(_csUpdateProbesDistance, probesCountCascadeX, probesCountCascadeY, 1);
// Update probes irradiance and distance textures (one thread-group per probe)
{
PROFILE_GPU_CPU("Update Probes");
context->BindSR(0, ddgiData.Result.ProbesState);
context->BindSR(1, ddgiData.ProbesTrace->View());
context->BindUA(0, ddgiData.Result.ProbesIrradiance);
context->Dispatch(_csUpdateProbesIrradiance, probesBatchSize, 1, 1);
context->BindUA(0, ddgiData.Result.ProbesDistance);
context->Dispatch(_csUpdateProbesDistance, probesBatchSize, 1, 1);
}
}
}

View File

@@ -34,8 +34,9 @@ float2 Padding0;
META_CB_END
META_CB_BEGIN(1, Data1)
float3 Padding1;
float2 Padding1;
uint CascadeIndex;
uint ProbeIndexOffset;
META_CB_END
// Calculates the evenly distributed direction ray on a sphere (Spherical Fibonacci lattice)
@@ -149,10 +150,10 @@ TextureCube Skybox : register(t13);
// Compute shader for tracing rays for probes using Global SDF and Global Surface Atlas.
META_CS(true, FEATURE_LEVEL_SM5)
[numthreads(DDGI_TRACE_RAYS_GROUP_SIZE_X, 1, 1)]
void CS_TraceRays(uint3 GroupId : SV_GroupID, uint3 DispatchThreadId : SV_DispatchThreadID, uint3 GroupThreadId : SV_GroupThreadID)
void CS_TraceRays(uint3 DispatchThreadId : SV_DispatchThreadID)
{
uint rayIndex = DispatchThreadId.x;
uint probeIndex = DispatchThreadId.y;
uint probeIndex = DispatchThreadId.y + ProbeIndexOffset;
uint3 probeCoords = GetDDGIProbeCoords(DDGI, probeIndex);
probeIndex = GetDDGIScrollingProbeIndex(DDGI, CascadeIndex, probeCoords);
@@ -196,7 +197,7 @@ void CS_TraceRays(uint3 GroupId : SV_GroupID, uint3 DispatchThreadId : SV_Dispat
}
// Write into probes trace results
RWProbesTrace[uint2(rayIndex, probeIndex)] = radiance;
RWProbesTrace[uint2(rayIndex, DispatchThreadId.y)] = radiance;
}
#endif
@@ -223,24 +224,28 @@ META_CS(true, FEATURE_LEVEL_SM5)
META_PERMUTATION_1(DDGI_PROBE_UPDATE_MODE=0)
META_PERMUTATION_1(DDGI_PROBE_UPDATE_MODE=1)
[numthreads(DDGI_PROBE_RESOLUTION, DDGI_PROBE_RESOLUTION, 1)]
void CS_UpdateProbes(uint3 DispatchThreadId : SV_DispatchThreadID, uint GroupIndex : SV_GroupIndex)
void CS_UpdateProbes(uint3 GroupThreadId : SV_GroupThreadID, uint3 GroupId : SV_GroupID, uint GroupIndex : SV_GroupIndex)
{
// GroupThreadId.xy - coordinates of the probe texel: [0; DDGI_PROBE_RESOLUTION)
// GroupId.x - index of the thread group which is probe index within a batch: [0; batchSize)
// GroupIndex.x - index of the thread within a thread group: [0; DDGI_PROBE_RESOLUTION * DDGI_PROBE_RESOLUTION)
// Get probe index and atlas location in the atlas
uint probeIndex = GetDDGIProbeIndex(DDGI, DispatchThreadId.xy, DDGI_PROBE_RESOLUTION);
uint probesCount = DDGI.ProbesCounts.x * DDGI.ProbesCounts.y * DDGI.ProbesCounts.z;
bool skip = probeIndex >= probesCount;
uint2 outputCoords = uint2(1, 1) + DispatchThreadId.xy + (DispatchThreadId.xy / DDGI_PROBE_RESOLUTION) * 2;
outputCoords.y += CascadeIndex * DDGI.ProbesCounts.z * (DDGI_PROBE_RESOLUTION + 2);
// Clear probes that have been scrolled to a new positions (blending with current irradiance will happen the next frame)
uint probeIndex = GroupId.x + ProbeIndexOffset;
uint3 probeCoords = GetDDGIProbeCoords(DDGI, probeIndex);
probeIndex = GetDDGIScrollingProbeIndex(DDGI, CascadeIndex, probeCoords);
probeCoords = GetDDGIProbeCoords(DDGI, probeIndex);
uint2 outputCoords = GetDDGIProbeTexelCoords(DDGI, CascadeIndex, probeIndex) * (DDGI_PROBE_RESOLUTION + 2) + 1 + GroupThreadId.xy;
// Clear probes that have been scrolled to a new positions (blending with current irradiance will happen the next frame)
int3 probesScrollOffsets = DDGI.ProbesScrollOffsets[CascadeIndex].xyz;
int probeScrollClear = DDGI.ProbesScrollOffsets[CascadeIndex].w;
int3 probeScrollDirections = DDGI.ProbeScrollDirections[CascadeIndex].xyz;
bool skip = false;
UNROLL
for (uint planeIndex = 0; planeIndex < 3; planeIndex++)
{
if (probeScrollClear & (1 << planeIndex) && !skip)
if (probeScrollClear & (1 << planeIndex))
{
int scrollOffset = probesScrollOffsets[planeIndex];
int scrollDirection = probeScrollDirections[planeIndex];
@@ -248,12 +253,15 @@ void CS_UpdateProbes(uint3 DispatchThreadId : SV_DispatchThreadID, uint GroupInd
uint coord = (probeCount + (scrollDirection ? (scrollOffset - 1) : (scrollOffset % probeCount))) % probeCount;
if (probeCoords[planeIndex] == coord)
{
// Clear and skip scrolled probes
RWOutput[outputCoords] = float4(0, 0, 0, 0);
skip = true;
}
}
}
if (skip)
{
// Clear scrolled probe
RWOutput[outputCoords] = float4(0, 0, 0, 0);
}
// Skip disabled probes
float probeState = LoadDDGIProbeState(DDGI, ProbesState, CascadeIndex, probeIndex);
@@ -262,14 +270,14 @@ void CS_UpdateProbes(uint3 DispatchThreadId : SV_DispatchThreadID, uint GroupInd
if (!skip)
{
// Load trace rays results into shared memory to reuse across whole thread group
uint count = (uint)(ceil((float)(DDGI_TRACE_RAYS_LIMIT) / (float)(DDGI_PROBE_RESOLUTION * DDGI_PROBE_RESOLUTION)));
for (uint i = 0; i < count; i++)
// Load trace rays results into shared memory to reuse across whole thread group (raysCount per thread)
uint raysCount = (uint)(ceil((float)DDGI.RaysCount / (float)(DDGI_PROBE_RESOLUTION * DDGI_PROBE_RESOLUTION)));
uint raysStart = GroupIndex * raysCount;
raysCount = max(min(raysStart + raysCount, DDGI.RaysCount), raysStart) - raysStart;
for (uint i = 0; i < raysCount; i++)
{
uint rayIndex = (GroupIndex * count) + i;
if (rayIndex >= DDGI.RaysCount)
break;
CachedProbesTraceRadiance[rayIndex] = ProbesTrace[uint2(rayIndex, probeIndex)];
uint rayIndex = raysStart + i;
CachedProbesTraceRadiance[rayIndex] = ProbesTrace[uint2(rayIndex, GroupId.x)];
CachedProbesTraceDirection[rayIndex] = GetProbeRayDirection(DDGI, rayIndex);
}
}
@@ -278,7 +286,7 @@ void CS_UpdateProbes(uint3 DispatchThreadId : SV_DispatchThreadID, uint GroupInd
return;
// Calculate octahedral projection for probe (unwraps spherical projection into a square)
float2 octahedralCoords = GetOctahedralCoords(DispatchThreadId.xy, DDGI_PROBE_RESOLUTION);
float2 octahedralCoords = GetOctahedralCoords(GroupThreadId.xy, DDGI_PROBE_RESOLUTION);
float3 octahedralDirection = GetOctahedralDirection(octahedralCoords);
// Loop over rays