summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-03-08 14:34:53 -0800
committerGitHub <noreply@github.com>2022-03-08 14:34:53 -0800
commitdcb434a5fe801d42d1b5f385fd27d0c500687647 (patch)
tree600d63ccce42e0ca3b5c63df23013044c4cea96c
parent771f29435d664f7344bc5596056146af5d64d352 (diff)
GFX Vulkan: deferred shader compilation and pipeline creation. (#2153)
* Vulkan: deferred shader compilation and pipeline creation. * Fix 32bit build. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/core/slang-dictionary.h4
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp66
-rw-r--r--tools/gfx/renderer-shared.cpp56
-rw-r--r--tools/gfx/renderer-shared.h4
-rw-r--r--tools/gfx/vulkan/render-vk.cpp1169
-rw-r--r--tools/gfx/vulkan/vk-util.cpp271
-rw-r--r--tools/gfx/vulkan/vk-util.h28
7 files changed, 867 insertions, 731 deletions
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h
index eef7d6908..6a8d03101 100644
--- a/source/core/slang-dictionary.h
+++ b/source/core/slang-dictionary.h
@@ -167,7 +167,7 @@ namespace Slang
return FindPositionResult(hashPos, -1);
}
numProbes++;
- hashPos = (hashPos + GetProbeOffset(numProbes)) & bucketSizeMinusOne;
+ hashPos = (hashPos + GetProbeOffset(numProbes)) % bucketSizeMinusOne;
}
if (insertPos != -1)
return FindPositionResult(-1, insertPos);
@@ -738,7 +738,7 @@ namespace Slang
return FindPositionResult(hashPos, -1);
}
numProbes++;
- hashPos = (hashPos + GetProbeOffset(numProbes)) & bucketSizeMinusOne;
+ hashPos = (hashPos + GetProbeOffset(numProbes)) % bucketSizeMinusOne;
}
if (insertPos != -1)
return FindPositionResult(-1, insertPos);
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index a6a0c7790..0fc3b2874 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -2479,68 +2479,26 @@ public:
{
SlangStage stage;
slang::EntryPointReflection* entryPointInfo;
+ String actualEntryPointNameInAPI;
List<uint8_t> code;
};
class ShaderProgramImpl : public ShaderProgramBase
{
public:
- List<ShaderBinary> m_shaders;
RefPtr<RootShaderObjectLayoutImpl> m_rootObjectLayout;
- Result compileShaders()
- {
- // For a fully specialized program, read and store its kernel code in `shaderProgram`.
- auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
- slang::IComponentType* entryPointComponent,
- SlangInt entryPointIndex)
- {
- auto stage = entryPointInfo->getStage();
- ComPtr<ISlangBlob> kernelCode;
- ComPtr<ISlangBlob> diagnostics;
- auto compileResult = entryPointComponent->getEntryPointCode(
- entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef());
- if (diagnostics)
- {
- getDebugCallback()->handleMessage(
- compileResult == SLANG_OK ? DebugMessageType::Warning
- : DebugMessageType::Error,
- DebugMessageSource::Slang,
- (char*)diagnostics->getBufferPointer());
- }
- SLANG_RETURN_ON_FAIL(compileResult);
- ShaderBinary shaderBin;
- shaderBin.stage = stage;
- shaderBin.entryPointInfo = entryPointInfo;
- shaderBin.code.addRange(
- reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
- (Index)kernelCode->getBufferSize());
- m_shaders.add(_Move(shaderBin));
- return SLANG_OK;
- };
+ List<ShaderBinary> m_shaders;
- if (linkedEntryPoints.getCount() == 0)
- {
- // If the user does not explicitly specify entry point components, find them from
- // `linkedEntryPoints`.
- auto programReflection = linkedProgram->getLayout();
- for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++)
- {
- SLANG_RETURN_ON_FAIL(compileShader(
- programReflection->getEntryPointByIndex(i),
- linkedProgram,
- (SlangInt)i));
- }
- }
- else
- {
- // If the user specifies entry point components via the separated entry point array,
- // compile code from there.
- for (auto& entryPoint : linkedEntryPoints)
- {
- SLANG_RETURN_ON_FAIL(compileShader(
- entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0));
- }
- }
+ virtual Result createShaderModule(
+ slang::EntryPointReflection* entryPointInfo, ComPtr<ISlangBlob> kernelCode) override
+ {
+ ShaderBinary shaderBin;
+ shaderBin.stage = entryPointInfo->getStage();
+ shaderBin.entryPointInfo = entryPointInfo;
+ shaderBin.code.addRange(
+ reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
+ (Index)kernelCode->getBufferSize());
+ m_shaders.add(_Move(shaderBin));
return SLANG_OK;
}
};
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index b476bd284..d01f6fe72 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -827,6 +827,62 @@ void ShaderProgramBase::init(const IShaderProgram::Desc& inDesc)
}
}
+Result ShaderProgramBase::compileShaders()
+{
+ // For a fully specialized program, read and store its kernel code in `shaderProgram`.
+ auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
+ slang::IComponentType* entryPointComponent,
+ SlangInt entryPointIndex)
+ {
+ auto stage = entryPointInfo->getStage();
+ ComPtr<ISlangBlob> kernelCode;
+ ComPtr<ISlangBlob> diagnostics;
+ auto compileResult = entryPointComponent->getEntryPointCode(
+ entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef());
+ if (diagnostics)
+ {
+ getDebugCallback()->handleMessage(
+ compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error,
+ DebugMessageSource::Slang,
+ (char*)diagnostics->getBufferPointer());
+ }
+ SLANG_RETURN_ON_FAIL(compileResult);
+ SLANG_RETURN_ON_FAIL(createShaderModule(entryPointInfo, kernelCode));
+ return SLANG_OK;
+ };
+
+ if (linkedEntryPoints.getCount() == 0)
+ {
+ // If the user does not explicitly specify entry point components, find them from
+ // `linkedEntryPoints`.
+ auto programReflection = linkedProgram->getLayout();
+ for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++)
+ {
+ SLANG_RETURN_ON_FAIL(compileShader(
+ programReflection->getEntryPointByIndex(i), linkedProgram, (SlangInt)i));
+ }
+ }
+ else
+ {
+ // If the user specifies entry point components via the separated entry point array,
+ // compile code from there.
+ for (auto& entryPoint : linkedEntryPoints)
+ {
+ SLANG_RETURN_ON_FAIL(
+ compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0));
+ }
+ }
+ return SLANG_OK;
+}
+
+Result ShaderProgramBase::createShaderModule(
+ slang::EntryPointReflection* entryPointInfo, ComPtr<ISlangBlob> kernelCode)
+{
+ SLANG_UNUSED(entryPointInfo);
+ SLANG_UNUSED(kernelCode);
+ return SLANG_OK;
+}
+
Result RendererBase::maybeSpecializePipeline(
PipelineStateBase* currentPipeline,
ShaderObjectBase* rootObject,
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 1814f97d9..5dc4da59a 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -853,6 +853,10 @@ public:
}
return false;
}
+
+ Slang::Result compileShaders();
+ virtual Slang::Result createShaderModule(
+ slang::EntryPointReflection* entryPointInfo, Slang::ComPtr<ISlangBlob> kernelCode);
};
class InputLayoutBase
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index 65eee6d3e..5f3c8f8df 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -1041,6 +1041,256 @@ public:
initializeBase(pipelineDesc);
}
+ Result createVKGraphicsPipelineState()
+ {
+ VkPipelineCache pipelineCache = VK_NULL_HANDLE;
+
+ auto inputLayoutImpl = (InputLayoutImpl*)desc.graphics.inputLayout;
+
+ // VertexBuffer/s
+ VkPipelineVertexInputStateCreateInfo vertexInputInfo = {
+ VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO};
+ vertexInputInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO;
+ vertexInputInfo.vertexBindingDescriptionCount = 0;
+ vertexInputInfo.vertexAttributeDescriptionCount = 0;
+
+ if (inputLayoutImpl)
+ {
+ const auto& srcAttributeDescs = inputLayoutImpl->m_attributeDescs;
+ const auto& srcStreamDescs = inputLayoutImpl->m_streamDescs;
+
+ vertexInputInfo.vertexBindingDescriptionCount = (uint32_t)srcStreamDescs.getCount();
+ vertexInputInfo.pVertexBindingDescriptions = srcStreamDescs.getBuffer();
+
+ vertexInputInfo.vertexAttributeDescriptionCount =
+ (uint32_t)srcAttributeDescs.getCount();
+ vertexInputInfo.pVertexAttributeDescriptions = srcAttributeDescs.getBuffer();
+ }
+
+ VkPipelineInputAssemblyStateCreateInfo inputAssembly = {};
+ inputAssembly.sType = VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO;
+ // All other forms of primitive toplogies are specified via dynamic state.
+ inputAssembly.topology =
+ VulkanUtil::translatePrimitiveTypeToListTopology(desc.graphics.primitiveType);
+ inputAssembly.primitiveRestartEnable = VK_FALSE; // TODO: Currently unsupported
+
+ VkViewport viewport = {};
+ viewport.x = 0.0f;
+ viewport.y = 0.0f;
+ // We are using dynamic viewport and scissor state.
+ // Here we specify an arbitrary size, actual viewport will be set at `beginRenderPass`
+ // time.
+ viewport.width = 16.0f;
+ viewport.height = 16.0f;
+ viewport.minDepth = 0.0f;
+ viewport.maxDepth = 1.0f;
+
+ VkRect2D scissor = {};
+ scissor.offset = {0, 0};
+ scissor.extent = {uint32_t(16), uint32_t(16)};
+
+ VkPipelineViewportStateCreateInfo viewportState = {};
+ viewportState.sType = VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO;
+ viewportState.viewportCount = 1;
+ viewportState.pViewports = &viewport;
+ viewportState.scissorCount = 1;
+ viewportState.pScissors = &scissor;
+
+ auto rasterizerDesc = desc.graphics.rasterizer;
+
+ VkPipelineRasterizationStateCreateInfo rasterizer = {};
+ rasterizer.sType = VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO;
+ rasterizer.depthClampEnable =
+ VK_TRUE; // TODO: Depth clipping and clamping are different between Vk and D3D12
+ rasterizer.rasterizerDiscardEnable = VK_FALSE; // TODO: Currently unsupported
+ rasterizer.polygonMode = VulkanUtil::translateFillMode(rasterizerDesc.fillMode);
+ rasterizer.cullMode = VulkanUtil::translateCullMode(rasterizerDesc.cullMode);
+ rasterizer.frontFace = VulkanUtil::translateFrontFaceMode(rasterizerDesc.frontFace);
+ rasterizer.depthBiasEnable = (rasterizerDesc.depthBias == 0) ? VK_FALSE : VK_TRUE;
+ rasterizer.depthBiasConstantFactor = (float)rasterizerDesc.depthBias;
+ rasterizer.depthBiasClamp = rasterizerDesc.depthBiasClamp;
+ rasterizer.depthBiasSlopeFactor = rasterizerDesc.slopeScaledDepthBias;
+ rasterizer.lineWidth = 1.0f; // TODO: Currently unsupported
+
+ VkPipelineRasterizationConservativeStateCreateInfoEXT conservativeRasterInfo = {};
+ conservativeRasterInfo.sType =
+ VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_CONSERVATIVE_STATE_CREATE_INFO_EXT;
+ conservativeRasterInfo.conservativeRasterizationMode =
+ VK_CONSERVATIVE_RASTERIZATION_MODE_OVERESTIMATE_EXT;
+ if (desc.graphics.rasterizer.enableConservativeRasterization)
+ {
+ rasterizer.pNext = &conservativeRasterInfo;
+ }
+
+ auto framebufferLayoutImpl =
+ static_cast<FramebufferLayoutImpl*>(desc.graphics.framebufferLayout);
+ auto forcedSampleCount = rasterizerDesc.forcedSampleCount;
+ auto blendDesc = desc.graphics.blend;
+
+ VkPipelineMultisampleStateCreateInfo multisampling = {};
+ multisampling.sType = VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO;
+ multisampling.rasterizationSamples = (forcedSampleCount == 0)
+ ? framebufferLayoutImpl->m_sampleCount
+ : VulkanUtil::translateSampleCount(forcedSampleCount);
+ multisampling.sampleShadingEnable =
+ VK_FALSE; // TODO: Should check if fragment shader needs this
+ // TODO: Sample mask is dynamic in D3D12 but PSO state in Vulkan
+ multisampling.alphaToCoverageEnable = blendDesc.alphaToCoverageEnable;
+ multisampling.alphaToOneEnable = VK_FALSE;
+
+ auto targetCount = Math::Min(
+ framebufferLayoutImpl->m_renderTargetCount, (uint32_t)blendDesc.targetCount);
+ List<VkPipelineColorBlendAttachmentState> colorBlendAttachments;
+
+ // Regardless of whether blending is enabled, Vulkan always applies the color write mask
+ // operation, so if there is no blending then we need to add an attachment that defines
+ // the color write mask to ensure colors are actually written.
+ if (targetCount == 0)
+ {
+ colorBlendAttachments.setCount(1);
+ auto& vkBlendDesc = colorBlendAttachments[0];
+ memset(&vkBlendDesc, 0, sizeof(vkBlendDesc));
+ vkBlendDesc.blendEnable = VK_FALSE;
+ vkBlendDesc.srcColorBlendFactor = VK_BLEND_FACTOR_ONE;
+ vkBlendDesc.dstColorBlendFactor = VK_BLEND_FACTOR_ONE;
+ vkBlendDesc.colorBlendOp = VK_BLEND_OP_ADD;
+ vkBlendDesc.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
+ vkBlendDesc.dstAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
+ vkBlendDesc.alphaBlendOp = VK_BLEND_OP_ADD;
+ vkBlendDesc.colorWriteMask =
+ (VkColorComponentFlags)RenderTargetWriteMask::EnableAll;
+ }
+ else
+ {
+ colorBlendAttachments.setCount(targetCount);
+ for (UInt i = 0; i < targetCount; ++i)
+ {
+ auto& gfxBlendDesc = blendDesc.targets[i];
+ auto& vkBlendDesc = colorBlendAttachments[i];
+
+ vkBlendDesc.blendEnable = gfxBlendDesc.enableBlend;
+ vkBlendDesc.srcColorBlendFactor =
+ VulkanUtil::translateBlendFactor(gfxBlendDesc.color.srcFactor);
+ vkBlendDesc.dstColorBlendFactor =
+ VulkanUtil::translateBlendFactor(gfxBlendDesc.color.dstFactor);
+ vkBlendDesc.colorBlendOp = VulkanUtil::translateBlendOp(gfxBlendDesc.color.op);
+ vkBlendDesc.srcAlphaBlendFactor =
+ VulkanUtil::translateBlendFactor(gfxBlendDesc.alpha.srcFactor);
+ vkBlendDesc.dstAlphaBlendFactor =
+ VulkanUtil::translateBlendFactor(gfxBlendDesc.alpha.dstFactor);
+ vkBlendDesc.alphaBlendOp = VulkanUtil::translateBlendOp(gfxBlendDesc.alpha.op);
+ vkBlendDesc.colorWriteMask = (VkColorComponentFlags)gfxBlendDesc.writeMask;
+ }
+ }
+
+ VkPipelineColorBlendStateCreateInfo colorBlending = {};
+ colorBlending.sType = VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO;
+ colorBlending.logicOpEnable = VK_FALSE; // TODO: D3D12 has per attachment logic op (and
+ // both have way more than one op)
+ colorBlending.logicOp = VK_LOGIC_OP_COPY;
+ colorBlending.attachmentCount = (uint32_t)colorBlendAttachments.getCount();
+ colorBlending.pAttachments = colorBlendAttachments.getBuffer();
+ colorBlending.blendConstants[0] = 0.0f;
+ colorBlending.blendConstants[1] = 0.0f;
+ colorBlending.blendConstants[2] = 0.0f;
+ colorBlending.blendConstants[3] = 0.0f;
+
+ Array<VkDynamicState, 8> dynamicStates;
+ dynamicStates.add(VK_DYNAMIC_STATE_VIEWPORT);
+ dynamicStates.add(VK_DYNAMIC_STATE_SCISSOR);
+ dynamicStates.add(VK_DYNAMIC_STATE_STENCIL_REFERENCE);
+ dynamicStates.add(VK_DYNAMIC_STATE_BLEND_CONSTANTS);
+ if (m_device->m_api.m_extendedFeatures.extendedDynamicStateFeatures.extendedDynamicState)
+ {
+ dynamicStates.add(VK_DYNAMIC_STATE_PRIMITIVE_TOPOLOGY_EXT);
+ }
+ VkPipelineDynamicStateCreateInfo dynamicStateInfo = {};
+ dynamicStateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO;
+ dynamicStateInfo.dynamicStateCount = (uint32_t)dynamicStates.getCount();
+ dynamicStateInfo.pDynamicStates = dynamicStates.getBuffer();
+
+ VkPipelineDepthStencilStateCreateInfo depthStencilStateInfo = {};
+ depthStencilStateInfo.sType =
+ VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO;
+ depthStencilStateInfo.depthTestEnable =
+ desc.graphics.depthStencil.depthTestEnable ? 1 : 0;
+ depthStencilStateInfo.back =
+ VulkanUtil::translateStencilState(desc.graphics.depthStencil.backFace);
+ depthStencilStateInfo.front =
+ VulkanUtil::translateStencilState(desc.graphics.depthStencil.frontFace);
+ depthStencilStateInfo.back.compareMask = desc.graphics.depthStencil.stencilReadMask;
+ depthStencilStateInfo.back.writeMask = desc.graphics.depthStencil.stencilWriteMask;
+ depthStencilStateInfo.front.compareMask = desc.graphics.depthStencil.stencilReadMask;
+ depthStencilStateInfo.front.writeMask = desc.graphics.depthStencil.stencilWriteMask;
+ depthStencilStateInfo.depthBoundsTestEnable = 0; // TODO: Currently unsupported
+ depthStencilStateInfo.depthCompareOp =
+ VulkanUtil::translateComparisonFunc(desc.graphics.depthStencil.depthFunc);
+ depthStencilStateInfo.depthWriteEnable =
+ desc.graphics.depthStencil.depthWriteEnable ? 1 : 0;
+ depthStencilStateInfo.stencilTestEnable =
+ desc.graphics.depthStencil.stencilEnable ? 1 : 0;
+
+ VkGraphicsPipelineCreateInfo pipelineInfo = {
+ VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO};
+
+ auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr());
+ if (programImpl->m_stageCreateInfos.getCount() == 0)
+ {
+ SLANG_RETURN_ON_FAIL(programImpl->compileShaders());
+ }
+
+ pipelineInfo.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO;
+ pipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount();
+ pipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer();
+ pipelineInfo.pVertexInputState = &vertexInputInfo;
+ pipelineInfo.pInputAssemblyState = &inputAssembly;
+ pipelineInfo.pViewportState = &viewportState;
+ pipelineInfo.pRasterizationState = &rasterizer;
+ pipelineInfo.pMultisampleState = &multisampling;
+ pipelineInfo.pColorBlendState = &colorBlending;
+ pipelineInfo.pDepthStencilState = &depthStencilStateInfo;
+ pipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
+ pipelineInfo.renderPass = framebufferLayoutImpl->m_renderPass;
+ pipelineInfo.subpass = 0;
+ pipelineInfo.basePipelineHandle = VK_NULL_HANDLE;
+ pipelineInfo.pDynamicState = &dynamicStateInfo;
+
+ SLANG_VK_CHECK(m_device->m_api.vkCreateGraphicsPipelines(
+ m_device->m_device, pipelineCache, 1, &pipelineInfo, nullptr, &m_pipeline));
+
+ return SLANG_OK;
+ }
+
+ Result createVKComputePipelineState()
+ {
+ auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr());
+ if (programImpl->m_stageCreateInfos.getCount() == 0)
+ {
+ SLANG_RETURN_ON_FAIL(programImpl->compileShaders());
+ }
+
+ VkPipelineCache pipelineCache = VK_NULL_HANDLE;
+
+ VkComputePipelineCreateInfo computePipelineInfo = {
+ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO};
+ computePipelineInfo.stage = programImpl->m_stageCreateInfos[0];
+ computePipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
+ SLANG_VK_CHECK(m_device->m_api.vkCreateComputePipelines(
+ m_device->m_device, pipelineCache, 1, &computePipelineInfo, nullptr, &m_pipeline));
+ return SLANG_OK;
+ }
+
+ Result ensureAPIPipelineStateCreated();
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(InteropHandle* outHandle) override
+ {
+ SLANG_RETURN_ON_FAIL(ensureAPIPipelineStateCreated());
+ outHandle->api = InteropHandleAPI::Vulkan;
+ outHandle->handleValue = 0;
+ memcpy(&outHandle->handleValue, &m_pipeline, sizeof(m_pipeline));
+ return SLANG_OK;
+ }
+
BreakableReference<VKDevice> m_device;
VkPipeline m_pipeline = VK_NULL_HANDLE;
@@ -1054,7 +1304,157 @@ public:
RayTracingPipelineStateImpl(VKDevice* device)
: PipelineStateImpl(device)
- {};
+ {}
+
+ static inline VkPipelineCreateFlags translateRayTracingPipelineFlags(RayTracingPipelineFlags::Enum flags)
+ {
+ VkPipelineCreateFlags vkFlags = 0;
+ if (flags & RayTracingPipelineFlags::Enum::SkipTriangles)
+ vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR;
+ if (flags & RayTracingPipelineFlags::Enum::SkipProcedurals)
+ vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR;
+
+ return vkFlags;
+ }
+
+ uint32_t findEntryPointIndexByName(
+ const Dictionary<String, Index>& entryPointNameToIndex, const char* name)
+ {
+ if (!name)
+ return VK_SHADER_UNUSED_KHR;
+
+ auto indexPtr = entryPointNameToIndex.TryGetValue(String(name));
+ if (indexPtr)
+ return (uint32_t)*indexPtr;
+ // TODO: Error reporting?
+ return VK_SHADER_UNUSED_KHR;
+ }
+
+ Result createVKRayTracingPipelineState()
+ {
+ auto programImpl = static_cast<ShaderProgramImpl*>(m_program.Ptr());
+ if (programImpl->m_stageCreateInfos.getCount() == 0)
+ {
+ SLANG_RETURN_ON_FAIL(programImpl->compileShaders());
+ }
+
+ VkRayTracingPipelineCreateInfoKHR raytracingPipelineInfo = {
+ VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR};
+ raytracingPipelineInfo.pNext = nullptr;
+ raytracingPipelineInfo.flags = translateRayTracingPipelineFlags(desc.rayTracing.flags);
+
+ raytracingPipelineInfo.stageCount =
+ (uint32_t)programImpl->m_stageCreateInfos.getCount();
+ raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer();
+
+ // Build Dictionary from entry point name to entry point index (stageCreateInfos index)
+ // for all hit shaders - findShaderIndexByName
+ Dictionary<String, Index> entryPointNameToIndex;
+
+ List<VkRayTracingShaderGroupCreateInfoKHR> shaderGroupInfos;
+ for (uint32_t i = 0; i < raytracingPipelineInfo.stageCount; ++i)
+ {
+ auto stageCreateInfo = programImpl->m_stageCreateInfos[i];
+ auto entryPointName = programImpl->m_entryPointNames[i];
+ entryPointNameToIndex.Add(entryPointName, i);
+ if (stageCreateInfo.stage &
+ (VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR |
+ VK_SHADER_STAGE_INTERSECTION_BIT_KHR))
+ continue;
+
+ VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = {
+ VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR};
+ shaderGroupInfo.pNext = nullptr;
+ shaderGroupInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
+ shaderGroupInfo.generalShader = i;
+ shaderGroupInfo.closestHitShader = VK_SHADER_UNUSED_KHR;
+ shaderGroupInfo.anyHitShader = VK_SHADER_UNUSED_KHR;
+ shaderGroupInfo.intersectionShader = VK_SHADER_UNUSED_KHR;
+ shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr;
+
+ // For groups with a single entry point, the group name is the entry point name.
+ auto shaderGroupName = entryPointName;
+ auto shaderGroupIndex = shaderGroupInfos.getCount();
+ shaderGroupInfos.add(shaderGroupInfo);
+ shaderGroupNameToIndex.Add(shaderGroupName, shaderGroupIndex);
+ }
+
+ for (int32_t i = 0; i < desc.rayTracing.hitGroupDescs.getCount(); ++i)
+ {
+ VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = {
+ VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR};
+ auto& groupDesc = desc.rayTracing.hitGroupDescs[i];
+
+ shaderGroupInfo.pNext = nullptr;
+ shaderGroupInfo.type =
+ (groupDesc.intersectionEntryPoint)
+ ? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR
+ : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
+ shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR;
+ shaderGroupInfo.closestHitShader = findEntryPointIndexByName(
+ entryPointNameToIndex, groupDesc.closestHitEntryPoint);
+ shaderGroupInfo.anyHitShader =
+ findEntryPointIndexByName(entryPointNameToIndex, groupDesc.anyHitEntryPoint);
+ shaderGroupInfo.intersectionShader = findEntryPointIndexByName(
+ entryPointNameToIndex, groupDesc.intersectionEntryPoint);
+ shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr;
+
+ auto shaderGroupIndex = shaderGroupInfos.getCount();
+ shaderGroupInfos.add(shaderGroupInfo);
+ shaderGroupNameToIndex.Add(String(groupDesc.hitGroupName), shaderGroupIndex);
+ }
+
+ raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount();
+ raytracingPipelineInfo.pGroups = shaderGroupInfos.getBuffer();
+
+ raytracingPipelineInfo.maxPipelineRayRecursionDepth =
+ (uint32_t)desc.rayTracing.maxRecursion;
+
+ raytracingPipelineInfo.pLibraryInfo = nullptr;
+ raytracingPipelineInfo.pLibraryInterface = nullptr;
+
+ raytracingPipelineInfo.pDynamicState = nullptr;
+
+ raytracingPipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
+ raytracingPipelineInfo.basePipelineHandle = VK_NULL_HANDLE;
+ raytracingPipelineInfo.basePipelineIndex = 0;
+
+ VkPipelineCache pipelineCache = VK_NULL_HANDLE;
+ SLANG_VK_CHECK(m_device->m_api.vkCreateRayTracingPipelinesKHR(
+ m_device->m_device,
+ VK_NULL_HANDLE,
+ pipelineCache,
+ 1,
+ &raytracingPipelineInfo,
+ nullptr,
+ &m_pipeline));
+ shaderGroupCount = shaderGroupInfos.getCount();
+ return SLANG_OK;
+ }
+
+ Result ensureAPIPipelineStateCreated()
+ {
+ if (m_pipeline)
+ return SLANG_OK;
+
+ switch (desc.type)
+ {
+ case PipelineType::RayTracing:
+ return createVKRayTracingPipelineState();
+ default:
+ SLANG_UNREACHABLE("Unknown pipeline type.");
+ return SLANG_FAIL;
+ }
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(InteropHandle* outHandle) override
+ {
+ SLANG_RETURN_ON_FAIL(ensureAPIPipelineStateCreated());
+ outHandle->api = InteropHandleAPI::Vulkan;
+ outHandle->handleValue = 0;
+ memcpy(&outHandle->handleValue, &m_pipeline, sizeof(m_pipeline));
+ return SLANG_OK;
+ }
};
// In order to bind shader parameters to the correct locations, we need to
@@ -2447,6 +2847,62 @@ public:
Array<ComPtr<ISlangBlob>, 8> m_codeBlobs; //< To keep storage of code in scope
Array<VkShaderModule, 8> m_modules;
RefPtr<RootShaderObjectLayout> m_rootObjectLayout;
+
+ VkPipelineShaderStageCreateInfo compileEntryPoint(
+ const char* entryPointName,
+ ISlangBlob* code,
+ VkShaderStageFlagBits stage,
+ VkShaderModule& outShaderModule)
+ {
+ char const* dataBegin = (char const*)code->getBufferPointer();
+ char const* dataEnd = (char const*)code->getBufferPointer() + code->getBufferSize();
+
+ // We need to make a copy of the code, since the Slang compiler
+ // will free the memory after a compile request is closed.
+
+ VkShaderModuleCreateInfo moduleCreateInfo = {
+ VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO};
+ moduleCreateInfo.pCode = (uint32_t*)code->getBufferPointer();
+ moduleCreateInfo.codeSize = code->getBufferSize();
+
+ VkShaderModule module;
+ SLANG_VK_CHECK(
+ m_device->m_api.vkCreateShaderModule(m_device->m_device, &moduleCreateInfo, nullptr, &module));
+ outShaderModule = module;
+
+ VkPipelineShaderStageCreateInfo shaderStageCreateInfo = {
+ VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO};
+ shaderStageCreateInfo.stage = stage;
+
+ shaderStageCreateInfo.module = module;
+ shaderStageCreateInfo.pName = entryPointName;
+
+ return shaderStageCreateInfo;
+ }
+
+ virtual Result createShaderModule(
+ slang::EntryPointReflection* entryPointInfo, ComPtr<ISlangBlob> kernelCode) override
+ {
+ m_codeBlobs.add(kernelCode);
+ VkShaderModule shaderModule;
+ // HACK: our direct-spirv-emit path generates SPIRV that respects
+ // the original entry point name, while the glslang path always
+ // uses "main" as the name. We should introduce a compiler parameter
+ // to control the entry point naming behavior in SPIRV-direct path
+ // so we can remove the ad-hoc logic here.
+ auto realEntryPointName = entryPointInfo->getNameOverride();
+ const char* spirvBinaryEntryPointName = "main";
+ if (m_device->m_desc.slang.targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY)
+ spirvBinaryEntryPointName = realEntryPointName;
+ m_stageCreateInfos.add(compileEntryPoint(
+ spirvBinaryEntryPointName,
+ kernelCode,
+ (VkShaderStageFlagBits)VulkanUtil::getShaderStage(entryPointInfo->getStage()),
+ shaderModule));
+ m_entryPointNames.add(realEntryPointName);
+ m_modules.add(shaderModule);
+ return SLANG_OK;
+ }
};
class CommandBufferImpl;
@@ -2570,6 +3026,9 @@ public:
m_currentPipeline, &m_commandBuffer->m_rootObject, newPipeline);
PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr());
+ newPipelineImpl->ensureAPIPipelineStateCreated();
+ m_currentPipeline = newPipelineImpl;
+
bindRootShaderObjectImpl(pipelineBindPoint);
auto pipelineBindPointId = getBindPointIndex(pipelineBindPoint);
@@ -5736,9 +6195,10 @@ public:
auto shaderTableImpl = (ShaderTableImpl*)shaderTable;
auto alignedHandleSize = VulkanUtil::calcAligned(rtProps.shaderGroupHandleSize, rtProps.shaderGroupHandleAlignment);
- ResourceCommandEncoder resourceCopyEncoder;
- resourceCopyEncoder.init(m_commandBuffer);
- auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(m_currentPipeline, m_commandBuffer->m_transientHeap, &resourceCopyEncoder);
+ auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(
+ m_currentPipeline,
+ m_commandBuffer->m_transientHeap,
+ static_cast<ResourceCommandEncoder*>(this));
VkStridedDeviceAddressRegionKHR raygenSBT;
raygenSBT.deviceAddress = shaderTableBuffer->getDeviceAddress();
@@ -6088,7 +6548,8 @@ public:
(uint32_t)count,
sizeof(uint64_t) * count,
data,
- sizeof(uint64_t), 0));
+ sizeof(uint64_t),
+ VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT));
return SLANG_OK;
}
public:
@@ -6496,13 +6957,6 @@ public:
VkBool32 handleDebugMessage(VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT objType, uint64_t srcObject,
size_t location, int32_t msgCode, const char* pLayerPrefix, const char* pMsg);
- /// Note that the outShaderModule value should be cleaned up when no longer needed by caller
- /// via vkShaderModuleDestroy()
- VkPipelineShaderStageCreateInfo compileEntryPoint(
- const char* entryPointName,
- ISlangBlob* code,
- VkShaderStageFlagBits stage,
- VkShaderModule& outShaderModule);
static VKAPI_ATTR VkBool32 VKAPI_CALL debugMessageCallback(VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT objType, uint64_t srcObject,
size_t location, int32_t msgCode, const char* pLayerPrefix, const char* pMsg, void* pUserData);
@@ -6798,35 +7252,6 @@ VkBool32 VKDevice::handleDebugMessage(VkDebugReportFlagsEXT flags, VkDebugReport
return ((VKDevice*)pUserData)->handleDebugMessage(flags, objType, srcObject, location, msgCode, pLayerPrefix, pMsg);
}
-VkPipelineShaderStageCreateInfo VKDevice::compileEntryPoint(
- const char* entryPointName,
- ISlangBlob* code,
- VkShaderStageFlagBits stage,
- VkShaderModule& outShaderModule)
-{
- char const* dataBegin = (char const*) code->getBufferPointer();
- char const* dataEnd = (char const*)code->getBufferPointer() + code->getBufferSize();
-
- // We need to make a copy of the code, since the Slang compiler
- // will free the memory after a compile request is closed.
-
- VkShaderModuleCreateInfo moduleCreateInfo = { VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO };
- moduleCreateInfo.pCode = (uint32_t*)code->getBufferPointer();
- moduleCreateInfo.codeSize = code->getBufferSize();
-
- VkShaderModule module;
- SLANG_VK_CHECK(m_api.vkCreateShaderModule(m_device, &moduleCreateInfo, nullptr, &module));
- outShaderModule = module;
-
- VkPipelineShaderStageCreateInfo shaderStageCreateInfo = { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO };
- shaderStageCreateInfo.stage = stage;
-
- shaderStageCreateInfo.module = module;
- shaderStageCreateInfo.pName = entryPointName;
-
- return shaderStageCreateInfo;
-}
-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!! Renderer interface !!!!!!!!!!!!!!!!!!!!!!!!!!
Result VKDevice::getNativeDeviceHandles(InteropHandles* outHandles)
@@ -8424,131 +8849,16 @@ Result VKDevice::createBufferFromNativeHandle(InteropHandle handle, const IBuffe
return SLANG_OK;
}
-VkFilter translateFilterMode(TextureFilteringMode mode)
-{
- switch (mode)
- {
- default:
- return VkFilter(0);
-
-#define CASE(SRC, DST) \
- case TextureFilteringMode::SRC: return VK_FILTER_##DST
-
- CASE(Point, NEAREST);
- CASE(Linear, LINEAR);
-
-#undef CASE
- }
-}
-
-VkSamplerMipmapMode translateMipFilterMode(TextureFilteringMode mode)
-{
- switch (mode)
- {
- default:
- return VkSamplerMipmapMode(0);
-
-#define CASE(SRC, DST) \
- case TextureFilteringMode::SRC: return VK_SAMPLER_MIPMAP_MODE_##DST
-
- CASE(Point, NEAREST);
- CASE(Linear, LINEAR);
-
-#undef CASE
- }
-}
-
-VkSamplerAddressMode translateAddressingMode(TextureAddressingMode mode)
-{
- switch (mode)
- {
- default:
- return VkSamplerAddressMode(0);
-
-#define CASE(SRC, DST) \
- case TextureAddressingMode::SRC: return VK_SAMPLER_ADDRESS_MODE_##DST
-
- CASE(Wrap, REPEAT);
- CASE(ClampToEdge, CLAMP_TO_EDGE);
- CASE(ClampToBorder, CLAMP_TO_BORDER);
- CASE(MirrorRepeat, MIRRORED_REPEAT);
- CASE(MirrorOnce, MIRROR_CLAMP_TO_EDGE);
-
-#undef CASE
- }
-}
-
-static VkCompareOp translateComparisonFunc(ComparisonFunc func)
-{
- switch (func)
- {
- default:
- // TODO: need to report failures
- return VK_COMPARE_OP_ALWAYS;
-
-#define CASE(FROM, TO) \
- case ComparisonFunc::FROM: return VK_COMPARE_OP_##TO
-
- CASE(Never, NEVER);
- CASE(Less, LESS);
- CASE(Equal, EQUAL);
- CASE(LessEqual, LESS_OR_EQUAL);
- CASE(Greater, GREATER);
- CASE(NotEqual, NOT_EQUAL);
- CASE(GreaterEqual, GREATER_OR_EQUAL);
- CASE(Always, ALWAYS);
-#undef CASE
- }
-}
-
-static VkStencilOp translateStencilOp(StencilOp op)
-{
- switch (op)
- {
- case StencilOp::DecrementSaturate:
- return VK_STENCIL_OP_DECREMENT_AND_CLAMP;
- case StencilOp::DecrementWrap:
- return VK_STENCIL_OP_DECREMENT_AND_WRAP;
- case StencilOp::IncrementSaturate:
- return VK_STENCIL_OP_INCREMENT_AND_CLAMP;
- case StencilOp::IncrementWrap:
- return VK_STENCIL_OP_INCREMENT_AND_WRAP;
- case StencilOp::Invert:
- return VK_STENCIL_OP_INVERT;
- case StencilOp::Keep:
- return VK_STENCIL_OP_KEEP;
- case StencilOp::Replace:
- return VK_STENCIL_OP_REPLACE;
- case StencilOp::Zero:
- return VK_STENCIL_OP_ZERO;
- default:
- return VK_STENCIL_OP_KEEP;
- }
-}
-
-static VkStencilOpState translateStencilState(DepthStencilOpDesc desc)
-{
- VkStencilOpState rs;
- rs.compareMask = 0xFF;
- rs.compareOp = translateComparisonFunc(desc.stencilFunc);
- rs.depthFailOp = translateStencilOp(desc.stencilDepthFailOp);
- rs.failOp = translateStencilOp(desc.stencilFailOp);
- rs.passOp = translateStencilOp(desc.stencilPassOp);
- rs.reference = 0;
- rs.writeMask = 0xFF;
- return rs;
-}
-
Result VKDevice::createSamplerState(ISamplerState::Desc const& desc, ISamplerState** outSampler)
{
VkSamplerCreateInfo samplerInfo = { VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO };
- samplerInfo.magFilter = translateFilterMode(desc.minFilter);
- samplerInfo.minFilter = translateFilterMode(desc.magFilter);
+ samplerInfo.magFilter = VulkanUtil::translateFilterMode(desc.minFilter);
+ samplerInfo.minFilter = VulkanUtil::translateFilterMode(desc.magFilter);
- samplerInfo.addressModeU = translateAddressingMode(desc.addressU);
- samplerInfo.addressModeV = translateAddressingMode(desc.addressV);
- samplerInfo.addressModeW = translateAddressingMode(desc.addressW);
+ samplerInfo.addressModeU = VulkanUtil::translateAddressingMode(desc.addressU);
+ samplerInfo.addressModeV = VulkanUtil::translateAddressingMode(desc.addressV);
+ samplerInfo.addressModeW = VulkanUtil::translateAddressingMode(desc.addressW);
samplerInfo.anisotropyEnable = desc.maxAnisotropy > 1;
samplerInfo.maxAnisotropy = (float) desc.maxAnisotropy;
@@ -8558,8 +8868,8 @@ Result VKDevice::createSamplerState(ISamplerState::Desc const& desc, ISamplerSta
samplerInfo.unnormalizedCoordinates = VK_FALSE;
samplerInfo.compareEnable = desc.reductionOp == TextureReductionOp::Comparison;
- samplerInfo.compareOp = translateComparisonFunc(desc.comparisonFunc);
- samplerInfo.mipmapMode = translateMipFilterMode(desc.mipFilter);
+ samplerInfo.compareOp = VulkanUtil::translateComparisonFunc(desc.comparisonFunc);
+ samplerInfo.mipmapMode = VulkanUtil::translateMipFilterMode(desc.mipFilter);
samplerInfo.minLod = Math::Max(0.0f, desc.minLOD);
samplerInfo.maxLod = Math::Clamp(desc.maxLOD, samplerInfo.minLod, VK_LOD_CLAMP_NONE);
@@ -8953,73 +9263,7 @@ Result VKDevice::createProgram(
shaderProgram->linkedProgram,
shaderProgram->linkedProgram->getLayout(),
shaderProgram->m_rootObjectLayout.writeRef());
- if (shaderProgram->isSpecializable())
- {
- // For a specializable program, we don't invoke any actual slang compilation yet.
- returnComPtr(outProgram, shaderProgram);
- return SLANG_OK;
- }
- // For a fully specialized program, create `VkShaderModule`s for each shader stage.
- auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
- slang::IComponentType* component,
- SlangInt entryPointIndex)
- {
- auto stage = entryPointInfo->getStage();
- ComPtr<ISlangBlob> kernelCode;
- ComPtr<ISlangBlob> diagnostics;
- auto compileResult = component->getEntryPointCode(
- entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef());
- if (diagnostics)
- {
- getDebugCallback()->handleMessage(
- compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error,
- DebugMessageSource::Slang,
- (char*)diagnostics->getBufferPointer());
- if (outDiagnosticBlob)
- returnComPtr(outDiagnosticBlob, diagnostics);
- }
- SLANG_RETURN_ON_FAIL(compileResult);
- shaderProgram->m_codeBlobs.add(kernelCode);
- VkShaderModule shaderModule;
- // HACK: our direct-spirv-emit path generates SPIRV that respects
- // the original entry point name, while the glslang path always
- // uses "main" as the name. We should introduce a compiler parameter
- // to control the entry point naming behavior in SPIRV-direct path
- // so we can remove the ad-hoc logic here.
- auto realEntryPointName = entryPointInfo->getNameOverride();
- const char* spirvBinaryEntryPointName = "main";
- if (m_desc.slang.targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY)
- spirvBinaryEntryPointName = realEntryPointName;
- shaderProgram->m_stageCreateInfos.add(compileEntryPoint(
- spirvBinaryEntryPointName,
- kernelCode,
- (VkShaderStageFlagBits)VulkanUtil::getShaderStage(stage),
- shaderModule));
- shaderProgram->m_entryPointNames.add(realEntryPointName);
- shaderProgram->m_modules.add(shaderModule);
- return SLANG_OK;
- };
- if (shaderProgram->linkedEntryPoints.getCount() == 0)
- {
- // If the user does not explicitly specify entry point components, find them from
- // `linkedEntryPoints`.
- auto programReflection = shaderProgram->linkedProgram->getLayout();
- for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++)
- {
- auto entryPointInfo = programReflection->getEntryPointByIndex(i);
- SLANG_RETURN_ON_FAIL(compileShader(entryPointInfo, shaderProgram->linkedProgram, (SlangInt)i));
- }
- }
- else
- {
- // If the user specifies entry point components via the separated entry point array, compile
- // code from there.
- for (auto& entryPoint : shaderProgram->linkedEntryPoints)
- {
- SLANG_RETURN_ON_FAIL(compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0));
- }
- }
returnComPtr(outProgram, shaderProgram);
return SLANG_OK;
}
@@ -9075,356 +9319,22 @@ Result VKDevice::createShaderTable(const IShaderTable::Desc& desc, IShaderTable*
return SLANG_OK;
}
-VkSampleCountFlagBits translateSampleCount(uint32_t sampleCount)
-{
- switch (sampleCount)
- {
- case 1: return VK_SAMPLE_COUNT_1_BIT;
- case 2: return VK_SAMPLE_COUNT_2_BIT;
- case 4: return VK_SAMPLE_COUNT_4_BIT;
- case 8: return VK_SAMPLE_COUNT_8_BIT;
- case 16: return VK_SAMPLE_COUNT_16_BIT;
- case 32: return VK_SAMPLE_COUNT_32_BIT;
- case 64: return VK_SAMPLE_COUNT_64_BIT;
- default:
- assert(!"Unsupported sample count");
- return VK_SAMPLE_COUNT_1_BIT;
- }
-}
-
-VkCullModeFlags translateCullMode(CullMode cullMode)
-{
- switch (cullMode)
- {
- case CullMode::None: return VK_CULL_MODE_NONE;
- case CullMode::Front: return VK_CULL_MODE_FRONT_BIT;
- case CullMode::Back: return VK_CULL_MODE_BACK_BIT;
- default:
- assert(!"Unsupported cull mode");
- return VK_CULL_MODE_NONE;
- }
-}
-
-VkFrontFace translateFrontFaceMode(FrontFaceMode frontFaceMode)
-{
- switch (frontFaceMode)
- {
- case FrontFaceMode::CounterClockwise:
- return VK_FRONT_FACE_COUNTER_CLOCKWISE;
- case FrontFaceMode::Clockwise:
- return VK_FRONT_FACE_CLOCKWISE;
- default:
- assert(!"Unsupported front face mode");
- return VK_FRONT_FACE_CLOCKWISE;
- }
-}
-
-VkPolygonMode translateFillMode(FillMode fillMode)
-{
- switch (fillMode)
- {
- case FillMode::Solid: return VK_POLYGON_MODE_FILL;
- case FillMode::Wireframe: return VK_POLYGON_MODE_LINE;
- default:
- assert(!"Unsupported fill mode");
- return VK_POLYGON_MODE_FILL;
- }
-}
-
-VkBlendFactor translateBlendFactor(BlendFactor blendFactor)
-{
- switch (blendFactor)
- {
- case BlendFactor::Zero: return VK_BLEND_FACTOR_ZERO;
- case BlendFactor::One: return VK_BLEND_FACTOR_ONE;
- case BlendFactor::SrcColor: return VK_BLEND_FACTOR_SRC_COLOR;
- case BlendFactor::InvSrcColor: return VK_BLEND_FACTOR_ONE_MINUS_SRC_COLOR;
- case BlendFactor::SrcAlpha: return VK_BLEND_FACTOR_SRC_ALPHA;
- case BlendFactor::InvSrcAlpha: return VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA;
- case BlendFactor::DestAlpha: return VK_BLEND_FACTOR_DST_ALPHA;
- case BlendFactor::InvDestAlpha: return VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA;
- case BlendFactor::DestColor: return VK_BLEND_FACTOR_DST_COLOR;
- case BlendFactor::InvDestColor: return VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA;
- case BlendFactor::SrcAlphaSaturate: return VK_BLEND_FACTOR_SRC_ALPHA_SATURATE;
- case BlendFactor::BlendColor: return VK_BLEND_FACTOR_CONSTANT_COLOR;
- case BlendFactor::InvBlendColor: return VK_BLEND_FACTOR_ONE_MINUS_CONSTANT_COLOR;
- case BlendFactor::SecondarySrcColor: return VK_BLEND_FACTOR_SRC1_COLOR;
- case BlendFactor::InvSecondarySrcColor: return VK_BLEND_FACTOR_ONE_MINUS_SRC1_COLOR;
- case BlendFactor::SecondarySrcAlpha: return VK_BLEND_FACTOR_SRC1_ALPHA;
- case BlendFactor::InvSecondarySrcAlpha: return VK_BLEND_FACTOR_ONE_MINUS_SRC1_ALPHA;
-
- default:
- assert(!"Unsupported blend factor");
- return VK_BLEND_FACTOR_ONE;
- }
-}
-
-VkBlendOp translateBlendOp(BlendOp op)
-{
- switch (op)
- {
- case BlendOp::Add: return VK_BLEND_OP_ADD;
- case BlendOp::Subtract: return VK_BLEND_OP_SUBTRACT;
- case BlendOp::ReverseSubtract: return VK_BLEND_OP_REVERSE_SUBTRACT;
- case BlendOp::Min: return VK_BLEND_OP_MIN;
- case BlendOp::Max: return VK_BLEND_OP_MAX;
- default:
- assert(!"Unsupported blend op");
- return VK_BLEND_OP_ADD;
- }
-}
-
-VkPrimitiveTopology translatePrimitiveTypeToListTopology(PrimitiveType primitiveType)
-{
- switch (primitiveType)
- {
- case PrimitiveType::Point: return VK_PRIMITIVE_TOPOLOGY_POINT_LIST;
- case PrimitiveType::Line: return VK_PRIMITIVE_TOPOLOGY_LINE_LIST;
- case PrimitiveType::Triangle: return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
- case PrimitiveType::Patch: return VK_PRIMITIVE_TOPOLOGY_PATCH_LIST;
- default:
- assert(!"unknown topology type.");
- return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
- }
-}
-
Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState)
{
GraphicsPipelineStateDesc desc = inDesc;
- auto programImpl = static_cast<ShaderProgramImpl*>(desc.program);
-
- if (!programImpl->m_rootObjectLayout->m_pipelineLayout)
- {
- RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
- pipelineStateImpl->init(desc);
- pipelineStateImpl->establishStrongDeviceReference();
- m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
- returnComPtr(outState, pipelineStateImpl);
- return SLANG_OK;
- }
-
- VkPipelineCache pipelineCache = VK_NULL_HANDLE;
-
- auto inputLayoutImpl = (InputLayoutImpl*) desc.inputLayout;
-
- // VertexBuffer/s
- VkPipelineVertexInputStateCreateInfo vertexInputInfo = { VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO };
- vertexInputInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO;
- vertexInputInfo.vertexBindingDescriptionCount = 0;
- vertexInputInfo.vertexAttributeDescriptionCount = 0;
-
- if (inputLayoutImpl)
- {
- const auto& srcAttributeDescs = inputLayoutImpl->m_attributeDescs;
- const auto& srcStreamDescs = inputLayoutImpl->m_streamDescs;
-
- vertexInputInfo.vertexBindingDescriptionCount = (uint32_t)srcStreamDescs.getCount();
- vertexInputInfo.pVertexBindingDescriptions = srcStreamDescs.getBuffer();
-
- vertexInputInfo.vertexAttributeDescriptionCount = (uint32_t)srcAttributeDescs.getCount();
- vertexInputInfo.pVertexAttributeDescriptions = srcAttributeDescs.getBuffer();
- }
-
- VkPipelineInputAssemblyStateCreateInfo inputAssembly = {};
- inputAssembly.sType = VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO;
- // All other forms of primitive toplogies are specified via dynamic state.
- inputAssembly.topology = translatePrimitiveTypeToListTopology(inDesc.primitiveType);
- inputAssembly.primitiveRestartEnable = VK_FALSE; // TODO: Currently unsupported
-
- VkViewport viewport = {};
- viewport.x = 0.0f;
- viewport.y = 0.0f;
- // We are using dynamic viewport and scissor state.
- // Here we specify an arbitrary size, actual viewport will be set at `beginRenderPass` time.
- viewport.width = 16.0f;
- viewport.height = 16.0f;
- viewport.minDepth = 0.0f;
- viewport.maxDepth = 1.0f;
-
- VkRect2D scissor = {};
- scissor.offset = { 0, 0 };
- scissor.extent = { uint32_t(16), uint32_t(16) };
-
- VkPipelineViewportStateCreateInfo viewportState = {};
- viewportState.sType = VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO;
- viewportState.viewportCount = 1;
- viewportState.pViewports = &viewport;
- viewportState.scissorCount = 1;
- viewportState.pScissors = &scissor;
-
- auto rasterizerDesc = desc.rasterizer;
-
- VkPipelineRasterizationStateCreateInfo rasterizer = {};
- rasterizer.sType = VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO;
- rasterizer.depthClampEnable = VK_TRUE; // TODO: Depth clipping and clamping are different between Vk and D3D12
- rasterizer.rasterizerDiscardEnable = VK_FALSE; // TODO: Currently unsupported
- rasterizer.polygonMode = translateFillMode(rasterizerDesc.fillMode);
- rasterizer.cullMode = translateCullMode(rasterizerDesc.cullMode);
- rasterizer.frontFace = translateFrontFaceMode(rasterizerDesc.frontFace);
- rasterizer.depthBiasEnable = (rasterizerDesc.depthBias == 0) ? VK_FALSE : VK_TRUE;
- rasterizer.depthBiasConstantFactor = (float)rasterizerDesc.depthBias;
- rasterizer.depthBiasClamp = rasterizerDesc.depthBiasClamp;
- rasterizer.depthBiasSlopeFactor = rasterizerDesc.slopeScaledDepthBias;
- rasterizer.lineWidth = 1.0f; // TODO: Currently unsupported
-
- VkPipelineRasterizationConservativeStateCreateInfoEXT conservativeRasterInfo = {};
- conservativeRasterInfo.sType =
- VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_CONSERVATIVE_STATE_CREATE_INFO_EXT;
- conservativeRasterInfo.conservativeRasterizationMode =
- VK_CONSERVATIVE_RASTERIZATION_MODE_OVERESTIMATE_EXT;
- if (desc.rasterizer.enableConservativeRasterization)
- {
- rasterizer.pNext = &conservativeRasterInfo;
- }
-
- auto framebufferLayoutImpl = static_cast<FramebufferLayoutImpl*>(desc.framebufferLayout);
- auto forcedSampleCount = rasterizerDesc.forcedSampleCount;
- auto blendDesc = desc.blend;
-
- VkPipelineMultisampleStateCreateInfo multisampling = {};
- multisampling.sType = VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO;
- multisampling.rasterizationSamples =
- (forcedSampleCount == 0) ? framebufferLayoutImpl->m_sampleCount : translateSampleCount(forcedSampleCount);
- multisampling.sampleShadingEnable = VK_FALSE; // TODO: Should check if fragment shader needs this
- // TODO: Sample mask is dynamic in D3D12 but PSO state in Vulkan
- multisampling.alphaToCoverageEnable = blendDesc.alphaToCoverageEnable;
- multisampling.alphaToOneEnable = VK_FALSE;
-
- auto targetCount = Math::Min(framebufferLayoutImpl->m_renderTargetCount, (uint32_t)blendDesc.targetCount);
- List<VkPipelineColorBlendAttachmentState> colorBlendAttachments;
-
- // Regardless of whether blending is enabled, Vulkan always applies the color write mask operation,
- // so if there is no blending then we need to add an attachment that defines the color write mask
- // to ensure colors are actually written.
- if (targetCount == 0)
- {
- colorBlendAttachments.setCount(1);
- auto& vkBlendDesc = colorBlendAttachments[0];
- memset(&vkBlendDesc, 0, sizeof(vkBlendDesc));
- vkBlendDesc.blendEnable = VK_FALSE;
- vkBlendDesc.srcColorBlendFactor = VK_BLEND_FACTOR_ONE;
- vkBlendDesc.dstColorBlendFactor = VK_BLEND_FACTOR_ONE;
- vkBlendDesc.colorBlendOp = VK_BLEND_OP_ADD;
- vkBlendDesc.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
- vkBlendDesc.dstAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
- vkBlendDesc.alphaBlendOp = VK_BLEND_OP_ADD;
- vkBlendDesc.colorWriteMask = (VkColorComponentFlags)RenderTargetWriteMask::EnableAll;
- }
- else
- {
- colorBlendAttachments.setCount(targetCount);
- for (UInt i = 0; i < targetCount; ++i)
- {
- auto& gfxBlendDesc = blendDesc.targets[i];
- auto& vkBlendDesc = colorBlendAttachments[i];
-
- vkBlendDesc.blendEnable = gfxBlendDesc.enableBlend;
- vkBlendDesc.srcColorBlendFactor = translateBlendFactor(gfxBlendDesc.color.srcFactor);
- vkBlendDesc.dstColorBlendFactor = translateBlendFactor(gfxBlendDesc.color.dstFactor);
- vkBlendDesc.colorBlendOp = translateBlendOp(gfxBlendDesc.color.op);
- vkBlendDesc.srcAlphaBlendFactor = translateBlendFactor(gfxBlendDesc.alpha.srcFactor);
- vkBlendDesc.dstAlphaBlendFactor = translateBlendFactor(gfxBlendDesc.alpha.dstFactor);
- vkBlendDesc.alphaBlendOp = translateBlendOp(gfxBlendDesc.alpha.op);
- vkBlendDesc.colorWriteMask = (VkColorComponentFlags)gfxBlendDesc.writeMask;
- }
- }
-
- VkPipelineColorBlendStateCreateInfo colorBlending = {};
- colorBlending.sType = VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO;
- colorBlending.logicOpEnable = VK_FALSE; // TODO: D3D12 has per attachment logic op (and both have way more than one op)
- colorBlending.logicOp = VK_LOGIC_OP_COPY;
- colorBlending.attachmentCount = (uint32_t)colorBlendAttachments.getCount();
- colorBlending.pAttachments = colorBlendAttachments.getBuffer();
- colorBlending.blendConstants[0] = 0.0f;
- colorBlending.blendConstants[1] = 0.0f;
- colorBlending.blendConstants[2] = 0.0f;
- colorBlending.blendConstants[3] = 0.0f;
-
- Array<VkDynamicState, 8> dynamicStates;
- dynamicStates.add(VK_DYNAMIC_STATE_VIEWPORT);
- dynamicStates.add(VK_DYNAMIC_STATE_SCISSOR);
- dynamicStates.add(VK_DYNAMIC_STATE_STENCIL_REFERENCE);
- dynamicStates.add(VK_DYNAMIC_STATE_BLEND_CONSTANTS);
- if (m_api.m_extendedFeatures.extendedDynamicStateFeatures.extendedDynamicState)
- {
- dynamicStates.add(VK_DYNAMIC_STATE_PRIMITIVE_TOPOLOGY_EXT);
- }
- VkPipelineDynamicStateCreateInfo dynamicStateInfo = {};
- dynamicStateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO;
- dynamicStateInfo.dynamicStateCount = (uint32_t)dynamicStates.getCount();
- dynamicStateInfo.pDynamicStates = dynamicStates.getBuffer();
-
- VkPipelineDepthStencilStateCreateInfo depthStencilStateInfo = {};
- depthStencilStateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO;
- depthStencilStateInfo.depthTestEnable = inDesc.depthStencil.depthTestEnable ? 1 : 0;
- depthStencilStateInfo.back = translateStencilState(inDesc.depthStencil.backFace);
- depthStencilStateInfo.front = translateStencilState(inDesc.depthStencil.frontFace);
- depthStencilStateInfo.back.compareMask = inDesc.depthStencil.stencilReadMask;
- depthStencilStateInfo.back.writeMask = inDesc.depthStencil.stencilWriteMask;
- depthStencilStateInfo.front.compareMask = inDesc.depthStencil.stencilReadMask;
- depthStencilStateInfo.front.writeMask = inDesc.depthStencil.stencilWriteMask;
- depthStencilStateInfo.depthBoundsTestEnable = 0; // TODO: Currently unsupported
- depthStencilStateInfo.depthCompareOp = translateComparisonFunc(inDesc.depthStencil.depthFunc);
- depthStencilStateInfo.depthWriteEnable = inDesc.depthStencil.depthWriteEnable ? 1 : 0;
- depthStencilStateInfo.stencilTestEnable = inDesc.depthStencil.stencilEnable ? 1 : 0;
-
- VkGraphicsPipelineCreateInfo pipelineInfo = { VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO };
-
- pipelineInfo.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO;
- pipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount();
- pipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer();
- pipelineInfo.pVertexInputState = &vertexInputInfo;
- pipelineInfo.pInputAssemblyState = &inputAssembly;
- pipelineInfo.pViewportState = &viewportState;
- pipelineInfo.pRasterizationState = &rasterizer;
- pipelineInfo.pMultisampleState = &multisampling;
- pipelineInfo.pColorBlendState = &colorBlending;
- pipelineInfo.pDepthStencilState = &depthStencilStateInfo;
- pipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
- pipelineInfo.renderPass = framebufferLayoutImpl->m_renderPass;
- pipelineInfo.subpass = 0;
- pipelineInfo.basePipelineHandle = VK_NULL_HANDLE;
- pipelineInfo.pDynamicState = &dynamicStateInfo;
-
- VkPipeline pipeline = VK_NULL_HANDLE;
- SLANG_VK_CHECK(m_api.vkCreateGraphicsPipelines(m_device, pipelineCache, 1, &pipelineInfo, nullptr, &pipeline));
-
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
- pipelineStateImpl->m_pipeline = pipeline;
pipelineStateImpl->init(desc);
pipelineStateImpl->establishStrongDeviceReference();
m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
returnComPtr(outState, pipelineStateImpl);
+
return SLANG_OK;
}
Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDesc, IPipelineState** outState)
{
ComputePipelineStateDesc desc = inDesc;
- auto programImpl = static_cast<ShaderProgramImpl*>(desc.program);
- if (!programImpl->m_rootObjectLayout->m_pipelineLayout)
- {
- RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
- pipelineStateImpl->init(desc);
- m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
- pipelineStateImpl->establishStrongDeviceReference();
- returnComPtr(outState, pipelineStateImpl);
- return SLANG_OK;
- }
-
- VkPipelineCache pipelineCache = VK_NULL_HANDLE;
-
- VkPipeline pipeline = VK_NULL_HANDLE;
-
- VkComputePipelineCreateInfo computePipelineInfo = {
- VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO };
- computePipelineInfo.stage = programImpl->m_stageCreateInfos[0];
- computePipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
- SLANG_VK_CHECK(m_api.vkCreateComputePipelines(
- m_device, pipelineCache, 1, &computePipelineInfo, nullptr, &pipeline));
-
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
- pipelineStateImpl->m_pipeline = pipeline;
pipelineStateImpl->init(desc);
m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
pipelineStateImpl->establishStrongDeviceReference();
@@ -9432,122 +9342,12 @@ Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe
return SLANG_OK;
}
-VkPipelineCreateFlags translateFlags(RayTracingPipelineFlags::Enum flags)
-{
- VkPipelineCreateFlags vkFlags = 0;
- if (flags & RayTracingPipelineFlags::Enum::SkipTriangles)
- vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR;
- if (flags & RayTracingPipelineFlags::Enum::SkipProcedurals)
- vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR;
-
- return vkFlags;
-}
-
-uint32_t findEntryPointIndexByName(const Dictionary<String, Index>& entryPointNameToIndex, const char* name)
-{
- if (!name) return VK_SHADER_UNUSED_KHR;
-
- auto indexPtr = entryPointNameToIndex.TryGetValue(String(name));
- if (indexPtr)
- return (uint32_t)*indexPtr;
- // TODO: Error reporting?
- return VK_SHADER_UNUSED_KHR;
-}
-
Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc& desc, IPipelineState** outState)
{
- auto programImpl = static_cast<ShaderProgramImpl*>(desc.program);
- if (!programImpl->m_rootObjectLayout->m_pipelineLayout)
- {
- RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
- pipelineStateImpl->init(desc);
- m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
- pipelineStateImpl->establishStrongDeviceReference();
- returnComPtr(outState, pipelineStateImpl);
- return SLANG_OK;
- }
-
- VkRayTracingPipelineCreateInfoKHR raytracingPipelineInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR };
- raytracingPipelineInfo.pNext = nullptr;
- raytracingPipelineInfo.flags = translateFlags(desc.flags);
-
- raytracingPipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount();
- raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer();
-
- // Build Dictionary from group name to group index
- Dictionary<String, Index> shaderGroupNameToIndex;
- // Build Dictionary from entry point name to entry point index (stageCreateInfos index) for all hit shaders - findShaderIndexByName
- Dictionary<String, Index> entryPointNameToIndex;
-
- List<VkRayTracingShaderGroupCreateInfoKHR> shaderGroupInfos;
- for (uint32_t i = 0; i < raytracingPipelineInfo.stageCount; ++i)
- {
- auto stageCreateInfo = programImpl->m_stageCreateInfos[i];
- auto entryPointName = programImpl->m_entryPointNames[i];
- entryPointNameToIndex.Add(entryPointName, i);
- if (stageCreateInfo.stage & (VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR))
- continue;
-
- VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR };
- shaderGroupInfo.pNext = nullptr;
- shaderGroupInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
- shaderGroupInfo.generalShader = i;
- shaderGroupInfo.closestHitShader = VK_SHADER_UNUSED_KHR;
- shaderGroupInfo.anyHitShader = VK_SHADER_UNUSED_KHR;
- shaderGroupInfo.intersectionShader = VK_SHADER_UNUSED_KHR;
- shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr;
-
- // For groups with a single entry point, the group name is the entry point name.
- auto shaderGroupName = entryPointName;
- auto shaderGroupIndex = shaderGroupInfos.getCount();
- shaderGroupInfos.add(shaderGroupInfo);
- shaderGroupNameToIndex.Add(shaderGroupName, shaderGroupIndex);
- }
-
- for (int32_t i = 0; i < desc.hitGroupCount; ++i)
- {
- VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR };
- auto& groupDesc = desc.hitGroups[i];
-
- shaderGroupInfo.pNext = nullptr;
- shaderGroupInfo.type = (groupDesc.intersectionEntryPoint)
- ? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
- shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR;
- shaderGroupInfo.closestHitShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.closestHitEntryPoint);
- shaderGroupInfo.anyHitShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.anyHitEntryPoint);
- shaderGroupInfo.intersectionShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.intersectionEntryPoint);
- shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr;
-
- auto shaderGroupIndex = shaderGroupInfos.getCount();
- shaderGroupInfos.add(shaderGroupInfo);
- shaderGroupNameToIndex.Add(String(groupDesc.hitGroupName), shaderGroupIndex);
- }
-
- raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount();
- raytracingPipelineInfo.pGroups = shaderGroupInfos.getBuffer();
-
- raytracingPipelineInfo.maxPipelineRayRecursionDepth = (uint32_t)desc.maxRecursion;
-
- raytracingPipelineInfo.pLibraryInfo = nullptr;
- raytracingPipelineInfo.pLibraryInterface = nullptr;
-
- raytracingPipelineInfo.pDynamicState = nullptr;
-
- raytracingPipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
- raytracingPipelineInfo.basePipelineHandle = VK_NULL_HANDLE;
- raytracingPipelineInfo.basePipelineIndex = 0;
-
- VkPipelineCache pipelineCache = VK_NULL_HANDLE;
- VkPipeline pipeline = VK_NULL_HANDLE;
- SLANG_VK_CHECK(m_api.vkCreateRayTracingPipelinesKHR(m_device, VK_NULL_HANDLE, pipelineCache, 1, &raytracingPipelineInfo, nullptr, &pipeline));
-
RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl(this);
- pipelineStateImpl->m_pipeline = pipeline;
pipelineStateImpl->init(desc);
- pipelineStateImpl->establishStrongDeviceReference();
- pipelineStateImpl->shaderGroupNameToIndex = shaderGroupNameToIndex;
- pipelineStateImpl->shaderGroupCount = shaderGroupInfos.getCount();
m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
+ pipelineStateImpl->establishStrongDeviceReference();
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}
@@ -9592,4 +9392,23 @@ Result VKDevice::waitForFences(
return result == VK_SUCCESS ? SLANG_OK : SLANG_FAIL;
}
+Result VKDevice::PipelineStateImpl::ensureAPIPipelineStateCreated()
+{
+ if (m_pipeline)
+ return SLANG_OK;
+
+ switch (desc.type)
+ {
+ case PipelineType::Compute:
+ return createVKComputePipelineState();
+ case PipelineType::Graphics:
+ return createVKGraphicsPipelineState();
+ case PipelineType::RayTracing:
+ return static_cast<RayTracingPipelineStateImpl*>(this)->createVKRayTracingPipelineState();
+ default:
+ SLANG_UNREACHABLE("Unknown pipeline type.");
+ return SLANG_FAIL;
+ }
+}
+
} // renderer_test
diff --git a/tools/gfx/vulkan/vk-util.cpp b/tools/gfx/vulkan/vk-util.cpp
index 974a61422..8e2a187a3 100644
--- a/tools/gfx/vulkan/vk-util.cpp
+++ b/tools/gfx/vulkan/vk-util.cpp
@@ -183,6 +183,277 @@ VkImageLayout VulkanUtil::getImageLayoutFromState(ResourceState state)
return VkImageLayout();
}
+VkSampleCountFlagBits VulkanUtil::translateSampleCount(uint32_t sampleCount)
+{
+ switch (sampleCount)
+ {
+ case 1:
+ return VK_SAMPLE_COUNT_1_BIT;
+ case 2:
+ return VK_SAMPLE_COUNT_2_BIT;
+ case 4:
+ return VK_SAMPLE_COUNT_4_BIT;
+ case 8:
+ return VK_SAMPLE_COUNT_8_BIT;
+ case 16:
+ return VK_SAMPLE_COUNT_16_BIT;
+ case 32:
+ return VK_SAMPLE_COUNT_32_BIT;
+ case 64:
+ return VK_SAMPLE_COUNT_64_BIT;
+ default:
+ assert(!"Unsupported sample count");
+ return VK_SAMPLE_COUNT_1_BIT;
+ }
+}
+
+VkCullModeFlags VulkanUtil::translateCullMode(CullMode cullMode)
+{
+ switch (cullMode)
+ {
+ case CullMode::None:
+ return VK_CULL_MODE_NONE;
+ case CullMode::Front:
+ return VK_CULL_MODE_FRONT_BIT;
+ case CullMode::Back:
+ return VK_CULL_MODE_BACK_BIT;
+ default:
+ assert(!"Unsupported cull mode");
+ return VK_CULL_MODE_NONE;
+ }
+}
+
+VkFrontFace VulkanUtil::translateFrontFaceMode(FrontFaceMode frontFaceMode)
+{
+ switch (frontFaceMode)
+ {
+ case FrontFaceMode::CounterClockwise:
+ return VK_FRONT_FACE_COUNTER_CLOCKWISE;
+ case FrontFaceMode::Clockwise:
+ return VK_FRONT_FACE_CLOCKWISE;
+ default:
+ assert(!"Unsupported front face mode");
+ return VK_FRONT_FACE_CLOCKWISE;
+ }
+}
+
+VkPolygonMode VulkanUtil::translateFillMode(FillMode fillMode)
+{
+ switch (fillMode)
+ {
+ case FillMode::Solid:
+ return VK_POLYGON_MODE_FILL;
+ case FillMode::Wireframe:
+ return VK_POLYGON_MODE_LINE;
+ default:
+ assert(!"Unsupported fill mode");
+ return VK_POLYGON_MODE_FILL;
+ }
+}
+
+VkBlendFactor VulkanUtil::translateBlendFactor(BlendFactor blendFactor)
+{
+ switch (blendFactor)
+ {
+ case BlendFactor::Zero:
+ return VK_BLEND_FACTOR_ZERO;
+ case BlendFactor::One:
+ return VK_BLEND_FACTOR_ONE;
+ case BlendFactor::SrcColor:
+ return VK_BLEND_FACTOR_SRC_COLOR;
+ case BlendFactor::InvSrcColor:
+ return VK_BLEND_FACTOR_ONE_MINUS_SRC_COLOR;
+ case BlendFactor::SrcAlpha:
+ return VK_BLEND_FACTOR_SRC_ALPHA;
+ case BlendFactor::InvSrcAlpha:
+ return VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA;
+ case BlendFactor::DestAlpha:
+ return VK_BLEND_FACTOR_DST_ALPHA;
+ case BlendFactor::InvDestAlpha:
+ return VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA;
+ case BlendFactor::DestColor:
+ return VK_BLEND_FACTOR_DST_COLOR;
+ case BlendFactor::InvDestColor:
+ return VK_BLEND_FACTOR_ONE_MINUS_DST_ALPHA;
+ case BlendFactor::SrcAlphaSaturate:
+ return VK_BLEND_FACTOR_SRC_ALPHA_SATURATE;
+ case BlendFactor::BlendColor:
+ return VK_BLEND_FACTOR_CONSTANT_COLOR;
+ case BlendFactor::InvBlendColor:
+ return VK_BLEND_FACTOR_ONE_MINUS_CONSTANT_COLOR;
+ case BlendFactor::SecondarySrcColor:
+ return VK_BLEND_FACTOR_SRC1_COLOR;
+ case BlendFactor::InvSecondarySrcColor:
+ return VK_BLEND_FACTOR_ONE_MINUS_SRC1_COLOR;
+ case BlendFactor::SecondarySrcAlpha:
+ return VK_BLEND_FACTOR_SRC1_ALPHA;
+ case BlendFactor::InvSecondarySrcAlpha:
+ return VK_BLEND_FACTOR_ONE_MINUS_SRC1_ALPHA;
+
+ default:
+ assert(!"Unsupported blend factor");
+ return VK_BLEND_FACTOR_ONE;
+ }
+}
+
+VkBlendOp VulkanUtil::translateBlendOp(BlendOp op)
+{
+ switch (op)
+ {
+ case BlendOp::Add:
+ return VK_BLEND_OP_ADD;
+ case BlendOp::Subtract:
+ return VK_BLEND_OP_SUBTRACT;
+ case BlendOp::ReverseSubtract:
+ return VK_BLEND_OP_REVERSE_SUBTRACT;
+ case BlendOp::Min:
+ return VK_BLEND_OP_MIN;
+ case BlendOp::Max:
+ return VK_BLEND_OP_MAX;
+ default:
+ assert(!"Unsupported blend op");
+ return VK_BLEND_OP_ADD;
+ }
+}
+
+VkPrimitiveTopology VulkanUtil::translatePrimitiveTypeToListTopology(
+ PrimitiveType primitiveType)
+{
+ switch (primitiveType)
+ {
+ case PrimitiveType::Point:
+ return VK_PRIMITIVE_TOPOLOGY_POINT_LIST;
+ case PrimitiveType::Line:
+ return VK_PRIMITIVE_TOPOLOGY_LINE_LIST;
+ case PrimitiveType::Triangle:
+ return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
+ case PrimitiveType::Patch:
+ return VK_PRIMITIVE_TOPOLOGY_PATCH_LIST;
+ default:
+ assert(!"unknown topology type.");
+ return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
+ }
+}
+
+VkStencilOp VulkanUtil::translateStencilOp(StencilOp op)
+{
+ switch (op)
+ {
+ case StencilOp::DecrementSaturate:
+ return VK_STENCIL_OP_DECREMENT_AND_CLAMP;
+ case StencilOp::DecrementWrap:
+ return VK_STENCIL_OP_DECREMENT_AND_WRAP;
+ case StencilOp::IncrementSaturate:
+ return VK_STENCIL_OP_INCREMENT_AND_CLAMP;
+ case StencilOp::IncrementWrap:
+ return VK_STENCIL_OP_INCREMENT_AND_WRAP;
+ case StencilOp::Invert:
+ return VK_STENCIL_OP_INVERT;
+ case StencilOp::Keep:
+ return VK_STENCIL_OP_KEEP;
+ case StencilOp::Replace:
+ return VK_STENCIL_OP_REPLACE;
+ case StencilOp::Zero:
+ return VK_STENCIL_OP_ZERO;
+ default:
+ return VK_STENCIL_OP_KEEP;
+ }
+}
+
+VkFilter VulkanUtil::translateFilterMode(TextureFilteringMode mode)
+{
+ switch (mode)
+ {
+ default:
+ return VkFilter(0);
+
+#define CASE(SRC, DST) \
+ case TextureFilteringMode::SRC: \
+ return VK_FILTER_##DST
+
+ CASE(Point, NEAREST);
+ CASE(Linear, LINEAR);
+
+#undef CASE
+ }
+}
+
+VkSamplerMipmapMode VulkanUtil::translateMipFilterMode(TextureFilteringMode mode)
+{
+ switch (mode)
+ {
+ default:
+ return VkSamplerMipmapMode(0);
+
+#define CASE(SRC, DST) \
+ case TextureFilteringMode::SRC: \
+ return VK_SAMPLER_MIPMAP_MODE_##DST
+
+ CASE(Point, NEAREST);
+ CASE(Linear, LINEAR);
+
+#undef CASE
+ }
+}
+
+VkSamplerAddressMode VulkanUtil::translateAddressingMode(TextureAddressingMode mode)
+{
+ switch (mode)
+ {
+ default:
+ return VkSamplerAddressMode(0);
+
+#define CASE(SRC, DST) \
+ case TextureAddressingMode::SRC: \
+ return VK_SAMPLER_ADDRESS_MODE_##DST
+
+ CASE(Wrap, REPEAT);
+ CASE(ClampToEdge, CLAMP_TO_EDGE);
+ CASE(ClampToBorder, CLAMP_TO_BORDER);
+ CASE(MirrorRepeat, MIRRORED_REPEAT);
+ CASE(MirrorOnce, MIRROR_CLAMP_TO_EDGE);
+
+#undef CASE
+ }
+}
+
+VkCompareOp VulkanUtil::translateComparisonFunc(ComparisonFunc func)
+{
+ switch (func)
+ {
+ default:
+ // TODO: need to report failures
+ return VK_COMPARE_OP_ALWAYS;
+
+#define CASE(FROM, TO) \
+ case ComparisonFunc::FROM: \
+ return VK_COMPARE_OP_##TO
+
+ CASE(Never, NEVER);
+ CASE(Less, LESS);
+ CASE(Equal, EQUAL);
+ CASE(LessEqual, LESS_OR_EQUAL);
+ CASE(Greater, GREATER);
+ CASE(NotEqual, NOT_EQUAL);
+ CASE(GreaterEqual, GREATER_OR_EQUAL);
+ CASE(Always, ALWAYS);
+#undef CASE
+ }
+}
+
+VkStencilOpState VulkanUtil::translateStencilState(DepthStencilOpDesc desc)
+{
+ VkStencilOpState rs;
+ rs.compareMask = 0xFF;
+ rs.compareOp = translateComparisonFunc(desc.stencilFunc);
+ rs.depthFailOp = translateStencilOp(desc.stencilDepthFailOp);
+ rs.failOp = translateStencilOp(desc.stencilFailOp);
+ rs.passOp = translateStencilOp(desc.stencilPassOp);
+ rs.reference = 0;
+ rs.writeMask = 0xFF;
+ return rs;
+}
+
/* static */Slang::Result VulkanUtil::handleFail(VkResult res)
{
if (res != VK_SUCCESS)
diff --git a/tools/gfx/vulkan/vk-util.h b/tools/gfx/vulkan/vk-util.h
index 4c32b1615..05533dad1 100644
--- a/tools/gfx/vulkan/vk-util.h
+++ b/tools/gfx/vulkan/vk-util.h
@@ -72,6 +72,34 @@ struct VulkanUtil
}
return false;
}
+
+ static VkSampleCountFlagBits translateSampleCount(uint32_t sampleCount);
+
+ static VkCullModeFlags translateCullMode(CullMode cullMode);
+
+ static VkFrontFace translateFrontFaceMode(FrontFaceMode frontFaceMode);
+
+ static VkPolygonMode translateFillMode(FillMode fillMode);
+
+ static VkBlendFactor translateBlendFactor(BlendFactor blendFactor);
+
+ static VkBlendOp translateBlendOp(BlendOp op);
+
+ static VkPrimitiveTopology translatePrimitiveTypeToListTopology(
+ PrimitiveType primitiveType);
+
+ static VkStencilOp translateStencilOp(StencilOp op);
+
+ static VkFilter translateFilterMode(TextureFilteringMode mode);
+
+ static VkSamplerMipmapMode translateMipFilterMode(TextureFilteringMode mode);
+
+ static VkSamplerAddressMode translateAddressingMode(TextureAddressingMode mode);
+
+ static VkCompareOp translateComparisonFunc(ComparisonFunc func);
+
+ static VkStencilOpState translateStencilState(DepthStencilOpDesc desc);
+
};
struct AccelerationStructureBuildGeometryInfoBuilder