Add SPIR-V compression with LZ4 of Vulkan shaders (35% avg smaller)

This commit is contained in:
Wojtek Figat
2026-03-07 23:24:40 +01:00
parent 3038c56af4
commit c4342b0a20
4 changed files with 44 additions and 7 deletions

View File

@@ -12,6 +12,7 @@
#include "Engine/Core/Types/DataContainer.h"
#include "Engine/Serialization/MemoryReadStream.h"
#include "Engine/Profiler/ProfilerMemory.h"
#include <ThirdParty/LZ4/lz4.h>
#if PLATFORM_DESKTOP
#define VULKAN_UNIFORM_RING_BUFFER_SIZE (24 * 1024 * 1024)
@@ -108,8 +109,28 @@ GPUShaderProgram* GPUShaderVulkan::CreateGPUShaderProgram(ShaderStage type, cons
// Extract the SPIR-V bytecode
BytesContainer spirv;
ASSERT(header->Type == SpirvShaderHeader::Types::SPIRV);
spirv.Link(bytecode);
switch (header->Type)
{
case SpirvShaderHeader::Types::SPIRV:
spirv.Link(bytecode);
break;
case SpirvShaderHeader::Types::SPIRV_LZ4:
{
int32 originalSize = *(int32*)bytecode.Get();
bytecode = bytecode.Slice(sizeof(int32));
spirv.Allocate(originalSize);
const int32 res = LZ4_decompress_safe((const char*)bytecode.Get(), (char*)spirv.Get(), bytecode.Length(), originalSize);
if (res <= 0)
{
LOG(Error, "Failed to decompress shader");
return nullptr;
}
break;
}
default:
LOG(Error, "Invalid shader program format");
return nullptr;
}
// Create shader module from SPIR-V bytecode
VkShaderModule shaderModule = VK_NULL_HANDLE;

View File

@@ -132,6 +132,11 @@ struct SpirvShaderHeader
/// The WGSL shader code compressed with LZ4.
/// </summary>
WGSL_LZ4 = 2,
/// <summary>
/// The SPIR-V byte code compressed with LZ4.
/// </summary>
SPIRV_LZ4 = 3,
};
/// <summary>

View File

@@ -9,6 +9,7 @@
#include "Engine/Serialization/MemoryWriteStream.h"
#include "Engine/Graphics/Config.h"
#include "Engine/GraphicsDevice/Vulkan/Types.h"
#include <ThirdParty/LZ4/lz4.h>
// Use glslang for HLSL to SPIR-V translation
// Source: https://github.com/KhronosGroup/glslang
@@ -939,8 +940,6 @@ bool ShaderCompilerVulkan::OnCompileBegin()
//_globalMacros.Add({ "VULKAN", "1" }); // glslang compiler adds VULKAN define if EShMsgVulkanRules flag is specified
// TODO: handle options->TreatWarningsAsErrors
return false;
}
@@ -967,9 +966,22 @@ void ShaderCompilerVulkan::InitCodegen(ShaderCompilationContext* context, glslan
bool ShaderCompilerVulkan::Write(ShaderCompilationContext* context, ShaderFunctionMeta& meta, int32 permutationIndex, const ShaderBindings& bindings, struct SpirvShaderHeader& header, std::vector<unsigned int>& spirv)
{
int32 spirvBytesCount = (int32)spirv.size() * sizeof(unsigned);
// Compress
const int32 srcSize = (int32)spirv.size() * sizeof(unsigned);
const int32 maxSize = LZ4_compressBound(srcSize);
Array<byte> spirvCompressed;
spirvCompressed.Resize(maxSize + sizeof(int32));
const int32 dstSize = LZ4_compress_default((const char*)&spirv[0], (char*)spirvCompressed.Get() + sizeof(int32), srcSize, maxSize);
if (dstSize > 0 && dstSize < (int32)(srcSize * 0.8f)) // Expect 20% or more compression ratio to use it (to avoid decompressing if the gain is not big enough)
{
spirvCompressed.Resize(dstSize + sizeof(int32));
*(int32*)spirvCompressed.Get() = srcSize; // Store original size in the beginning to decompress it
header.Type = SpirvShaderHeader::Types::SPIRV_LZ4;
return WriteShaderFunctionPermutation(_context, meta, permutationIndex, bindings, &header, sizeof(header), spirvCompressed.Get(), spirvCompressed.Count());
}
header.Type = SpirvShaderHeader::Types::SPIRV;
return WriteShaderFunctionPermutation(_context, meta, permutationIndex, bindings, &header, sizeof(header), &spirv[0], spirvBytesCount);
return WriteShaderFunctionPermutation(_context, meta, permutationIndex, bindings, &header, sizeof(header), &spirv[0], srcSize);
}
#endif

View File

@@ -105,7 +105,6 @@ bool ShaderCompilerWebGPU::Write(ShaderCompilationContext* context, ShaderFuncti
return WriteShaderFunctionPermutation(_context, meta, permutationIndex, bindings, &header, sizeof(header), wgslCompressed.Get(), wgslCompressed.Count());
}
header.Type = SpirvShaderHeader::Types::WGSL;
return WriteShaderFunctionPermutation(_context, meta, permutationIndex, bindings, &header, sizeof(header), wgsl.Get(), wgsl.Length() + 1);
}