diff options
| author | Yong He <yonghe@outlook.com> | 2022-03-08 14:34:53 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-03-08 14:34:53 -0800 |
| commit | dcb434a5fe801d42d1b5f385fd27d0c500687647 (patch) | |
| tree | 600d63ccce42e0ca3b5c63df23013044c4cea96c | |
| parent | 771f29435d664f7344bc5596056146af5d64d352 (diff) | |
GFX Vulkan: deferred shader compilation and pipeline creation. (#2153)
* Vulkan: deferred shader compilation and pipeline creation.
* Fix 32bit build.
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/core/slang-dictionary.h | 4 | ||||
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 66 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.cpp | 56 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.h | 4 | ||||
| -rw-r--r-- | tools/gfx/vulkan/render-vk.cpp | 1169 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-util.cpp | 271 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-util.h | 28 |
7 files changed, 867 insertions, 731 deletions
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h index eef7d6908..6a8d03101 100644 --- a/source/core/slang-dictionary.h +++ b/source/core/slang-dictionary.h @@ -167,7 +167,7 @@ namespace Slang return FindPositionResult(hashPos, -1); } numProbes++; - hashPos = (hashPos + GetProbeOffset(numProbes)) & bucketSizeMinusOne; + hashPos = (hashPos + GetProbeOffset(numProbes)) % bucketSizeMinusOne; } if (insertPos != -1) return FindPositionResult(-1, insertPos); @@ -738,7 +738,7 @@ namespace Slang return FindPositionResult(hashPos, -1); } numProbes++; - hashPos = (hashPos + GetProbeOffset(numProbes)) & bucketSizeMinusOne; + hashPos = (hashPos + GetProbeOffset(numProbes)) % bucketSizeMinusOne; } if (insertPos != -1) return FindPositionResult(-1, insertPos); diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index a6a0c7790..0fc3b2874 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -2479,68 +2479,26 @@ public: { SlangStage stage; slang::EntryPointReflection* entryPointInfo; + String actualEntryPointNameInAPI; List<uint8_t> code; }; class ShaderProgramImpl : public ShaderProgramBase { public: - List<ShaderBinary> m_shaders; RefPtr<RootShaderObjectLayoutImpl> m_rootObjectLayout; - Result compileShaders() - { - // For a fully specialized program, read and store its kernel code in `shaderProgram`. - auto compileShader = [&](slang::EntryPointReflection* entryPointInfo, - slang::IComponentType* entryPointComponent, - SlangInt entryPointIndex) - { - auto stage = entryPointInfo->getStage(); - ComPtr<ISlangBlob> kernelCode; - ComPtr<ISlangBlob> diagnostics; - auto compileResult = entryPointComponent->getEntryPointCode( - entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef()); - if (diagnostics) - { - getDebugCallback()->handleMessage( - compileResult == SLANG_OK ? DebugMessageType::Warning - : DebugMessageType::Error, - DebugMessageSource::Slang, - (char*)diagnostics->getBufferPointer()); - } - SLANG_RETURN_ON_FAIL(compileResult); - ShaderBinary shaderBin; - shaderBin.stage = stage; - shaderBin.entryPointInfo = entryPointInfo; - shaderBin.code.addRange( - reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()), - (Index)kernelCode->getBufferSize()); - m_shaders.add(_Move(shaderBin)); - return SLANG_OK; - }; + List<ShaderBinary> m_shaders; - if (linkedEntryPoints.getCount() == 0) - { - // If the user does not explicitly specify entry point components, find them from - // `linkedEntryPoints`. - auto programReflection = linkedProgram->getLayout(); - for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++) - { - SLANG_RETURN_ON_FAIL(compileShader( - programReflection->getEntryPointByIndex(i), - linkedProgram, - (SlangInt)i)); - } - } - else - { - // If the user specifies entry point components via the separated entry point array, - // compile code from there. - for (auto& entryPoint : linkedEntryPoints) - { - SLANG_RETURN_ON_FAIL(compileShader( - entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0)); - } - } + virtual Result createShaderModule( + slang::EntryPointReflection* entryPointInfo, ComPtr<ISlangBlob> kernelCode) override + { + ShaderBinary shaderBin; + shaderBin.stage = entryPointInfo->getStage(); + shaderBin.entryPointInfo = entryPointInfo; + shaderBin.code.addRange( + reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()), + (Index)kernelCode->getBufferSize()); + m_shaders.add(_Move(shaderBin)); return SLANG_OK; } }; diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index b476bd284..d01f6fe72 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -827,6 +827,62 @@ void ShaderProgramBase::init(const IShaderProgram::Desc& inDesc) } } +Result ShaderProgramBase::compileShaders() +{ + // For a fully specialized program, read and store its kernel code in `shaderProgram`. + auto compileShader = [&](slang::EntryPointReflection* entryPointInfo, + slang::IComponentType* entryPointComponent, + SlangInt entryPointIndex) + { + auto stage = entryPointInfo->getStage(); + ComPtr<ISlangBlob> kernelCode; + ComPtr<ISlangBlob> diagnostics; + auto compileResult = entryPointComponent->getEntryPointCode( + entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef()); + if (diagnostics) + { + getDebugCallback()->handleMessage( + compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error, + DebugMessageSource::Slang, + (char*)diagnostics->getBufferPointer()); + } + SLANG_RETURN_ON_FAIL(compileResult); + SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode)); + return SLANG_OK; + }; + + if (linkedEntryPoints.getCount() == 0) + { + // If the user does not explicitly specify entry point components, find them from + // `linkedEntryPoints`. + auto programReflection = linkedProgram->getLayout(); + for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++) + { + SLANG_RETURN_ON_FAIL(compileShader( + programReflection->getEntryPointByIndex(i), linkedProgram, (SlangInt)i)); + } + } + else + { + // If the user specifies entry point components via the separated entry point array, + // compile code from there. + for (auto& entryPoint : linkedEntryPoints) + { + SLANG_RETURN_ON_FAIL( + compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0)); + } + } + return SLANG_OK; +} + +Result ShaderProgramBase::createShaderModule( + slang::EntryPointReflection* entryPointInfo, ComPtr<ISlangBlob> kernelCode) +{ + SLANG_UNUSED(entryPointInfo); + SLANG_UNUSED(kernelCode); + return SLANG_OK; +} + Result RendererBase::maybeSpecializePipeline( PipelineStateBase* currentPipeline, ShaderObjectBase* rootObject, diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 1814f97d9..5dc4da59a 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -853,6 +853,10 @@ public: } return false; } + + Slang::Result compileShaders(); + virtual Slang::Result createShaderModule( + slang::EntryPointReflection* entryPointInfo, Slang::ComPtr<ISlangBlob> kernelCode); }; class InputLayoutBase diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 65eee6d3e..5f3c8f8df 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -1041,6 +1041,256 @@ public: initializeBase(pipelineDesc); } + Result createVKGraphicsPipelineState() + { + VkPipelineCache pipelineCache = VK_NULL_HANDLE; + + auto inputLayoutImpl = (InputLayoutImpl*)desc.graphics.inputLayout; + + // VertexBuffer/s + VkPipelineVertexInputStateCreateInfo vertexInputInfo = { + VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO}; + vertexInputInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO; + vertexInputInfo.vertexBindingDescriptionCount = 0; + vertexInputInfo.vertexAttributeDescriptionCount = 0; + + if (inputLayoutImpl) + { + const auto& srcAttributeDescs = inputLayoutImpl->m_attributeDescs; + const auto& srcStreamDescs = inputLayoutImpl->m_streamDescs; + + vertexInputInfo.vertexBindingDescriptionCount = (uint32_t)srcStreamDescs.getCount(); + vertexInputInfo.pVertexBindingDescriptions = srcStreamDescs.getBuffer(); + + vertexInputInfo.vertexAttributeDescriptionCount = + (uint32_t)srcAttributeDescs.getCount(); + vertexInputInfo.pVertexAttributeDescriptions = srcAttributeDescs.getBuffer(); + } + + VkPipelineInputAssemblyStateCreateInfo inputAssembly = {}; + inputAssembly.sType = VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO; + // All other forms of primitive toplogies are specified via dynamic state. + inputAssembly.topology = + VulkanUtil::translatePrimitiveTypeToListTopology(desc.graphics.primitiveType); + inputAssembly.primitiveRestartEnable = VK_FALSE; // TODO: Currently unsupported + + VkViewport viewport = {}; + viewport.x = 0.0f; + viewport.y = 0.0f; + // We are using dynamic viewport and scissor state. + // Here we specify an arbitrary size, actual viewport will be set at `beginRenderPass` + // time. + viewport.width = 16.0f; + viewport.height = 16.0f; + viewport.minDepth = 0.0f; + viewport.maxDepth = 1.0f; + + VkRect2D scissor = {}; + scissor.offset = {0, 0}; + scissor.extent = {uint32_t(16), uint32_t(16)}; + + VkPipelineViewportStateCreateInfo viewportState = {}; + viewportState.sType = VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO; + viewportState.viewportCount = 1; + viewportState.pViewports = &viewport; + viewportState.scissorCount = 1; + viewportState.pScissors = &scissor; + + auto rasterizerDesc = desc.graphics.rasterizer; + + VkPipelineRasterizationStateCreateInfo rasterizer = {}; + rasterizer.sType = VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO; + rasterizer.depthClampEnable = + VK_TRUE; // TODO: Depth clipping and clamping are different between Vk and D3D12 + rasterizer.rasterizerDiscardEnable = VK_FALSE; // TODO: Currently unsupported + rasterizer.polygonMode = VulkanUtil::translateFillMode(rasterizerDesc.fillMode); + rasterizer.cullMode = VulkanUtil::translateCullMode(rasterizerDesc.cullMode); + rasterizer.frontFace = VulkanUtil::translateFrontFaceMode(rasterizerDesc.frontFace); + rasterizer.depthBiasEnable = (rasterizerDesc.depthBias == 0) ? VK_FALSE : VK_TRUE; + rasterizer.depthBiasConstantFactor = (float)rasterizerDesc.depthBias; + rasterizer.depthBiasClamp = rasterizerDesc.depthBiasClamp; + rasterizer.depthBiasSlopeFactor = rasterizerDesc.slopeScaledDepthBias; + rasterizer.lineWidth = 1.0f; // TODO: Currently unsupported + + VkPipelineRasterizationConservativeStateCreateInfoEXT conservativeRasterInfo = {}; + conservativeRasterInfo.sType = + VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_CONSERVATIVE_STATE_CREATE_INFO_EXT; + conservativeRasterInfo.conservativeRasterizationMode = + VK_CONSERVATIVE_RASTERIZATION_MODE_OVERESTIMATE_EXT; + if (desc.graphics.rasterizer.enableConservativeRasterization) + { + rasterizer.pNext = &conservativeRasterInfo; + } + + auto framebufferLayoutImpl = + static_cast<FramebufferLayoutImpl*>(desc.graphics.framebufferLayout); + auto forcedSampleCount = rasterizerDesc.forcedSampleCount; + auto blendDesc = desc.graphics.blend; + + VkPipelineMultisampleStateCreateInfo multisampling = {}; + multisampling.sType = VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO; + multisampling.rasterizationSamples = (forcedSampleCount == 0) + ? framebufferLayoutImpl->m_sampleCount + : VulkanUtil::translateSampleCount(forcedSampleCount); + multisampling.sampleShadingEnable = + VK_FALSE; // TODO: Should check if fragment shader needs this + // TODO: Sample mask is dynamic in D3D12 but PSO state in Vulkan + multisampling.alphaToCoverageEnable = blendDesc.alphaToCoverageEnable; + multisampling.alphaToOneEnable = VK_FALSE; + + auto targetCount = Math::Min( + framebufferLayoutImpl->m_renderTargetCount, (uint32_t)blendDesc.targetCount); + List<VkPipelineColorBlendAttachmentState> colorBlendAttachments; + + // Regardless of whether blending is enabled, Vulkan always applies the color write mask + // operation, so if there is no blending then we need to add an attachment that defines + // the color write mask to ensure colors are actually written. + if (targetCount == 0) + { + colorBlendAttachments.setCount(1); + auto& vkBlendDesc = colorBlendAttachments[0]; + memset(&vkBlendDesc, 0, sizeof(vkBlendDesc)); + vkBlendDesc.blendEnable = VK_FALSE; + vkBlendDesc.srcColorBlendFactor = VK_BLEND_FACTOR_ONE; + vkBlendDesc.dstColorBlendFactor = VK_BLEND_FACTOR_ONE; + vkBlendDesc.colorBlendOp = VK_BLEND_OP_ADD; + vkBlendDesc.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE; + vkBlendDesc.dstAlphaBlendFactor = VK_BLEND_FACTOR_ONE; + vkBlendDesc.alphaBlendOp = VK_BLEND_OP_ADD; + vkBlendDesc.colorWriteMask = + (VkColorComponentFlags)RenderTargetWriteMask::EnableAll; + } + else + { + colorBlendAttachments.setCount(targetCount); + for (UInt i = 0; i < targetCount; ++i) + { + auto& gfxBlendDesc = blendDesc.targets[i]; + auto& vkBlendDesc = colorBlendAttachments[i]; + + vkBlendDesc.blendEnable = gfxBlendDesc.enableBlend; + vkBlendDesc.srcColorBlendFactor = + VulkanUtil::translateBlendFactor(gfxBlendDesc.color.srcFactor); + vkBlendDesc.dstColorBlendFactor = + VulkanUtil::translateBlendFactor(gfxBlendDesc.color.dstFactor); + vkBlendDesc.colorBlendOp = VulkanUtil::translateBlendOp(gfxBlendDesc.color.op); + vkBlendDesc.srcAlphaBlendFactor = + VulkanUtil::translateBlendFactor(gfxBlendDesc.alpha.srcFactor); + vkBlendDesc.dstAlphaBlendFactor = + VulkanUtil::translateBlendFactor(gfxBlendDesc.alpha.dstFactor); + vkBlendDesc.alphaBlendOp = VulkanUtil::translateBlendOp(gfxBlendDesc.alpha.op); + vkBlendDesc.colorWriteMask = (VkColorComponentFlags)gfxBlendDesc.writeMask; + } + } + + VkPipelineColorBlendStateCreateInfo colorBlending = {}; + colorBlending.sType = VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO; + colorBlending.logicOpEnable = VK_FALSE; // TODO: D3D12 has per attachment logic op (and + // both have way more than one op) + colorBlending.logicOp = VK_LOGIC_OP_COPY; + colorBlending.attachmentCount = (uint32_t)colorBlendAttachments.getCount(); + colorBlending.pAttachments = colorBlendAttachments.getBuffer(); + colorBlending.blendConstants[0] = 0.0f; + colorBlending.blendConstants[1] = 0.0f; + colorBlending.blendConstants[2] = 0.0f; + colorBlending.blendConstants[3] = 0.0f; + + Array<VkDynamicState, 8> dynamicStates; + dynamicStates.add(VK_DYNAMIC_STATE_VIEWPORT); + dynamicStates.add(VK_DYNAMIC_STATE_SCISSOR); + dynamicStates.add(VK_DYNAMIC_STATE_STENCIL_REFERENCE); + dynamicStates.add(VK_DYNAMIC_STATE_BLEND_CONSTANTS); + if (m_device->m_api.m_extendedFeatures.extendedDynamicStateFeatures.extendedDynamicState) + { + dynamicStates.add(VK_DYNAMIC_STATE_PRIMITIVE_TOPOLOGY_EXT); + } + VkPipelineDynamicStateCreateInfo dynamicStateInfo = {}; + dynamicStateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO; + dynamicStateInfo.dynamicStateCount = (uint32_t)dynamicStates.getCount(); + dynamicStateInfo.pDynamicStates = dynamicStates.getBuffer(); + + VkPipelineDepthStencilStateCreateInfo depthStencilStateInfo = {}; + depthStencilStateInfo.sType = + VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO; + depthStencilStateInfo.depthTestEnable = + desc.graphics.depthStencil.depthTestEnable ? 1 : 0; + depthStencilStateInfo.back = + VulkanUtil::translateStencilState(desc.graphics.depthStencil.backFace); + depthStencilStateInfo.front = + VulkanUtil::translateStencilState(desc.graphics.depthStencil.frontFace); + depthStencilStateInfo.back.compareMask = desc.graphics.depthStencil.stencilReadMask; + depthStencilStateInfo.back.writeMask = desc.graphics.depthStencil.stencilWriteMask; + depthStencilStateInfo.front.compareMask = desc.graphics.depthStencil.stencilReadMask; + depthStencilStateInfo.front.writeMask = desc.graphics.depthStencil.stencilWriteMask; + depthStencilStateInfo.depthBoundsTestEnable = 0; // TODO: Currently unsupported + depthStencilStateInfo.depthCompareOp = + VulkanUtil::translateComparisonFunc(desc.graphics.depthStencil.depthFunc); + depthStencilStateInfo.depthWriteEnable = + desc.graphics.depthStencil.depthWriteEnable ? 1 : 0; + depthStencilStateInfo.stencilTestEnable = + desc.graphics.depthStencil.stencilEnable ? 1 : 0; + + VkGraphicsPipelineCreateInfo pipelineInfo = { + VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO}; + + auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr()); + if (programImpl->m_stageCreateInfos.getCount() == 0) + { + SLANG_RETURN_ON_FAIL(programImpl->compileShaders()); + } + + pipelineInfo.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO; + pipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount(); + pipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer(); + pipelineInfo.pVertexInputState = &vertexInputInfo; + pipelineInfo.pInputAssemblyState = &inputAssembly; + pipelineInfo.pViewportState = &viewportState; + pipelineInfo.pRasterizationState = &rasterizer; + pipelineInfo.pMultisampleState = &multisampling; + pipelineInfo.pColorBlendState = &colorBlending; + pipelineInfo.pDepthStencilState = &depthStencilStateInfo; + pipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout; + pipelineInfo.renderPass = framebufferLayoutImpl->m_renderPass; + pipelineInfo.subpass = 0; + pipelineInfo.basePipelineHandle = VK_NULL_HANDLE; + pipelineInfo.pDynamicState = &dynamicStateInfo; + + SLANG_VK_CHECK(m_device->m_api.vkCreateGraphicsPipelines( + m_device->m_device, pipelineCache, 1, &pipelineInfo, nullptr, &m_pipeline)); + + return SLANG_OK; + } + + Result createVKComputePipelineState() + { + auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr()); + if (programImpl->m_stageCreateInfos.getCount() == 0) + { + SLANG_RETURN_ON_FAIL(programImpl->compileShaders()); + } + + VkPipelineCache pipelineCache = VK_NULL_HANDLE; + + VkComputePipelineCreateInfo computePipelineInfo = { + VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO}; + computePipelineInfo.stage = programImpl->m_stageCreateInfos[0]; + computePipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout; + SLANG_VK_CHECK(m_device->m_api.vkCreateComputePipelines( + m_device->m_device, pipelineCache, 1, &computePipelineInfo, nullptr, &m_pipeline)); + return SLANG_OK; + } + + Result ensureAPIPipelineStateCreated(); + + virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(InteropHandle* outHandle) override + { + SLANG_RETURN_ON_FAIL(ensureAPIPipelineStateCreated()); + outHandle->api = InteropHandleAPI::Vulkan; + outHandle->handleValue = 0; + memcpy(&outHandle->handleValue, &m_pipeline, sizeof(m_pipeline)); + return SLANG_OK; + } + BreakableReference<VKDevice> m_device; VkPipeline m_pipeline = VK_NULL_HANDLE; @@ -1054,7 +1304,157 @@ public: RayTracingPipelineStateImpl(VKDevice* device) : PipelineStateImpl(device) - {}; + {} + + static inline VkPipelineCreateFlags translateRayTracingPipelineFlags(RayTracingPipelineFlags::Enum flags) + { + VkPipelineCreateFlags vkFlags = 0; + if (flags & RayTracingPipelineFlags::Enum::SkipTriangles) + vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR; + if (flags & RayTracingPipelineFlags::Enum::SkipProcedurals) + vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR; + + return vkFlags; + } + + uint32_t findEntryPointIndexByName( + const Dictionary<String, Index>& entryPointNameToIndex, const char* name) + { + if (!name) + return VK_SHADER_UNUSED_KHR; + + auto indexPtr = entryPointNameToIndex.TryGetValue(String(name)); + if (indexPtr) + return (uint32_t)*indexPtr; + // TODO: Error reporting? + return VK_SHADER_UNUSED_KHR; + } + + Result createVKRayTracingPipelineState() + { + auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr()); + if (programImpl->m_stageCreateInfos.getCount() == 0) + { + SLANG_RETURN_ON_FAIL(programImpl->compileShaders()); + } + + VkRayTracingPipelineCreateInfoKHR raytracingPipelineInfo = { + VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR}; + raytracingPipelineInfo.pNext = nullptr; + raytracingPipelineInfo.flags = translateRayTracingPipelineFlags(desc.rayTracing.flags); + + raytracingPipelineInfo.stageCount = + (uint32_t)programImpl->m_stageCreateInfos.getCount(); + raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer(); + + // Build Dictionary from entry point name to entry point index (stageCreateInfos index) + // for all hit shaders - findShaderIndexByName + Dictionary<String, Index> entryPointNameToIndex; + + List<VkRayTracingShaderGroupCreateInfoKHR> shaderGroupInfos; + for (uint32_t i = 0; i < raytracingPipelineInfo.stageCount; ++i) + { + auto stageCreateInfo = programImpl->m_stageCreateInfos[i]; + auto entryPointName = programImpl->m_entryPointNames[i]; + entryPointNameToIndex.Add(entryPointName, i); + if (stageCreateInfo.stage & + (VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | + VK_SHADER_STAGE_INTERSECTION_BIT_KHR)) + continue; + + VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { + VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR}; + shaderGroupInfo.pNext = nullptr; + shaderGroupInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; + shaderGroupInfo.generalShader = i; + shaderGroupInfo.closestHitShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.anyHitShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.intersectionShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr; + + // For groups with a single entry point, the group name is the entry point name. + auto shaderGroupName = entryPointName; + auto shaderGroupIndex = shaderGroupInfos.getCount(); + shaderGroupInfos.add(shaderGroupInfo); + shaderGroupNameToIndex.Add(shaderGroupName, shaderGroupIndex); + } + + for (int32_t i = 0; i < desc.rayTracing.hitGroupDescs.getCount(); ++i) + { + VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { + VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR}; + auto& groupDesc = desc.rayTracing.hitGroupDescs[i]; + + shaderGroupInfo.pNext = nullptr; + shaderGroupInfo.type = + (groupDesc.intersectionEntryPoint) + ? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR + : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR; + shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.closestHitShader = findEntryPointIndexByName( + entryPointNameToIndex, groupDesc.closestHitEntryPoint); + shaderGroupInfo.anyHitShader = + findEntryPointIndexByName(entryPointNameToIndex, groupDesc.anyHitEntryPoint); + shaderGroupInfo.intersectionShader = findEntryPointIndexByName( + entryPointNameToIndex, groupDesc.intersectionEntryPoint); + shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr; + + auto shaderGroupIndex = shaderGroupInfos.getCount(); + shaderGroupInfos.add(shaderGroupInfo); + shaderGroupNameToIndex.Add(String(groupDesc.hitGroupName), shaderGroupIndex); + } + + raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount(); + raytracingPipelineInfo.pGroups = shaderGroupInfos.getBuffer(); + + raytracingPipelineInfo.maxPipelineRayRecursionDepth = + (uint32_t)desc.rayTracing.maxRecursion; + + raytracingPipelineInfo.pLibraryInfo = nullptr; + raytracingPipelineInfo.pLibraryInterface = nullptr; + + raytracingPipelineInfo.pDynamicState = nullptr; + + raytracingPipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout; + raytracingPipelineInfo.basePipelineHandle = VK_NULL_HANDLE; + raytracingPipelineInfo.basePipelineIndex = 0; + + VkPipelineCache pipelineCache = VK_NULL_HANDLE; + SLANG_VK_CHECK(m_device->m_api.vkCreateRayTracingPipelinesKHR( + m_device->m_device, + VK_NULL_HANDLE, + pipelineCache, + 1, + &raytracingPipelineInfo, + nullptr, + &m_pipeline)); + shaderGroupCount = shaderGroupInfos.getCount(); + return SLANG_OK; + } + + Result ensureAPIPipelineStateCreated() + { + if (m_pipeline) + return SLANG_OK; + + switch (desc.type) + { + case PipelineType::RayTracing: + return createVKRayTracingPipelineState(); + default: + SLANG_UNREACHABLE("Unknown pipeline type."); + return SLANG_FAIL; + } + } + + virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(InteropHandle* outHandle) override + { + SLANG_RETURN_ON_FAIL(ensureAPIPipelineStateCreated()); + outHandle->api = InteropHandleAPI::Vulkan; + outHandle->handleValue = 0; + memcpy(&outHandle->handleValue, &m_pipeline, sizeof(m_pipeline)); + return SLANG_OK; + } }; // In order to bind shader parameters to the correct locations, we need to @@ -2447,6 +2847,62 @@ public: Array<ComPtr<ISlangBlob>, 8> m_codeBlobs; //< To keep storage of code in scope Array<VkShaderModule, 8> m_modules; RefPtr<RootShaderObjectLayout> m_rootObjectLayout; + + VkPipelineShaderStageCreateInfo compileEntryPoint( + const char* entryPointName, + ISlangBlob* code, + VkShaderStageFlagBits stage, + VkShaderModule& outShaderModule) + { + char const* dataBegin = (char const*)code->getBufferPointer(); + char const* dataEnd = (char const*)code->getBufferPointer() + code->getBufferSize(); + + // We need to make a copy of the code, since the Slang compiler + // will free the memory after a compile request is closed. + + VkShaderModuleCreateInfo moduleCreateInfo = { + VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO}; + moduleCreateInfo.pCode = (uint32_t*)code->getBufferPointer(); + moduleCreateInfo.codeSize = code->getBufferSize(); + + VkShaderModule module; + SLANG_VK_CHECK( + m_device->m_api.vkCreateShaderModule(m_device->m_device, &moduleCreateInfo, nullptr, &module)); + outShaderModule = module; + + VkPipelineShaderStageCreateInfo shaderStageCreateInfo = { + VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO}; + shaderStageCreateInfo.stage = stage; + + shaderStageCreateInfo.module = module; + shaderStageCreateInfo.pName = entryPointName; + + return shaderStageCreateInfo; + } + + virtual Result createShaderModule( + slang::EntryPointReflection* entryPointInfo, ComPtr<ISlangBlob> kernelCode) override + { + m_codeBlobs.add(kernelCode); + VkShaderModule shaderModule; + // HACK: our direct-spirv-emit path generates SPIRV that respects + // the original entry point name, while the glslang path always + // uses "main" as the name. We should introduce a compiler parameter + // to control the entry point naming behavior in SPIRV-direct path + // so we can remove the ad-hoc logic here. + auto realEntryPointName = entryPointInfo->getNameOverride(); + const char* spirvBinaryEntryPointName = "main"; + if (m_device->m_desc.slang.targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) + spirvBinaryEntryPointName = realEntryPointName; + m_stageCreateInfos.add(compileEntryPoint( + spirvBinaryEntryPointName, + kernelCode, + (VkShaderStageFlagBits)VulkanUtil::getShaderStage(entryPointInfo->getStage()), + shaderModule)); + m_entryPointNames.add(realEntryPointName); + m_modules.add(shaderModule); + return SLANG_OK; + } }; class CommandBufferImpl; @@ -2570,6 +3026,9 @@ public: m_currentPipeline, &m_commandBuffer->m_rootObject, newPipeline); PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr()); + newPipelineImpl->ensureAPIPipelineStateCreated(); + m_currentPipeline = newPipelineImpl; + bindRootShaderObjectImpl(pipelineBindPoint); auto pipelineBindPointId = getBindPointIndex(pipelineBindPoint); @@ -5736,9 +6195,10 @@ public: auto shaderTableImpl = (ShaderTableImpl*)shaderTable; auto alignedHandleSize = VulkanUtil::calcAligned(rtProps.shaderGroupHandleSize, rtProps.shaderGroupHandleAlignment); - ResourceCommandEncoder resourceCopyEncoder; - resourceCopyEncoder.init(m_commandBuffer); - auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(m_currentPipeline, m_commandBuffer->m_transientHeap, &resourceCopyEncoder); + auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer( + m_currentPipeline, + m_commandBuffer->m_transientHeap, + static_cast<ResourceCommandEncoder*>(this)); VkStridedDeviceAddressRegionKHR raygenSBT; raygenSBT.deviceAddress = shaderTableBuffer->getDeviceAddress(); @@ -6088,7 +6548,8 @@ public: (uint32_t)count, sizeof(uint64_t) * count, data, - sizeof(uint64_t), 0)); + sizeof(uint64_t), + VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT)); return SLANG_OK; } public: @@ -6496,13 +6957,6 @@ public: VkBool32 handleDebugMessage(VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT objType, uint64_t srcObject, size_t location, int32_t msgCode, const char* pLayerPrefix, const char* pMsg); - /// Note that the outShaderModule value should be cleaned up when no longer needed by caller - /// via vkShaderModuleDestroy() - VkPipelineShaderStageCreateInfo compileEntryPoint( - const char* entryPointName, - ISlangBlob* code, - VkShaderStageFlagBits stage, - VkShaderModule& outShaderModule); static VKAPI_ATTR VkBool32 VKAPI_CALL debugMessageCallback(VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT objType, uint64_t srcObject, size_t location, int32_t msgCode, const char* pLayerPrefix, const char* pMsg, void* pUserData); @@ -6798,35 +7252,6 @@ VkBool32 VKDevice::handleDebugMessage(VkDebugReportFlagsEXT flags, VkDebugReport return ((VKDevice*)pUserData)->handleDebugMessage(flags, objType, srcObject, location, msgCode, pLayerPrefix, pMsg); } -VkPipelineShaderStageCreateInfo VKDevice::compileEntryPoint( - const char* entryPointName, - ISlangBlob* code, - VkShaderStageFlagBits stage, - VkShaderModule& outShaderModule) -{ - char const* dataBegin = (char const*) code->getBufferPointer(); - char const* dataEnd = (char const*)code->getBufferPointer() + code->getBufferSize(); - - // We need to make a copy of the code, since the Slang compiler - // will free the memory after a compile request is closed. - - VkShaderModuleCreateInfo moduleCreateInfo = { VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO }; - moduleCreateInfo.pCode = (uint32_t*)code->getBufferPointer(); - moduleCreateInfo.codeSize = code->getBufferSize(); - - VkShaderModule module; - SLANG_VK_CHECK(m_api.vkCreateShaderModule(m_device, &moduleCreateInfo, nullptr, &module)); - outShaderModule = module; - - VkPipelineShaderStageCreateInfo shaderStageCreateInfo = { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO }; - shaderStageCreateInfo.stage = stage; - - shaderStageCreateInfo.module = module; - shaderStageCreateInfo.pName = entryPointName; - - return shaderStageCreateInfo; -} - // !!!!!!!!!!!!!!!!!!!!!!!!!!!! Renderer interface !!!!!!!!!!!!!!!!!!!!!!!!!! Result VKDevice::getNativeDeviceHandles(InteropHandles* outHandles) @@ -8424,131 +8849,16 @@ Result VKDevice::createBufferFromNativeHandle(InteropHandle handle, const IBuffe return SLANG_OK; } -VkFilter translateFilterMode(TextureFilteringMode mode) -{ - switch (mode) - { - default: - return VkFilter(0); - -#define CASE(SRC, DST) \ - case TextureFilteringMode::SRC: return VK_FILTER_##DST - - CASE(Point, NEAREST); - CASE(Linear, LINEAR); - -#undef CASE - } -} - -VkSamplerMipmapMode translateMipFilterMode(TextureFilteringMode mode) -{ - switch (mode) - { - default: - return VkSamplerMipmapMode(0); - -#define CASE(SRC, DST) \ - case TextureFilteringMode::SRC: return VK_SAMPLER_MIPMAP_MODE_##DST - - CASE(Point, NEAREST); - CASE(Linear, LINEAR); - -#undef CASE - } -} - -VkSamplerAddressMode translateAddressingMode(TextureAddressingMode mode) -{ - switch (mode) - { - default: - return VkSamplerAddressMode(0); - -#define CASE(SRC, DST) \ - case TextureAddressingMode::SRC: return VK_SAMPLER_ADDRESS_MODE_##DST - - CASE(Wrap, REPEAT); - CASE(ClampToEdge, CLAMP_TO_EDGE); - CASE(ClampToBorder, CLAMP_TO_BORDER); - CASE(MirrorRepeat, MIRRORED_REPEAT); - CASE(MirrorOnce, MIRROR_CLAMP_TO_EDGE); - -#undef CASE - } -} - -static VkCompareOp translateComparisonFunc(ComparisonFunc func) -{ - switch (func) - { - default: - // TODO: need to report failures - return VK_COMPARE_OP_ALWAYS; - -#define CASE(FROM, TO) \ - case ComparisonFunc::FROM: return VK_COMPARE_OP_##TO - - CASE(Never, NEVER); - CASE(Less, LESS); - CASE(Equal, EQUAL); - CASE(LessEqual, LESS_OR_EQUAL); - CASE(Greater, GREATER); - CASE(NotEqual, NOT_EQUAL); - CASE(GreaterEqual, GREATER_OR_EQUAL); - CASE(Always, ALWAYS); -#undef CASE - } -} - -static VkStencilOp translateStencilOp(StencilOp op) -{ - switch (op) - { - case StencilOp::DecrementSaturate: - return VK_STENCIL_OP_DECREMENT_AND_CLAMP; - case StencilOp::DecrementWrap: - return VK_STENCIL_OP_DECREMENT_AND_WRAP; - case StencilOp::IncrementSaturate: - return VK_STENCIL_OP_INCREMENT_AND_CLAMP; - case StencilOp::IncrementWrap: - return VK_STENCIL_OP_INCREMENT_AND_WRAP; - case StencilOp::Invert: - return VK_STENCIL_OP_INVERT; - case StencilOp::Keep: - return VK_STENCIL_OP_KEEP; - case StencilOp::Replace: - return VK_STENCIL_OP_REPLACE; - case StencilOp::Zero: - return VK_STENCIL_OP_ZERO; - default: - return VK_STENCIL_OP_KEEP; - } -} - -static VkStencilOpState translateStencilState(DepthStencilOpDesc desc) -{ - VkStencilOpState rs; - rs.compareMask = 0xFF; - rs.compareOp = translateComparisonFunc(desc.stencilFunc); - rs.depthFailOp = translateStencilOp(desc.stencilDepthFailOp); - rs.failOp = translateStencilOp(desc.stencilFailOp); - rs.passOp = translateStencilOp(desc.stencilPassOp); - rs.reference = 0; - rs.writeMask = 0xFF; - return rs; -} - Result VKDevice::createSamplerState(ISamplerState::Desc const& desc, ISamplerState** outSampler) { VkSamplerCreateInfo samplerInfo = { VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO }; - samplerInfo.magFilter = translateFilterMode(desc.minFilter); - samplerInfo.minFilter = translateFilterMode(desc.magFilter); + samplerInfo.magFilter = VulkanUtil::translateFilterMode(desc.minFilter); + samplerInfo.minFilter = VulkanUtil::translateFilterMode(desc.magFilter); - samplerInfo.addressModeU = translateAddressingMode(desc.addressU); - samplerInfo.addressModeV = translateAddressingMode(desc.addressV); - samplerInfo.addressModeW = translateAddressingMode(desc.addressW); + samplerInfo.addressModeU = VulkanUtil::translateAddressingMode(desc.addressU); + samplerInfo.addressModeV = VulkanUtil::translateAddressingMode(desc.addressV); + samplerInfo.addressModeW = VulkanUtil::translateAddressingMode(desc.addressW); samplerInfo.anisotropyEnable = desc.maxAnisotropy > 1; samplerInfo.maxAnisotropy = (float) desc.maxAnisotropy; @@ -8558,8 +8868,8 @@ Result VKDevice::createSamplerState(ISamplerState::Desc const& desc, ISamplerSta samplerInfo.unnormalizedCoordinates = VK_FALSE; samplerInfo.compareEnable = desc.reductionOp == TextureReductionOp::Comparison; - samplerInfo.compareOp = translateComparisonFunc(desc.comparisonFunc); - samplerInfo.mipmapMode = translateMipFilterMode(desc.mipFilter); + samplerInfo.compareOp = VulkanUtil::translateComparisonFunc(desc.comparisonFunc); + samplerInfo.mipmapMode = VulkanUtil::translateMipFilterMode(desc.mipFilter); samplerInfo.minLod = Math::Max(0.0f, desc.minLOD); samplerInfo.maxLod = Math::Clamp(desc.maxLOD, samplerInfo.minLod, VK_LOD_CLAMP_NONE); @@ -8953,73 +9263,7 @@ Result VKDevice::createProgram( shaderProgram->linkedProgram, shaderProgram->linkedProgram->getLayout(), shaderProgram->m_rootObjectLayout.writeRef()); - if (shaderProgram->isSpecializable()) - { - // For a specializable program, we don't invoke any actual slang compilation yet. - returnComPtr(outProgram, shaderProgram); - return SLANG_OK; - } - // For a fully specialized program, create `VkShaderModule`s for each shader stage. - auto compileShader = [&](slang::EntryPointReflection* entryPointInfo, - slang::IComponentType* component, - SlangInt entryPointIndex) - { - auto stage = entryPointInfo->getStage(); - ComPtr<ISlangBlob> kernelCode; - ComPtr<ISlangBlob> diagnostics; - auto compileResult = component->getEntryPointCode( - entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef()); - if (diagnostics) - { - getDebugCallback()->handleMessage( - compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error, - DebugMessageSource::Slang, - (char*)diagnostics->getBufferPointer()); - if (outDiagnosticBlob) - returnComPtr(outDiagnosticBlob, diagnostics); - } - SLANG_RETURN_ON_FAIL(compileResult); - shaderProgram->m_codeBlobs.add(kernelCode); - VkShaderModule shaderModule; - // HACK: our direct-spirv-emit path generates SPIRV that respects - // the original entry point name, while the glslang path always - // uses "main" as the name. We should introduce a compiler parameter - // to control the entry point naming behavior in SPIRV-direct path - // so we can remove the ad-hoc logic here. - auto realEntryPointName = entryPointInfo->getNameOverride(); - const char* spirvBinaryEntryPointName = "main"; - if (m_desc.slang.targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) - spirvBinaryEntryPointName = realEntryPointName; - shaderProgram->m_stageCreateInfos.add(compileEntryPoint( - spirvBinaryEntryPointName, - kernelCode, - (VkShaderStageFlagBits)VulkanUtil::getShaderStage(stage), - shaderModule)); - shaderProgram->m_entryPointNames.add(realEntryPointName); - shaderProgram->m_modules.add(shaderModule); - return SLANG_OK; - }; - if (shaderProgram->linkedEntryPoints.getCount() == 0) - { - // If the user does not explicitly specify entry point components, find them from - // `linkedEntryPoints`. - auto programReflection = shaderProgram->linkedProgram->getLayout(); - for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++) - { - auto entryPointInfo = programReflection->getEntryPointByIndex(i); - SLANG_RETURN_ON_FAIL(compileShader(entryPointInfo, shaderProgram->linkedProgram, (SlangInt)i)); - } - } - else - { - // If the user specifies entry point components via the separated entry point array, compile - // code from there. - for (auto& entryPoint : shaderProgram->linkedEntryPoints) - { - SLANG_RETURN_ON_FAIL(compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0)); - } - } returnComPtr(outProgram, shaderProgram); return SLANG_OK; } @@ -9075,356 +9319,22 @@ Result VKDevice::createShaderTable(const IShaderTable::Desc& desc, IShaderTable* return SLANG_OK; } -VkSampleCountFlagBits translateSampleCount(uint32_t sampleCount) -{ - switch (sampleCount) - { - case 1: return VK_SAMPLE_COUNT_1_BIT; - case 2: return VK_SAMPLE_COUNT_2_BIT; - case 4: return VK_SAMPLE_COUNT_4_BIT; - case 8: return VK_SAMPLE_COUNT_8_BIT; - case 16: return VK_SAMPLE_COUNT_16_BIT; - case 32: return VK_SAMPLE_COUNT_32_BIT; - case 64: return VK_SAMPLE_COUNT_64_BIT; - default: - assert(!"Unsupported sample count"); - return VK_SAMPLE_COUNT_1_BIT; - } -} - -VkCullModeFlags translateCullMode(CullMode cullMode) -{ - switch (cullMode) - { - case CullMode::None: return VK_CULL_MODE_NONE; - case CullMode::Front: return VK_CULL_MODE_FRONT_BIT; - case CullMode::Back: return VK_CULL_MODE_BACK_BIT; - default: - assert(!"Unsupported cull mode"); - return VK_CULL_MODE_NONE; - } -} - -VkFrontFace translateFrontFaceMode(FrontFaceMode frontFaceMode) -{ - switch (frontFaceMode) - { - case FrontFaceMode::CounterClockwise: - return VK_FRONT_FACE_COUNTER_CLOCKWISE; - case FrontFaceMode::Clockwise: - return VK_FRONT_FACE_CLOCKWISE; - default: - assert(!"Unsupported front face mode"); - return VK_FRONT_FACE_CLOCKWISE; - } -} - -VkPolygonMode translateFillMode(FillMode fillMode) -{ - switch (fillMode) - { - case FillMode::Solid: return VK_POLYGON_MODE_FILL; - case FillMode::Wireframe: return VK_POLYGON_MODE_LINE; - default: - assert(!"Unsupported fill mode"); - return VK_POLYGON_MODE_FILL; - } -} - -VkBlendFactor translateBlendFactor(BlendFactor blendFactor) -{ - switch (blendFactor) - { - case BlendFactor::Zero: return VK_BLEND_FACTOR_ZERO; - case BlendFactor::One: return VK_BLEND_FACTOR_ONE; - case BlendFactor::SrcColor: return VK_BLEND_FACTOR_SRC_COLOR; - case BlendFactor::InvSrcColor: return VK_BLEND_FACTOR_ONE_MINUS_SRC_COLOR; - case BlendFactor::SrcAlpha: return VK_BLEND_FACTOR_SRC_ALPHA; - case BlendFactor::InvSrcAlpha: return VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA; - case BlendFactor::DestAlpha: return VK_BLEND_FACTOR_DST_ALPHA; - case BlendFactor::InvDestAlpha: return VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA; - case BlendFactor::DestColor: return VK_BLEND_FACTOR_DST_COLOR; - case BlendFactor::InvDestColor: return VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA; - case BlendFactor::SrcAlphaSaturate: return VK_BLEND_FACTOR_SRC_ALPHA_SATURATE; - case BlendFactor::BlendColor: return VK_BLEND_FACTOR_CONSTANT_COLOR; - case BlendFactor::InvBlendColor: return VK_BLEND_FACTOR_ONE_MINUS_CONSTANT_COLOR; - case BlendFactor::SecondarySrcColor: return VK_BLEND_FACTOR_SRC1_COLOR; - case BlendFactor::InvSecondarySrcColor: return VK_BLEND_FACTOR_ONE_MINUS_SRC1_COLOR; - case BlendFactor::SecondarySrcAlpha: return VK_BLEND_FACTOR_SRC1_ALPHA; - case BlendFactor::InvSecondarySrcAlpha: return VK_BLEND_FACTOR_ONE_MINUS_SRC1_ALPHA; - - default: - assert(!"Unsupported blend factor"); - return VK_BLEND_FACTOR_ONE; - } -} - -VkBlendOp translateBlendOp(BlendOp op) -{ - switch (op) - { - case BlendOp::Add: return VK_BLEND_OP_ADD; - case BlendOp::Subtract: return VK_BLEND_OP_SUBTRACT; - case BlendOp::ReverseSubtract: return VK_BLEND_OP_REVERSE_SUBTRACT; - case BlendOp::Min: return VK_BLEND_OP_MIN; - case BlendOp::Max: return VK_BLEND_OP_MAX; - default: - assert(!"Unsupported blend op"); - return VK_BLEND_OP_ADD; - } -} - -VkPrimitiveTopology translatePrimitiveTypeToListTopology(PrimitiveType primitiveType) -{ - switch (primitiveType) - { - case PrimitiveType::Point: return VK_PRIMITIVE_TOPOLOGY_POINT_LIST; - case PrimitiveType::Line: return VK_PRIMITIVE_TOPOLOGY_LINE_LIST; - case PrimitiveType::Triangle: return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST; - case PrimitiveType::Patch: return VK_PRIMITIVE_TOPOLOGY_PATCH_LIST; - default: - assert(!"unknown topology type."); - return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST; - } -} - Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState) { GraphicsPipelineStateDesc desc = inDesc; - auto programImpl = static_cast<ShaderProgramImpl*>(desc.program); - - if (!programImpl->m_rootObjectLayout->m_pipelineLayout) - { - RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this); - pipelineStateImpl->init(desc); - pipelineStateImpl->establishStrongDeviceReference(); - m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); - returnComPtr(outState, pipelineStateImpl); - return SLANG_OK; - } - - VkPipelineCache pipelineCache = VK_NULL_HANDLE; - - auto inputLayoutImpl = (InputLayoutImpl*) desc.inputLayout; - - // VertexBuffer/s - VkPipelineVertexInputStateCreateInfo vertexInputInfo = { VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO }; - vertexInputInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO; - vertexInputInfo.vertexBindingDescriptionCount = 0; - vertexInputInfo.vertexAttributeDescriptionCount = 0; - - if (inputLayoutImpl) - { - const auto& srcAttributeDescs = inputLayoutImpl->m_attributeDescs; - const auto& srcStreamDescs = inputLayoutImpl->m_streamDescs; - - vertexInputInfo.vertexBindingDescriptionCount = (uint32_t)srcStreamDescs.getCount(); - vertexInputInfo.pVertexBindingDescriptions = srcStreamDescs.getBuffer(); - - vertexInputInfo.vertexAttributeDescriptionCount = (uint32_t)srcAttributeDescs.getCount(); - vertexInputInfo.pVertexAttributeDescriptions = srcAttributeDescs.getBuffer(); - } - - VkPipelineInputAssemblyStateCreateInfo inputAssembly = {}; - inputAssembly.sType = VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO; - // All other forms of primitive toplogies are specified via dynamic state. - inputAssembly.topology = translatePrimitiveTypeToListTopology(inDesc.primitiveType); - inputAssembly.primitiveRestartEnable = VK_FALSE; // TODO: Currently unsupported - - VkViewport viewport = {}; - viewport.x = 0.0f; - viewport.y = 0.0f; - // We are using dynamic viewport and scissor state. - // Here we specify an arbitrary size, actual viewport will be set at `beginRenderPass` time. - viewport.width = 16.0f; - viewport.height = 16.0f; - viewport.minDepth = 0.0f; - viewport.maxDepth = 1.0f; - - VkRect2D scissor = {}; - scissor.offset = { 0, 0 }; - scissor.extent = { uint32_t(16), uint32_t(16) }; - - VkPipelineViewportStateCreateInfo viewportState = {}; - viewportState.sType = VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO; - viewportState.viewportCount = 1; - viewportState.pViewports = &viewport; - viewportState.scissorCount = 1; - viewportState.pScissors = &scissor; - - auto rasterizerDesc = desc.rasterizer; - - VkPipelineRasterizationStateCreateInfo rasterizer = {}; - rasterizer.sType = VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO; - rasterizer.depthClampEnable = VK_TRUE; // TODO: Depth clipping and clamping are different between Vk and D3D12 - rasterizer.rasterizerDiscardEnable = VK_FALSE; // TODO: Currently unsupported - rasterizer.polygonMode = translateFillMode(rasterizerDesc.fillMode); - rasterizer.cullMode = translateCullMode(rasterizerDesc.cullMode); - rasterizer.frontFace = translateFrontFaceMode(rasterizerDesc.frontFace); - rasterizer.depthBiasEnable = (rasterizerDesc.depthBias == 0) ? VK_FALSE : VK_TRUE; - rasterizer.depthBiasConstantFactor = (float)rasterizerDesc.depthBias; - rasterizer.depthBiasClamp = rasterizerDesc.depthBiasClamp; - rasterizer.depthBiasSlopeFactor = rasterizerDesc.slopeScaledDepthBias; - rasterizer.lineWidth = 1.0f; // TODO: Currently unsupported - - VkPipelineRasterizationConservativeStateCreateInfoEXT conservativeRasterInfo = {}; - conservativeRasterInfo.sType = - VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_CONSERVATIVE_STATE_CREATE_INFO_EXT; - conservativeRasterInfo.conservativeRasterizationMode = - VK_CONSERVATIVE_RASTERIZATION_MODE_OVERESTIMATE_EXT; - if (desc.rasterizer.enableConservativeRasterization) - { - rasterizer.pNext = &conservativeRasterInfo; - } - - auto framebufferLayoutImpl = static_cast<FramebufferLayoutImpl*>(desc.framebufferLayout); - auto forcedSampleCount = rasterizerDesc.forcedSampleCount; - auto blendDesc = desc.blend; - - VkPipelineMultisampleStateCreateInfo multisampling = {}; - multisampling.sType = VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO; - multisampling.rasterizationSamples = - (forcedSampleCount == 0) ? framebufferLayoutImpl->m_sampleCount : translateSampleCount(forcedSampleCount); - multisampling.sampleShadingEnable = VK_FALSE; // TODO: Should check if fragment shader needs this - // TODO: Sample mask is dynamic in D3D12 but PSO state in Vulkan - multisampling.alphaToCoverageEnable = blendDesc.alphaToCoverageEnable; - multisampling.alphaToOneEnable = VK_FALSE; - - auto targetCount = Math::Min(framebufferLayoutImpl->m_renderTargetCount, (uint32_t)blendDesc.targetCount); - List<VkPipelineColorBlendAttachmentState> colorBlendAttachments; - - // Regardless of whether blending is enabled, Vulkan always applies the color write mask operation, - // so if there is no blending then we need to add an attachment that defines the color write mask - // to ensure colors are actually written. - if (targetCount == 0) - { - colorBlendAttachments.setCount(1); - auto& vkBlendDesc = colorBlendAttachments[0]; - memset(&vkBlendDesc, 0, sizeof(vkBlendDesc)); - vkBlendDesc.blendEnable = VK_FALSE; - vkBlendDesc.srcColorBlendFactor = VK_BLEND_FACTOR_ONE; - vkBlendDesc.dstColorBlendFactor = VK_BLEND_FACTOR_ONE; - vkBlendDesc.colorBlendOp = VK_BLEND_OP_ADD; - vkBlendDesc.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE; - vkBlendDesc.dstAlphaBlendFactor = VK_BLEND_FACTOR_ONE; - vkBlendDesc.alphaBlendOp = VK_BLEND_OP_ADD; - vkBlendDesc.colorWriteMask = (VkColorComponentFlags)RenderTargetWriteMask::EnableAll; - } - else - { - colorBlendAttachments.setCount(targetCount); - for (UInt i = 0; i < targetCount; ++i) - { - auto& gfxBlendDesc = blendDesc.targets[i]; - auto& vkBlendDesc = colorBlendAttachments[i]; - - vkBlendDesc.blendEnable = gfxBlendDesc.enableBlend; - vkBlendDesc.srcColorBlendFactor = translateBlendFactor(gfxBlendDesc.color.srcFactor); - vkBlendDesc.dstColorBlendFactor = translateBlendFactor(gfxBlendDesc.color.dstFactor); - vkBlendDesc.colorBlendOp = translateBlendOp(gfxBlendDesc.color.op); - vkBlendDesc.srcAlphaBlendFactor = translateBlendFactor(gfxBlendDesc.alpha.srcFactor); - vkBlendDesc.dstAlphaBlendFactor = translateBlendFactor(gfxBlendDesc.alpha.dstFactor); - vkBlendDesc.alphaBlendOp = translateBlendOp(gfxBlendDesc.alpha.op); - vkBlendDesc.colorWriteMask = (VkColorComponentFlags)gfxBlendDesc.writeMask; - } - } - - VkPipelineColorBlendStateCreateInfo colorBlending = {}; - colorBlending.sType = VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO; - colorBlending.logicOpEnable = VK_FALSE; // TODO: D3D12 has per attachment logic op (and both have way more than one op) - colorBlending.logicOp = VK_LOGIC_OP_COPY; - colorBlending.attachmentCount = (uint32_t)colorBlendAttachments.getCount(); - colorBlending.pAttachments = colorBlendAttachments.getBuffer(); - colorBlending.blendConstants[0] = 0.0f; - colorBlending.blendConstants[1] = 0.0f; - colorBlending.blendConstants[2] = 0.0f; - colorBlending.blendConstants[3] = 0.0f; - - Array<VkDynamicState, 8> dynamicStates; - dynamicStates.add(VK_DYNAMIC_STATE_VIEWPORT); - dynamicStates.add(VK_DYNAMIC_STATE_SCISSOR); - dynamicStates.add(VK_DYNAMIC_STATE_STENCIL_REFERENCE); - dynamicStates.add(VK_DYNAMIC_STATE_BLEND_CONSTANTS); - if (m_api.m_extendedFeatures.extendedDynamicStateFeatures.extendedDynamicState) - { - dynamicStates.add(VK_DYNAMIC_STATE_PRIMITIVE_TOPOLOGY_EXT); - } - VkPipelineDynamicStateCreateInfo dynamicStateInfo = {}; - dynamicStateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO; - dynamicStateInfo.dynamicStateCount = (uint32_t)dynamicStates.getCount(); - dynamicStateInfo.pDynamicStates = dynamicStates.getBuffer(); - - VkPipelineDepthStencilStateCreateInfo depthStencilStateInfo = {}; - depthStencilStateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO; - depthStencilStateInfo.depthTestEnable = inDesc.depthStencil.depthTestEnable ? 1 : 0; - depthStencilStateInfo.back = translateStencilState(inDesc.depthStencil.backFace); - depthStencilStateInfo.front = translateStencilState(inDesc.depthStencil.frontFace); - depthStencilStateInfo.back.compareMask = inDesc.depthStencil.stencilReadMask; - depthStencilStateInfo.back.writeMask = inDesc.depthStencil.stencilWriteMask; - depthStencilStateInfo.front.compareMask = inDesc.depthStencil.stencilReadMask; - depthStencilStateInfo.front.writeMask = inDesc.depthStencil.stencilWriteMask; - depthStencilStateInfo.depthBoundsTestEnable = 0; // TODO: Currently unsupported - depthStencilStateInfo.depthCompareOp = translateComparisonFunc(inDesc.depthStencil.depthFunc); - depthStencilStateInfo.depthWriteEnable = inDesc.depthStencil.depthWriteEnable ? 1 : 0; - depthStencilStateInfo.stencilTestEnable = inDesc.depthStencil.stencilEnable ? 1 : 0; - - VkGraphicsPipelineCreateInfo pipelineInfo = { VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO }; - - pipelineInfo.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO; - pipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount(); - pipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer(); - pipelineInfo.pVertexInputState = &vertexInputInfo; - pipelineInfo.pInputAssemblyState = &inputAssembly; - pipelineInfo.pViewportState = &viewportState; - pipelineInfo.pRasterizationState = &rasterizer; - pipelineInfo.pMultisampleState = &multisampling; - pipelineInfo.pColorBlendState = &colorBlending; - pipelineInfo.pDepthStencilState = &depthStencilStateInfo; - pipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout; - pipelineInfo.renderPass = framebufferLayoutImpl->m_renderPass; - pipelineInfo.subpass = 0; - pipelineInfo.basePipelineHandle = VK_NULL_HANDLE; - pipelineInfo.pDynamicState = &dynamicStateInfo; - - VkPipeline pipeline = VK_NULL_HANDLE; - SLANG_VK_CHECK(m_api.vkCreateGraphicsPipelines(m_device, pipelineCache, 1, &pipelineInfo, nullptr, &pipeline)); - RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this); - pipelineStateImpl->m_pipeline = pipeline; pipelineStateImpl->init(desc); pipelineStateImpl->establishStrongDeviceReference(); m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; } Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDesc, IPipelineState** outState) { ComputePipelineStateDesc desc = inDesc; - auto programImpl = static_cast<ShaderProgramImpl*>(desc.program); - if (!programImpl->m_rootObjectLayout->m_pipelineLayout) - { - RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this); - pipelineStateImpl->init(desc); - m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); - pipelineStateImpl->establishStrongDeviceReference(); - returnComPtr(outState, pipelineStateImpl); - return SLANG_OK; - } - - VkPipelineCache pipelineCache = VK_NULL_HANDLE; - - VkPipeline pipeline = VK_NULL_HANDLE; - - VkComputePipelineCreateInfo computePipelineInfo = { - VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO }; - computePipelineInfo.stage = programImpl->m_stageCreateInfos[0]; - computePipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout; - SLANG_VK_CHECK(m_api.vkCreateComputePipelines( - m_device, pipelineCache, 1, &computePipelineInfo, nullptr, &pipeline)); - RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this); - pipelineStateImpl->m_pipeline = pipeline; pipelineStateImpl->init(desc); m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); pipelineStateImpl->establishStrongDeviceReference(); @@ -9432,122 +9342,12 @@ Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe return SLANG_OK; } -VkPipelineCreateFlags translateFlags(RayTracingPipelineFlags::Enum flags) -{ - VkPipelineCreateFlags vkFlags = 0; - if (flags & RayTracingPipelineFlags::Enum::SkipTriangles) - vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR; - if (flags & RayTracingPipelineFlags::Enum::SkipProcedurals) - vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR; - - return vkFlags; -} - -uint32_t findEntryPointIndexByName(const Dictionary<String, Index>& entryPointNameToIndex, const char* name) -{ - if (!name) return VK_SHADER_UNUSED_KHR; - - auto indexPtr = entryPointNameToIndex.TryGetValue(String(name)); - if (indexPtr) - return (uint32_t)*indexPtr; - // TODO: Error reporting? - return VK_SHADER_UNUSED_KHR; -} - Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc& desc, IPipelineState** outState) { - auto programImpl = static_cast<ShaderProgramImpl*>(desc.program); - if (!programImpl->m_rootObjectLayout->m_pipelineLayout) - { - RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this); - pipelineStateImpl->init(desc); - m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); - pipelineStateImpl->establishStrongDeviceReference(); - returnComPtr(outState, pipelineStateImpl); - return SLANG_OK; - } - - VkRayTracingPipelineCreateInfoKHR raytracingPipelineInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR }; - raytracingPipelineInfo.pNext = nullptr; - raytracingPipelineInfo.flags = translateFlags(desc.flags); - - raytracingPipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount(); - raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer(); - - // Build Dictionary from group name to group index - Dictionary<String, Index> shaderGroupNameToIndex; - // Build Dictionary from entry point name to entry point index (stageCreateInfos index) for all hit shaders - findShaderIndexByName - Dictionary<String, Index> entryPointNameToIndex; - - List<VkRayTracingShaderGroupCreateInfoKHR> shaderGroupInfos; - for (uint32_t i = 0; i < raytracingPipelineInfo.stageCount; ++i) - { - auto stageCreateInfo = programImpl->m_stageCreateInfos[i]; - auto entryPointName = programImpl->m_entryPointNames[i]; - entryPointNameToIndex.Add(entryPointName, i); - if (stageCreateInfo.stage & (VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR)) - continue; - - VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR }; - shaderGroupInfo.pNext = nullptr; - shaderGroupInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR; - shaderGroupInfo.generalShader = i; - shaderGroupInfo.closestHitShader = VK_SHADER_UNUSED_KHR; - shaderGroupInfo.anyHitShader = VK_SHADER_UNUSED_KHR; - shaderGroupInfo.intersectionShader = VK_SHADER_UNUSED_KHR; - shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr; - - // For groups with a single entry point, the group name is the entry point name. - auto shaderGroupName = entryPointName; - auto shaderGroupIndex = shaderGroupInfos.getCount(); - shaderGroupInfos.add(shaderGroupInfo); - shaderGroupNameToIndex.Add(shaderGroupName, shaderGroupIndex); - } - - for (int32_t i = 0; i < desc.hitGroupCount; ++i) - { - VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR }; - auto& groupDesc = desc.hitGroups[i]; - - shaderGroupInfo.pNext = nullptr; - shaderGroupInfo.type = (groupDesc.intersectionEntryPoint) - ? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR; - shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR; - shaderGroupInfo.closestHitShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.closestHitEntryPoint); - shaderGroupInfo.anyHitShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.anyHitEntryPoint); - shaderGroupInfo.intersectionShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.intersectionEntryPoint); - shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr; - - auto shaderGroupIndex = shaderGroupInfos.getCount(); - shaderGroupInfos.add(shaderGroupInfo); - shaderGroupNameToIndex.Add(String(groupDesc.hitGroupName), shaderGroupIndex); - } - - raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount(); - raytracingPipelineInfo.pGroups = shaderGroupInfos.getBuffer(); - - raytracingPipelineInfo.maxPipelineRayRecursionDepth = (uint32_t)desc.maxRecursion; - - raytracingPipelineInfo.pLibraryInfo = nullptr; - raytracingPipelineInfo.pLibraryInterface = nullptr; - - raytracingPipelineInfo.pDynamicState = nullptr; - - raytracingPipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout; - raytracingPipelineInfo.basePipelineHandle = VK_NULL_HANDLE; - raytracingPipelineInfo.basePipelineIndex = 0; - - VkPipelineCache pipelineCache = VK_NULL_HANDLE; - VkPipeline pipeline = VK_NULL_HANDLE; - SLANG_VK_CHECK(m_api.vkCreateRayTracingPipelinesKHR(m_device, VK_NULL_HANDLE, pipelineCache, 1, &raytracingPipelineInfo, nullptr, &pipeline)); - RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl(this); - pipelineStateImpl->m_pipeline = pipeline; pipelineStateImpl->init(desc); - pipelineStateImpl->establishStrongDeviceReference(); - pipelineStateImpl->shaderGroupNameToIndex = shaderGroupNameToIndex; - pipelineStateImpl->shaderGroupCount = shaderGroupInfos.getCount(); m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + pipelineStateImpl->establishStrongDeviceReference(); returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } @@ -9592,4 +9392,23 @@ Result VKDevice::waitForFences( return result == VK_SUCCESS ? SLANG_OK : SLANG_FAIL; } +Result VKDevice::PipelineStateImpl::ensureAPIPipelineStateCreated() +{ + if (m_pipeline) + return SLANG_OK; + + switch (desc.type) + { + case PipelineType::Compute: + return createVKComputePipelineState(); + case PipelineType::Graphics: + return createVKGraphicsPipelineState(); + case PipelineType::RayTracing: + return static_cast<RayTracingPipelineStateImpl*>(this)->createVKRayTracingPipelineState(); + default: + SLANG_UNREACHABLE("Unknown pipeline type."); + return SLANG_FAIL; + } +} + } // renderer_test diff --git a/tools/gfx/vulkan/vk-util.cpp b/tools/gfx/vulkan/vk-util.cpp index 974a61422..8e2a187a3 100644 --- a/tools/gfx/vulkan/vk-util.cpp +++ b/tools/gfx/vulkan/vk-util.cpp @@ -183,6 +183,277 @@ VkImageLayout VulkanUtil::getImageLayoutFromState(ResourceState state) return VkImageLayout(); } +VkSampleCountFlagBits VulkanUtil::translateSampleCount(uint32_t sampleCount) +{ + switch (sampleCount) + { + case 1: + return VK_SAMPLE_COUNT_1_BIT; + case 2: + return VK_SAMPLE_COUNT_2_BIT; + case 4: + return VK_SAMPLE_COUNT_4_BIT; + case 8: + return VK_SAMPLE_COUNT_8_BIT; + case 16: + return VK_SAMPLE_COUNT_16_BIT; + case 32: + return VK_SAMPLE_COUNT_32_BIT; + case 64: + return VK_SAMPLE_COUNT_64_BIT; + default: + assert(!"Unsupported sample count"); + return VK_SAMPLE_COUNT_1_BIT; + } +} + +VkCullModeFlags VulkanUtil::translateCullMode(CullMode cullMode) +{ + switch (cullMode) + { + case CullMode::None: + return VK_CULL_MODE_NONE; + case CullMode::Front: + return VK_CULL_MODE_FRONT_BIT; + case CullMode::Back: + return VK_CULL_MODE_BACK_BIT; + default: + assert(!"Unsupported cull mode"); + return VK_CULL_MODE_NONE; + } +} + +VkFrontFace VulkanUtil::translateFrontFaceMode(FrontFaceMode frontFaceMode) +{ + switch (frontFaceMode) + { + case FrontFaceMode::CounterClockwise: + return VK_FRONT_FACE_COUNTER_CLOCKWISE; + case FrontFaceMode::Clockwise: + return VK_FRONT_FACE_CLOCKWISE; + default: + assert(!"Unsupported front face mode"); + return VK_FRONT_FACE_CLOCKWISE; + } +} + +VkPolygonMode VulkanUtil::translateFillMode(FillMode fillMode) +{ + switch (fillMode) + { + case FillMode::Solid: + return VK_POLYGON_MODE_FILL; + case FillMode::Wireframe: + return VK_POLYGON_MODE_LINE; + default: + assert(!"Unsupported fill mode"); + return VK_POLYGON_MODE_FILL; + } +} + +VkBlendFactor VulkanUtil::translateBlendFactor(BlendFactor blendFactor) +{ + switch (blendFactor) + { + case BlendFactor::Zero: + return VK_BLEND_FACTOR_ZERO; + case BlendFactor::One: + return VK_BLEND_FACTOR_ONE; + case BlendFactor::SrcColor: + return VK_BLEND_FACTOR_SRC_COLOR; + case BlendFactor::InvSrcColor: + return VK_BLEND_FACTOR_ONE_MINUS_SRC_COLOR; + case BlendFactor::SrcAlpha: + return VK_BLEND_FACTOR_SRC_ALPHA; + case BlendFactor::InvSrcAlpha: + return VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA; + case BlendFactor::DestAlpha: + return VK_BLEND_FACTOR_DST_ALPHA; + case BlendFactor::InvDestAlpha: + return VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA; + case BlendFactor::DestColor: + return VK_BLEND_FACTOR_DST_COLOR; + case BlendFactor::InvDestColor: + return VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA; + case BlendFactor::SrcAlphaSaturate: + return VK_BLEND_FACTOR_SRC_ALPHA_SATURATE; + case BlendFactor::BlendColor: + return VK_BLEND_FACTOR_CONSTANT_COLOR; + case BlendFactor::InvBlendColor: + return VK_BLEND_FACTOR_ONE_MINUS_CONSTANT_COLOR; + case BlendFactor::SecondarySrcColor: + return VK_BLEND_FACTOR_SRC1_COLOR; + case BlendFactor::InvSecondarySrcColor: + return VK_BLEND_FACTOR_ONE_MINUS_SRC1_COLOR; + case BlendFactor::SecondarySrcAlpha: + return VK_BLEND_FACTOR_SRC1_ALPHA; + case BlendFactor::InvSecondarySrcAlpha: + return VK_BLEND_FACTOR_ONE_MINUS_SRC1_ALPHA; + + default: + assert(!"Unsupported blend factor"); + return VK_BLEND_FACTOR_ONE; + } +} + +VkBlendOp VulkanUtil::translateBlendOp(BlendOp op) +{ + switch (op) + { + case BlendOp::Add: + return VK_BLEND_OP_ADD; + case BlendOp::Subtract: + return VK_BLEND_OP_SUBTRACT; + case BlendOp::ReverseSubtract: + return VK_BLEND_OP_REVERSE_SUBTRACT; + case BlendOp::Min: + return VK_BLEND_OP_MIN; + case BlendOp::Max: + return VK_BLEND_OP_MAX; + default: + assert(!"Unsupported blend op"); + return VK_BLEND_OP_ADD; + } +} + +VkPrimitiveTopology VulkanUtil::translatePrimitiveTypeToListTopology( + PrimitiveType primitiveType) +{ + switch (primitiveType) + { + case PrimitiveType::Point: + return VK_PRIMITIVE_TOPOLOGY_POINT_LIST; + case PrimitiveType::Line: + return VK_PRIMITIVE_TOPOLOGY_LINE_LIST; + case PrimitiveType::Triangle: + return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST; + case PrimitiveType::Patch: + return VK_PRIMITIVE_TOPOLOGY_PATCH_LIST; + default: + assert(!"unknown topology type."); + return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST; + } +} + +VkStencilOp VulkanUtil::translateStencilOp(StencilOp op) +{ + switch (op) + { + case StencilOp::DecrementSaturate: + return VK_STENCIL_OP_DECREMENT_AND_CLAMP; + case StencilOp::DecrementWrap: + return VK_STENCIL_OP_DECREMENT_AND_WRAP; + case StencilOp::IncrementSaturate: + return VK_STENCIL_OP_INCREMENT_AND_CLAMP; + case StencilOp::IncrementWrap: + return VK_STENCIL_OP_INCREMENT_AND_WRAP; + case StencilOp::Invert: + return VK_STENCIL_OP_INVERT; + case StencilOp::Keep: + return VK_STENCIL_OP_KEEP; + case StencilOp::Replace: + return VK_STENCIL_OP_REPLACE; + case StencilOp::Zero: + return VK_STENCIL_OP_ZERO; + default: + return VK_STENCIL_OP_KEEP; + } +} + +VkFilter VulkanUtil::translateFilterMode(TextureFilteringMode mode) +{ + switch (mode) + { + default: + return VkFilter(0); + +#define CASE(SRC, DST) \ + case TextureFilteringMode::SRC: \ + return VK_FILTER_##DST + + CASE(Point, NEAREST); + CASE(Linear, LINEAR); + +#undef CASE + } +} + +VkSamplerMipmapMode VulkanUtil::translateMipFilterMode(TextureFilteringMode mode) +{ + switch (mode) + { + default: + return VkSamplerMipmapMode(0); + +#define CASE(SRC, DST) \ + case TextureFilteringMode::SRC: \ + return VK_SAMPLER_MIPMAP_MODE_##DST + + CASE(Point, NEAREST); + CASE(Linear, LINEAR); + +#undef CASE + } +} + +VkSamplerAddressMode VulkanUtil::translateAddressingMode(TextureAddressingMode mode) +{ + switch (mode) + { + default: + return VkSamplerAddressMode(0); + +#define CASE(SRC, DST) \ + case TextureAddressingMode::SRC: \ + return VK_SAMPLER_ADDRESS_MODE_##DST + + CASE(Wrap, REPEAT); + CASE(ClampToEdge, CLAMP_TO_EDGE); + CASE(ClampToBorder, CLAMP_TO_BORDER); + CASE(MirrorRepeat, MIRRORED_REPEAT); + CASE(MirrorOnce, MIRROR_CLAMP_TO_EDGE); + +#undef CASE + } +} + +VkCompareOp VulkanUtil::translateComparisonFunc(ComparisonFunc func) +{ + switch (func) + { + default: + // TODO: need to report failures + return VK_COMPARE_OP_ALWAYS; + +#define CASE(FROM, TO) \ + case ComparisonFunc::FROM: \ + return VK_COMPARE_OP_##TO + + CASE(Never, NEVER); + CASE(Less, LESS); + CASE(Equal, EQUAL); + CASE(LessEqual, LESS_OR_EQUAL); + CASE(Greater, GREATER); + CASE(NotEqual, NOT_EQUAL); + CASE(GreaterEqual, GREATER_OR_EQUAL); + CASE(Always, ALWAYS); +#undef CASE + } +} + +VkStencilOpState VulkanUtil::translateStencilState(DepthStencilOpDesc desc) +{ + VkStencilOpState rs; + rs.compareMask = 0xFF; + rs.compareOp = translateComparisonFunc(desc.stencilFunc); + rs.depthFailOp = translateStencilOp(desc.stencilDepthFailOp); + rs.failOp = translateStencilOp(desc.stencilFailOp); + rs.passOp = translateStencilOp(desc.stencilPassOp); + rs.reference = 0; + rs.writeMask = 0xFF; + return rs; +} + /* static */Slang::Result VulkanUtil::handleFail(VkResult res) { if (res != VK_SUCCESS) diff --git a/tools/gfx/vulkan/vk-util.h b/tools/gfx/vulkan/vk-util.h index 4c32b1615..05533dad1 100644 --- a/tools/gfx/vulkan/vk-util.h +++ b/tools/gfx/vulkan/vk-util.h @@ -72,6 +72,34 @@ struct VulkanUtil } return false; } + + static VkSampleCountFlagBits translateSampleCount(uint32_t sampleCount); + + static VkCullModeFlags translateCullMode(CullMode cullMode); + + static VkFrontFace translateFrontFaceMode(FrontFaceMode frontFaceMode); + + static VkPolygonMode translateFillMode(FillMode fillMode); + + static VkBlendFactor translateBlendFactor(BlendFactor blendFactor); + + static VkBlendOp translateBlendOp(BlendOp op); + + static VkPrimitiveTopology translatePrimitiveTypeToListTopology( + PrimitiveType primitiveType); + + static VkStencilOp translateStencilOp(StencilOp op); + + static VkFilter translateFilterMode(TextureFilteringMode mode); + + static VkSamplerMipmapMode translateMipFilterMode(TextureFilteringMode mode); + + static VkSamplerAddressMode translateAddressingMode(TextureAddressingMode mode); + + static VkCompareOp translateComparisonFunc(ComparisonFunc func); + + static VkStencilOpState translateStencilState(DepthStencilOpDesc desc); + }; struct AccelerationStructureBuildGeometryInfoBuilder |
