Optimize ProbesTrace texture in DDGI to use batched probes update

This commit is contained in:
Wojciech Figat
2022-06-10 10:39:46 +02:00
parent 3a8e5e0bbe
commit 3b27ae5fa9
2 changed files with 80 additions and 71 deletions

View File

@@ -34,8 +34,9 @@ float2 Padding0;
META_CB_END
META_CB_BEGIN(1, Data1)
float3 Padding1;
float2 Padding1;
uint CascadeIndex;
uint ProbeIndexOffset;
META_CB_END
// Calculates the evenly distributed direction ray on a sphere (Spherical Fibonacci lattice)
@@ -149,10 +150,10 @@ TextureCube Skybox : register(t13);
// Compute shader for tracing rays for probes using Global SDF and Global Surface Atlas.
META_CS(true, FEATURE_LEVEL_SM5)
[numthreads(DDGI_TRACE_RAYS_GROUP_SIZE_X, 1, 1)]
void CS_TraceRays(uint3 GroupId : SV_GroupID, uint3 DispatchThreadId : SV_DispatchThreadID, uint3 GroupThreadId : SV_GroupThreadID)
void CS_TraceRays(uint3 DispatchThreadId : SV_DispatchThreadID)
{
uint rayIndex = DispatchThreadId.x;
uint probeIndex = DispatchThreadId.y;
uint probeIndex = DispatchThreadId.y + ProbeIndexOffset;
uint3 probeCoords = GetDDGIProbeCoords(DDGI, probeIndex);
probeIndex = GetDDGIScrollingProbeIndex(DDGI, CascadeIndex, probeCoords);
@@ -196,7 +197,7 @@ void CS_TraceRays(uint3 GroupId : SV_GroupID, uint3 DispatchThreadId : SV_Dispat
}
// Write into probes trace results
RWProbesTrace[uint2(rayIndex, probeIndex)] = radiance;
RWProbesTrace[uint2(rayIndex, DispatchThreadId.y)] = radiance;
}
#endif
@@ -223,24 +224,28 @@ META_CS(true, FEATURE_LEVEL_SM5)
META_PERMUTATION_1(DDGI_PROBE_UPDATE_MODE=0)
META_PERMUTATION_1(DDGI_PROBE_UPDATE_MODE=1)
[numthreads(DDGI_PROBE_RESOLUTION, DDGI_PROBE_RESOLUTION, 1)]
void CS_UpdateProbes(uint3 DispatchThreadId : SV_DispatchThreadID, uint GroupIndex : SV_GroupIndex)
void CS_UpdateProbes(uint3 GroupThreadId : SV_GroupThreadID, uint3 GroupId : SV_GroupID, uint GroupIndex : SV_GroupIndex)
{
// GroupThreadId.xy - coordinates of the probe texel: [0; DDGI_PROBE_RESOLUTION)
// GroupId.x - index of the thread group which is probe index within a batch: [0; batchSize)
// GroupIndex.x - index of the thread within a thread group: [0; DDGI_PROBE_RESOLUTION * DDGI_PROBE_RESOLUTION)
// Get probe index and atlas location in the atlas
uint probeIndex = GetDDGIProbeIndex(DDGI, DispatchThreadId.xy, DDGI_PROBE_RESOLUTION);
uint probesCount = DDGI.ProbesCounts.x * DDGI.ProbesCounts.y * DDGI.ProbesCounts.z;
bool skip = probeIndex >= probesCount;
uint2 outputCoords = uint2(1, 1) + DispatchThreadId.xy + (DispatchThreadId.xy / DDGI_PROBE_RESOLUTION) * 2;
outputCoords.y += CascadeIndex * DDGI.ProbesCounts.z * (DDGI_PROBE_RESOLUTION + 2);
// Clear probes that have been scrolled to a new positions (blending with current irradiance will happen the next frame)
uint probeIndex = GroupId.x + ProbeIndexOffset;
uint3 probeCoords = GetDDGIProbeCoords(DDGI, probeIndex);
probeIndex = GetDDGIScrollingProbeIndex(DDGI, CascadeIndex, probeCoords);
probeCoords = GetDDGIProbeCoords(DDGI, probeIndex);
uint2 outputCoords = GetDDGIProbeTexelCoords(DDGI, CascadeIndex, probeIndex) * (DDGI_PROBE_RESOLUTION + 2) + 1 + GroupThreadId.xy;
// Clear probes that have been scrolled to a new positions (blending with current irradiance will happen the next frame)
int3 probesScrollOffsets = DDGI.ProbesScrollOffsets[CascadeIndex].xyz;
int probeScrollClear = DDGI.ProbesScrollOffsets[CascadeIndex].w;
int3 probeScrollDirections = DDGI.ProbeScrollDirections[CascadeIndex].xyz;
bool skip = false;
UNROLL
for (uint planeIndex = 0; planeIndex < 3; planeIndex++)
{
if (probeScrollClear & (1 << planeIndex) && !skip)
if (probeScrollClear & (1 << planeIndex))
{
int scrollOffset = probesScrollOffsets[planeIndex];
int scrollDirection = probeScrollDirections[planeIndex];
@@ -248,12 +253,15 @@ void CS_UpdateProbes(uint3 DispatchThreadId : SV_DispatchThreadID, uint GroupInd
uint coord = (probeCount + (scrollDirection ? (scrollOffset - 1) : (scrollOffset % probeCount))) % probeCount;
if (probeCoords[planeIndex] == coord)
{
// Clear and skip scrolled probes
RWOutput[outputCoords] = float4(0, 0, 0, 0);
skip = true;
}
}
}
if (skip)
{
// Clear scrolled probe
RWOutput[outputCoords] = float4(0, 0, 0, 0);
}
// Skip disabled probes
float probeState = LoadDDGIProbeState(DDGI, ProbesState, CascadeIndex, probeIndex);
@@ -262,14 +270,14 @@ void CS_UpdateProbes(uint3 DispatchThreadId : SV_DispatchThreadID, uint GroupInd
if (!skip)
{
// Load trace rays results into shared memory to reuse across whole thread group
uint count = (uint)(ceil((float)(DDGI_TRACE_RAYS_LIMIT) / (float)(DDGI_PROBE_RESOLUTION * DDGI_PROBE_RESOLUTION)));
for (uint i = 0; i < count; i++)
// Load trace rays results into shared memory to reuse across whole thread group (raysCount per thread)
uint raysCount = (uint)(ceil((float)DDGI.RaysCount / (float)(DDGI_PROBE_RESOLUTION * DDGI_PROBE_RESOLUTION)));
uint raysStart = GroupIndex * raysCount;
raysCount = max(min(raysStart + raysCount, DDGI.RaysCount), raysStart) - raysStart;
for (uint i = 0; i < raysCount; i++)
{
uint rayIndex = (GroupIndex * count) + i;
if (rayIndex >= DDGI.RaysCount)
break;
CachedProbesTraceRadiance[rayIndex] = ProbesTrace[uint2(rayIndex, probeIndex)];
uint rayIndex = raysStart + i;
CachedProbesTraceRadiance[rayIndex] = ProbesTrace[uint2(rayIndex, GroupId.x)];
CachedProbesTraceDirection[rayIndex] = GetProbeRayDirection(DDGI, rayIndex);
}
}
@@ -278,7 +286,7 @@ void CS_UpdateProbes(uint3 DispatchThreadId : SV_DispatchThreadID, uint GroupInd
return;
// Calculate octahedral projection for probe (unwraps spherical projection into a square)
float2 octahedralCoords = GetOctahedralCoords(DispatchThreadId.xy, DDGI_PROBE_RESOLUTION);
float2 octahedralCoords = GetOctahedralCoords(GroupThreadId.xy, DDGI_PROBE_RESOLUTION);
float3 octahedralDirection = GetOctahedralDirection(octahedralCoords);
// Loop over rays