summaryrefslogtreecommitdiff
path: root/tools/gfx/cuda/render-cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
-rw-r--r--tools/gfx/cuda/render-cuda.cpp237
1 files changed, 130 insertions, 107 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp
index 057674550..7d7ee8eb9 100644
--- a/tools/gfx/cuda/render-cuda.cpp
+++ b/tools/gfx/cuda/render-cuda.cpp
@@ -244,21 +244,12 @@ public:
class CUDAProgramLayout;
-class CUDAShaderProgram : public IShaderProgram, public RefObject
+class CUDAShaderProgram : public ShaderProgramBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IShaderProgram* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram)
- return static_cast<IShaderProgram*>(this);
- return nullptr;
- }
-public:
CUmodule cudaModule = nullptr;
CUfunction cudaKernel;
String kernelName;
- ComPtr<slang::IComponentType> slangProgram;
RefPtr<CUDAProgramLayout> layout;
~CUDAShaderProgram()
@@ -268,33 +259,22 @@ public:
}
};
-class CUDAPipelineState : public IPipelineState, public RefObject
+class CUDAPipelineState : public PipelineStateBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IPipelineState* getInterface(const Guid& guid)
+ RefPtr<CUDAShaderProgram> shaderProgram;
+ void init(const ComputePipelineStateDesc& inDesc)
{
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState)
- return static_cast<IPipelineState*>(this);
- return nullptr;
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Compute;
+ pipelineDesc.compute = inDesc;
+ initializeBase(pipelineDesc);
}
-public:
- RefPtr<CUDAShaderProgram> shaderProgram;
};
-class CUDAShaderObjectLayout : public IShaderObjectLayout, public RefObject
+class CUDAShaderObjectLayout : public ShaderObjectLayoutBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IShaderObjectLayout* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout)
- return static_cast<IShaderObjectLayout*>(this);
- return nullptr;
- }
-public:
- slang::TypeLayoutReflection* typeLayout = nullptr;
-
struct BindingRangeInfo
{
slang::BindingType bindingType;
@@ -335,30 +315,32 @@ public:
}
}
- CUDAShaderObjectLayout(slang::TypeLayoutReflection* layout)
+ CUDAShaderObjectLayout(RendererBase* renderer, slang::TypeLayoutReflection* layout)
{
+ initBase(renderer, layout);
+
Index subObjectCount = 0;
- typeLayout = unwrapParameterGroups(layout);
+ m_elementTypeLayout = unwrapParameterGroups(layout);
// Compute the binding ranges that are used to store
// the logical contents of the object in memory. These will relate
// to the descriptor ranges in the various sets, but not always
// in a one-to-one fashion.
- SlangInt bindingRangeCount = typeLayout->getBindingRangeCount();
+ SlangInt bindingRangeCount = m_elementTypeLayout->getBindingRangeCount();
for (SlangInt r = 0; r < bindingRangeCount; ++r)
{
- slang::BindingType slangBindingType = typeLayout->getBindingRangeType(r);
- SlangInt count = typeLayout->getBindingRangeBindingCount(r);
+ slang::BindingType slangBindingType = m_elementTypeLayout->getBindingRangeType(r);
+ SlangInt count = m_elementTypeLayout->getBindingRangeBindingCount(r);
slang::TypeLayoutReflection* slangLeafTypeLayout =
- typeLayout->getBindingRangeLeafTypeLayout(r);
+ m_elementTypeLayout->getBindingRangeLeafTypeLayout(r);
- SlangInt descriptorSetIndex = typeLayout->getBindingRangeDescriptorSetIndex(r);
+ SlangInt descriptorSetIndex = m_elementTypeLayout->getBindingRangeDescriptorSetIndex(r);
SlangInt rangeIndexInDescriptorSet =
- typeLayout->getBindingRangeFirstDescriptorRangeIndex(r);
+ m_elementTypeLayout->getBindingRangeFirstDescriptorRangeIndex(r);
- auto uniformOffset = typeLayout->getDescriptorSetDescriptorRangeIndexOffset(
+ auto uniformOffset = m_elementTypeLayout->getDescriptorSetDescriptorRangeIndexOffset(
descriptorSetIndex, rangeIndexInDescriptorSet);
Index baseIndex = 0;
@@ -383,13 +365,13 @@ public:
m_bindingRanges.add(bindingRangeInfo);
}
- SlangInt subObjectRangeCount = typeLayout->getSubObjectRangeCount();
+ SlangInt subObjectRangeCount = m_elementTypeLayout->getSubObjectRangeCount();
for (SlangInt r = 0; r < subObjectRangeCount; ++r)
{
- SlangInt bindingRangeIndex = typeLayout->getSubObjectRangeBindingRangeIndex(r);
- auto slangBindingType = typeLayout->getBindingRangeType(bindingRangeIndex);
+ SlangInt bindingRangeIndex = m_elementTypeLayout->getSubObjectRangeBindingRangeIndex(r);
+ auto slangBindingType = m_elementTypeLayout->getBindingRangeType(bindingRangeIndex);
slang::TypeLayoutReflection* slangLeafTypeLayout =
- typeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex);
+ m_elementTypeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex);
// A sub-object range can either represent a sub-object of a known
// type, like a `ConstantBuffer<Foo>` or `ParameterBlock<Foo>`
@@ -402,7 +384,7 @@ public:
if (slangBindingType != slang::BindingType::ExistentialValue)
{
subObjectLayout =
- new CUDAShaderObjectLayout(slangLeafTypeLayout->getElementTypeLayout());
+ new CUDAShaderObjectLayout(renderer, slangLeafTypeLayout->getElementTypeLayout());
}
SubObjectRangeInfo subObjectRange;
@@ -418,13 +400,14 @@ class CUDAProgramLayout : public CUDAShaderObjectLayout
public:
slang::ProgramLayout* programLayout = nullptr;
List<RefPtr<CUDAShaderObjectLayout>> entryPointLayouts;
- CUDAProgramLayout(slang::ProgramLayout* inProgramLayout)
- : CUDAShaderObjectLayout(inProgramLayout->getGlobalParamsTypeLayout())
+ CUDAProgramLayout(RendererBase* renderer, slang::ProgramLayout* inProgramLayout)
+ : CUDAShaderObjectLayout(renderer, inProgramLayout->getGlobalParamsTypeLayout())
, programLayout(inProgramLayout)
{
for (UInt i =0; i< programLayout->getEntryPointCount(); i++)
{
entryPointLayouts.add(new CUDAShaderObjectLayout(
+ renderer,
programLayout->getEntryPointByIndex(i)->getTypeLayout()));
}
@@ -450,26 +433,21 @@ public:
}
};
-class CUDAShaderObject : public IShaderObject, public RefObject
+class CUDAShaderObject : public ShaderObjectBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IShaderObject* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject)
- return static_cast<IShaderObject*>(this);
- return nullptr;
- }
-
-public:
RefPtr<MemoryCUDAResource> bufferResource;
- RefPtr<CUDAShaderObjectLayout> layout;
List<RefPtr<CUDAShaderObject>> objects;
List<RefPtr<CUDAResourceView>> resources;
virtual SLANG_NO_THROW Result SLANG_MCALL
init(IRenderer* renderer, CUDAShaderObjectLayout* typeLayout);
+ CUDAShaderObjectLayout* getLayout()
+ {
+ return static_cast<CUDAShaderObjectLayout*>(m_layout.Ptr());
+ }
+
virtual SLANG_NO_THROW Result SLANG_MCALL initBuffer(IRenderer* renderer, size_t bufferSize)
{
BufferResource::Desc bufferDesc;
@@ -494,7 +472,7 @@ public:
virtual SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getElementTypeLayout() override
{
- return layout->typeLayout;
+ return getLayout()->getElementTypeLayout();
}
virtual SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() override { return 0; }
@@ -519,7 +497,7 @@ public:
getObject(ShaderOffset const& offset, IShaderObject** object)
{
auto subObjectIndex =
- layout->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex;
+ getLayout()->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex;
if (subObjectIndex >= objects.getCount())
{
*object = nullptr;
@@ -533,10 +511,10 @@ public:
setObject(ShaderOffset const& offset, IShaderObject* object)
{
auto subObjectIndex =
- layout->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex;
+ getLayout()->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex;
SLANG_ASSERT(
offset.uniformOffset ==
- layout->m_bindingRanges[offset.bindingRangeIndex].uniformOffset +
+ getLayout()->m_bindingRanges[offset.bindingRangeIndex].uniformOffset +
offset.bindingArrayIndex * sizeof(void*));
auto cudaObject = dynamic_cast<CUDAShaderObject*>(object);
if (subObjectIndex >= objects.getCount())
@@ -593,6 +571,56 @@ public:
setResource(offset, textureView);
return SLANG_OK;
}
+
+ // Appends all types that are used to specialize the element type of this shader object in `args` list.
+ virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
+ {
+ // TODO: the logic here is a copy-paste of `GraphicsCommonShaderObject::collectSpecializationArgs`,
+ // consider moving the implementation to `ShaderObjectBase` and share the logic among different implementations.
+
+ if (!m_bindingFinalized)
+ return SLANG_FAIL;
+ auto& subObjectRanges = getLayout()->subObjectRanges;
+ // The following logic is built on the assumption that all fields that involve existential types (and
+ // therefore require specialization) will results in a sub-object range in the type layout.
+ // This allows us to simply scan the sub-object ranges to find out all specialization arguments.
+ for (Index subObjIndex = 0; subObjIndex < subObjectRanges.getCount(); subObjIndex++)
+ {
+ // Retrieve the corresponding binding range of the sub object.
+ auto bindingRange = getLayout()->m_bindingRanges[subObjectRanges[subObjIndex].bindingRangeIndex];
+ switch (bindingRange.bindingType)
+ {
+ case slang::BindingType::ExistentialValue:
+ {
+ // A binding type of `ExistentialValue` means the sub-object represents a interface-typed field.
+ // In this case the specialization argument for this field is the actual specialized type of the bound
+ // shader object. If the shader object's type is an ordinary type without existential fields, then the
+ // type argument will simply be the ordinary type. But if the sub object's type is itself a specialized
+ // type, we need to make sure to use that type as the specialization argument.
+
+ // TODO: need to implement the case where the field is an array of existential values.
+ SLANG_ASSERT(bindingRange.count == 1);
+ ExtendedShaderObjectType specializedSubObjType;
+ SLANG_RETURN_ON_FAIL(objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType));
+ args.add(specializedSubObjType);
+ break;
+ }
+ case slang::BindingType::ParameterBlock:
+ case slang::BindingType::ConstantBuffer:
+ // Currently we only handle the case where the field's type is
+ // `ParameterBlock<SomeStruct>` or `ConstantBuffer<SomeStruct>`, where `SomeStruct` is a struct type
+ // (not directly an interface type). In this case, we just recursively collect the specialization arguments
+ // from the bound sub object.
+ SLANG_RETURN_ON_FAIL(objects[subObjIndex]->collectSpecializationArgs(args));
+ // TODO: we need to handle the case where the field is of the form `ParameterBlock<IFoo>`. We should treat
+ // this case the same way as the `ExistentialValue` case here, but currently we lack a mechanism to distinguish
+ // the two scenarios.
+ break;
+ }
+ // TODO: need to handle another case where specialization happens on resources fields e.g. `StructuredBuffer<IFoo>`.
+ }
+ return SLANG_OK;
+ }
};
class CUDAEntryPointShaderObject : public CUDAShaderObject
@@ -652,17 +680,8 @@ public:
};
-class CUDARenderer : public IRenderer, public RefObject
+class CUDARenderer : public RendererBase
{
-public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IRenderer* getInterface(const Guid& guid)
- {
- return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer)
- ? static_cast<IRenderer*>(this)
- : nullptr;
- }
-
private:
static const CUDAReportStyle reportType = CUDAReportStyle::Normal;
static int _calcSMCountPerMultiProcessor(int major, int minor)
@@ -781,12 +800,14 @@ private:
int m_deviceIndex = -1;
CUdevice m_device = 0;
CUcontext m_context = nullptr;
- CUDAPipelineState* currentPipeline = nullptr;
- CUDARootShaderObject* currentRootObject = nullptr;
+ RefPtr<CUDAPipelineState> currentPipeline = nullptr;
+ RefPtr<CUDARootShaderObject> currentRootObject = nullptr;
SlangContext slangContext;
public:
~CUDARenderer()
{
+ currentPipeline = nullptr;
+ currentRootObject = nullptr;
if (m_context)
{
cuCtxDestroy(m_context);
@@ -796,6 +817,8 @@ private:
{
SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_PTX, "sm_5_1"));
+ SLANG_RETURN_ON_FAIL(RendererBase::initialize(desc, inWindowHandle));
+
SLANG_RETURN_ON_FAIL(_initCuda(reportType));
SLANG_RETURN_ON_FAIL(_findMaxFlopsDeviceIndex(&m_deviceIndex));
@@ -813,13 +836,6 @@ private:
return SLANG_OK;
}
- virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) override
- {
- *outSlangSession = slangContext.session.get();
- slangContext.session->addRef();
- return SLANG_OK;
- }
-
virtual SLANG_NO_THROW Result SLANG_MCALL createTextureResource(
IResource::Usage initialUsage,
const ITextureResource::Desc& desc,
@@ -1271,7 +1287,7 @@ private:
slang::TypeLayoutReflection* typeLayout, IShaderObjectLayout** outLayout) override
{
RefPtr<CUDAShaderObjectLayout> cudaLayout;
- cudaLayout = new CUDAShaderObjectLayout(typeLayout);
+ cudaLayout = new CUDAShaderObjectLayout(this, typeLayout);
*outLayout = cudaLayout.detach();
return SLANG_OK;
}
@@ -1309,6 +1325,17 @@ private:
virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override
{
+ // If this is a specializable program, we just keep a reference to the slang program and
+ // don't actually create any kernels. This program will be specialized later when we know
+ // the shader object bindings.
+ if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0)
+ {
+ RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram();
+ cudaProgram->slangProgram = desc.slangProgram;
+ *outProgram = cudaProgram.detach();
+ return SLANG_OK;
+ }
+
if( desc.kernelCount == 0 )
{
return createProgramFromSlang(this, desc, outProgram);
@@ -1332,7 +1359,7 @@ private:
return SLANG_FAIL;
RefPtr<CUDAProgramLayout> cudaLayout;
- cudaLayout = new CUDAProgramLayout(slangProgramLayout);
+ cudaLayout = new CUDAProgramLayout(this, slangProgramLayout);
cudaLayout->programLayout = slangProgramLayout;
cudaProgram->layout = cudaLayout;
}
@@ -1346,6 +1373,7 @@ private:
{
RefPtr<CUDAPipelineState> state = new CUDAPipelineState();
state->shaderProgram = dynamic_cast<CUDAShaderProgram*>(desc.program);
+ state->init(desc);
*outState = state.detach();
return Result();
}
@@ -1360,18 +1388,19 @@ private:
SLANG_UNUSED(buffer);
}
- virtual SLANG_NO_THROW void SLANG_MCALL
- setPipelineState(PipelineType pipelineType, IPipelineState* state) override
+ virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override
{
- SLANG_ASSERT(pipelineType == PipelineType::Compute);
currentPipeline = dynamic_cast<CUDAPipelineState*>(state);
}
virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override
{
+ // Specialize the compute kernel based on the shader object bindings.
+ maybeSpecializePipeline(currentRootObject);
+
// Find out thread group size from program reflection.
auto& kernelName = currentPipeline->shaderProgram->kernelName;
- auto programLayout = dynamic_cast<CUDAProgramLayout*>(currentRootObject->layout.Ptr());
+ auto programLayout = static_cast<CUDAProgramLayout*>(currentRootObject->getLayout());
int kernelId = programLayout->getKernelIndex(kernelName.getUnownedSlice());
SLANG_ASSERT(kernelId != -1);
UInt threadGroupSize[3];
@@ -1451,21 +1480,12 @@ private:
return RendererType::CUDA;
}
-public:
- // Unused public interfaces. These functions are not supported on CUDA.
- SLANG_NO_THROW Result SLANG_MCALL getFeatures(
- const char** outFeatures, UInt bufferSize, UInt* outFeatureCount)
- {
- if (outFeatureCount)
- *outFeatureCount = 0;
- return SLANG_OK;
- }
-
- SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* featureName)
+ virtual PipelineStateBase* getCurrentPipeline() override
{
- return false;
+ return currentPipeline;
}
+public:
virtual SLANG_NO_THROW void SLANG_MCALL setClearColor(const float color[4]) override
{
SLANG_UNUSED(color);
@@ -1602,8 +1622,8 @@ public:
SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* typeLayout)
{
- this->layout = typeLayout;
-
+ m_layout = typeLayout;
+
// If the layout tells us that there is any uniform data,
// then we need to allocate a constant buffer to hold that data.
//
@@ -1613,8 +1633,8 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout*
// TODO: When/where do we bind this constant buffer into
// a descriptor set for later use?
//
- auto slangLayout = layout->typeLayout;
- size_t uniformSize = layout->typeLayout->getSize();
+ auto slangLayout = getLayout()->getElementTypeLayout();
+ size_t uniformSize = slangLayout->getSize();
if (uniformSize)
{
initBuffer(renderer, uniformSize);
@@ -1626,7 +1646,7 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout*
Index subObjectCount = slangLayout->getSubObjectRangeCount();
objects.setCount(subObjectCount);
- for (auto subObjectRange : layout->subObjectRanges)
+ for (auto subObjectRange : getLayout()->subObjectRanges)
{
RefPtr<CUDAShaderObjectLayout> subObjectLayout = subObjectRange.layout;
@@ -1643,7 +1663,7 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout*
// in each entry in this range, based on the layout
// information we already have.
- auto& bindingRangeInfo = layout->m_bindingRanges[subObjectRange.bindingRangeIndex];
+ auto& bindingRangeInfo = getLayout()->m_bindingRanges[subObjectRange.bindingRangeIndex];
for (Index i = 0; i < bindingRangeInfo.count; ++i)
{
RefPtr<CUDAShaderObject> subObject = new CUDAShaderObject();
@@ -1671,15 +1691,18 @@ SlangResult CUDARootShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayo
return SLANG_OK;
}
-SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer)
+SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
- *outRenderer = new CUDARenderer();
- (*outRenderer)->addRef();
+ RefPtr<CUDARenderer> result = new CUDARenderer();
+ SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle));
+ *outRenderer = result.detach();
return SLANG_OK;
}
#else
-SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer)
+SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
+ SLANG_UNUSED(desc);
+ SLANG_UNUSED(windowHandle);
*outRenderer = nullptr;
return SLANG_OK;
}