summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorlucy96chen <47800040+lucy96chen@users.noreply.github.com>2022-06-21 13:14:18 -0700
committerGitHub <noreply@github.com>2022-06-21 13:14:18 -0700
commitea3800e115d4ad1ce06ec07689152616f47a0e3d (patch)
tree16bc7af0da281af5aae7da382abf782c1edc202a /tools
parent92dfec2320421113498ae7b5b72e78bd8b5b09a8 (diff)
Added a second set of shaders to the ray tracing test and added another test that uses these shaders; Fixed a bug in Vulkan's RayTracingCommandEncoder::dispatchRays() where the supplied raygen shader index wasn't being used (#2290)
Diffstat (limited to 'tools')
-rw-r--r--tools/gfx-unit-test/ray-tracing-test-shader.slang47
-rw-r--r--tools/gfx-unit-test/ray-tracing-test-shaders.slang83
-rw-r--r--tools/gfx-unit-test/ray-tracing-tests.cpp124
-rw-r--r--tools/gfx/vulkan/vk-command-encoder.cpp5
4 files changed, 177 insertions, 82 deletions
diff --git a/tools/gfx-unit-test/ray-tracing-test-shader.slang b/tools/gfx-unit-test/ray-tracing-test-shader.slang
deleted file mode 100644
index 30d8f3bb8..000000000
--- a/tools/gfx-unit-test/ray-tracing-test-shader.slang
+++ /dev/null
@@ -1,47 +0,0 @@
-// ray-tracing-test-shader.slang
-
-struct RayPayload
-{
- float4 color;
-};
-
-uniform RWTexture2D resultTexture;
-uniform RaytracingAccelerationStructure sceneBVH;
-
-[shader("raygeneration")]
-void rayGenShader()
-{
- int2 threadIdx = DispatchRaysIndex().xy;
-
- float3 rayDir = float3(0, 0, 1);
- float3 rayOrigin = 0;
- rayOrigin.x = (threadIdx.x * 2) - 1;
- rayOrigin.y = (threadIdx.y * 2) - 1;
-
- // Trace the ray.
- RayDesc ray;
- ray.Origin = rayOrigin;
- ray.Direction = rayDir;
- ray.TMin = 0.001;
- ray.TMax = 10000.0;
- RayPayload payload = { float4(0, 0, 0, 0) };
- TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 0, 0, 0, ray, payload);
-
- //resultTexture[threadIdx.xy] = float4(0, 0, 1, 1);
- resultTexture[threadIdx.xy] = payload.color;
-}
-
-[shader("miss")]
-void missShader(inout RayPayload payload)
-{
- payload.color = float4(1, 1, 1, 1);
-}
-
-[shader("closesthit")]
-void closestHitShader(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr)
-{
- uint primitiveIndex = PrimitiveIndex();
- float4 color = float4(0, 0, 0, 1);
- color[primitiveIndex] = 1;
- payload.color = color;
-}
diff --git a/tools/gfx-unit-test/ray-tracing-test-shaders.slang b/tools/gfx-unit-test/ray-tracing-test-shaders.slang
new file mode 100644
index 000000000..aa2e5055f
--- /dev/null
+++ b/tools/gfx-unit-test/ray-tracing-test-shaders.slang
@@ -0,0 +1,83 @@
+// ray-tracing-test-shaders.slang
+
+struct RayPayload
+{
+ float4 color;
+};
+
+uniform RWTexture2D resultTexture;
+uniform RaytracingAccelerationStructure sceneBVH;
+
+[shader("raygeneration")]
+void rayGenShaderA()
+{
+ int2 threadIdx = DispatchRaysIndex().xy;
+
+ float3 rayDir = float3(0, 0, 1);
+ float3 rayOrigin = 0;
+ rayOrigin.x = (threadIdx.x * 2) - 1;
+ rayOrigin.y = (threadIdx.y * 2) - 1;
+
+ // Trace the ray.
+ RayDesc ray;
+ ray.Origin = rayOrigin;
+ ray.Direction = rayDir;
+ ray.TMin = 0.001;
+ ray.TMax = 10000.0;
+ RayPayload payload = { float4(0, 0, 0, 0) };
+ TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 0, 0, 0, ray, payload);
+
+ resultTexture[threadIdx.xy] = payload.color;
+}
+
+[shader("raygeneration")]
+void rayGenShaderB()
+{
+ int2 threadIdx = DispatchRaysIndex().xy;
+
+ float3 rayDir = float3(0, 0, 1);
+ float3 rayOrigin = 0;
+ rayOrigin.x = (threadIdx.x * 2) - 1;
+ rayOrigin.y = (threadIdx.y * 2) - 1;
+
+ // Trace the ray.
+ RayDesc ray;
+ ray.Origin = rayOrigin;
+ ray.Direction = rayDir;
+ ray.TMin = 0.001;
+ ray.TMax = 10000.0;
+ RayPayload payload = { float4(0, 0, 0, 0) };
+ TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 1, 0, 1, ray, payload);
+
+ resultTexture[threadIdx.xy] = payload.color;
+}
+
+[shader("miss")]
+void missShaderA(inout RayPayload payload)
+{
+ payload.color = float4(1, 1, 1, 1);
+}
+
+[shader("miss")]
+void missShaderB(inout RayPayload payload)
+{
+ payload.color = float4(0, 0, 0, 1);
+}
+
+[shader("closesthit")]
+void closestHitShaderA(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr)
+{
+ uint primitiveIndex = PrimitiveIndex();
+ float4 color = float4(0, 0, 0, 1);
+ color[primitiveIndex] = 1;
+ payload.color = color;
+}
+
+[shader("closesthit")]
+void closestHitShaderB(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr)
+{
+ uint primitiveIndex = PrimitiveIndex();
+ float4 color = float4(1, 1, 1, 1);
+ color[primitiveIndex] = 0;
+ payload.color = color;
+}
diff --git a/tools/gfx-unit-test/ray-tracing-tests.cpp b/tools/gfx-unit-test/ray-tracing-tests.cpp
index f62c056f9..ee593dc2d 100644
--- a/tools/gfx-unit-test/ray-tracing-tests.cpp
+++ b/tools/gfx-unit-test/ray-tracing-tests.cpp
@@ -73,7 +73,9 @@ namespace gfx_test
void init(IDevice* device, UnitTestContext* context)
{
if (!device->hasFeature("ray-tracing"))
+ {
SLANG_IGNORE_TEST;
+ }
this->device = device;
this->context = context;
@@ -86,19 +88,26 @@ namespace gfx_test
slangSession = device->getSlangSession();
ComPtr<slang::IBlob> diagnosticsBlob;
- slang::IModule* module = slangSession->loadModule("ray-tracing-test-shader", diagnosticsBlob.writeRef());
+ slang::IModule* module = slangSession->loadModule("ray-tracing-test-shaders", diagnosticsBlob.writeRef());
if (!module)
return SLANG_FAIL;
Slang::List<slang::IComponentType*> componentTypes;
componentTypes.add(module);
ComPtr<slang::IEntryPoint> entryPoint;
- SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShader", entryPoint.writeRef()));
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShaderA", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShaderB", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShaderA", entryPoint.writeRef()));
componentTypes.add(entryPoint);
- SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShader", entryPoint.writeRef()));
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShaderB", entryPoint.writeRef()));
componentTypes.add(entryPoint);
SLANG_RETURN_ON_FAIL(
- module->findEntryPointByName("closestHitShader", entryPoint.writeRef()));
+ module->findEntryPointByName("closestHitShaderA", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(
+ module->findEntryPointByName("closestHitShaderB", entryPoint.writeRef()));
componentTypes.add(entryPoint);
ComPtr<slang::IComponentType> linkedProgram;
@@ -334,17 +343,19 @@ namespace gfx_test
queue->waitOnHost();
}
- const char* hitgroupNames[] = { "hitgroup" };
+ const char* hitgroupNames[] = { "hitgroupA", "hitgroupB"};
ComPtr<IShaderProgram> rayTracingProgram;
GFX_CHECK_CALL_ABORT(
loadShaderProgram(device, rayTracingProgram.writeRef()));
RayTracingPipelineStateDesc rtpDesc = {};
rtpDesc.program = rayTracingProgram;
- rtpDesc.hitGroupCount = 1;
- HitGroupDesc hitGroups[1];
- hitGroups[0].closestHitEntryPoint = "closestHitShader";
+ rtpDesc.hitGroupCount = 2;
+ HitGroupDesc hitGroups[2];
+ hitGroups[0].closestHitEntryPoint = "closestHitShaderA";
hitGroups[0].hitGroupName = hitgroupNames[0];
+ hitGroups[1].closestHitEntryPoint = "closestHitShaderB";
+ hitGroups[1].hitGroupName = hitgroupNames[1];
rtpDesc.hitGroups = hitGroups;
rtpDesc.maxRayPayloadSize = 64;
rtpDesc.maxRecursion = 2;
@@ -352,19 +363,38 @@ namespace gfx_test
device->createRayTracingPipelineState(rtpDesc, renderPipelineState.writeRef()));
SLANG_CHECK_ABORT(renderPipelineState != nullptr);
+ const char* raygenNames[] = { "rayGenShaderA", "rayGenShaderB" };
+ const char* missNames[] = { "missShaderA", "missShaderB" };
+
IShaderTable::Desc shaderTableDesc = {};
- const char* raygenName = "rayGenShader";
- const char* missName = "missShader";
shaderTableDesc.program = rayTracingProgram;
- shaderTableDesc.hitGroupCount = 1;
+ shaderTableDesc.hitGroupCount = 2;
shaderTableDesc.hitGroupNames = hitgroupNames;
- shaderTableDesc.rayGenShaderCount = 1;
- shaderTableDesc.rayGenShaderEntryPointNames = &raygenName;
- shaderTableDesc.missShaderCount = 1;
- shaderTableDesc.missShaderEntryPointNames = &missName;
+ shaderTableDesc.rayGenShaderCount = 2;
+ shaderTableDesc.rayGenShaderEntryPointNames = raygenNames;
+ shaderTableDesc.missShaderCount = 2;
+ shaderTableDesc.missShaderEntryPointNames = missNames;
GFX_CHECK_CALL_ABORT(device->createShaderTable(shaderTableDesc, shaderTable.writeRef()));
}
+ void checkTestResults(float* expectedResult)
+ {
+ ComPtr<ISlangBlob> resultBlob;
+ size_t rowPitch = 0;
+ size_t pixelSize = 0;
+ GFX_CHECK_CALL_ABORT(device->readTextureResource(
+ resultTexture, ResourceState::CopySource, resultBlob.writeRef(), &rowPitch, &pixelSize));
+
+ writeImage("C:/Users/lucchen/Documents/test.hdr", resultBlob, width, height, rowPitch, pixelSize);
+
+ auto buffer = removePadding(resultBlob, width, height, rowPitch, pixelSize);
+ auto actualData = (float*)buffer.getBuffer();
+ SLANG_CHECK(memcmp(actualData, expectedResult, sizeof(expectedResult)) == 0)
+ }
+ };
+
+ struct RayTracingTestA : BaseRayTracingTest
+ {
void renderFrame()
{
ComPtr<ICommandBuffer> renderCommandBuffer =
@@ -382,48 +412,76 @@ namespace gfx_test
queue->waitOnHost();
}
- void checkTestResults()
+ void run()
{
- ComPtr<ISlangBlob> resultBlob;
- size_t rowPitch = 0;
- size_t pixelSize = 0;
- GFX_CHECK_CALL_ABORT(device->readTextureResource(
- resultTexture, ResourceState::CopySource, resultBlob.writeRef(), &rowPitch, &pixelSize));
-
- //writeImage("C:/Users/lucchen/Documents/test.hdr", resultBlob, width, height, rowPitch, pixelSize);
+ createRequiredResources();
+ renderFrame();
- auto buffer = removePadding(resultBlob, width, height, rowPitch, pixelSize);
float expectedResult[16] = { 1, 1, 1, 1,
0, 0, 1, 1,
0, 1, 0, 1,
1, 0, 0, 1 };
- auto actualData = (float*)buffer.getBuffer();
- SLANG_CHECK(memcmp(actualData, expectedResult, sizeof(expectedResult)) == 0)
+ checkTestResults(expectedResult);
+ }
+ };
+
+ struct RayTracingTestB : BaseRayTracingTest
+ {
+ void renderFrame()
+ {
+ ComPtr<ICommandBuffer> renderCommandBuffer =
+ transientHeap->createCommandBuffer();
+ auto renderEncoder = renderCommandBuffer->encodeRayTracingCommands();
+ IShaderObject* rootObject = nullptr;
+ renderEncoder->bindPipeline(renderPipelineState, &rootObject);
+ auto cursor = ShaderCursor(rootObject);
+ cursor["resultTexture"].setResource(resultTextureUAV);
+ cursor["sceneBVH"].setResource(TLAS);
+ renderEncoder->dispatchRays(1, shaderTable, width, height, 1);
+ renderEncoder->endEncoding();
+ renderCommandBuffer->close();
+ queue->executeCommandBuffer(renderCommandBuffer);
+ queue->waitOnHost();
}
void run()
{
createRequiredResources();
renderFrame();
- checkTestResults();
+
+ float expectedResult[16] = { 0, 0, 0, 0,
+ 1, 1, 0, 1,
+ 1, 0, 1, 1,
+ 0, 1, 1, 1 };
+ checkTestResults(expectedResult);
}
};
- void simpleRayTracingTestImpl(IDevice* device, UnitTestContext* context)
+ template <typename T>
+ void rayTracingTestImpl(IDevice* device, UnitTestContext* context)
{
- BaseRayTracingTest test;
+ T test;
test.init(device, context);
test.run();
}
- SLANG_UNIT_TEST(simpleRayTracingD3D12)
+ SLANG_UNIT_TEST(RayTracingTestAD3D12)
{
- runTestImpl(simpleRayTracingTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTestImpl(rayTracingTestImpl<RayTracingTestA>, unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(simpleRayTracingVulkan)
+ SLANG_UNIT_TEST(RayTracingTestAVulkan)
{
- runTestImpl(simpleRayTracingTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTestImpl(rayTracingTestImpl<RayTracingTestA>, unitTestContext, Slang::RenderApiFlag::Vulkan);
}
+ SLANG_UNIT_TEST(RayTracingTestBD3D12)
+ {
+ runTestImpl(rayTracingTestImpl<RayTracingTestB>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ }
+
+ SLANG_UNIT_TEST(RayTracingTestBVulkan)
+ {
+ runTestImpl(rayTracingTestImpl<RayTracingTestB>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ }
}
diff --git a/tools/gfx/vulkan/vk-command-encoder.cpp b/tools/gfx/vulkan/vk-command-encoder.cpp
index 00915e0c4..33cd567c7 100644
--- a/tools/gfx/vulkan/vk-command-encoder.cpp
+++ b/tools/gfx/vulkan/vk-command-encoder.cpp
@@ -1460,14 +1460,15 @@ void RayTracingCommandEncoder::dispatchRays(
m_currentPipeline,
m_commandBuffer->m_transientHeap,
static_cast<ResourceCommandEncoder*>(this));
+ auto shaderTableAddr = shaderTableBuffer->getDeviceAddress();
VkStridedDeviceAddressRegionKHR raygenSBT;
- raygenSBT.deviceAddress = shaderTableBuffer->getDeviceAddress();
raygenSBT.stride = VulkanUtil::calcAligned(alignedHandleSize, rtProps.shaderGroupBaseAlignment);
+ raygenSBT.deviceAddress = shaderTableAddr + raygenShaderIndex * raygenSBT.stride;
raygenSBT.size = raygenSBT.stride;
VkStridedDeviceAddressRegionKHR missSBT;
- missSBT.deviceAddress = raygenSBT.deviceAddress + raygenSBT.size;
+ missSBT.deviceAddress = shaderTableAddr + shaderTableImpl->m_raygenTableSize;
missSBT.stride = alignedHandleSize;
missSBT.size = shaderTableImpl->m_missTableSize;