diff options
Diffstat (limited to 'tools/gfx/metal/metal-command-encoder.cpp')
| -rw-r--r-- | tools/gfx/metal/metal-command-encoder.cpp | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/tools/gfx/metal/metal-command-encoder.cpp b/tools/gfx/metal/metal-command-encoder.cpp index 2447295c4..442c216aa 100644 --- a/tools/gfx/metal/metal-command-encoder.cpp +++ b/tools/gfx/metal/metal-command-encoder.cpp @@ -478,18 +478,28 @@ Result ComputeCommandEncoder::bindPipelineWithRootObject( Result ComputeCommandEncoder::dispatchCompute(int x, int y, int z) { - auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr()); - pipeline->ensureAPIPipelineStateCreated(); - MTL::ComputeCommandEncoder* encoder = m_commandBuffer->getMetalComputeCommandEncoder(); - encoder->setComputePipelineState(pipeline->m_computePipelineState.get()); ComputeBindingContext bindingContext; bindingContext.init(m_commandBuffer->m_device, encoder); auto program = static_cast<ShaderProgramImpl*>(m_currentPipeline->m_program.get()); m_commandBuffer->m_rootObject.bindAsRoot(&bindingContext, program->m_rootObjectLayout); - encoder->dispatchThreadgroups(MTL::Size(x, y, z), pipeline->m_threadGroupSize); + auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr()); + RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootObject; + RefPtr<PipelineStateBase> newPipeline; + SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->maybeSpecializePipeline( + m_currentPipeline, rootObjectImpl, newPipeline)); + PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr()); + + SLANG_RETURN_ON_FAIL(newPipelineImpl->ensureAPIPipelineStateCreated()); + m_currentPipeline = newPipelineImpl; + + m_currentPipeline->ensureAPIPipelineStateCreated(); + encoder->setComputePipelineState(m_currentPipeline->m_computePipelineState.get()); + + + encoder->dispatchThreadgroups(MTL::Size(x, y, z), m_currentPipeline->m_threadGroupSize); return SLANG_OK; } |
