diff options
| author | Yong He <yonghe@outlook.com> | 2024-07-19 11:49:42 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-19 11:49:42 -0700 |
| commit | f114433debfba67cbe1db239b6e92278d41ed438 (patch) | |
| tree | 3a8ff78deb657d203c87bd22bc2ee83575e834f6 /tools/gfx/metal/metal-command-encoder.cpp | |
| parent | adf758c8c4032afcd96d995840bd697d2adef34c (diff) | |
Support parameter block in metal shader objects. (#4671)
* Support parameter block in metal shader objects.
* Ingore parameter block tests on devices without tier2 argument buffer.
* Fix warning.
* Fix texture subscript test.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tools/gfx/metal/metal-command-encoder.cpp')
| -rw-r--r-- | tools/gfx/metal/metal-command-encoder.cpp | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/tools/gfx/metal/metal-command-encoder.cpp b/tools/gfx/metal/metal-command-encoder.cpp index 2447295c4..442c216aa 100644 --- a/tools/gfx/metal/metal-command-encoder.cpp +++ b/tools/gfx/metal/metal-command-encoder.cpp @@ -478,18 +478,28 @@ Result ComputeCommandEncoder::bindPipelineWithRootObject( Result ComputeCommandEncoder::dispatchCompute(int x, int y, int z) { - auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr()); - pipeline->ensureAPIPipelineStateCreated(); - MTL::ComputeCommandEncoder* encoder = m_commandBuffer->getMetalComputeCommandEncoder(); - encoder->setComputePipelineState(pipeline->m_computePipelineState.get()); ComputeBindingContext bindingContext; bindingContext.init(m_commandBuffer->m_device, encoder); auto program = static_cast<ShaderProgramImpl*>(m_currentPipeline->m_program.get()); m_commandBuffer->m_rootObject.bindAsRoot(&bindingContext, program->m_rootObjectLayout); - encoder->dispatchThreadgroups(MTL::Size(x, y, z), pipeline->m_threadGroupSize); + auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr()); + RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootObject; + RefPtr<PipelineStateBase> newPipeline; + SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->maybeSpecializePipeline( + m_currentPipeline, rootObjectImpl, newPipeline)); + PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr()); + + SLANG_RETURN_ON_FAIL(newPipelineImpl->ensureAPIPipelineStateCreated()); + m_currentPipeline = newPipelineImpl; + + m_currentPipeline->ensureAPIPipelineStateCreated(); + encoder->setComputePipelineState(m_currentPipeline->m_computePipelineState.get()); + + + encoder->dispatchThreadgroups(MTL::Size(x, y, z), m_currentPipeline->m_threadGroupSize); return SLANG_OK; } |
