summaryrefslogtreecommitdiff
path: root/tools/gfx/metal/metal-command-encoder.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx/metal/metal-command-encoder.cpp')
-rw-r--r--tools/gfx/metal/metal-command-encoder.cpp20
1 files changed, 15 insertions, 5 deletions
diff --git a/tools/gfx/metal/metal-command-encoder.cpp b/tools/gfx/metal/metal-command-encoder.cpp
index 2447295c4..442c216aa 100644
--- a/tools/gfx/metal/metal-command-encoder.cpp
+++ b/tools/gfx/metal/metal-command-encoder.cpp
@@ -478,18 +478,28 @@ Result ComputeCommandEncoder::bindPipelineWithRootObject(
Result ComputeCommandEncoder::dispatchCompute(int x, int y, int z)
{
- auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr());
- pipeline->ensureAPIPipelineStateCreated();
-
MTL::ComputeCommandEncoder* encoder = m_commandBuffer->getMetalComputeCommandEncoder();
- encoder->setComputePipelineState(pipeline->m_computePipelineState.get());
ComputeBindingContext bindingContext;
bindingContext.init(m_commandBuffer->m_device, encoder);
auto program = static_cast<ShaderProgramImpl*>(m_currentPipeline->m_program.get());
m_commandBuffer->m_rootObject.bindAsRoot(&bindingContext, program->m_rootObjectLayout);
- encoder->dispatchThreadgroups(MTL::Size(x, y, z), pipeline->m_threadGroupSize);
+ auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr());
+ RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootObject;
+ RefPtr<PipelineStateBase> newPipeline;
+ SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->maybeSpecializePipeline(
+ m_currentPipeline, rootObjectImpl, newPipeline));
+ PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr());
+
+ SLANG_RETURN_ON_FAIL(newPipelineImpl->ensureAPIPipelineStateCreated());
+ m_currentPipeline = newPipelineImpl;
+
+ m_currentPipeline->ensureAPIPipelineStateCreated();
+ encoder->setComputePipelineState(m_currentPipeline->m_computePipelineState.get());
+
+
+ encoder->dispatchThreadgroups(MTL::Size(x, y, z), m_currentPipeline->m_threadGroupSize);
return SLANG_OK;
}