summaryrefslogtreecommitdiff
path: root/tools/gfx/cuda/render-cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
-rw-r--r--tools/gfx/cuda/render-cuda.cpp43
1 files changed, 31 insertions, 12 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp
index 0e4ee6c13..d1e320224 100644
--- a/tools/gfx/cuda/render-cuda.cpp
+++ b/tools/gfx/cuda/render-cuda.cpp
@@ -241,6 +241,8 @@ public:
RefPtr<TextureCUDAResource> textureResource = nullptr;
};
+class CUDAProgramLayout;
+
class CUDAShaderProgram : public IShaderProgram, public RefObject
{
public:
@@ -255,6 +257,9 @@ public:
CUmodule cudaModule = nullptr;
CUfunction cudaKernel;
String kernelName;
+ ComPtr<slang::IComponentType> slangProgram;
+ RefPtr<CUDAProgramLayout> layout;
+
~CUDAShaderProgram()
{
if (cudaModule)
@@ -1260,16 +1265,6 @@ private:
return SLANG_OK;
}
- virtual SLANG_NO_THROW Result SLANG_MCALL createRootShaderObjectLayout(
- slang::ProgramLayout* layout, IShaderObjectLayout** outLayout) override
- {
- RefPtr<CUDAProgramLayout> cudaLayout;
- cudaLayout = new CUDAProgramLayout(layout);
- cudaLayout->programLayout = layout;
- *outLayout = cudaLayout.detach();
- return SLANG_OK;
- }
-
virtual SLANG_NO_THROW Result SLANG_MCALL
createShaderObject(IShaderObjectLayout* layout, IShaderObject** outObject) override
{
@@ -1280,10 +1275,13 @@ private:
}
virtual SLANG_NO_THROW Result SLANG_MCALL
- createRootShaderObject(IShaderObjectLayout* layout, IShaderObject** outObject) override
+ createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override
{
+ auto cudaProgram = dynamic_cast<CUDAShaderProgram*>(program);
+ auto cudaLayout = cudaProgram->layout;
+
RefPtr<CUDARootShaderObject> result = new CUDARootShaderObject();
- SLANG_RETURN_ON_FAIL(result->init(this, dynamic_cast<CUDAShaderObjectLayout*>(layout)));
+ SLANG_RETURN_ON_FAIL(result->init(this, cudaLayout));
*outObject = result.detach();
return SLANG_OK;
}
@@ -1300,6 +1298,11 @@ private:
virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override
{
+ if( desc.kernelCount == 0 )
+ {
+ return createProgramFromSlang(this, desc, outProgram);
+ }
+
if (desc.kernelCount != 1)
return SLANG_E_INVALID_ARG;
RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram();
@@ -1307,6 +1310,22 @@ private:
SLANG_CUDA_RETURN_ON_FAIL(
cuModuleGetFunction(&cudaProgram->cudaKernel, cudaProgram->cudaModule, desc.kernels[0].entryPointName));
cudaProgram->kernelName = desc.kernels[0].entryPointName;
+
+ auto slangProgram = desc.slangProgram;
+ if( slangProgram )
+ {
+ cudaProgram->slangProgram = slangProgram;
+
+ auto slangProgramLayout = slangProgram->getLayout();
+ if(!slangProgramLayout)
+ return SLANG_FAIL;
+
+ RefPtr<CUDAProgramLayout> cudaLayout;
+ cudaLayout = new CUDAProgramLayout(slangProgramLayout);
+ cudaLayout->programLayout = slangProgramLayout;
+ cudaProgram->layout = cudaLayout;
+ }
+
*outProgram = cudaProgram.detach();
return SLANG_OK;
}