diff options
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
| -rw-r--r-- | tools/gfx/cuda/render-cuda.cpp | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp index 0e4ee6c13..d1e320224 100644 --- a/tools/gfx/cuda/render-cuda.cpp +++ b/tools/gfx/cuda/render-cuda.cpp @@ -241,6 +241,8 @@ public: RefPtr<TextureCUDAResource> textureResource = nullptr; }; +class CUDAProgramLayout; + class CUDAShaderProgram : public IShaderProgram, public RefObject { public: @@ -255,6 +257,9 @@ public: CUmodule cudaModule = nullptr; CUfunction cudaKernel; String kernelName; + ComPtr<slang::IComponentType> slangProgram; + RefPtr<CUDAProgramLayout> layout; + ~CUDAShaderProgram() { if (cudaModule) @@ -1260,16 +1265,6 @@ private: return SLANG_OK; } - virtual SLANG_NO_THROW Result SLANG_MCALL createRootShaderObjectLayout( - slang::ProgramLayout* layout, IShaderObjectLayout** outLayout) override - { - RefPtr<CUDAProgramLayout> cudaLayout; - cudaLayout = new CUDAProgramLayout(layout); - cudaLayout->programLayout = layout; - *outLayout = cudaLayout.detach(); - return SLANG_OK; - } - virtual SLANG_NO_THROW Result SLANG_MCALL createShaderObject(IShaderObjectLayout* layout, IShaderObject** outObject) override { @@ -1280,10 +1275,13 @@ private: } virtual SLANG_NO_THROW Result SLANG_MCALL - createRootShaderObject(IShaderObjectLayout* layout, IShaderObject** outObject) override + createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override { + auto cudaProgram = dynamic_cast<CUDAShaderProgram*>(program); + auto cudaLayout = cudaProgram->layout; + RefPtr<CUDARootShaderObject> result = new CUDARootShaderObject(); - SLANG_RETURN_ON_FAIL(result->init(this, dynamic_cast<CUDAShaderObjectLayout*>(layout))); + SLANG_RETURN_ON_FAIL(result->init(this, cudaLayout)); *outObject = result.detach(); return SLANG_OK; } @@ -1300,6 +1298,11 @@ private: virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override { + if( desc.kernelCount == 0 ) + { + return createProgramFromSlang(this, desc, outProgram); + } + if (desc.kernelCount != 1) return SLANG_E_INVALID_ARG; RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram(); @@ -1307,6 +1310,22 @@ private: SLANG_CUDA_RETURN_ON_FAIL( cuModuleGetFunction(&cudaProgram->cudaKernel, cudaProgram->cudaModule, desc.kernels[0].entryPointName)); cudaProgram->kernelName = desc.kernels[0].entryPointName; + + auto slangProgram = desc.slangProgram; + if( slangProgram ) + { + cudaProgram->slangProgram = slangProgram; + + auto slangProgramLayout = slangProgram->getLayout(); + if(!slangProgramLayout) + return SLANG_FAIL; + + RefPtr<CUDAProgramLayout> cudaLayout; + cudaLayout = new CUDAProgramLayout(slangProgramLayout); + cudaLayout->programLayout = slangProgramLayout; + cudaProgram->layout = cudaLayout; + } + *outProgram = cudaProgram.detach(); return SLANG_OK; } |
