From df7548ef62c02b9ab1cc5addecaa6b6c150f2750 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 5 Feb 2021 14:36:07 -0800 Subject: Shader-Object example (#1694) --- tools/gfx/cuda/render-cuda.cpp | 69 ++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 33 deletions(-) (limited to 'tools/gfx/cuda/render-cuda.cpp') diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp index 7d7ee8eb9..4f87bdfc9 100644 --- a/tools/gfx/cuda/render-cuda.cpp +++ b/tools/gfx/cuda/render-cuda.cpp @@ -242,36 +242,6 @@ public: RefPtr textureResource = nullptr; }; -class CUDAProgramLayout; - -class CUDAShaderProgram : public ShaderProgramBase -{ -public: - CUmodule cudaModule = nullptr; - CUfunction cudaKernel; - String kernelName; - RefPtr layout; - - ~CUDAShaderProgram() - { - if (cudaModule) - cuModuleUnload(cudaModule); - } -}; - -class CUDAPipelineState : public PipelineStateBase -{ -public: - RefPtr shaderProgram; - void init(const ComputePipelineStateDesc& inDesc) - { - PipelineStateDesc pipelineDesc; - pipelineDesc.type = PipelineType::Compute; - pipelineDesc.compute = inDesc; - initializeBase(pipelineDesc); - } -}; - class CUDAShaderObjectLayout : public ShaderObjectLayoutBase { public: @@ -578,8 +548,6 @@ public: // TODO: the logic here is a copy-paste of `GraphicsCommonShaderObject::collectSpecializationArgs`, // consider moving the implementation to `ShaderObjectBase` and share the logic among different implementations. - if (!m_bindingFinalized) - return SLANG_FAIL; auto& subObjectRanges = getLayout()->subObjectRanges; // The following logic is built on the assumption that all fields that involve existential types (and // therefore require specialization) will results in a sub-object range in the type layout. @@ -677,7 +645,42 @@ public: entryPointObjects[index]->addRef(); return SLANG_OK; } + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override + { + SLANG_RETURN_ON_FAIL(CUDAShaderObject::collectSpecializationArgs(args)); + for (auto& entryPoint : entryPointObjects) + { + SLANG_RETURN_ON_FAIL(entryPoint->collectSpecializationArgs(args)); + } + return SLANG_OK; + } +}; +class CUDAShaderProgram : public ShaderProgramBase +{ +public: + CUmodule cudaModule = nullptr; + CUfunction cudaKernel; + String kernelName; + RefPtr layout; + ~CUDAShaderProgram() + { + if (cudaModule) + cuModuleUnload(cudaModule); + } +}; + +class CUDAPipelineState : public PipelineStateBase +{ +public: + RefPtr shaderProgram; + void init(const ComputePipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Compute; + pipelineDesc.compute = inDesc; + initializeBase(pipelineDesc); + } }; class CUDARenderer : public RendererBase @@ -802,7 +805,6 @@ private: CUcontext m_context = nullptr; RefPtr currentPipeline = nullptr; RefPtr currentRootObject = nullptr; - SlangContext slangContext; public: ~CUDARenderer() { @@ -1332,6 +1334,7 @@ private: { RefPtr cudaProgram = new CUDAShaderProgram(); cudaProgram->slangProgram = desc.slangProgram; + cudaProgram->layout = new CUDAProgramLayout(this, desc.slangProgram->getLayout()); *outProgram = cudaProgram.detach(); return SLANG_OK; } -- cgit v1.2.3