From f114433debfba67cbe1db239b6e92278d41ed438 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 19 Jul 2024 11:49:42 -0700 Subject: Support parameter block in metal shader objects. (#4671) * Support parameter block in metal shader objects. * Ingore parameter block tests on devices without tier2 argument buffer. * Fix warning. * Fix texture subscript test. --------- Co-authored-by: Yong He --- tools/gfx/metal/metal-command-encoder.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) (limited to 'tools/gfx/metal/metal-command-encoder.cpp') diff --git a/tools/gfx/metal/metal-command-encoder.cpp b/tools/gfx/metal/metal-command-encoder.cpp index 2447295c4..442c216aa 100644 --- a/tools/gfx/metal/metal-command-encoder.cpp +++ b/tools/gfx/metal/metal-command-encoder.cpp @@ -478,18 +478,28 @@ Result ComputeCommandEncoder::bindPipelineWithRootObject( Result ComputeCommandEncoder::dispatchCompute(int x, int y, int z) { - auto pipeline = static_cast(m_currentPipeline.Ptr()); - pipeline->ensureAPIPipelineStateCreated(); - MTL::ComputeCommandEncoder* encoder = m_commandBuffer->getMetalComputeCommandEncoder(); - encoder->setComputePipelineState(pipeline->m_computePipelineState.get()); ComputeBindingContext bindingContext; bindingContext.init(m_commandBuffer->m_device, encoder); auto program = static_cast(m_currentPipeline->m_program.get()); m_commandBuffer->m_rootObject.bindAsRoot(&bindingContext, program->m_rootObjectLayout); - encoder->dispatchThreadgroups(MTL::Size(x, y, z), pipeline->m_threadGroupSize); + auto pipeline = static_cast(m_currentPipeline.Ptr()); + RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootObject; + RefPtr newPipeline; + SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->maybeSpecializePipeline( + m_currentPipeline, rootObjectImpl, newPipeline)); + PipelineStateImpl* newPipelineImpl = static_cast(newPipeline.Ptr()); + + SLANG_RETURN_ON_FAIL(newPipelineImpl->ensureAPIPipelineStateCreated()); + m_currentPipeline = newPipelineImpl; + + m_currentPipeline->ensureAPIPipelineStateCreated(); + encoder->setComputePipelineState(m_currentPipeline->m_computePipelineState.get()); + + + encoder->dispatchThreadgroups(MTL::Size(x, y, z), m_currentPipeline->m_threadGroupSize); return SLANG_OK; } -- cgit v1.2.3