diff options
| author | lucy96chen <47800040+lucy96chen@users.noreply.github.com> | 2022-06-21 13:14:18 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-21 13:14:18 -0700 |
| commit | ea3800e115d4ad1ce06ec07689152616f47a0e3d (patch) | |
| tree | 16bc7af0da281af5aae7da382abf782c1edc202a /tools | |
| parent | 92dfec2320421113498ae7b5b72e78bd8b5b09a8 (diff) | |
Added a second set of shaders to the ray tracing test and added another test that uses these shaders; Fixed a bug in Vulkan's RayTracingCommandEncoder::dispatchRays() where the supplied raygen shader index wasn't being used (#2290)
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/gfx-unit-test/ray-tracing-test-shader.slang | 47 | ||||
| -rw-r--r-- | tools/gfx-unit-test/ray-tracing-test-shaders.slang | 83 | ||||
| -rw-r--r-- | tools/gfx-unit-test/ray-tracing-tests.cpp | 124 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-command-encoder.cpp | 5 |
4 files changed, 177 insertions, 82 deletions
diff --git a/tools/gfx-unit-test/ray-tracing-test-shader.slang b/tools/gfx-unit-test/ray-tracing-test-shader.slang deleted file mode 100644 index 30d8f3bb8..000000000 --- a/tools/gfx-unit-test/ray-tracing-test-shader.slang +++ /dev/null @@ -1,47 +0,0 @@ -// ray-tracing-test-shader.slang - -struct RayPayload -{ - float4 color; -}; - -uniform RWTexture2D resultTexture; -uniform RaytracingAccelerationStructure sceneBVH; - -[shader("raygeneration")] -void rayGenShader() -{ - int2 threadIdx = DispatchRaysIndex().xy; - - float3 rayDir = float3(0, 0, 1); - float3 rayOrigin = 0; - rayOrigin.x = (threadIdx.x * 2) - 1; - rayOrigin.y = (threadIdx.y * 2) - 1; - - // Trace the ray. - RayDesc ray; - ray.Origin = rayOrigin; - ray.Direction = rayDir; - ray.TMin = 0.001; - ray.TMax = 10000.0; - RayPayload payload = { float4(0, 0, 0, 0) }; - TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 0, 0, 0, ray, payload); - - //resultTexture[threadIdx.xy] = float4(0, 0, 1, 1); - resultTexture[threadIdx.xy] = payload.color; -} - -[shader("miss")] -void missShader(inout RayPayload payload) -{ - payload.color = float4(1, 1, 1, 1); -} - -[shader("closesthit")] -void closestHitShader(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr) -{ - uint primitiveIndex = PrimitiveIndex(); - float4 color = float4(0, 0, 0, 1); - color[primitiveIndex] = 1; - payload.color = color; -} diff --git a/tools/gfx-unit-test/ray-tracing-test-shaders.slang b/tools/gfx-unit-test/ray-tracing-test-shaders.slang new file mode 100644 index 000000000..aa2e5055f --- /dev/null +++ b/tools/gfx-unit-test/ray-tracing-test-shaders.slang @@ -0,0 +1,83 @@ +// ray-tracing-test-shaders.slang + +struct RayPayload +{ + float4 color; +}; + +uniform RWTexture2D resultTexture; +uniform RaytracingAccelerationStructure sceneBVH; + +[shader("raygeneration")] +void rayGenShaderA() +{ + int2 threadIdx = DispatchRaysIndex().xy; + + float3 rayDir = float3(0, 0, 1); + float3 rayOrigin = 0; + rayOrigin.x = (threadIdx.x * 2) - 1; + rayOrigin.y = (threadIdx.y * 2) - 1; + + // Trace the ray. + RayDesc ray; + ray.Origin = rayOrigin; + ray.Direction = rayDir; + ray.TMin = 0.001; + ray.TMax = 10000.0; + RayPayload payload = { float4(0, 0, 0, 0) }; + TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 0, 0, 0, ray, payload); + + resultTexture[threadIdx.xy] = payload.color; +} + +[shader("raygeneration")] +void rayGenShaderB() +{ + int2 threadIdx = DispatchRaysIndex().xy; + + float3 rayDir = float3(0, 0, 1); + float3 rayOrigin = 0; + rayOrigin.x = (threadIdx.x * 2) - 1; + rayOrigin.y = (threadIdx.y * 2) - 1; + + // Trace the ray. + RayDesc ray; + ray.Origin = rayOrigin; + ray.Direction = rayDir; + ray.TMin = 0.001; + ray.TMax = 10000.0; + RayPayload payload = { float4(0, 0, 0, 0) }; + TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 1, 0, 1, ray, payload); + + resultTexture[threadIdx.xy] = payload.color; +} + +[shader("miss")] +void missShaderA(inout RayPayload payload) +{ + payload.color = float4(1, 1, 1, 1); +} + +[shader("miss")] +void missShaderB(inout RayPayload payload) +{ + payload.color = float4(0, 0, 0, 1); +} + +[shader("closesthit")] +void closestHitShaderA(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr) +{ + uint primitiveIndex = PrimitiveIndex(); + float4 color = float4(0, 0, 0, 1); + color[primitiveIndex] = 1; + payload.color = color; +} + +[shader("closesthit")] +void closestHitShaderB(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr) +{ + uint primitiveIndex = PrimitiveIndex(); + float4 color = float4(1, 1, 1, 1); + color[primitiveIndex] = 0; + payload.color = color; +} diff --git a/tools/gfx-unit-test/ray-tracing-tests.cpp b/tools/gfx-unit-test/ray-tracing-tests.cpp index f62c056f9..ee593dc2d 100644 --- a/tools/gfx-unit-test/ray-tracing-tests.cpp +++ b/tools/gfx-unit-test/ray-tracing-tests.cpp @@ -73,7 +73,9 @@ namespace gfx_test void init(IDevice* device, UnitTestContext* context) { if (!device->hasFeature("ray-tracing")) + { SLANG_IGNORE_TEST; + } this->device = device; this->context = context; @@ -86,19 +88,26 @@ namespace gfx_test slangSession = device->getSlangSession(); ComPtr<slang::IBlob> diagnosticsBlob; - slang::IModule* module = slangSession->loadModule("ray-tracing-test-shader", diagnosticsBlob.writeRef()); + slang::IModule* module = slangSession->loadModule("ray-tracing-test-shaders", diagnosticsBlob.writeRef()); if (!module) return SLANG_FAIL; Slang::List<slang::IComponentType*> componentTypes; componentTypes.add(module); ComPtr<slang::IEntryPoint> entryPoint; - SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShader", entryPoint.writeRef())); + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShaderA", entryPoint.writeRef())); + componentTypes.add(entryPoint); + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShaderB", entryPoint.writeRef())); + componentTypes.add(entryPoint); + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShaderA", entryPoint.writeRef())); componentTypes.add(entryPoint); - SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShader", entryPoint.writeRef())); + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShaderB", entryPoint.writeRef())); componentTypes.add(entryPoint); SLANG_RETURN_ON_FAIL( - module->findEntryPointByName("closestHitShader", entryPoint.writeRef())); + module->findEntryPointByName("closestHitShaderA", entryPoint.writeRef())); + componentTypes.add(entryPoint); + SLANG_RETURN_ON_FAIL( + module->findEntryPointByName("closestHitShaderB", entryPoint.writeRef())); componentTypes.add(entryPoint); ComPtr<slang::IComponentType> linkedProgram; @@ -334,17 +343,19 @@ namespace gfx_test queue->waitOnHost(); } - const char* hitgroupNames[] = { "hitgroup" }; + const char* hitgroupNames[] = { "hitgroupA", "hitgroupB"}; ComPtr<IShaderProgram> rayTracingProgram; GFX_CHECK_CALL_ABORT( loadShaderProgram(device, rayTracingProgram.writeRef())); RayTracingPipelineStateDesc rtpDesc = {}; rtpDesc.program = rayTracingProgram; - rtpDesc.hitGroupCount = 1; - HitGroupDesc hitGroups[1]; - hitGroups[0].closestHitEntryPoint = "closestHitShader"; + rtpDesc.hitGroupCount = 2; + HitGroupDesc hitGroups[2]; + hitGroups[0].closestHitEntryPoint = "closestHitShaderA"; hitGroups[0].hitGroupName = hitgroupNames[0]; + hitGroups[1].closestHitEntryPoint = "closestHitShaderB"; + hitGroups[1].hitGroupName = hitgroupNames[1]; rtpDesc.hitGroups = hitGroups; rtpDesc.maxRayPayloadSize = 64; rtpDesc.maxRecursion = 2; @@ -352,19 +363,38 @@ namespace gfx_test device->createRayTracingPipelineState(rtpDesc, renderPipelineState.writeRef())); SLANG_CHECK_ABORT(renderPipelineState != nullptr); + const char* raygenNames[] = { "rayGenShaderA", "rayGenShaderB" }; + const char* missNames[] = { "missShaderA", "missShaderB" }; + IShaderTable::Desc shaderTableDesc = {}; - const char* raygenName = "rayGenShader"; - const char* missName = "missShader"; shaderTableDesc.program = rayTracingProgram; - shaderTableDesc.hitGroupCount = 1; + shaderTableDesc.hitGroupCount = 2; shaderTableDesc.hitGroupNames = hitgroupNames; - shaderTableDesc.rayGenShaderCount = 1; - shaderTableDesc.rayGenShaderEntryPointNames = &raygenName; - shaderTableDesc.missShaderCount = 1; - shaderTableDesc.missShaderEntryPointNames = &missName; + shaderTableDesc.rayGenShaderCount = 2; + shaderTableDesc.rayGenShaderEntryPointNames = raygenNames; + shaderTableDesc.missShaderCount = 2; + shaderTableDesc.missShaderEntryPointNames = missNames; GFX_CHECK_CALL_ABORT(device->createShaderTable(shaderTableDesc, shaderTable.writeRef())); } + void checkTestResults(float* expectedResult) + { + ComPtr<ISlangBlob> resultBlob; + size_t rowPitch = 0; + size_t pixelSize = 0; + GFX_CHECK_CALL_ABORT(device->readTextureResource( + resultTexture, ResourceState::CopySource, resultBlob.writeRef(), &rowPitch, &pixelSize)); + + writeImage("C:/Users/lucchen/Documents/test.hdr", resultBlob, width, height, rowPitch, pixelSize); + + auto buffer = removePadding(resultBlob, width, height, rowPitch, pixelSize); + auto actualData = (float*)buffer.getBuffer(); + SLANG_CHECK(memcmp(actualData, expectedResult, sizeof(expectedResult)) == 0) + } + }; + + struct RayTracingTestA : BaseRayTracingTest + { void renderFrame() { ComPtr<ICommandBuffer> renderCommandBuffer = @@ -382,48 +412,76 @@ namespace gfx_test queue->waitOnHost(); } - void checkTestResults() + void run() { - ComPtr<ISlangBlob> resultBlob; - size_t rowPitch = 0; - size_t pixelSize = 0; - GFX_CHECK_CALL_ABORT(device->readTextureResource( - resultTexture, ResourceState::CopySource, resultBlob.writeRef(), &rowPitch, &pixelSize)); - - //writeImage("C:/Users/lucchen/Documents/test.hdr", resultBlob, width, height, rowPitch, pixelSize); + createRequiredResources(); + renderFrame(); - auto buffer = removePadding(resultBlob, width, height, rowPitch, pixelSize); float expectedResult[16] = { 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1 }; - auto actualData = (float*)buffer.getBuffer(); - SLANG_CHECK(memcmp(actualData, expectedResult, sizeof(expectedResult)) == 0) + checkTestResults(expectedResult); + } + }; + + struct RayTracingTestB : BaseRayTracingTest + { + void renderFrame() + { + ComPtr<ICommandBuffer> renderCommandBuffer = + transientHeap->createCommandBuffer(); + auto renderEncoder = renderCommandBuffer->encodeRayTracingCommands(); + IShaderObject* rootObject = nullptr; + renderEncoder->bindPipeline(renderPipelineState, &rootObject); + auto cursor = ShaderCursor(rootObject); + cursor["resultTexture"].setResource(resultTextureUAV); + cursor["sceneBVH"].setResource(TLAS); + renderEncoder->dispatchRays(1, shaderTable, width, height, 1); + renderEncoder->endEncoding(); + renderCommandBuffer->close(); + queue->executeCommandBuffer(renderCommandBuffer); + queue->waitOnHost(); } void run() { createRequiredResources(); renderFrame(); - checkTestResults(); + + float expectedResult[16] = { 0, 0, 0, 0, + 1, 1, 0, 1, + 1, 0, 1, 1, + 0, 1, 1, 1 }; + checkTestResults(expectedResult); } }; - void simpleRayTracingTestImpl(IDevice* device, UnitTestContext* context) + template <typename T> + void rayTracingTestImpl(IDevice* device, UnitTestContext* context) { - BaseRayTracingTest test; + T test; test.init(device, context); test.run(); } - SLANG_UNIT_TEST(simpleRayTracingD3D12) + SLANG_UNIT_TEST(RayTracingTestAD3D12) { - runTestImpl(simpleRayTracingTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + runTestImpl(rayTracingTestImpl<RayTracingTestA>, unitTestContext, Slang::RenderApiFlag::D3D12); } - SLANG_UNIT_TEST(simpleRayTracingVulkan) + SLANG_UNIT_TEST(RayTracingTestAVulkan) { - runTestImpl(simpleRayTracingTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + runTestImpl(rayTracingTestImpl<RayTracingTestA>, unitTestContext, Slang::RenderApiFlag::Vulkan); } + SLANG_UNIT_TEST(RayTracingTestBD3D12) + { + runTestImpl(rayTracingTestImpl<RayTracingTestB>, unitTestContext, Slang::RenderApiFlag::D3D12); + } + + SLANG_UNIT_TEST(RayTracingTestBVulkan) + { + runTestImpl(rayTracingTestImpl<RayTracingTestB>, unitTestContext, Slang::RenderApiFlag::Vulkan); + } } diff --git a/tools/gfx/vulkan/vk-command-encoder.cpp b/tools/gfx/vulkan/vk-command-encoder.cpp index 00915e0c4..33cd567c7 100644 --- a/tools/gfx/vulkan/vk-command-encoder.cpp +++ b/tools/gfx/vulkan/vk-command-encoder.cpp @@ -1460,14 +1460,15 @@ void RayTracingCommandEncoder::dispatchRays( m_currentPipeline, m_commandBuffer->m_transientHeap, static_cast<ResourceCommandEncoder*>(this)); + auto shaderTableAddr = shaderTableBuffer->getDeviceAddress(); VkStridedDeviceAddressRegionKHR raygenSBT; - raygenSBT.deviceAddress = shaderTableBuffer->getDeviceAddress(); raygenSBT.stride = VulkanUtil::calcAligned(alignedHandleSize, rtProps.shaderGroupBaseAlignment); + raygenSBT.deviceAddress = shaderTableAddr + raygenShaderIndex * raygenSBT.stride; raygenSBT.size = raygenSBT.stride; VkStridedDeviceAddressRegionKHR missSBT; - missSBT.deviceAddress = raygenSBT.deviceAddress + raygenSBT.size; + missSBT.deviceAddress = shaderTableAddr + shaderTableImpl->m_raygenTableSize; missSBT.stride = alignedHandleSize; missSBT.size = shaderTableImpl->m_missTableSize; |
