diff options
| author | Yong He <yonghe@outlook.com> | 2025-02-23 10:31:05 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-23 10:31:05 -0800 |
| commit | 51ad07d1fbffd41c758eba172aa77ebba3204924 (patch) | |
| tree | fadd788714c4ad37830846b0274d56b5ae1eff56 /tools | |
| parent | 0101e5ab59a1678ed7212913c3880edfaf039537 (diff) | |
Improve performance when compiling small shaders. (#6396)
Improve performance when compiling small shaders.
Avoid copying witness table entries that are not getting used during linking.
Avoid copying auto-diff related decorations and derivative functions during linking, if the user modules doesn't use autodiff.
Cache operator overload resolution results on global session, so each new Session doesn't need to repetitively run through overload resolution from scratch.
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/slang-unit-test/unit-test-compile-benchmark.cpp | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-compile-benchmark.cpp b/tools/slang-unit-test/unit-test-compile-benchmark.cpp new file mode 100644 index 000000000..e38edc6ad --- /dev/null +++ b/tools/slang-unit-test/unit-test-compile-benchmark.cpp @@ -0,0 +1,113 @@ +// unit-test-glsl-compile.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "../../tools/platform/performance-counter.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +// Test the compilation API for cross-compiling glsl source to SPIRV. + +SLANG_UNIT_TEST(compileBenchmark) +{ + const char* userSourceBody = R"( +// shader.slang + +struct PushConstantCompute +{ + uint64_t bufferAddress; + uint numVertices; +}; + +struct Vertex +{ + float3 position; +}; + + +[[vk::push_constant]] +ConstantBuffer<PushConstantCompute> pushConst; + +[shader("compute")] +[numthreads(256, 1, 1)] +void main(uint3 threadIdx : SV_DispatchThreadID) +{ + uint index = threadIdx.x; + + if(index >= pushConst.numVertices) + return; + + Vertex* vertices = (Vertex*)pushConst.bufferAddress; + + float angle = (index + 1) * 2.3f; + + float3 vertex = vertices[index].position; + + float cosAngle = cos(angle); + float sinAngle = sin(angle); + float3x3 rotationMatrix = float3x3( + cosAngle, -sinAngle, 0.0, + sinAngle, cosAngle, 0.0, + 0.0, 0.0, 1.0 + ); + + float3 rotatedVertex = mul(rotationMatrix, vertex); + + vertices[index].position = rotatedVertex; +} + )"; + ComPtr<slang::IGlobalSession> globalSession; + SlangGlobalSessionDesc globalDesc = {}; + globalDesc.enableGLSL = false; + SLANG_CHECK(slang_createGlobalSession2(&globalDesc, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_SPIRV; + targetDesc.profile = globalSession->findProfile("spirv_1_5"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + + auto start = platform::PerformanceCounter::now(); + for (int pass = 0; pass < 100; pass++) + { + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "main", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + } + auto time = platform::PerformanceCounter::getElapsedTimeInSeconds(start); + getTestReporter()->addExecutionTime(time); +} |
