diff options
| author | Yong He <yonghe@outlook.com> | 2021-02-05 14:36:07 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-02-05 14:36:07 -0800 |
| commit | df7548ef62c02b9ab1cc5addecaa6b6c150f2750 (patch) | |
| tree | 17081a8d5de3fd3292043aae6761d0c8960e6783 /tools/gfx/cuda/render-cuda.cpp | |
| parent | 5fbaccfc1d4ac7d17d528de894d1f276e41d9ce1 (diff) | |
Shader-Object example (#1694)
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
| -rw-r--r-- | tools/gfx/cuda/render-cuda.cpp | 69 |
1 files changed, 36 insertions, 33 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp index 7d7ee8eb9..4f87bdfc9 100644 --- a/tools/gfx/cuda/render-cuda.cpp +++ b/tools/gfx/cuda/render-cuda.cpp @@ -242,36 +242,6 @@ public: RefPtr<TextureCUDAResource> textureResource = nullptr; }; -class CUDAProgramLayout; - -class CUDAShaderProgram : public ShaderProgramBase -{ -public: - CUmodule cudaModule = nullptr; - CUfunction cudaKernel; - String kernelName; - RefPtr<CUDAProgramLayout> layout; - - ~CUDAShaderProgram() - { - if (cudaModule) - cuModuleUnload(cudaModule); - } -}; - -class CUDAPipelineState : public PipelineStateBase -{ -public: - RefPtr<CUDAShaderProgram> shaderProgram; - void init(const ComputePipelineStateDesc& inDesc) - { - PipelineStateDesc pipelineDesc; - pipelineDesc.type = PipelineType::Compute; - pipelineDesc.compute = inDesc; - initializeBase(pipelineDesc); - } -}; - class CUDAShaderObjectLayout : public ShaderObjectLayoutBase { public: @@ -578,8 +548,6 @@ public: // TODO: the logic here is a copy-paste of `GraphicsCommonShaderObject::collectSpecializationArgs`, // consider moving the implementation to `ShaderObjectBase` and share the logic among different implementations. - if (!m_bindingFinalized) - return SLANG_FAIL; auto& subObjectRanges = getLayout()->subObjectRanges; // The following logic is built on the assumption that all fields that involve existential types (and // therefore require specialization) will results in a sub-object range in the type layout. @@ -677,7 +645,42 @@ public: entryPointObjects[index]->addRef(); return SLANG_OK; } + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override + { + SLANG_RETURN_ON_FAIL(CUDAShaderObject::collectSpecializationArgs(args)); + for (auto& entryPoint : entryPointObjects) + { + SLANG_RETURN_ON_FAIL(entryPoint->collectSpecializationArgs(args)); + } + return SLANG_OK; + } +}; +class CUDAShaderProgram : public ShaderProgramBase +{ +public: + CUmodule cudaModule = nullptr; + CUfunction cudaKernel; + String kernelName; + RefPtr<CUDAProgramLayout> layout; + ~CUDAShaderProgram() + { + if (cudaModule) + cuModuleUnload(cudaModule); + } +}; + +class CUDAPipelineState : public PipelineStateBase +{ +public: + RefPtr<CUDAShaderProgram> shaderProgram; + void init(const ComputePipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Compute; + pipelineDesc.compute = inDesc; + initializeBase(pipelineDesc); + } }; class CUDARenderer : public RendererBase @@ -802,7 +805,6 @@ private: CUcontext m_context = nullptr; RefPtr<CUDAPipelineState> currentPipeline = nullptr; RefPtr<CUDARootShaderObject> currentRootObject = nullptr; - SlangContext slangContext; public: ~CUDARenderer() { @@ -1332,6 +1334,7 @@ private: { RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram(); cudaProgram->slangProgram = desc.slangProgram; + cudaProgram->layout = new CUDAProgramLayout(this, desc.slangProgram->getLayout()); *outProgram = cudaProgram.detach(); return SLANG_OK; } |
