summaryrefslogtreecommitdiffstats
path: root/tools/gfx/cuda/render-cuda.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-02-05 14:36:07 -0800
committerGitHub <noreply@github.com>2021-02-05 14:36:07 -0800
commitdf7548ef62c02b9ab1cc5addecaa6b6c150f2750 (patch)
tree17081a8d5de3fd3292043aae6761d0c8960e6783 /tools/gfx/cuda/render-cuda.cpp
parent5fbaccfc1d4ac7d17d528de894d1f276e41d9ce1 (diff)
Shader-Object example (#1694)
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
-rw-r--r--tools/gfx/cuda/render-cuda.cpp69
1 files changed, 36 insertions, 33 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp
index 7d7ee8eb9..4f87bdfc9 100644
--- a/tools/gfx/cuda/render-cuda.cpp
+++ b/tools/gfx/cuda/render-cuda.cpp
@@ -242,36 +242,6 @@ public:
RefPtr<TextureCUDAResource> textureResource = nullptr;
};
-class CUDAProgramLayout;
-
-class CUDAShaderProgram : public ShaderProgramBase
-{
-public:
- CUmodule cudaModule = nullptr;
- CUfunction cudaKernel;
- String kernelName;
- RefPtr<CUDAProgramLayout> layout;
-
- ~CUDAShaderProgram()
- {
- if (cudaModule)
- cuModuleUnload(cudaModule);
- }
-};
-
-class CUDAPipelineState : public PipelineStateBase
-{
-public:
- RefPtr<CUDAShaderProgram> shaderProgram;
- void init(const ComputePipelineStateDesc& inDesc)
- {
- PipelineStateDesc pipelineDesc;
- pipelineDesc.type = PipelineType::Compute;
- pipelineDesc.compute = inDesc;
- initializeBase(pipelineDesc);
- }
-};
-
class CUDAShaderObjectLayout : public ShaderObjectLayoutBase
{
public:
@@ -578,8 +548,6 @@ public:
// TODO: the logic here is a copy-paste of `GraphicsCommonShaderObject::collectSpecializationArgs`,
// consider moving the implementation to `ShaderObjectBase` and share the logic among different implementations.
- if (!m_bindingFinalized)
- return SLANG_FAIL;
auto& subObjectRanges = getLayout()->subObjectRanges;
// The following logic is built on the assumption that all fields that involve existential types (and
// therefore require specialization) will results in a sub-object range in the type layout.
@@ -677,7 +645,42 @@ public:
entryPointObjects[index]->addRef();
return SLANG_OK;
}
+ virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
+ {
+ SLANG_RETURN_ON_FAIL(CUDAShaderObject::collectSpecializationArgs(args));
+ for (auto& entryPoint : entryPointObjects)
+ {
+ SLANG_RETURN_ON_FAIL(entryPoint->collectSpecializationArgs(args));
+ }
+ return SLANG_OK;
+ }
+};
+class CUDAShaderProgram : public ShaderProgramBase
+{
+public:
+ CUmodule cudaModule = nullptr;
+ CUfunction cudaKernel;
+ String kernelName;
+ RefPtr<CUDAProgramLayout> layout;
+ ~CUDAShaderProgram()
+ {
+ if (cudaModule)
+ cuModuleUnload(cudaModule);
+ }
+};
+
+class CUDAPipelineState : public PipelineStateBase
+{
+public:
+ RefPtr<CUDAShaderProgram> shaderProgram;
+ void init(const ComputePipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Compute;
+ pipelineDesc.compute = inDesc;
+ initializeBase(pipelineDesc);
+ }
};
class CUDARenderer : public RendererBase
@@ -802,7 +805,6 @@ private:
CUcontext m_context = nullptr;
RefPtr<CUDAPipelineState> currentPipeline = nullptr;
RefPtr<CUDARootShaderObject> currentRootObject = nullptr;
- SlangContext slangContext;
public:
~CUDARenderer()
{
@@ -1332,6 +1334,7 @@ private:
{
RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram();
cudaProgram->slangProgram = desc.slangProgram;
+ cudaProgram->layout = new CUDAProgramLayout(this, desc.slangProgram->getLayout());
*outProgram = cudaProgram.detach();
return SLANG_OK;
}