summaryrefslogtreecommitdiffstats
path: root/tools/gfx/metal/metal-command-encoder.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-19 11:49:42 -0700
committerGitHub <noreply@github.com>2024-07-19 11:49:42 -0700
commitf114433debfba67cbe1db239b6e92278d41ed438 (patch)
tree3a8ff78deb657d203c87bd22bc2ee83575e834f6 /tools/gfx/metal/metal-command-encoder.cpp
parentadf758c8c4032afcd96d995840bd697d2adef34c (diff)
Support parameter block in metal shader objects. (#4671)
* Support parameter block in metal shader objects. * Ingore parameter block tests on devices without tier2 argument buffer. * Fix warning. * Fix texture subscript test. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tools/gfx/metal/metal-command-encoder.cpp')
-rw-r--r--tools/gfx/metal/metal-command-encoder.cpp20
1 files changed, 15 insertions, 5 deletions
diff --git a/tools/gfx/metal/metal-command-encoder.cpp b/tools/gfx/metal/metal-command-encoder.cpp
index 2447295c4..442c216aa 100644
--- a/tools/gfx/metal/metal-command-encoder.cpp
+++ b/tools/gfx/metal/metal-command-encoder.cpp
@@ -478,18 +478,28 @@ Result ComputeCommandEncoder::bindPipelineWithRootObject(
Result ComputeCommandEncoder::dispatchCompute(int x, int y, int z)
{
- auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr());
- pipeline->ensureAPIPipelineStateCreated();
-
MTL::ComputeCommandEncoder* encoder = m_commandBuffer->getMetalComputeCommandEncoder();
- encoder->setComputePipelineState(pipeline->m_computePipelineState.get());
ComputeBindingContext bindingContext;
bindingContext.init(m_commandBuffer->m_device, encoder);
auto program = static_cast<ShaderProgramImpl*>(m_currentPipeline->m_program.get());
m_commandBuffer->m_rootObject.bindAsRoot(&bindingContext, program->m_rootObjectLayout);
- encoder->dispatchThreadgroups(MTL::Size(x, y, z), pipeline->m_threadGroupSize);
+ auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr());
+ RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootObject;
+ RefPtr<PipelineStateBase> newPipeline;
+ SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->maybeSpecializePipeline(
+ m_currentPipeline, rootObjectImpl, newPipeline));
+ PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr());
+
+ SLANG_RETURN_ON_FAIL(newPipelineImpl->ensureAPIPipelineStateCreated());
+ m_currentPipeline = newPipelineImpl;
+
+ m_currentPipeline->ensureAPIPipelineStateCreated();
+ encoder->setComputePipelineState(m_currentPipeline->m_computePipelineState.get());
+
+
+ encoder->dispatchThreadgroups(MTL::Size(x, y, z), m_currentPipeline->m_threadGroupSize);
return SLANG_OK;
}