diff options
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
| -rw-r--r-- | tools/gfx/cuda/render-cuda.cpp | 237 |
1 files changed, 130 insertions, 107 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp index 057674550..7d7ee8eb9 100644 --- a/tools/gfx/cuda/render-cuda.cpp +++ b/tools/gfx/cuda/render-cuda.cpp @@ -244,21 +244,12 @@ public: class CUDAProgramLayout; -class CUDAShaderProgram : public IShaderProgram, public RefObject +class CUDAShaderProgram : public ShaderProgramBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IShaderProgram* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram) - return static_cast<IShaderProgram*>(this); - return nullptr; - } -public: CUmodule cudaModule = nullptr; CUfunction cudaKernel; String kernelName; - ComPtr<slang::IComponentType> slangProgram; RefPtr<CUDAProgramLayout> layout; ~CUDAShaderProgram() @@ -268,33 +259,22 @@ public: } }; -class CUDAPipelineState : public IPipelineState, public RefObject +class CUDAPipelineState : public PipelineStateBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IPipelineState* getInterface(const Guid& guid) + RefPtr<CUDAShaderProgram> shaderProgram; + void init(const ComputePipelineStateDesc& inDesc) { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState) - return static_cast<IPipelineState*>(this); - return nullptr; + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Compute; + pipelineDesc.compute = inDesc; + initializeBase(pipelineDesc); } -public: - RefPtr<CUDAShaderProgram> shaderProgram; }; -class CUDAShaderObjectLayout : public IShaderObjectLayout, public RefObject +class CUDAShaderObjectLayout : public ShaderObjectLayoutBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IShaderObjectLayout* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout) - return static_cast<IShaderObjectLayout*>(this); - return nullptr; - } -public: - slang::TypeLayoutReflection* typeLayout = nullptr; - struct BindingRangeInfo { slang::BindingType bindingType; @@ -335,30 +315,32 @@ public: } } - CUDAShaderObjectLayout(slang::TypeLayoutReflection* layout) + CUDAShaderObjectLayout(RendererBase* renderer, slang::TypeLayoutReflection* layout) { + initBase(renderer, layout); + Index subObjectCount = 0; - typeLayout = unwrapParameterGroups(layout); + m_elementTypeLayout = unwrapParameterGroups(layout); // Compute the binding ranges that are used to store // the logical contents of the object in memory. These will relate // to the descriptor ranges in the various sets, but not always // in a one-to-one fashion. - SlangInt bindingRangeCount = typeLayout->getBindingRangeCount(); + SlangInt bindingRangeCount = m_elementTypeLayout->getBindingRangeCount(); for (SlangInt r = 0; r < bindingRangeCount; ++r) { - slang::BindingType slangBindingType = typeLayout->getBindingRangeType(r); - SlangInt count = typeLayout->getBindingRangeBindingCount(r); + slang::BindingType slangBindingType = m_elementTypeLayout->getBindingRangeType(r); + SlangInt count = m_elementTypeLayout->getBindingRangeBindingCount(r); slang::TypeLayoutReflection* slangLeafTypeLayout = - typeLayout->getBindingRangeLeafTypeLayout(r); + m_elementTypeLayout->getBindingRangeLeafTypeLayout(r); - SlangInt descriptorSetIndex = typeLayout->getBindingRangeDescriptorSetIndex(r); + SlangInt descriptorSetIndex = m_elementTypeLayout->getBindingRangeDescriptorSetIndex(r); SlangInt rangeIndexInDescriptorSet = - typeLayout->getBindingRangeFirstDescriptorRangeIndex(r); + m_elementTypeLayout->getBindingRangeFirstDescriptorRangeIndex(r); - auto uniformOffset = typeLayout->getDescriptorSetDescriptorRangeIndexOffset( + auto uniformOffset = m_elementTypeLayout->getDescriptorSetDescriptorRangeIndexOffset( descriptorSetIndex, rangeIndexInDescriptorSet); Index baseIndex = 0; @@ -383,13 +365,13 @@ public: m_bindingRanges.add(bindingRangeInfo); } - SlangInt subObjectRangeCount = typeLayout->getSubObjectRangeCount(); + SlangInt subObjectRangeCount = m_elementTypeLayout->getSubObjectRangeCount(); for (SlangInt r = 0; r < subObjectRangeCount; ++r) { - SlangInt bindingRangeIndex = typeLayout->getSubObjectRangeBindingRangeIndex(r); - auto slangBindingType = typeLayout->getBindingRangeType(bindingRangeIndex); + SlangInt bindingRangeIndex = m_elementTypeLayout->getSubObjectRangeBindingRangeIndex(r); + auto slangBindingType = m_elementTypeLayout->getBindingRangeType(bindingRangeIndex); slang::TypeLayoutReflection* slangLeafTypeLayout = - typeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex); + m_elementTypeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex); // A sub-object range can either represent a sub-object of a known // type, like a `ConstantBuffer<Foo>` or `ParameterBlock<Foo>` @@ -402,7 +384,7 @@ public: if (slangBindingType != slang::BindingType::ExistentialValue) { subObjectLayout = - new CUDAShaderObjectLayout(slangLeafTypeLayout->getElementTypeLayout()); + new CUDAShaderObjectLayout(renderer, slangLeafTypeLayout->getElementTypeLayout()); } SubObjectRangeInfo subObjectRange; @@ -418,13 +400,14 @@ class CUDAProgramLayout : public CUDAShaderObjectLayout public: slang::ProgramLayout* programLayout = nullptr; List<RefPtr<CUDAShaderObjectLayout>> entryPointLayouts; - CUDAProgramLayout(slang::ProgramLayout* inProgramLayout) - : CUDAShaderObjectLayout(inProgramLayout->getGlobalParamsTypeLayout()) + CUDAProgramLayout(RendererBase* renderer, slang::ProgramLayout* inProgramLayout) + : CUDAShaderObjectLayout(renderer, inProgramLayout->getGlobalParamsTypeLayout()) , programLayout(inProgramLayout) { for (UInt i =0; i< programLayout->getEntryPointCount(); i++) { entryPointLayouts.add(new CUDAShaderObjectLayout( + renderer, programLayout->getEntryPointByIndex(i)->getTypeLayout())); } @@ -450,26 +433,21 @@ public: } }; -class CUDAShaderObject : public IShaderObject, public RefObject +class CUDAShaderObject : public ShaderObjectBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IShaderObject* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject) - return static_cast<IShaderObject*>(this); - return nullptr; - } - -public: RefPtr<MemoryCUDAResource> bufferResource; - RefPtr<CUDAShaderObjectLayout> layout; List<RefPtr<CUDAShaderObject>> objects; List<RefPtr<CUDAResourceView>> resources; virtual SLANG_NO_THROW Result SLANG_MCALL init(IRenderer* renderer, CUDAShaderObjectLayout* typeLayout); + CUDAShaderObjectLayout* getLayout() + { + return static_cast<CUDAShaderObjectLayout*>(m_layout.Ptr()); + } + virtual SLANG_NO_THROW Result SLANG_MCALL initBuffer(IRenderer* renderer, size_t bufferSize) { BufferResource::Desc bufferDesc; @@ -494,7 +472,7 @@ public: virtual SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getElementTypeLayout() override { - return layout->typeLayout; + return getLayout()->getElementTypeLayout(); } virtual SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() override { return 0; } @@ -519,7 +497,7 @@ public: getObject(ShaderOffset const& offset, IShaderObject** object) { auto subObjectIndex = - layout->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex; + getLayout()->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex; if (subObjectIndex >= objects.getCount()) { *object = nullptr; @@ -533,10 +511,10 @@ public: setObject(ShaderOffset const& offset, IShaderObject* object) { auto subObjectIndex = - layout->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex; + getLayout()->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex; SLANG_ASSERT( offset.uniformOffset == - layout->m_bindingRanges[offset.bindingRangeIndex].uniformOffset + + getLayout()->m_bindingRanges[offset.bindingRangeIndex].uniformOffset + offset.bindingArrayIndex * sizeof(void*)); auto cudaObject = dynamic_cast<CUDAShaderObject*>(object); if (subObjectIndex >= objects.getCount()) @@ -593,6 +571,56 @@ public: setResource(offset, textureView); return SLANG_OK; } + + // Appends all types that are used to specialize the element type of this shader object in `args` list. + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override + { + // TODO: the logic here is a copy-paste of `GraphicsCommonShaderObject::collectSpecializationArgs`, + // consider moving the implementation to `ShaderObjectBase` and share the logic among different implementations. + + if (!m_bindingFinalized) + return SLANG_FAIL; + auto& subObjectRanges = getLayout()->subObjectRanges; + // The following logic is built on the assumption that all fields that involve existential types (and + // therefore require specialization) will results in a sub-object range in the type layout. + // This allows us to simply scan the sub-object ranges to find out all specialization arguments. + for (Index subObjIndex = 0; subObjIndex < subObjectRanges.getCount(); subObjIndex++) + { + // Retrieve the corresponding binding range of the sub object. + auto bindingRange = getLayout()->m_bindingRanges[subObjectRanges[subObjIndex].bindingRangeIndex]; + switch (bindingRange.bindingType) + { + case slang::BindingType::ExistentialValue: + { + // A binding type of `ExistentialValue` means the sub-object represents a interface-typed field. + // In this case the specialization argument for this field is the actual specialized type of the bound + // shader object. If the shader object's type is an ordinary type without existential fields, then the + // type argument will simply be the ordinary type. But if the sub object's type is itself a specialized + // type, we need to make sure to use that type as the specialization argument. + + // TODO: need to implement the case where the field is an array of existential values. + SLANG_ASSERT(bindingRange.count == 1); + ExtendedShaderObjectType specializedSubObjType; + SLANG_RETURN_ON_FAIL(objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType)); + args.add(specializedSubObjType); + break; + } + case slang::BindingType::ParameterBlock: + case slang::BindingType::ConstantBuffer: + // Currently we only handle the case where the field's type is + // `ParameterBlock<SomeStruct>` or `ConstantBuffer<SomeStruct>`, where `SomeStruct` is a struct type + // (not directly an interface type). In this case, we just recursively collect the specialization arguments + // from the bound sub object. + SLANG_RETURN_ON_FAIL(objects[subObjIndex]->collectSpecializationArgs(args)); + // TODO: we need to handle the case where the field is of the form `ParameterBlock<IFoo>`. We should treat + // this case the same way as the `ExistentialValue` case here, but currently we lack a mechanism to distinguish + // the two scenarios. + break; + } + // TODO: need to handle another case where specialization happens on resources fields e.g. `StructuredBuffer<IFoo>`. + } + return SLANG_OK; + } }; class CUDAEntryPointShaderObject : public CUDAShaderObject @@ -652,17 +680,8 @@ public: }; -class CUDARenderer : public IRenderer, public RefObject +class CUDARenderer : public RendererBase { -public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IRenderer* getInterface(const Guid& guid) - { - return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer) - ? static_cast<IRenderer*>(this) - : nullptr; - } - private: static const CUDAReportStyle reportType = CUDAReportStyle::Normal; static int _calcSMCountPerMultiProcessor(int major, int minor) @@ -781,12 +800,14 @@ private: int m_deviceIndex = -1; CUdevice m_device = 0; CUcontext m_context = nullptr; - CUDAPipelineState* currentPipeline = nullptr; - CUDARootShaderObject* currentRootObject = nullptr; + RefPtr<CUDAPipelineState> currentPipeline = nullptr; + RefPtr<CUDARootShaderObject> currentRootObject = nullptr; SlangContext slangContext; public: ~CUDARenderer() { + currentPipeline = nullptr; + currentRootObject = nullptr; if (m_context) { cuCtxDestroy(m_context); @@ -796,6 +817,8 @@ private: { SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_PTX, "sm_5_1")); + SLANG_RETURN_ON_FAIL(RendererBase::initialize(desc, inWindowHandle)); + SLANG_RETURN_ON_FAIL(_initCuda(reportType)); SLANG_RETURN_ON_FAIL(_findMaxFlopsDeviceIndex(&m_deviceIndex)); @@ -813,13 +836,6 @@ private: return SLANG_OK; } - virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) override - { - *outSlangSession = slangContext.session.get(); - slangContext.session->addRef(); - return SLANG_OK; - } - virtual SLANG_NO_THROW Result SLANG_MCALL createTextureResource( IResource::Usage initialUsage, const ITextureResource::Desc& desc, @@ -1271,7 +1287,7 @@ private: slang::TypeLayoutReflection* typeLayout, IShaderObjectLayout** outLayout) override { RefPtr<CUDAShaderObjectLayout> cudaLayout; - cudaLayout = new CUDAShaderObjectLayout(typeLayout); + cudaLayout = new CUDAShaderObjectLayout(this, typeLayout); *outLayout = cudaLayout.detach(); return SLANG_OK; } @@ -1309,6 +1325,17 @@ private: virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override { + // If this is a specializable program, we just keep a reference to the slang program and + // don't actually create any kernels. This program will be specialized later when we know + // the shader object bindings. + if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0) + { + RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram(); + cudaProgram->slangProgram = desc.slangProgram; + *outProgram = cudaProgram.detach(); + return SLANG_OK; + } + if( desc.kernelCount == 0 ) { return createProgramFromSlang(this, desc, outProgram); @@ -1332,7 +1359,7 @@ private: return SLANG_FAIL; RefPtr<CUDAProgramLayout> cudaLayout; - cudaLayout = new CUDAProgramLayout(slangProgramLayout); + cudaLayout = new CUDAProgramLayout(this, slangProgramLayout); cudaLayout->programLayout = slangProgramLayout; cudaProgram->layout = cudaLayout; } @@ -1346,6 +1373,7 @@ private: { RefPtr<CUDAPipelineState> state = new CUDAPipelineState(); state->shaderProgram = dynamic_cast<CUDAShaderProgram*>(desc.program); + state->init(desc); *outState = state.detach(); return Result(); } @@ -1360,18 +1388,19 @@ private: SLANG_UNUSED(buffer); } - virtual SLANG_NO_THROW void SLANG_MCALL - setPipelineState(PipelineType pipelineType, IPipelineState* state) override + virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override { - SLANG_ASSERT(pipelineType == PipelineType::Compute); currentPipeline = dynamic_cast<CUDAPipelineState*>(state); } virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override { + // Specialize the compute kernel based on the shader object bindings. + maybeSpecializePipeline(currentRootObject); + // Find out thread group size from program reflection. auto& kernelName = currentPipeline->shaderProgram->kernelName; - auto programLayout = dynamic_cast<CUDAProgramLayout*>(currentRootObject->layout.Ptr()); + auto programLayout = static_cast<CUDAProgramLayout*>(currentRootObject->getLayout()); int kernelId = programLayout->getKernelIndex(kernelName.getUnownedSlice()); SLANG_ASSERT(kernelId != -1); UInt threadGroupSize[3]; @@ -1451,21 +1480,12 @@ private: return RendererType::CUDA; } -public: - // Unused public interfaces. These functions are not supported on CUDA. - SLANG_NO_THROW Result SLANG_MCALL getFeatures( - const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) - { - if (outFeatureCount) - *outFeatureCount = 0; - return SLANG_OK; - } - - SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* featureName) + virtual PipelineStateBase* getCurrentPipeline() override { - return false; + return currentPipeline; } +public: virtual SLANG_NO_THROW void SLANG_MCALL setClearColor(const float color[4]) override { SLANG_UNUSED(color); @@ -1602,8 +1622,8 @@ public: SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* typeLayout) { - this->layout = typeLayout; - + m_layout = typeLayout; + // If the layout tells us that there is any uniform data, // then we need to allocate a constant buffer to hold that data. // @@ -1613,8 +1633,8 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* // TODO: When/where do we bind this constant buffer into // a descriptor set for later use? // - auto slangLayout = layout->typeLayout; - size_t uniformSize = layout->typeLayout->getSize(); + auto slangLayout = getLayout()->getElementTypeLayout(); + size_t uniformSize = slangLayout->getSize(); if (uniformSize) { initBuffer(renderer, uniformSize); @@ -1626,7 +1646,7 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* Index subObjectCount = slangLayout->getSubObjectRangeCount(); objects.setCount(subObjectCount); - for (auto subObjectRange : layout->subObjectRanges) + for (auto subObjectRange : getLayout()->subObjectRanges) { RefPtr<CUDAShaderObjectLayout> subObjectLayout = subObjectRange.layout; @@ -1643,7 +1663,7 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* // in each entry in this range, based on the layout // information we already have. - auto& bindingRangeInfo = layout->m_bindingRanges[subObjectRange.bindingRangeIndex]; + auto& bindingRangeInfo = getLayout()->m_bindingRanges[subObjectRange.bindingRangeIndex]; for (Index i = 0; i < bindingRangeInfo.count; ++i) { RefPtr<CUDAShaderObject> subObject = new CUDAShaderObject(); @@ -1671,15 +1691,18 @@ SlangResult CUDARootShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayo return SLANG_OK; } -SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer) +SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { - *outRenderer = new CUDARenderer(); - (*outRenderer)->addRef(); + RefPtr<CUDARenderer> result = new CUDARenderer(); + SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle)); + *outRenderer = result.detach(); return SLANG_OK; } #else -SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer) +SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { + SLANG_UNUSED(desc); + SLANG_UNUSED(windowHandle); *outRenderer = nullptr; return SLANG_OK; } |
