summaryrefslogtreecommitdiff
path: root/tools/gfx-unit-test/ray-tracing-tests.cpp
diff options
context:
space:
mode:
authorlucy96chen <47800040+lucy96chen@users.noreply.github.com>2022-06-21 13:14:18 -0700
committerGitHub <noreply@github.com>2022-06-21 13:14:18 -0700
commitea3800e115d4ad1ce06ec07689152616f47a0e3d (patch)
tree16bc7af0da281af5aae7da382abf782c1edc202a /tools/gfx-unit-test/ray-tracing-tests.cpp
parent92dfec2320421113498ae7b5b72e78bd8b5b09a8 (diff)
Added a second set of shaders to the ray tracing test and added another test that uses these shaders; Fixed a bug in Vulkan's RayTracingCommandEncoder::dispatchRays() where the supplied raygen shader index wasn't being used (#2290)
Diffstat (limited to 'tools/gfx-unit-test/ray-tracing-tests.cpp')
-rw-r--r--tools/gfx-unit-test/ray-tracing-tests.cpp124
1 files changed, 91 insertions, 33 deletions
diff --git a/tools/gfx-unit-test/ray-tracing-tests.cpp b/tools/gfx-unit-test/ray-tracing-tests.cpp
index f62c056f9..ee593dc2d 100644
--- a/tools/gfx-unit-test/ray-tracing-tests.cpp
+++ b/tools/gfx-unit-test/ray-tracing-tests.cpp
@@ -73,7 +73,9 @@ namespace gfx_test
void init(IDevice* device, UnitTestContext* context)
{
if (!device->hasFeature("ray-tracing"))
+ {
SLANG_IGNORE_TEST;
+ }
this->device = device;
this->context = context;
@@ -86,19 +88,26 @@ namespace gfx_test
slangSession = device->getSlangSession();
ComPtr<slang::IBlob> diagnosticsBlob;
- slang::IModule* module = slangSession->loadModule("ray-tracing-test-shader", diagnosticsBlob.writeRef());
+ slang::IModule* module = slangSession->loadModule("ray-tracing-test-shaders", diagnosticsBlob.writeRef());
if (!module)
return SLANG_FAIL;
Slang::List<slang::IComponentType*> componentTypes;
componentTypes.add(module);
ComPtr<slang::IEntryPoint> entryPoint;
- SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShader", entryPoint.writeRef()));
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShaderA", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShaderB", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShaderA", entryPoint.writeRef()));
componentTypes.add(entryPoint);
- SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShader", entryPoint.writeRef()));
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShaderB", entryPoint.writeRef()));
componentTypes.add(entryPoint);
SLANG_RETURN_ON_FAIL(
- module->findEntryPointByName("closestHitShader", entryPoint.writeRef()));
+ module->findEntryPointByName("closestHitShaderA", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(
+ module->findEntryPointByName("closestHitShaderB", entryPoint.writeRef()));
componentTypes.add(entryPoint);
ComPtr<slang::IComponentType> linkedProgram;
@@ -334,17 +343,19 @@ namespace gfx_test
queue->waitOnHost();
}
- const char* hitgroupNames[] = { "hitgroup" };
+ const char* hitgroupNames[] = { "hitgroupA", "hitgroupB"};
ComPtr<IShaderProgram> rayTracingProgram;
GFX_CHECK_CALL_ABORT(
loadShaderProgram(device, rayTracingProgram.writeRef()));
RayTracingPipelineStateDesc rtpDesc = {};
rtpDesc.program = rayTracingProgram;
- rtpDesc.hitGroupCount = 1;
- HitGroupDesc hitGroups[1];
- hitGroups[0].closestHitEntryPoint = "closestHitShader";
+ rtpDesc.hitGroupCount = 2;
+ HitGroupDesc hitGroups[2];
+ hitGroups[0].closestHitEntryPoint = "closestHitShaderA";
hitGroups[0].hitGroupName = hitgroupNames[0];
+ hitGroups[1].closestHitEntryPoint = "closestHitShaderB";
+ hitGroups[1].hitGroupName = hitgroupNames[1];
rtpDesc.hitGroups = hitGroups;
rtpDesc.maxRayPayloadSize = 64;
rtpDesc.maxRecursion = 2;
@@ -352,19 +363,38 @@ namespace gfx_test
device->createRayTracingPipelineState(rtpDesc, renderPipelineState.writeRef()));
SLANG_CHECK_ABORT(renderPipelineState != nullptr);
+ const char* raygenNames[] = { "rayGenShaderA", "rayGenShaderB" };
+ const char* missNames[] = { "missShaderA", "missShaderB" };
+
IShaderTable::Desc shaderTableDesc = {};
- const char* raygenName = "rayGenShader";
- const char* missName = "missShader";
shaderTableDesc.program = rayTracingProgram;
- shaderTableDesc.hitGroupCount = 1;
+ shaderTableDesc.hitGroupCount = 2;
shaderTableDesc.hitGroupNames = hitgroupNames;
- shaderTableDesc.rayGenShaderCount = 1;
- shaderTableDesc.rayGenShaderEntryPointNames = &raygenName;
- shaderTableDesc.missShaderCount = 1;
- shaderTableDesc.missShaderEntryPointNames = &missName;
+ shaderTableDesc.rayGenShaderCount = 2;
+ shaderTableDesc.rayGenShaderEntryPointNames = raygenNames;
+ shaderTableDesc.missShaderCount = 2;
+ shaderTableDesc.missShaderEntryPointNames = missNames;
GFX_CHECK_CALL_ABORT(device->createShaderTable(shaderTableDesc, shaderTable.writeRef()));
}
+ void checkTestResults(float* expectedResult)
+ {
+ ComPtr<ISlangBlob> resultBlob;
+ size_t rowPitch = 0;
+ size_t pixelSize = 0;
+ GFX_CHECK_CALL_ABORT(device->readTextureResource(
+ resultTexture, ResourceState::CopySource, resultBlob.writeRef(), &rowPitch, &pixelSize));
+
+ writeImage("C:/Users/lucchen/Documents/test.hdr", resultBlob, width, height, rowPitch, pixelSize);
+
+ auto buffer = removePadding(resultBlob, width, height, rowPitch, pixelSize);
+ auto actualData = (float*)buffer.getBuffer();
+ SLANG_CHECK(memcmp(actualData, expectedResult, sizeof(expectedResult)) == 0)
+ }
+ };
+
+ struct RayTracingTestA : BaseRayTracingTest
+ {
void renderFrame()
{
ComPtr<ICommandBuffer> renderCommandBuffer =
@@ -382,48 +412,76 @@ namespace gfx_test
queue->waitOnHost();
}
- void checkTestResults()
+ void run()
{
- ComPtr<ISlangBlob> resultBlob;
- size_t rowPitch = 0;
- size_t pixelSize = 0;
- GFX_CHECK_CALL_ABORT(device->readTextureResource(
- resultTexture, ResourceState::CopySource, resultBlob.writeRef(), &rowPitch, &pixelSize));
-
- //writeImage("C:/Users/lucchen/Documents/test.hdr", resultBlob, width, height, rowPitch, pixelSize);
+ createRequiredResources();
+ renderFrame();
- auto buffer = removePadding(resultBlob, width, height, rowPitch, pixelSize);
float expectedResult[16] = { 1, 1, 1, 1,
0, 0, 1, 1,
0, 1, 0, 1,
1, 0, 0, 1 };
- auto actualData = (float*)buffer.getBuffer();
- SLANG_CHECK(memcmp(actualData, expectedResult, sizeof(expectedResult)) == 0)
+ checkTestResults(expectedResult);
+ }
+ };
+
+ struct RayTracingTestB : BaseRayTracingTest
+ {
+ void renderFrame()
+ {
+ ComPtr<ICommandBuffer> renderCommandBuffer =
+ transientHeap->createCommandBuffer();
+ auto renderEncoder = renderCommandBuffer->encodeRayTracingCommands();
+ IShaderObject* rootObject = nullptr;
+ renderEncoder->bindPipeline(renderPipelineState, &rootObject);
+ auto cursor = ShaderCursor(rootObject);
+ cursor["resultTexture"].setResource(resultTextureUAV);
+ cursor["sceneBVH"].setResource(TLAS);
+ renderEncoder->dispatchRays(1, shaderTable, width, height, 1);
+ renderEncoder->endEncoding();
+ renderCommandBuffer->close();
+ queue->executeCommandBuffer(renderCommandBuffer);
+ queue->waitOnHost();
}
void run()
{
createRequiredResources();
renderFrame();
- checkTestResults();
+
+ float expectedResult[16] = { 0, 0, 0, 0,
+ 1, 1, 0, 1,
+ 1, 0, 1, 1,
+ 0, 1, 1, 1 };
+ checkTestResults(expectedResult);
}
};
- void simpleRayTracingTestImpl(IDevice* device, UnitTestContext* context)
+ template <typename T>
+ void rayTracingTestImpl(IDevice* device, UnitTestContext* context)
{
- BaseRayTracingTest test;
+ T test;
test.init(device, context);
test.run();
}
- SLANG_UNIT_TEST(simpleRayTracingD3D12)
+ SLANG_UNIT_TEST(RayTracingTestAD3D12)
{
- runTestImpl(simpleRayTracingTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTestImpl(rayTracingTestImpl<RayTracingTestA>, unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(simpleRayTracingVulkan)
+ SLANG_UNIT_TEST(RayTracingTestAVulkan)
{
- runTestImpl(simpleRayTracingTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTestImpl(rayTracingTestImpl<RayTracingTestA>, unitTestContext, Slang::RenderApiFlag::Vulkan);
}
+ SLANG_UNIT_TEST(RayTracingTestBD3D12)
+ {
+ runTestImpl(rayTracingTestImpl<RayTracingTestB>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ }
+
+ SLANG_UNIT_TEST(RayTracingTestBVulkan)
+ {
+ runTestImpl(rayTracingTestImpl<RayTracingTestB>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ }
}