summaryrefslogtreecommitdiffstats
path: root/tools/gfx/cuda/render-cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
-rw-r--r--tools/gfx/cuda/render-cuda.cpp45
1 files changed, 23 insertions, 22 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp
index fccc1e3f0..ed22495cb 100644
--- a/tools/gfx/cuda/render-cuda.cpp
+++ b/tools/gfx/cuda/render-cuda.cpp
@@ -12,6 +12,7 @@
#include "slang-com-helper.h"
#include "../command-writer.h"
#include "../renderer-shared.h"
+#include "../mutable-shader-object.h"
#include "../simple-transient-resource-heap.h"
#include "../slang-context.h"
@@ -659,6 +660,9 @@ public:
}
};
+class CUDAMutableShaderObject : public MutableShaderObject< CUDAMutableShaderObject, CUDAShaderObjectLayout>
+{};
+
class CUDAEntryPointShaderObject : public CUDAShaderObject
{
public:
@@ -668,11 +672,8 @@ public:
class CUDARootShaderObject : public CUDAShaderObject
{
public:
- // Override default reference counting behavior to disable lifetime management.
- // Root objects are managed by command buffer and does not need to be freed by the user.
- SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; }
- SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; }
-
+ virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; }
+ virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; }
public:
List<RefPtr<CUDAEntryPointShaderObject>> entryPointObjects;
virtual SLANG_NO_THROW Result SLANG_MCALL
@@ -723,17 +724,9 @@ public:
}
};
-class CUDAQueryPool : public IQueryPool, public ComObject
+class CUDAQueryPool : public QueryPoolBase
{
public:
- SLANG_COM_OBJECT_IUNKNOWN_ALL;
- IQueryPool* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IQueryPool)
- return static_cast<IQueryPool*>(this);
- return nullptr;
- }
-public:
// The event object for each query. Owned by the pool.
List<CUevent> m_events;
@@ -1001,7 +994,6 @@ public:
ResourceState src,
ResourceState dst) override
{
- assert(!"Unimplemented");
}
virtual SLANG_NO_THROW void SLANG_MCALL bufferBarrier(
@@ -1010,7 +1002,6 @@ public:
ResourceState src,
ResourceState dst) override
{
- assert(!"Unimplemented");
}
virtual SLANG_NO_THROW void SLANG_MCALL
@@ -1236,11 +1227,11 @@ public:
switch (cmd.name)
{
case CommandName::SetPipelineState:
- setPipelineState(commandBuffer->getObject<IPipelineState>(cmd.operands[0]));
+ setPipelineState(commandBuffer->getObject<PipelineStateBase>(cmd.operands[0]));
break;
case CommandName::BindRootShaderObject:
bindRootShaderObject(
- commandBuffer->getObject<IShaderObject>(cmd.operands[0]));
+ commandBuffer->getObject<ShaderObjectBase>(cmd.operands[0]));
break;
case CommandName::DispatchCompute:
dispatchCompute(
@@ -1248,22 +1239,22 @@ public:
break;
case CommandName::CopyBuffer:
copyBuffer(
- commandBuffer->getObject<IBufferResource>(cmd.operands[0]),
+ commandBuffer->getObject<BufferResource>(cmd.operands[0]),
cmd.operands[1],
- commandBuffer->getObject<IBufferResource>(cmd.operands[2]),
+ commandBuffer->getObject<BufferResource>(cmd.operands[2]),
cmd.operands[3],
cmd.operands[4]);
break;
case CommandName::UploadBufferData:
uploadBufferData(
- commandBuffer->getObject<IBufferResource>(cmd.operands[0]),
+ commandBuffer->getObject<BufferResource>(cmd.operands[0]),
cmd.operands[1],
cmd.operands[2],
commandBuffer->getData<uint8_t>(cmd.operands[3]));
break;
case CommandName::WriteTimestamp:
writeTimestamp(
- commandBuffer->getObject<IQueryPool>(cmd.operands[0]),
+ commandBuffer->getObject<QueryPoolBase>(cmd.operands[0]),
(SlangInt)cmd.operands[1]);
}
}
@@ -1816,6 +1807,16 @@ public:
return SLANG_OK;
}
+ virtual Result createMutableShaderObject(
+ ShaderObjectLayoutBase* layout,
+ IShaderObject** outObject) override
+ {
+ RefPtr<CUDAMutableShaderObject> result = new CUDAMutableShaderObject();
+ SLANG_RETURN_ON_FAIL(result->init(this, dynamic_cast<CUDAShaderObjectLayout*>(layout)));
+ returnComPtr(outObject, result);
+ return SLANG_OK;
+ }
+
Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject)
{
auto cudaProgram = dynamic_cast<CUDAShaderProgram*>(program);