diff options
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 6 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.cpp | 8 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.h | 2 | ||||
| -rw-r--r-- | tools/gfx/vulkan/render-vk.cpp | 271 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-api.h | 15 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-util.cpp | 2 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-util.h | 3 |
7 files changed, 273 insertions, 34 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index 12f8bafa6..9be157605 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -3512,7 +3512,7 @@ public: copyShaderIdInto( stagingBufferPtr + m_rayGenTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, - m_entryPointNames[i], + m_shaderGroupNames[i], m_recordOverwrites[i]); } for (uint32_t i = 0; i < m_missShaderCount; i++) @@ -3520,7 +3520,7 @@ public: copyShaderIdInto( stagingBufferPtr + m_missTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, - m_entryPointNames[m_rayGenShaderCount + i], + m_shaderGroupNames[m_rayGenShaderCount + i], m_recordOverwrites[m_rayGenShaderCount + i]); } for (uint32_t i = 0; i < m_hitGroupCount; i++) @@ -3528,7 +3528,7 @@ public: copyShaderIdInto( stagingBufferPtr + m_hitGroupTableOffset + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, - m_entryPointNames[m_rayGenShaderCount + m_missShaderCount + i], + m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + i], m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + i]); } diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index a97462c73..244146ca3 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -956,11 +956,11 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc) m_rayGenShaderCount = desc.rayGenShaderCount; m_missShaderCount = desc.missShaderCount; m_hitGroupCount = desc.hitGroupCount; - m_entryPointNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount); + m_shaderGroupNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount); m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount); for (uint32_t i = 0; i < desc.rayGenShaderCount; i++) { - m_entryPointNames.add(desc.rayGenShaderEntryPointNames[i]); + m_shaderGroupNames.add(desc.rayGenShaderEntryPointNames[i]); if (desc.rayGenShaderRecordOverwrites) { m_recordOverwrites.add(desc.rayGenShaderRecordOverwrites[i]); @@ -972,7 +972,7 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc) } for (uint32_t i = 0; i < desc.missShaderCount; i++) { - m_entryPointNames.add(desc.missShaderEntryPointNames[i]); + m_shaderGroupNames.add(desc.missShaderEntryPointNames[i]); if (desc.missShaderRecordOverwrites) { m_recordOverwrites.add(desc.missShaderRecordOverwrites[i]); @@ -984,7 +984,7 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc) } for (uint32_t i = 0; i < desc.hitGroupCount; i++) { - m_entryPointNames.add(desc.hitGroupNames[i]); + m_shaderGroupNames.add(desc.hitGroupNames[i]); if (desc.hitGroupRecordOverwrites) { m_recordOverwrites.add(desc.hitGroupRecordOverwrites[i]); diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 4ebd28ed2..5afc77027 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -1076,7 +1076,7 @@ class ShaderTableBase , public Slang::ComObject { public: - Slang::List<Slang::String> m_entryPointNames; + Slang::List<Slang::String> m_shaderGroupNames; Slang::List<ShaderRecordOverwrite> m_recordOverwrites; uint32_t m_rayGenShaderCount; diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index f11ca59aa..1a42f3261 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -114,6 +114,8 @@ public: createMutableRootShaderObject( IShaderProgram* program, IShaderObject** outObject) override; + virtual SLANG_NO_THROW Result SLANG_MCALL + createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) override; virtual SLANG_NO_THROW Result SLANG_MCALL createProgram( const IShaderProgram::Desc& desc, IShaderProgram** outProgram, @@ -1013,6 +1015,17 @@ public: VkPipeline m_pipeline = VK_NULL_HANDLE; }; + class RayTracingPipelineStateImpl : public PipelineStateImpl + { + public: + Dictionary<String, Index> shaderGroupNameToIndex; + Int shaderGroupCount; + + RayTracingPipelineStateImpl(VKDevice* device) + : PipelineStateImpl(device) + {}; + }; + // In order to bind shader parameters to the correct locations, we need to // be able to describe those locations. Most shader parameters in Vulkan // simply consume a single `binding`, but we also need to deal with @@ -2399,6 +2412,7 @@ public: BreakableReference<VKDevice> m_device; Array<VkPipelineShaderStageCreateInfo, 8> m_stageCreateInfos; + Array<String, 8> m_entryPointNames; Array<ComPtr<ISlangBlob>, 8> m_codeBlobs; //< To keep storage of code in scope Array<VkShaderModule, 8> m_modules; RefPtr<RootShaderObjectLayout> m_rootObjectLayout; @@ -3885,6 +3899,119 @@ public: List<RefPtr<EntryPointShaderObject>> m_entryPoints; }; + class ShaderTableImpl : public ShaderTableBase + { + public: + uint32_t m_raygenTableSize; + uint32_t m_missTableSize; + uint32_t m_hitTableSize; + uint32_t m_callableTableSize; + + VKDevice* m_device; + + virtual RefPtr<BufferResource> createDeviceBuffer( + PipelineStateBase* pipeline, + TransientResourceHeapBase* transientHeap, + IResourceCommandEncoder* encoder) override + { + auto vkApi = m_device->m_api; + auto rtProps = vkApi.m_rtProperties; + uint32_t handleSize = rtProps.shaderGroupHandleSize; + m_raygenTableSize = (uint32_t)VulkanUtil::calcAligned(m_rayGenShaderCount * handleSize, rtProps.shaderGroupBaseAlignment); + m_missTableSize = (uint32_t)VulkanUtil::calcAligned(m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment); + m_hitTableSize = (uint32_t)VulkanUtil::calcAligned(m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment); + m_callableTableSize = 0; // TODO: Are callable shaders needed? + uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize; + + auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline); + ComPtr<IBufferResource> bufferResource; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.memoryType = MemoryType::DeviceLocal; + bufferDesc.defaultState = ResourceState::General; + bufferDesc.allowedStates = ResourceStateSet(ResourceState::General, ResourceState::CopyDestination); + bufferDesc.type = IResource::Type::Buffer; + bufferDesc.sizeInBytes = tableSize; + m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef()); + + TransientResourceHeapImpl* transientHeapImpl = + static_cast<TransientResourceHeapImpl*>(transientHeap); + + IBufferResource* stagingBuffer = nullptr; + transientHeapImpl->allocateStagingBuffer( + tableSize, stagingBuffer, ResourceState::General); + + assert(stagingBuffer); + void* stagingPtr = nullptr; + stagingBuffer->map(nullptr, &stagingPtr); + + List<uint8_t> handles; + auto handleCount = pipelineImpl->shaderGroupCount; + auto totalHandleSize = handleSize * handleCount; + handles.setCount(totalHandleSize); + auto result = vkApi.vkGetRayTracingShaderGroupHandlesKHR(m_device->m_device, pipelineImpl->m_pipeline, 0, (uint32_t)handleCount, totalHandleSize, handles.getBuffer()); + + uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr; + auto subTablePtr = stagingBufferPtr; + Int shaderTableEntryCounter = 0; + + // Each loop calculates the copy source and destination locations by fetching the name of the shader + // group from the list of shader group names and getting its corresponding index in the buffer of handles. + for (uint32_t i = 0; i < m_rayGenShaderCount; i++) + { + auto dstHandlePtr = subTablePtr + i * handleSize; + auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++]; + auto shaderGroupIndexPtr = pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName); + if (!shaderGroupIndexPtr) + continue; + + auto shaderGroupIndex = *shaderGroupIndexPtr; + auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize; + memcpy(dstHandlePtr, srcHandlePtr, handleSize); + } + subTablePtr += m_raygenTableSize; + + for (uint32_t i = 0; i < m_missShaderCount; i++) + { + auto dstHandlePtr = subTablePtr + i * handleSize; + auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++]; + auto shaderGroupIndexPtr = pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName); + if (!shaderGroupIndexPtr) + continue; + + auto shaderGroupIndex = *shaderGroupIndexPtr; + auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize; + memcpy(dstHandlePtr, srcHandlePtr, handleSize); + } + subTablePtr += m_missTableSize; + + for (uint32_t i = 0; i < m_hitGroupCount; i++) + { + auto dstHandlePtr = subTablePtr + i * handleSize; + auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++]; + auto shaderGroupIndexPtr = pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName); + if (!shaderGroupIndexPtr) + continue; + + auto shaderGroupIndex = *shaderGroupIndexPtr; + auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize; + memcpy(dstHandlePtr, srcHandlePtr, handleSize); + } + subTablePtr += m_hitTableSize; + + // TODO: Callable shaders? + + stagingBuffer->unmap(nullptr); + encoder->copyBuffer(bufferResource, 0, stagingBuffer, 0, tableSize); + encoder->bufferBarrier( + 1, + bufferResource.readRef(), + gfx::ResourceState::CopyDestination, + gfx::ResourceState::ShaderResource); + RefPtr<BufferResource> resultPtr = static_cast<BufferResource*>(bufferResource.get()); + return _Move(resultPtr); + } + }; + class CommandBufferImpl : public ICommandBuffer , public ComObject @@ -5506,11 +5633,9 @@ public: virtual SLANG_NO_THROW void SLANG_MCALL bindPipeline(IPipelineState* pipeline, IShaderObject** outRootObject) override { - SLANG_UNUSED(pipeline); - SLANG_UNUSED(outRootObject); + setPipelineStateImpl(pipeline, outRootObject); } - // TODO: Implement after implementing createRayTracingPipelineState virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( uint32_t raygenShaderIndex, IShaderTable* shaderTable, @@ -5518,15 +5643,46 @@ public: int32_t height, int32_t depth) override { - SLANG_UNUSED(raygenShaderIndex); - SLANG_UNUSED(shaderTable); - SLANG_UNUSED(width); - SLANG_UNUSED(height); - SLANG_UNUSED(depth); + auto vkApi = m_commandBuffer->m_renderer->m_api; + auto vkCommandBuffer = m_commandBuffer->m_commandBuffer; + + flushBindingState(VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR); + + auto rtProps = vkApi.m_rtProperties; + auto shaderTableImpl = (ShaderTableImpl*)shaderTable; + auto alignedHandleSize = VulkanUtil::calcAligned(rtProps.shaderGroupHandleSize, rtProps.shaderGroupHandleAlignment); + + ResourceCommandEncoder resourceCopyEncoder; + resourceCopyEncoder.init(m_commandBuffer); + auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(m_currentPipeline, m_commandBuffer->m_transientHeap, &resourceCopyEncoder); + + VkStridedDeviceAddressRegionKHR raygenSBT; + raygenSBT.deviceAddress = shaderTableBuffer->getDeviceAddress(); + raygenSBT.stride = VulkanUtil::calcAligned(alignedHandleSize, rtProps.shaderGroupBaseAlignment); + raygenSBT.size = raygenSBT.stride; + + VkStridedDeviceAddressRegionKHR missSBT; + missSBT.deviceAddress = raygenSBT.deviceAddress + raygenSBT.size; + missSBT.stride = alignedHandleSize; + missSBT.size = shaderTableImpl->m_missTableSize; + + VkStridedDeviceAddressRegionKHR hitSBT; + hitSBT.deviceAddress = missSBT.deviceAddress + missSBT.size; + hitSBT.stride = alignedHandleSize; + hitSBT.size = shaderTableImpl->m_hitTableSize; + + // TODO: Are callable shaders needed? + VkStridedDeviceAddressRegionKHR callableSBT; + callableSBT.deviceAddress = 0; + callableSBT.stride = 0; + callableSBT.size = 0; + + vkApi.vkCmdTraceRaysKHR(vkCommandBuffer, &raygenSBT, &missSBT, &hitSBT, &callableSBT, (uint32_t)width, (uint32_t)height, (uint32_t)depth); } virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override { + endEncodingImpl(); } }; @@ -6811,6 +6967,10 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool extendedFeatures.rayQueryFeatures.pNext = deviceFeatures2.pNext; deviceFeatures2.pNext = &extendedFeatures.rayQueryFeatures; + // Ray tracing pipeline features + extendedFeatures.rayTracingPipelineFeatures.pNext = deviceFeatures2.pNext; + deviceFeatures2.pNext = &extendedFeatures.rayTracingPipelineFeatures; + // Acceleration structure features extendedFeatures.accelerationStructureFeatures.pNext = deviceFeatures2.pNext; deviceFeatures2.pNext = &extendedFeatures.accelerationStructureFeatures; @@ -6941,6 +7101,7 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool deviceExtensions.add(VK_KHR_SHADER_SUBGROUP_EXTENDED_TYPES_EXTENSION_NAME); m_features.add("shader-subgroup-extended-types"); } + if (extendedFeatures.accelerationStructureFeatures.accelerationStructure) { extendedFeatures.accelerationStructureFeatures.pNext = (void*)deviceCreateInfo.pNext; @@ -6949,6 +7110,15 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool deviceExtensions.add(VK_KHR_DEFERRED_HOST_OPERATIONS_EXTENSION_NAME); m_features.add("acceleration-structure"); } + + if (extendedFeatures.rayTracingPipelineFeatures.rayTracingPipeline) + { + extendedFeatures.rayTracingPipelineFeatures.pNext = (void*)deviceCreateInfo.pNext; + deviceCreateInfo.pNext = &extendedFeatures.rayTracingPipelineFeatures; + deviceExtensions.add(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME); + m_features.add("ray-tracing-pipeline"); + } + if (extendedFeatures.rayQueryFeatures.rayQuery) { extendedFeatures.rayQueryFeatures.pNext = (void*)deviceCreateInfo.pNext; @@ -6957,6 +7127,7 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool m_features.add("ray-query"); m_features.add("ray-tracing"); } + if (extendedFeatures.bufferDeviceAddressFeatures.bufferDeviceAddress) { extendedFeatures.bufferDeviceAddressFeatures.pNext = (void*)deviceCreateInfo.pNext; @@ -6982,6 +7153,12 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool m_features.add("robustness2"); } + VkPhysicalDeviceProperties2 extendedProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2 }; + VkPhysicalDeviceRayTracingPipelinePropertiesKHR rtProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_TRACING_PIPELINE_PROPERTIES_KHR }; + extendedProps.pNext = &rtProps; + m_api.vkGetPhysicalDeviceProperties2(m_api.m_physicalDevice, &extendedProps); + m_api.m_rtProperties = rtProps; + uint32_t extensionCount = 0; m_api.vkEnumerateDeviceExtensionProperties(m_api.m_physicalDevice, NULL, &extensionCount, NULL); Slang::List<VkExtensionProperties> extensions; @@ -8672,14 +8849,16 @@ Result VKDevice::createProgram( // uses "main" as the name. We should introduce a compiler parameter // to control the entry point naming behavior in SPIRV-direct path // so we can remove the ad-hoc logic here. - const char* entryPointName = "main"; + auto realEntryPointName = entryPointInfo->getNameOverride(); + const char* spirvBinaryEntryPointName = "main"; if (m_desc.slang.targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) - entryPointName = entryPointInfo->getName(); + spirvBinaryEntryPointName = realEntryPointName; shaderProgram->m_stageCreateInfos.add(compileEntryPoint( - entryPointName, + spirvBinaryEntryPointName, kernelCode, (VkShaderStageFlagBits)VulkanUtil::getShaderStage(stage), shaderModule)); + shaderProgram->m_entryPointNames.add(realEntryPointName); shaderProgram->m_modules.add(shaderModule); return SLANG_OK; }; @@ -8749,6 +8928,15 @@ Result VKDevice::createMutableRootShaderObject( return SLANG_OK; } +Result VKDevice::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) +{ + RefPtr<ShaderTableImpl> result = new ShaderTableImpl(); + result->m_device = this; + result->init(desc); + returnComPtr(outShaderTable, result); + return SLANG_OK; +} + VkSampleCountFlagBits translateSampleCount(uint32_t sampleCount) { switch (sampleCount) @@ -8966,6 +9154,12 @@ Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& in auto& vkBlendDesc = colorBlendAttachments[0]; memset(&vkBlendDesc, 0, sizeof(vkBlendDesc)); vkBlendDesc.blendEnable = VK_FALSE; + vkBlendDesc.srcColorBlendFactor = VK_BLEND_FACTOR_ONE; + vkBlendDesc.dstColorBlendFactor = VK_BLEND_FACTOR_ONE; + vkBlendDesc.colorBlendOp = VK_BLEND_OP_ADD; + vkBlendDesc.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE; + vkBlendDesc.dstAlphaBlendFactor = VK_BLEND_FACTOR_ONE; + vkBlendDesc.alphaBlendOp = VK_BLEND_OP_ADD; vkBlendDesc.colorWriteMask = (VkColorComponentFlags)RenderTargetWriteMask::EnableAll; } else @@ -9101,13 +9295,14 @@ VkPipelineCreateFlags translateFlags(RayTracingPipelineFlags::Enum flags) return vkFlags; } -uint32_t findShaderIndexByName(const VkPipelineShaderStageCreateInfo* stageCreateInfos, size_t stageCount, const char* name) +uint32_t findEntryPointIndexByName(const Dictionary<String, Index>& entryPointNameToIndex, const char* name) { - // TODO: Linear search is inefficient, use a Dictionary? - for (size_t i = 0; i < stageCount; ++i) - { - if (strcmp(stageCreateInfos[i].pName, name)) return (uint32_t)i; - } + if (!name) return VK_SHADER_UNUSED_KHR; + + auto indexPtr = entryPointNameToIndex.TryGetValue(String(name)); + if (indexPtr) + return (uint32_t)*indexPtr; + // TODO: Error reporting? return VK_SHADER_UNUSED_KHR; } @@ -9128,11 +9323,39 @@ Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc raytracingPipelineInfo.pNext = nullptr; raytracingPipelineInfo.flags = translateFlags(desc.flags); - VkPipelineShaderStageCreateInfo shaderStageInfo = { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO }; raytracingPipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount(); raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer(); + // Build Dictionary from group name to group index + Dictionary<String, Index> shaderGroupNameToIndex; + // Build Dictionary from entry point name to entry point index (stageCreateInfos index) for all hit shaders - findShaderIndexByName + Dictionary<String, Index> entryPointNameToIndex; + List<VkRayTracingShaderGroupCreateInfoKHR> shaderGroupInfos; + for (uint32_t i = 0; i < raytracingPipelineInfo.stageCount; ++i) + { + auto stageCreateInfo = programImpl->m_stageCreateInfos[i]; + auto entryPointName = programImpl->m_entryPointNames[i]; + entryPointNameToIndex.Add(entryPointName, i); + if (stageCreateInfo.stage & (VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR)) + continue; + + VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR }; + shaderGroupInfo.pNext = nullptr; + shaderGroupInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; + shaderGroupInfo.generalShader = i; + shaderGroupInfo.closestHitShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.anyHitShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.intersectionShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr; + + // For groups with a single entry point, the group name is the entry point name. + auto shaderGroupName = entryPointName; + auto shaderGroupIndex = shaderGroupInfos.getCount(); + shaderGroupInfos.add(shaderGroupInfo); + shaderGroupNameToIndex.Add(shaderGroupName, shaderGroupIndex); + } + for (int32_t i = 0; i < desc.hitGroupCount; ++i) { VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR }; @@ -9142,12 +9365,14 @@ Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc shaderGroupInfo.type = (groupDesc.intersectionEntryPoint) ? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR; shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR; - shaderGroupInfo.closestHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.closestHitEntryPoint); - shaderGroupInfo.anyHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.anyHitEntryPoint); - shaderGroupInfo.intersectionShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.intersectionEntryPoint); + shaderGroupInfo.closestHitShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.closestHitEntryPoint); + shaderGroupInfo.anyHitShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.anyHitEntryPoint); + shaderGroupInfo.intersectionShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.intersectionEntryPoint); shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr; + auto shaderGroupIndex = shaderGroupInfos.getCount(); shaderGroupInfos.add(shaderGroupInfo); + shaderGroupNameToIndex.Add(String(groupDesc.hitGroupName), shaderGroupIndex); } raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount(); @@ -9168,10 +9393,12 @@ Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc VkPipeline pipeline = VK_NULL_HANDLE; SLANG_VK_CHECK(m_api.vkCreateRayTracingPipelinesKHR(m_device, VK_NULL_HANDLE, pipelineCache, 1, &raytracingPipelineInfo, nullptr, &pipeline)); - RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this); + RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl(this); pipelineStateImpl->m_pipeline = pipeline; pipelineStateImpl->init(desc); pipelineStateImpl->establishStrongDeviceReference(); + pipelineStateImpl->shaderGroupNameToIndex = shaderGroupNameToIndex; + pipelineStateImpl->shaderGroupCount = shaderGroupInfos.getCount(); m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); returnComPtr(outState, pipelineStateImpl); return SLANG_OK; diff --git a/tools/gfx/vulkan/vk-api.h b/tools/gfx/vulkan/vk-api.h index 8c87e6254..fa601ef35 100644 --- a/tools/gfx/vulkan/vk-api.h +++ b/tools/gfx/vulkan/vk-api.h @@ -164,6 +164,8 @@ namespace gfx { x(vkDestroySwapchainKHR) \ x(vkAcquireNextImageKHR) \ x(vkCreateRayTracingPipelinesKHR) \ + x(vkCmdTraceRaysKHR) \ + x(vkGetRayTracingShaderGroupHandlesKHR) \ /* */ #if SLANG_WINDOWS_FAMILY @@ -251,6 +253,10 @@ struct VulkanExtendedFeatureProperties // Acceleration structure features VkPhysicalDeviceAccelerationStructureFeaturesKHR accelerationStructureFeatures = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ACCELERATION_STRUCTURE_FEATURES_KHR}; + // Ray tracing pipeline features + VkPhysicalDeviceRayTracingPipelineFeaturesKHR rayTracingPipelineFeatures = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_TRACING_PIPELINE_FEATURES_KHR + }; // Ray query (inline ray-tracing) features VkPhysicalDeviceRayQueryFeaturesKHR rayQueryFeatures = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_QUERY_FEATURES_KHR}; @@ -303,10 +309,11 @@ struct VulkanApi VkDevice m_device = VK_NULL_HANDLE; VkPhysicalDevice m_physicalDevice = VK_NULL_HANDLE; - VkPhysicalDeviceProperties m_deviceProperties; - VkPhysicalDeviceFeatures m_deviceFeatures; - VkPhysicalDeviceMemoryProperties m_deviceMemoryProperties; - VulkanExtendedFeatureProperties m_extendedFeatures; + VkPhysicalDeviceProperties m_deviceProperties; + VkPhysicalDeviceRayTracingPipelinePropertiesKHR m_rtProperties; + VkPhysicalDeviceFeatures m_deviceFeatures; + VkPhysicalDeviceMemoryProperties m_deviceMemoryProperties; + VulkanExtendedFeatureProperties m_extendedFeatures; }; } // renderer_test diff --git a/tools/gfx/vulkan/vk-util.cpp b/tools/gfx/vulkan/vk-util.cpp index 633239f24..974a61422 100644 --- a/tools/gfx/vulkan/vk-util.cpp +++ b/tools/gfx/vulkan/vk-util.cpp @@ -140,6 +140,8 @@ VkShaderStageFlags VulkanUtil::getShaderStage(SlangStage stage) return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; case SLANG_STAGE_INTERSECTION: return VK_SHADER_STAGE_INTERSECTION_BIT_KHR; + case SLANG_STAGE_MISS: + return VK_SHADER_STAGE_MISS_BIT_KHR; case SLANG_STAGE_RAY_GENERATION: return VK_SHADER_STAGE_RAYGEN_BIT_KHR; case SLANG_STAGE_VERTEX: diff --git a/tools/gfx/vulkan/vk-util.h b/tools/gfx/vulkan/vk-util.h index cdf7bcc79..4c32b1615 100644 --- a/tools/gfx/vulkan/vk-util.h +++ b/tools/gfx/vulkan/vk-util.h @@ -44,6 +44,9 @@ struct VulkanUtil static VkImageLayout getImageLayoutFromState(ResourceState state); + /// Calculate size taking into account alignment. Alignment must be a power of 2 + static UInt calcAligned(UInt size, UInt alignment) { return (size + alignment - 1) & ~(alignment - 1); } + static inline bool isDepthFormat(VkFormat format) { switch (format) |
