diff options
31 files changed, 1242 insertions, 544 deletions
diff --git a/.gitignore b/.gitignore index e941540c1..fad8a1030 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,9 @@ *.obj *.slang-module *.zip +*.ini .clang-format + bin/ intermediate/ build.*/ diff --git a/examples/gpu-printing/main.cpp b/examples/gpu-printing/main.cpp index 6a4aabf3a..8dc0b0f3d 100644 --- a/examples/gpu-printing/main.cpp +++ b/examples/gpu-printing/main.cpp @@ -118,14 +118,12 @@ Result execute() windowDesc.height = gWindowHeight; gWindow = createWindow(windowDesc); - gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef()); IRenderer::Desc rendererDesc; + rendererDesc.rendererType = gfx::RendererType::DirectX11; rendererDesc.width = gWindowWidth; rendererDesc.height = gWindowHeight; - { - Result res = gRenderer->initialize(rendererDesc, getPlatformWindowHandle(gWindow)); - if(SLANG_FAILED(res)) return res; - } + Result res = gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(gWindow), gRenderer.writeRef()); + if(SLANG_FAILED(res)) return res; gSlangSession = createSlangSession(gRenderer); gSlangModule = compileShaderModuleFromFile(gSlangSession, "kernels.slang"); @@ -191,7 +189,7 @@ Result execute() gDescriptorSet->setResource(0, 0, printBufferView); gRenderer->setDescriptorSet(PipelineType::Compute, gPipelineLayout, 0, gDescriptorSet); - gRenderer->setPipelineState(PipelineType::Compute, gPipelineState); + gRenderer->setPipelineState(gPipelineState); gRenderer->dispatchCompute(1, 1, 1); // TODO: need to copy from the print buffer to a staging buffer... diff --git a/examples/hello-world/main.cpp b/examples/hello-world/main.cpp index e8e3c45e0..91c9c0627 100644 --- a/examples/hello-world/main.cpp +++ b/examples/hello-world/main.cpp @@ -240,14 +240,12 @@ Slang::Result initialize() // A future version of this example may support multiple // platforms/APIs. // - gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef()); - IRenderer::Desc rendererDesc; + IRenderer::Desc rendererDesc = {}; + rendererDesc.rendererType = gfx::RendererType::DirectX11; rendererDesc.width = gWindowWidth; rendererDesc.height = gWindowHeight; - { - gfx::Result res = gRenderer->initialize(rendererDesc, getPlatformWindowHandle(gWindow)); - if(SLANG_FAILED(res)) return res; - } + gfx::Result res = gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(gWindow), gRenderer.writeRef()); + if(SLANG_FAILED(res)) return res; // Now we will create objects needed to configur the "input assembler" // (IA) stage of the D3D pipeline. @@ -414,7 +412,7 @@ void renderFrame() // PSO, binding our root shader object to it (which references // the `Uniforms` buffer that will filled in above). // - gRenderer->setPipelineState(PipelineType::Graphics, gPipelineState); + gRenderer->setPipelineState(gPipelineState); gRenderer->bindRootShaderObject(PipelineType::Graphics, gRootObject); // We also need to set up a few pieces of fixed-function pipeline diff --git a/examples/heterogeneous-hello-world/main.cpp b/examples/heterogeneous-hello-world/main.cpp index 8610a5fa2..163b17deb 100644 --- a/examples/heterogeneous-hello-world/main.cpp +++ b/examples/heterogeneous-hello-world/main.cpp @@ -43,7 +43,6 @@ using namespace gfx; // We create global ref pointers to avoid dereferencing values // ComPtr<gfx::IShaderProgram> gShaderProgram; -Slang::RefPtr<gfx::ApplicationContext> gAppContext; Slang::ComPtr<gfx::IRenderer> gRenderer; ComPtr<gfx::IBufferResource> gStructuredBuffer; @@ -123,14 +122,12 @@ gfx::IRenderer* createRenderer( // A future version of this example may support multiple // platforms/APIs. // - gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef()); - IRenderer::Desc rendererDesc; + IRenderer::Desc rendererDesc = {}; + rendererDesc.rendererType = gfx::RendererType::DirectX11; rendererDesc.width = windowWidth; rendererDesc.height = windowHeight; - { - Result res = gRenderer->initialize(rendererDesc, getPlatformWindowHandle(window)); - if (SLANG_FAILED(res)) return nullptr; - } + Result res = gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(window), gRenderer.writeRef()); + if (SLANG_FAILED(res)) return nullptr; return gRenderer; } @@ -249,7 +246,7 @@ void dispatchComputation( unsigned int gridDimsZ) { - gRenderer->setPipelineState(PipelineType::Compute, gPipelineState); + gRenderer->setPipelineState(gPipelineState); gRenderer->setDescriptorSet(PipelineType::Compute, gPipelineLayout, 0, gDescriptorSet); gRenderer->dispatchCompute(gridDimsX, gridDimsY, gridDimsZ); diff --git a/examples/model-viewer/main.cpp b/examples/model-viewer/main.cpp index 384cc5eac..f830d4044 100644 --- a/examples/model-viewer/main.cpp +++ b/examples/model-viewer/main.cpp @@ -1254,7 +1254,7 @@ public: // we simply bind its PSO into the GPU state, and // remember the variant we've selected. // - renderer->setPipelineState(PipelineType::Graphics, variant->pipelineState); + renderer->setPipelineState(variant->pipelineState); currentEffectVariant = variant; } @@ -2050,11 +2050,11 @@ Result initialize() windowDesc.userData = this; gWindow = createWindow(windowDesc); - gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef()); - IRenderer::Desc rendererDesc; + IRenderer::Desc rendererDesc = {}; + rendererDesc.rendererType = gfx::RendererType::DirectX11; rendererDesc.width = gWindowWidth; rendererDesc.height = gWindowHeight; - gRenderer->initialize(rendererDesc, getPlatformWindowHandle(gWindow)); + gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(gWindow), gRenderer.writeRef()); InputElementDesc inputElements[] = { {"POSITION", 0, Format::RGB_Float32, offsetof(Model::Vertex, position) }, diff --git a/examples/shader-toy/main.cpp b/examples/shader-toy/main.cpp index a1408e38e..2bbc59113 100644 --- a/examples/shader-toy/main.cpp +++ b/examples/shader-toy/main.cpp @@ -347,14 +347,12 @@ Result initialize() windowDesc.userData = this; gWindow = createWindow(windowDesc); - gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef()); IRenderer::Desc rendererDesc; + rendererDesc.rendererType = RendererType::DirectX11; rendererDesc.width = gWindowWidth; rendererDesc.height = gWindowHeight; - { - Result res = gRenderer->initialize(rendererDesc, getPlatformWindowHandle(gWindow)); - if(SLANG_FAILED(res)) return res; - } + Result res = gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(gWindow), gRenderer.writeRef()); + if(SLANG_FAILED(res)) return res; int constantBufferSize = sizeof(Uniforms); @@ -477,7 +475,7 @@ void renderFrame() gRenderer->unmap(gConstantBuffer); } - gRenderer->setPipelineState(PipelineType::Graphics, gPipelineState); + gRenderer->setPipelineState(gPipelineState); gRenderer->setDescriptorSet(PipelineType::Graphics, gPipelineLayout, 0, gDescriptorSet); gRenderer->setVertexBuffer(0, gVertexBuffer, sizeof(FullScreenTriangle::Vertex)); @@ -3991,6 +3991,10 @@ namespace slang SlangInt targetIndex = 0, IBlob** outDiagnostics = nullptr) = 0; + /** Get the number of (unspecialized) specialization parameters for the component type. + */ + virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() = 0; + /** Get the compiled code for the entry point at `entryPointIndex` for the chosen `targetIndex` Entry point code can only be computed for a component type that diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h index 4c352d55b..51b61de60 100644 --- a/source/core/slang-dictionary.h +++ b/source/core/slang-dictionary.h @@ -135,15 +135,17 @@ namespace Slang } }; - inline int GetHashPos(TKey& key) const + template<typename KeyType> + inline int GetHashPos(KeyType& key) const { SLANG_ASSERT(bucketSizeMinusOne > 0); const unsigned int hash = (unsigned int)getHashCode(key); return (hash * 2654435761u) % (unsigned int)(bucketSizeMinusOne); } - FindPositionResult FindPosition(const TKey& key) const + template<typename KeyType> + FindPositionResult FindPosition(const KeyType& key) const { - int hashPos = GetHashPos(const_cast<TKey&>(key)); + int hashPos = GetHashPos(const_cast<KeyType&>(key)); int insertPos = -1; int numProbes = 0; while (numProbes <= bucketSizeMinusOne) @@ -380,14 +382,16 @@ namespace Slang throw InvalidOperationException("Inconsistent find result returned. This is a bug in Dictionary implementation."); } - bool ContainsKey(const TKey& key) const + template<typename KeyType> + bool ContainsKey(const KeyType& key) const { if (bucketSizeMinusOne == -1) return false; auto pos = FindPosition(key); return pos.ObjectPosition != -1; } - bool TryGetValue(const TKey& key, TValue& value) const + template<typename KeyType> + bool TryGetValue(const KeyType& key, TValue& value) const { if (bucketSizeMinusOne == -1) return false; @@ -399,7 +403,8 @@ namespace Slang } return false; } - TValue* TryGetValue(const TKey& key) const + template<typename KeyType> + TValue* TryGetValue(const KeyType& key) const { if (bucketSizeMinusOne == -1) return nullptr; diff --git a/source/core/slang-short-list.h b/source/core/slang-short-list.h index 82ad4fe1e..7d51a8abf 100644 --- a/source/core/slang-short-list.h +++ b/source/core/slang-short-list.h @@ -47,6 +47,13 @@ namespace Slang return *this; } + ThisType& operator=(const ThisType& other) + { + clearAndDeallocate(); + addRange(other); + return *this; + } + ThisType& operator=(ThisType&& list) { // Could just do a swap here, and memory would be freed on rhs dtor diff --git a/source/core/slang-smart-pointer.h b/source/core/slang-smart-pointer.h index 53ac010ed..eb9fe7be5 100644 --- a/source/core/slang-smart-pointer.h +++ b/source/core/slang-smart-pointer.h @@ -114,9 +114,9 @@ namespace Slang template <typename U> RefPtr(RefPtr<U> const& p, typename EnableIf<IsConvertible<T*, U*>::Value, void>::type * = 0) - : pointer((U*) p) + : pointer(static_cast<U*>(p)) { - addReference((U*) p); + addReference(static_cast<U*>(p)); } #if 0 @@ -200,7 +200,7 @@ namespace Slang ~RefPtr() { - releaseReference((Slang::RefObject*) pointer); + releaseReference(static_cast<Slang::RefObject*>(pointer)); } T& operator*() const diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index ac2d72862..90b1e30f3 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -318,9 +318,6 @@ namespace Slang /// Get one of the global shader parametesr linked into this component type. virtual ShaderParamInfo getShaderParam(Index index) = 0; - /// Get the number of (unspecialized) specialization parameters for the component type. - virtual Index getSpecializationParamCount() = 0; - /// Get the specialization parameter at `index`. virtual SpecializationParam const& getSpecializationParam(Index index) = 0; @@ -511,7 +508,7 @@ namespace Slang Index getShaderParamCount() SLANG_OVERRIDE; ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE; - Index getSpecializationParamCount() SLANG_OVERRIDE; + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; Index getRequirementCount() SLANG_OVERRIDE; @@ -598,7 +595,7 @@ namespace Slang Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); } - Index getSpecializationParamCount() SLANG_OVERRIDE { return 0; } + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); static SpecializationParam dummy; return dummy; } Index getRequirementCount() SLANG_OVERRIDE; @@ -759,7 +756,7 @@ namespace Slang String mangledName); /// Get the number of existential type parameters for the entry point. - Index getSpecializationParamCount() SLANG_OVERRIDE; + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE; /// Get the existential type parameter at `index`. SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE; @@ -969,7 +966,7 @@ namespace Slang Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); } ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; } - Index getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); } + SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); } SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; } Index getRequirementCount() SLANG_OVERRIDE; diff --git a/tests/compute/dynamic-dispatch-11.slang b/tests/compute/dynamic-dispatch-11.slang index daebff39a..964431aaf 100644 --- a/tests/compute/dynamic-dispatch-11.slang +++ b/tests/compute/dynamic-dispatch-11.slang @@ -1,9 +1,9 @@ // Test using interface typed shader parameters with dynamic dispatch. -//TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj -//TEST(compute):COMPARE_COMPUTE:-vk -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -shaderobj //TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj [anyValueSize(8)] interface IInterface diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp index 057674550..7d7ee8eb9 100644 --- a/tools/gfx/cuda/render-cuda.cpp +++ b/tools/gfx/cuda/render-cuda.cpp @@ -244,21 +244,12 @@ public: class CUDAProgramLayout; -class CUDAShaderProgram : public IShaderProgram, public RefObject +class CUDAShaderProgram : public ShaderProgramBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IShaderProgram* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram) - return static_cast<IShaderProgram*>(this); - return nullptr; - } -public: CUmodule cudaModule = nullptr; CUfunction cudaKernel; String kernelName; - ComPtr<slang::IComponentType> slangProgram; RefPtr<CUDAProgramLayout> layout; ~CUDAShaderProgram() @@ -268,33 +259,22 @@ public: } }; -class CUDAPipelineState : public IPipelineState, public RefObject +class CUDAPipelineState : public PipelineStateBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IPipelineState* getInterface(const Guid& guid) + RefPtr<CUDAShaderProgram> shaderProgram; + void init(const ComputePipelineStateDesc& inDesc) { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState) - return static_cast<IPipelineState*>(this); - return nullptr; + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Compute; + pipelineDesc.compute = inDesc; + initializeBase(pipelineDesc); } -public: - RefPtr<CUDAShaderProgram> shaderProgram; }; -class CUDAShaderObjectLayout : public IShaderObjectLayout, public RefObject +class CUDAShaderObjectLayout : public ShaderObjectLayoutBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IShaderObjectLayout* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout) - return static_cast<IShaderObjectLayout*>(this); - return nullptr; - } -public: - slang::TypeLayoutReflection* typeLayout = nullptr; - struct BindingRangeInfo { slang::BindingType bindingType; @@ -335,30 +315,32 @@ public: } } - CUDAShaderObjectLayout(slang::TypeLayoutReflection* layout) + CUDAShaderObjectLayout(RendererBase* renderer, slang::TypeLayoutReflection* layout) { + initBase(renderer, layout); + Index subObjectCount = 0; - typeLayout = unwrapParameterGroups(layout); + m_elementTypeLayout = unwrapParameterGroups(layout); // Compute the binding ranges that are used to store // the logical contents of the object in memory. These will relate // to the descriptor ranges in the various sets, but not always // in a one-to-one fashion. - SlangInt bindingRangeCount = typeLayout->getBindingRangeCount(); + SlangInt bindingRangeCount = m_elementTypeLayout->getBindingRangeCount(); for (SlangInt r = 0; r < bindingRangeCount; ++r) { - slang::BindingType slangBindingType = typeLayout->getBindingRangeType(r); - SlangInt count = typeLayout->getBindingRangeBindingCount(r); + slang::BindingType slangBindingType = m_elementTypeLayout->getBindingRangeType(r); + SlangInt count = m_elementTypeLayout->getBindingRangeBindingCount(r); slang::TypeLayoutReflection* slangLeafTypeLayout = - typeLayout->getBindingRangeLeafTypeLayout(r); + m_elementTypeLayout->getBindingRangeLeafTypeLayout(r); - SlangInt descriptorSetIndex = typeLayout->getBindingRangeDescriptorSetIndex(r); + SlangInt descriptorSetIndex = m_elementTypeLayout->getBindingRangeDescriptorSetIndex(r); SlangInt rangeIndexInDescriptorSet = - typeLayout->getBindingRangeFirstDescriptorRangeIndex(r); + m_elementTypeLayout->getBindingRangeFirstDescriptorRangeIndex(r); - auto uniformOffset = typeLayout->getDescriptorSetDescriptorRangeIndexOffset( + auto uniformOffset = m_elementTypeLayout->getDescriptorSetDescriptorRangeIndexOffset( descriptorSetIndex, rangeIndexInDescriptorSet); Index baseIndex = 0; @@ -383,13 +365,13 @@ public: m_bindingRanges.add(bindingRangeInfo); } - SlangInt subObjectRangeCount = typeLayout->getSubObjectRangeCount(); + SlangInt subObjectRangeCount = m_elementTypeLayout->getSubObjectRangeCount(); for (SlangInt r = 0; r < subObjectRangeCount; ++r) { - SlangInt bindingRangeIndex = typeLayout->getSubObjectRangeBindingRangeIndex(r); - auto slangBindingType = typeLayout->getBindingRangeType(bindingRangeIndex); + SlangInt bindingRangeIndex = m_elementTypeLayout->getSubObjectRangeBindingRangeIndex(r); + auto slangBindingType = m_elementTypeLayout->getBindingRangeType(bindingRangeIndex); slang::TypeLayoutReflection* slangLeafTypeLayout = - typeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex); + m_elementTypeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex); // A sub-object range can either represent a sub-object of a known // type, like a `ConstantBuffer<Foo>` or `ParameterBlock<Foo>` @@ -402,7 +384,7 @@ public: if (slangBindingType != slang::BindingType::ExistentialValue) { subObjectLayout = - new CUDAShaderObjectLayout(slangLeafTypeLayout->getElementTypeLayout()); + new CUDAShaderObjectLayout(renderer, slangLeafTypeLayout->getElementTypeLayout()); } SubObjectRangeInfo subObjectRange; @@ -418,13 +400,14 @@ class CUDAProgramLayout : public CUDAShaderObjectLayout public: slang::ProgramLayout* programLayout = nullptr; List<RefPtr<CUDAShaderObjectLayout>> entryPointLayouts; - CUDAProgramLayout(slang::ProgramLayout* inProgramLayout) - : CUDAShaderObjectLayout(inProgramLayout->getGlobalParamsTypeLayout()) + CUDAProgramLayout(RendererBase* renderer, slang::ProgramLayout* inProgramLayout) + : CUDAShaderObjectLayout(renderer, inProgramLayout->getGlobalParamsTypeLayout()) , programLayout(inProgramLayout) { for (UInt i =0; i< programLayout->getEntryPointCount(); i++) { entryPointLayouts.add(new CUDAShaderObjectLayout( + renderer, programLayout->getEntryPointByIndex(i)->getTypeLayout())); } @@ -450,26 +433,21 @@ public: } }; -class CUDAShaderObject : public IShaderObject, public RefObject +class CUDAShaderObject : public ShaderObjectBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IShaderObject* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject) - return static_cast<IShaderObject*>(this); - return nullptr; - } - -public: RefPtr<MemoryCUDAResource> bufferResource; - RefPtr<CUDAShaderObjectLayout> layout; List<RefPtr<CUDAShaderObject>> objects; List<RefPtr<CUDAResourceView>> resources; virtual SLANG_NO_THROW Result SLANG_MCALL init(IRenderer* renderer, CUDAShaderObjectLayout* typeLayout); + CUDAShaderObjectLayout* getLayout() + { + return static_cast<CUDAShaderObjectLayout*>(m_layout.Ptr()); + } + virtual SLANG_NO_THROW Result SLANG_MCALL initBuffer(IRenderer* renderer, size_t bufferSize) { BufferResource::Desc bufferDesc; @@ -494,7 +472,7 @@ public: virtual SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getElementTypeLayout() override { - return layout->typeLayout; + return getLayout()->getElementTypeLayout(); } virtual SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() override { return 0; } @@ -519,7 +497,7 @@ public: getObject(ShaderOffset const& offset, IShaderObject** object) { auto subObjectIndex = - layout->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex; + getLayout()->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex; if (subObjectIndex >= objects.getCount()) { *object = nullptr; @@ -533,10 +511,10 @@ public: setObject(ShaderOffset const& offset, IShaderObject* object) { auto subObjectIndex = - layout->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex; + getLayout()->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex; SLANG_ASSERT( offset.uniformOffset == - layout->m_bindingRanges[offset.bindingRangeIndex].uniformOffset + + getLayout()->m_bindingRanges[offset.bindingRangeIndex].uniformOffset + offset.bindingArrayIndex * sizeof(void*)); auto cudaObject = dynamic_cast<CUDAShaderObject*>(object); if (subObjectIndex >= objects.getCount()) @@ -593,6 +571,56 @@ public: setResource(offset, textureView); return SLANG_OK; } + + // Appends all types that are used to specialize the element type of this shader object in `args` list. + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override + { + // TODO: the logic here is a copy-paste of `GraphicsCommonShaderObject::collectSpecializationArgs`, + // consider moving the implementation to `ShaderObjectBase` and share the logic among different implementations. + + if (!m_bindingFinalized) + return SLANG_FAIL; + auto& subObjectRanges = getLayout()->subObjectRanges; + // The following logic is built on the assumption that all fields that involve existential types (and + // therefore require specialization) will results in a sub-object range in the type layout. + // This allows us to simply scan the sub-object ranges to find out all specialization arguments. + for (Index subObjIndex = 0; subObjIndex < subObjectRanges.getCount(); subObjIndex++) + { + // Retrieve the corresponding binding range of the sub object. + auto bindingRange = getLayout()->m_bindingRanges[subObjectRanges[subObjIndex].bindingRangeIndex]; + switch (bindingRange.bindingType) + { + case slang::BindingType::ExistentialValue: + { + // A binding type of `ExistentialValue` means the sub-object represents a interface-typed field. + // In this case the specialization argument for this field is the actual specialized type of the bound + // shader object. If the shader object's type is an ordinary type without existential fields, then the + // type argument will simply be the ordinary type. But if the sub object's type is itself a specialized + // type, we need to make sure to use that type as the specialization argument. + + // TODO: need to implement the case where the field is an array of existential values. + SLANG_ASSERT(bindingRange.count == 1); + ExtendedShaderObjectType specializedSubObjType; + SLANG_RETURN_ON_FAIL(objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType)); + args.add(specializedSubObjType); + break; + } + case slang::BindingType::ParameterBlock: + case slang::BindingType::ConstantBuffer: + // Currently we only handle the case where the field's type is + // `ParameterBlock<SomeStruct>` or `ConstantBuffer<SomeStruct>`, where `SomeStruct` is a struct type + // (not directly an interface type). In this case, we just recursively collect the specialization arguments + // from the bound sub object. + SLANG_RETURN_ON_FAIL(objects[subObjIndex]->collectSpecializationArgs(args)); + // TODO: we need to handle the case where the field is of the form `ParameterBlock<IFoo>`. We should treat + // this case the same way as the `ExistentialValue` case here, but currently we lack a mechanism to distinguish + // the two scenarios. + break; + } + // TODO: need to handle another case where specialization happens on resources fields e.g. `StructuredBuffer<IFoo>`. + } + return SLANG_OK; + } }; class CUDAEntryPointShaderObject : public CUDAShaderObject @@ -652,17 +680,8 @@ public: }; -class CUDARenderer : public IRenderer, public RefObject +class CUDARenderer : public RendererBase { -public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IRenderer* getInterface(const Guid& guid) - { - return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer) - ? static_cast<IRenderer*>(this) - : nullptr; - } - private: static const CUDAReportStyle reportType = CUDAReportStyle::Normal; static int _calcSMCountPerMultiProcessor(int major, int minor) @@ -781,12 +800,14 @@ private: int m_deviceIndex = -1; CUdevice m_device = 0; CUcontext m_context = nullptr; - CUDAPipelineState* currentPipeline = nullptr; - CUDARootShaderObject* currentRootObject = nullptr; + RefPtr<CUDAPipelineState> currentPipeline = nullptr; + RefPtr<CUDARootShaderObject> currentRootObject = nullptr; SlangContext slangContext; public: ~CUDARenderer() { + currentPipeline = nullptr; + currentRootObject = nullptr; if (m_context) { cuCtxDestroy(m_context); @@ -796,6 +817,8 @@ private: { SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_PTX, "sm_5_1")); + SLANG_RETURN_ON_FAIL(RendererBase::initialize(desc, inWindowHandle)); + SLANG_RETURN_ON_FAIL(_initCuda(reportType)); SLANG_RETURN_ON_FAIL(_findMaxFlopsDeviceIndex(&m_deviceIndex)); @@ -813,13 +836,6 @@ private: return SLANG_OK; } - virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) override - { - *outSlangSession = slangContext.session.get(); - slangContext.session->addRef(); - return SLANG_OK; - } - virtual SLANG_NO_THROW Result SLANG_MCALL createTextureResource( IResource::Usage initialUsage, const ITextureResource::Desc& desc, @@ -1271,7 +1287,7 @@ private: slang::TypeLayoutReflection* typeLayout, IShaderObjectLayout** outLayout) override { RefPtr<CUDAShaderObjectLayout> cudaLayout; - cudaLayout = new CUDAShaderObjectLayout(typeLayout); + cudaLayout = new CUDAShaderObjectLayout(this, typeLayout); *outLayout = cudaLayout.detach(); return SLANG_OK; } @@ -1309,6 +1325,17 @@ private: virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override { + // If this is a specializable program, we just keep a reference to the slang program and + // don't actually create any kernels. This program will be specialized later when we know + // the shader object bindings. + if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0) + { + RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram(); + cudaProgram->slangProgram = desc.slangProgram; + *outProgram = cudaProgram.detach(); + return SLANG_OK; + } + if( desc.kernelCount == 0 ) { return createProgramFromSlang(this, desc, outProgram); @@ -1332,7 +1359,7 @@ private: return SLANG_FAIL; RefPtr<CUDAProgramLayout> cudaLayout; - cudaLayout = new CUDAProgramLayout(slangProgramLayout); + cudaLayout = new CUDAProgramLayout(this, slangProgramLayout); cudaLayout->programLayout = slangProgramLayout; cudaProgram->layout = cudaLayout; } @@ -1346,6 +1373,7 @@ private: { RefPtr<CUDAPipelineState> state = new CUDAPipelineState(); state->shaderProgram = dynamic_cast<CUDAShaderProgram*>(desc.program); + state->init(desc); *outState = state.detach(); return Result(); } @@ -1360,18 +1388,19 @@ private: SLANG_UNUSED(buffer); } - virtual SLANG_NO_THROW void SLANG_MCALL - setPipelineState(PipelineType pipelineType, IPipelineState* state) override + virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override { - SLANG_ASSERT(pipelineType == PipelineType::Compute); currentPipeline = dynamic_cast<CUDAPipelineState*>(state); } virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override { + // Specialize the compute kernel based on the shader object bindings. + maybeSpecializePipeline(currentRootObject); + // Find out thread group size from program reflection. auto& kernelName = currentPipeline->shaderProgram->kernelName; - auto programLayout = dynamic_cast<CUDAProgramLayout*>(currentRootObject->layout.Ptr()); + auto programLayout = static_cast<CUDAProgramLayout*>(currentRootObject->getLayout()); int kernelId = programLayout->getKernelIndex(kernelName.getUnownedSlice()); SLANG_ASSERT(kernelId != -1); UInt threadGroupSize[3]; @@ -1451,21 +1480,12 @@ private: return RendererType::CUDA; } -public: - // Unused public interfaces. These functions are not supported on CUDA. - SLANG_NO_THROW Result SLANG_MCALL getFeatures( - const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) - { - if (outFeatureCount) - *outFeatureCount = 0; - return SLANG_OK; - } - - SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* featureName) + virtual PipelineStateBase* getCurrentPipeline() override { - return false; + return currentPipeline; } +public: virtual SLANG_NO_THROW void SLANG_MCALL setClearColor(const float color[4]) override { SLANG_UNUSED(color); @@ -1602,8 +1622,8 @@ public: SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* typeLayout) { - this->layout = typeLayout; - + m_layout = typeLayout; + // If the layout tells us that there is any uniform data, // then we need to allocate a constant buffer to hold that data. // @@ -1613,8 +1633,8 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* // TODO: When/where do we bind this constant buffer into // a descriptor set for later use? // - auto slangLayout = layout->typeLayout; - size_t uniformSize = layout->typeLayout->getSize(); + auto slangLayout = getLayout()->getElementTypeLayout(); + size_t uniformSize = slangLayout->getSize(); if (uniformSize) { initBuffer(renderer, uniformSize); @@ -1626,7 +1646,7 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* Index subObjectCount = slangLayout->getSubObjectRangeCount(); objects.setCount(subObjectCount); - for (auto subObjectRange : layout->subObjectRanges) + for (auto subObjectRange : getLayout()->subObjectRanges) { RefPtr<CUDAShaderObjectLayout> subObjectLayout = subObjectRange.layout; @@ -1643,7 +1663,7 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* // in each entry in this range, based on the layout // information we already have. - auto& bindingRangeInfo = layout->m_bindingRanges[subObjectRange.bindingRangeIndex]; + auto& bindingRangeInfo = getLayout()->m_bindingRanges[subObjectRange.bindingRangeIndex]; for (Index i = 0; i < bindingRangeInfo.count; ++i) { RefPtr<CUDAShaderObject> subObject = new CUDAShaderObject(); @@ -1671,15 +1691,18 @@ SlangResult CUDARootShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayo return SLANG_OK; } -SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer) +SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { - *outRenderer = new CUDARenderer(); - (*outRenderer)->addRef(); + RefPtr<CUDARenderer> result = new CUDARenderer(); + SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle)); + *outRenderer = result.detach(); return SLANG_OK; } #else -SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer) +SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { + SLANG_UNUSED(desc); + SLANG_UNUSED(windowHandle); *outRenderer = nullptr; return SLANG_OK; } diff --git a/tools/gfx/cuda/render-cuda.h b/tools/gfx/cuda/render-cuda.h index e209af02c..39d5b60f8 100644 --- a/tools/gfx/cuda/render-cuda.h +++ b/tools/gfx/cuda/render-cuda.h @@ -1,11 +1,9 @@ #pragma once -#include <cstdint> -#include "slang.h" +#include "../renderer-shared.h" namespace gfx { -class IRenderer; -SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer); +SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer); } diff --git a/tools/gfx/d3d11/render-d3d11.cpp b/tools/gfx/d3d11/render-d3d11.cpp index 3df5e25bc..079d89a59 100644 --- a/tools/gfx/d3d11/render-d3d11.cpp +++ b/tools/gfx/d3d11/render-d3d11.cpp @@ -59,7 +59,9 @@ public: kMaxUAVs = 64, kMaxRTVs = 8, }; - + + ~D3D11Renderer() {} + // Renderer implementation virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc, void* inWindowHandle) override; virtual SLANG_NO_THROW void SLANG_MCALL setClearColor(const float color[4]) override; @@ -132,8 +134,7 @@ public: setViewports(UInt count, Viewport const* viewports) override; virtual SLANG_NO_THROW void SLANG_MCALL setScissorRects(UInt count, ScissorRect const* rects) override; - virtual SLANG_NO_THROW void SLANG_MCALL - setPipelineState(PipelineType pipelineType, IPipelineState* state) override; + virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override; virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex) override; virtual SLANG_NO_THROW void SLANG_MCALL drawIndexed(UInt indexCount, UInt startIndex, UInt baseVertex) override; @@ -144,9 +145,10 @@ public: { return RendererType::DirectX11; } - - ~D3D11Renderer() {} - + virtual PipelineStateBase* getCurrentPipeline() override + { + return m_currentPipelineState; + } protected: class ScopeNVAPI @@ -444,17 +446,9 @@ public: ComPtr<ID3D11InputLayout> m_layout; }; - class PipelineStateImpl : public IPipelineState, public RefObject + class PipelineStateImpl : public PipelineStateBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IPipelineState* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState) - return static_cast<IPipelineState*>(this); - return nullptr; - } - public: RefPtr<ShaderProgramImpl> m_program; RefPtr<PipelineLayoutImpl> m_pipelineLayout; }; @@ -473,11 +467,26 @@ public: UINT m_stencilRef; float m_blendColor[4]; UINT m_sampleMask; + + void init(const GraphicsPipelineStateDesc& inDesc) + { + PipelineStateBase::PipelineStateDesc pipelineDesc; + pipelineDesc.graphics = inDesc; + pipelineDesc.type = PipelineType::Graphics; + initializeBase(pipelineDesc); + } }; class ComputePipelineStateImpl : public PipelineStateImpl { public: + void init(const ComputePipelineStateDesc& inDesc) + { + PipelineStateBase::PipelineStateDesc pipelineDesc; + pipelineDesc.compute = inDesc; + pipelineDesc.type = PipelineType::Compute; + initializeBase(pipelineDesc); + } }; /// Capture a texture to a file @@ -506,8 +515,7 @@ public: bool m_renderTargetBindingsDirty = false; - ComPtr<GraphicsPipelineStateImpl> m_currentGraphicsState; - ComPtr<ComputePipelineStateImpl> m_currentComputeState; + ComPtr<PipelineStateImpl> m_currentPipelineState; ComPtr<ID3D11RenderTargetView> m_rtvBindings[kMaxRTVs]; ComPtr<ID3D11DepthStencilView> m_dsvBinding; @@ -521,10 +529,11 @@ public: bool m_nvapi = false; }; -SlangResult SLANG_MCALL createD3D11Renderer(IRenderer** outRenderer) +SlangResult SLANG_MCALL createD3D11Renderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { - *outRenderer = new D3D11Renderer(); - (*outRenderer)->addRef(); + RefPtr<D3D11Renderer> result = new D3D11Renderer(); + SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle)); + *outRenderer = result.detach(); return SLANG_OK; } @@ -666,6 +675,8 @@ SlangResult D3D11Renderer::initialize(const Desc& desc, void* inWindowHandle) { SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_DXBC, "sm_5_0")); + SLANG_RETURN_ON_FAIL(GraphicsAPIRenderer::initialize(desc, inWindowHandle)); + auto windowHandle = (HWND)inWindowHandle; m_desc = desc; @@ -1698,8 +1709,10 @@ void D3D11Renderer::setScissorRects(UInt count, ScissorRect const* rects) } -void D3D11Renderer::setPipelineState(PipelineType pipelineType, IPipelineState* state) +void D3D11Renderer::setPipelineState(IPipelineState* state) { + auto pipelineType = static_cast<PipelineStateBase*>(state)->desc.type; + switch(pipelineType) { default: @@ -1722,8 +1735,8 @@ void D3D11Renderer::setPipelineState(PipelineType pipelineType, IPipelineState* m_immediateContext->IASetInputLayout(stateImpl->m_inputLayout->m_layout); // VS - - m_immediateContext->VSSetShader(programImpl->m_vertexShader, nullptr, 0); + if (programImpl->m_vertexShader) + m_immediateContext->VSSetShader(programImpl->m_vertexShader, nullptr, 0); // HS @@ -1736,15 +1749,15 @@ void D3D11Renderer::setPipelineState(PipelineType pipelineType, IPipelineState* m_immediateContext->RSSetState(stateImpl->m_rasterizerState); // PS - - m_immediateContext->PSSetShader(programImpl->m_pixelShader, nullptr, 0); + if (programImpl->m_pixelShader) + m_immediateContext->PSSetShader(programImpl->m_pixelShader, nullptr, 0); // OM m_immediateContext->OMSetBlendState(stateImpl->m_blendState, stateImpl->m_blendColor, stateImpl->m_sampleMask); m_immediateContext->OMSetDepthStencilState(stateImpl->m_depthStencilState, stateImpl->m_stencilRef); - m_currentGraphicsState = stateImpl; + m_currentPipelineState = stateImpl; } break; @@ -1756,8 +1769,7 @@ void D3D11Renderer::setPipelineState(PipelineType pipelineType, IPipelineState* // CS m_immediateContext->CSSetShader(programImpl->m_computeShader, nullptr, 0); - - m_currentComputeState = stateImpl; + m_currentPipelineState = stateImpl; } break; } @@ -2083,7 +2095,7 @@ Result D3D11Renderer::createGraphicsPipelineState(const GraphicsPipelineStateDes state->m_blendColor[2] = 0; state->m_blendColor[3] = 0; state->m_sampleMask = 0xFFFFFFFF; - + state->init(desc); *outState = state.detach(); return SLANG_OK; } @@ -2099,7 +2111,7 @@ Result D3D11Renderer::createComputePipelineState(const ComputePipelineStateDesc& RefPtr<ComputePipelineStateImpl> state = new ComputePipelineStateImpl(); state->m_program = programImpl; state->m_pipelineLayout = pipelineLayoutImpl; - + state->init(desc); *outState = state.detach(); return SLANG_OK; } @@ -2303,7 +2315,7 @@ void D3D11Renderer::_flushGraphicsState() { m_targetBindingsDirty[pipelineType] = false; - auto pipelineState = m_currentGraphicsState.get(); + auto pipelineState = static_cast<GraphicsPipelineStateImpl*>(m_currentPipelineState.get()); auto rtvCount = pipelineState->m_rtvCount; auto uavCount = pipelineState->m_pipelineLayout->m_uavCount; @@ -2326,7 +2338,7 @@ void D3D11Renderer::_flushComputeState() { m_targetBindingsDirty[pipelineType] = false; - auto pipelineState = m_currentComputeState.get(); + auto pipelineState = static_cast<ComputePipelineStateImpl*>(m_currentPipelineState.get()); auto uavCount = pipelineState->m_pipelineLayout->m_uavCount; diff --git a/tools/gfx/d3d11/render-d3d11.h b/tools/gfx/d3d11/render-d3d11.h index 96a6a981c..1c81b688c 100644 --- a/tools/gfx/d3d11/render-d3d11.h +++ b/tools/gfx/d3d11/render-d3d11.h @@ -1,13 +1,11 @@ // render-d3d11.h #pragma once -#include <cstdint> -#include "slang.h" +#include "../renderer-shared.h" -namespace gfx { +namespace gfx +{ -class IRenderer; - -SlangResult SLANG_MCALL createD3D11Renderer(IRenderer** outRenderer); +SlangResult SLANG_MCALL createD3D11Renderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer); } // gfx diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index 407f0dbf2..de7cbd2e2 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -140,8 +140,7 @@ public: setViewports(UInt count, Viewport const* viewports) override; virtual SLANG_NO_THROW void SLANG_MCALL setScissorRects(UInt count, ScissorRect const* rects) override; - virtual SLANG_NO_THROW void SLANG_MCALL - setPipelineState(PipelineType pipelineType, IPipelineState* state) override; + virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override; virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex) override; virtual SLANG_NO_THROW void SLANG_MCALL drawIndexed(UInt indexCount, UInt startIndex, UInt baseVertex) override; @@ -152,7 +151,10 @@ public: { return RendererType::DirectX12; } - + virtual PipelineStateBase* getCurrentPipeline() override + { + return m_currentPipelineState; + } ~D3D12Renderer(); protected: @@ -540,20 +542,25 @@ protected: D3D12DescriptorHeap m_cpuViewHeap; ///< Cbv, Srv, Uav D3D12DescriptorHeap m_cpuSamplerHeap; ///< Heap for samplers - class PipelineStateImpl : public IPipelineState, public RefObject + class PipelineStateImpl : public PipelineStateBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IPipelineState* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState) - return static_cast<IPipelineState*>(this); - return nullptr; - } - public: - PipelineType m_pipelineType; RefPtr<PipelineLayoutImpl> m_pipelineLayout; ComPtr<ID3D12PipelineState> m_pipelineState; + void init(const GraphicsPipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Graphics; + pipelineDesc.graphics = inDesc; + initializeBase(pipelineDesc); + } + void init(const ComputePipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Compute; + pipelineDesc.compute = inDesc; + initializeBase(pipelineDesc); + } }; struct BoundVertexBuffer @@ -760,10 +767,11 @@ protected: bool m_nvapi = false; }; -SlangResult SLANG_MCALL createD3D12Renderer(IRenderer** outRenderer) +SlangResult SLANG_MCALL createD3D12Renderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { - *outRenderer = new D3D12Renderer(); - (*outRenderer)->addRef(); + RefPtr<D3D12Renderer> result = new D3D12Renderer(); + SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle)); + *outRenderer = result.detach(); return SLANG_OK; } @@ -1160,7 +1168,7 @@ Result D3D12Renderer::_bindRenderState(PipelineStateImpl* pipelineStateImpl, ID3 { // TODO: we should only set some of this state as needed... - auto pipelineTypeIndex = (int) pipelineStateImpl->m_pipelineType; + auto pipelineTypeIndex = (int) pipelineStateImpl->desc.type; auto pipelineLayout = pipelineStateImpl->m_pipelineLayout; submitter->setRootSignature(pipelineLayout->m_rootSignature); @@ -1350,6 +1358,8 @@ Result D3D12Renderer::initialize(const Desc& desc, void* inWindowHandle) { SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_DXBC, "sm_5_1")); + SLANG_RETURN_ON_FAIL(GraphicsAPIRenderer::initialize(desc, inWindowHandle)); + m_hwnd = (HWND)inWindowHandle; // Rather than statically link against D3D, we load it dynamically. @@ -2692,7 +2702,7 @@ void D3D12Renderer::setScissorRects(UInt count, ScissorRect const* rects) m_commandList->RSSetScissorRects(UINT(count), dxRects); } -void D3D12Renderer::setPipelineState(PipelineType pipelineType, IPipelineState* state) +void D3D12Renderer::setPipelineState(IPipelineState* state) { m_currentPipelineState = (PipelineStateImpl*)state; } @@ -2702,7 +2712,7 @@ void D3D12Renderer::draw(UInt vertexCount, UInt startVertex) ID3D12GraphicsCommandList* commandList = m_commandList; auto pipelineState = m_currentPipelineState.Ptr(); - if (!pipelineState || (pipelineState->m_pipelineType != PipelineType::Graphics)) + if (!pipelineState || (pipelineState->desc.type != PipelineType::Graphics)) { assert(!"No graphics pipeline state set"); return; @@ -3715,9 +3725,9 @@ Result D3D12Renderer::createGraphicsPipelineState(const GraphicsPipelineStateDes SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef()))); RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); - pipelineStateImpl->m_pipelineType = PipelineType::Graphics; pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl; pipelineStateImpl->m_pipelineState = pipelineState; + pipelineStateImpl->init(desc); *outState = pipelineStateImpl.detach(); return SLANG_OK; } @@ -3768,9 +3778,9 @@ Result D3D12Renderer::createComputePipelineState(const ComputePipelineStateDesc& } RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); - pipelineStateImpl->m_pipelineType = PipelineType::Compute; pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl; pipelineStateImpl->m_pipelineState = pipelineState; + pipelineStateImpl->init(desc); *outState = pipelineStateImpl.detach(); return SLANG_OK; } diff --git a/tools/gfx/d3d12/render-d3d12.h b/tools/gfx/d3d12/render-d3d12.h index bc28e276b..3304d92dc 100644 --- a/tools/gfx/d3d12/render-d3d12.h +++ b/tools/gfx/d3d12/render-d3d12.h @@ -1,13 +1,11 @@ // render-d3d12.h #pragma once -#include <cstdint> -#include "slang.h" +#include "../renderer-shared.h" -namespace gfx { +namespace gfx +{ -class IRenderer; - -SlangResult SLANG_MCALL createD3D12Renderer(IRenderer** outRenderer); +SlangResult SLANG_MCALL createD3D12Renderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer); } // gfx diff --git a/tools/gfx/open-gl/render-gl.cpp b/tools/gfx/open-gl/render-gl.cpp index 80edbdceb..b84db44b6 100644 --- a/tools/gfx/open-gl/render-gl.cpp +++ b/tools/gfx/open-gl/render-gl.cpp @@ -151,8 +151,7 @@ public: setViewports(UInt count, Viewport const* viewports) override; virtual SLANG_NO_THROW void SLANG_MCALL setScissorRects(UInt count, ScissorRect const* rects) override; - virtual SLANG_NO_THROW void SLANG_MCALL - setPipelineState(PipelineType pipelineType, IPipelineState* state) override; + virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override; virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex) override; virtual void SLANG_MCALL drawIndexed(UInt indexCount, UInt startIndex, UInt baseVertex) override; @@ -163,7 +162,10 @@ public: { return RendererType::OpenGl; } - + virtual PipelineStateBase* getCurrentPipeline() override + { + return m_currentPipelineState.Ptr(); + } GLRenderer(); ~GLRenderer(); @@ -401,20 +403,26 @@ public: RefPtr<WeakSink<GLRenderer> > m_renderer; }; - class PipelineStateImpl : public IPipelineState, public RefObject + class PipelineStateImpl : public PipelineStateBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IPipelineState* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState) - return static_cast<IPipelineState*>(this); - return nullptr; - } - public: RefPtr<ShaderProgramImpl> m_program; RefPtr<PipelineLayoutImpl> m_pipelineLayout; RefPtr<InputLayoutImpl> m_inputLayout; + void init(const GraphicsPipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Graphics; + pipelineDesc.graphics = inDesc; + initializeBase(pipelineDesc); + } + void init(const ComputePipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Compute; + pipelineDesc.compute = inDesc; + initializeBase(pipelineDesc); + } }; enum class GlPixelFormat @@ -494,10 +502,11 @@ public: SLANG_COMPILE_TIME_ASSERT(SLANG_COUNT_OF(s_pixelFormatInfos) == int(GlPixelFormat::CountOf)); } -SlangResult SLANG_MCALL createGLRenderer(IRenderer** outRenderer) +SlangResult SLANG_MCALL createGLRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { - *outRenderer = new GLRenderer(); - (*outRenderer)->addRef(); + RefPtr<GLRenderer> result = new GLRenderer(); + SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle)); + *outRenderer = result.detach(); return SLANG_OK; } @@ -785,7 +794,9 @@ void GLRenderer::destroyBindingEntries(const BindingState::Desc& desc, const Bin SLANG_NO_THROW Result SLANG_MCALL GLRenderer::initialize(const Desc& desc, void* inWindowHandle) { - SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_GLSL, "sm_5_0")); + SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_GLSL, "glsl_440")); + + SLANG_RETURN_ON_FAIL(GraphicsAPIRenderer::initialize(desc, inWindowHandle)); auto windowHandle = (HWND)inWindowHandle; m_desc = desc; @@ -1270,8 +1281,7 @@ SLANG_NO_THROW void SLANG_MCALL GLRenderer::setScissorRects(UInt count, ScissorR } } -SLANG_NO_THROW void SLANG_MCALL - GLRenderer::setPipelineState(PipelineType pipelineType, IPipelineState* state) +SLANG_NO_THROW void SLANG_MCALL GLRenderer::setPipelineState(IPipelineState* state) { auto pipelineStateImpl = (PipelineStateImpl*) state; @@ -1552,6 +1562,7 @@ Result GLRenderer::createGraphicsPipelineState(const GraphicsPipelineStateDesc& pipelineStateImpl->m_program = programImpl; pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl; pipelineStateImpl->m_inputLayout = inputLayoutImpl; + pipelineStateImpl->init(desc); *outState = pipelineStateImpl.detach(); return SLANG_OK; } @@ -1567,6 +1578,7 @@ Result GLRenderer::createComputePipelineState(const ComputePipelineStateDesc& in RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_program = programImpl; pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl; + pipelineStateImpl->init(desc); *outState = pipelineStateImpl.detach(); return SLANG_OK; } diff --git a/tools/gfx/open-gl/render-gl.h b/tools/gfx/open-gl/render-gl.h index 79ff2d124..ef8bb318c 100644 --- a/tools/gfx/open-gl/render-gl.h +++ b/tools/gfx/open-gl/render-gl.h @@ -1,13 +1,10 @@ // render-d3d11.h #pragma once -#include <cstdint> -#include "slang.h" +#include "../renderer-shared.h" namespace gfx { -class IRenderer; - -SlangResult SLANG_MCALL createGLRenderer(IRenderer** outRenderer); +SlangResult SLANG_MCALL createGLRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer); } // gfx diff --git a/tools/gfx/render-graphics-common.cpp b/tools/gfx/render-graphics-common.cpp index adc943d1d..fb01867d8 100644 --- a/tools/gfx/render-graphics-common.cpp +++ b/tools/gfx/render-graphics-common.cpp @@ -6,64 +6,9 @@ using namespace Slang; namespace gfx { -const Slang::Guid GfxGUID::IID_ISlangUnknown = SLANG_UUID_ISlangUnknown; -const Slang::Guid GfxGUID::IID_IDescriptorSetLayout = SLANG_UUID_IDescriptorSetLayout; -const Slang::Guid GfxGUID::IID_IDescriptorSet = SLANG_UUID_IDescriptorSet; -const Slang::Guid GfxGUID::IID_IShaderProgram = SLANG_UUID_IShaderProgram; -const Slang::Guid GfxGUID::IID_IPipelineLayout = SLANG_UUID_IPipelineLayout; -const Slang::Guid GfxGUID::IID_IInputLayout = SLANG_UUID_IInputLayout; -const Slang::Guid GfxGUID::IID_IPipelineState = SLANG_UUID_IPipelineState; -const Slang::Guid GfxGUID::IID_IResourceView = SLANG_UUID_IResourceView; -const Slang::Guid GfxGUID::IID_ISamplerState = SLANG_UUID_ISamplerState; -const Slang::Guid GfxGUID::IID_IResource = SLANG_UUID_IResource; -const Slang::Guid GfxGUID::IID_IBufferResource = SLANG_UUID_IBufferResource; -const Slang::Guid GfxGUID::IID_ITextureResource = SLANG_UUID_ITextureResource; -const Slang::Guid GfxGUID::IID_IRenderer = SLANG_UUID_IRenderer; -const Slang::Guid GfxGUID::IID_IShaderObjectLayout = SLANG_UUID_IShaderObjectLayout; -const Slang::Guid GfxGUID::IID_IShaderObject = SLANG_UUID_IShaderObject; - -gfx::StageType translateStage(SlangStage slangStage) -{ - switch (slangStage) - { - default: - SLANG_ASSERT(!"unhandled case"); - return gfx::StageType::Unknown; - -#define CASE(FROM, TO) \ - case SLANG_STAGE_##FROM: \ - return gfx::StageType::TO - - CASE(VERTEX, Vertex); - CASE(HULL, Hull); - CASE(DOMAIN, Domain); - CASE(GEOMETRY, Geometry); - CASE(FRAGMENT, Fragment); - - CASE(COMPUTE, Compute); - - CASE(RAY_GENERATION, RayGeneration); - CASE(INTERSECTION, Intersection); - CASE(ANY_HIT, AnyHit); - CASE(CLOSEST_HIT, ClosestHit); - CASE(MISS, Miss); - CASE(CALLABLE, Callable); - -#undef CASE - } -} - -class GraphicsCommonShaderObjectLayout : public IShaderObjectLayout, public RefObject +class GraphicsCommonShaderObjectLayout : public ShaderObjectLayoutBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IShaderObjectLayout* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout) - return static_cast<IShaderObjectLayout*>(this); - return nullptr; - } -public: struct BindingRangeInfo { slang::BindingType bindingType; @@ -71,7 +16,12 @@ public: Index baseIndex; Index descriptorSetIndex; Index rangeIndexInDescriptorSet; - // Index subObjectRangeIndex = -1; + + // Returns true if this binding range consumes a specialization argument slot. + bool isSpecializationArg() const + { + return bindingType == slang::BindingType::ExistentialValue; + } }; struct SubObjectRangeInfo @@ -91,10 +41,13 @@ public: struct Builder { public: - Builder(IRenderer* renderer) + Builder(RendererBase* renderer) : m_renderer(renderer) {} + RendererBase* m_renderer; + slang::TypeLayoutReflection* m_elementTypeLayout; + List<BindingRangeInfo> m_bindingRanges; List<SubObjectRangeInfo> m_subObjectRanges; @@ -196,6 +149,15 @@ public: for (SlangInt r = 0; r < descriptorRangeCount; ++r) { auto slangBindingType = typeLayout->getDescriptorSetDescriptorRangeType(s, r); + + switch (slangBindingType) + { + case slang::BindingType::ExistentialValue: + continue; + default: + break; + } + auto gfxDescriptorType = _mapDescriptorType(slangBindingType); IDescriptorSetLayout::SlotRangeDesc descriptorRangeDesc; @@ -284,9 +246,6 @@ public: BindingRangeInfo bindingRangeInfo; bindingRangeInfo.bindingType = slangBindingType; bindingRangeInfo.count = count; - // bindingRangeInfo.descriptorSetIndex = descriptorSetIndex; - // bindingRangeInfo.rangeIndexInDescriptorSet = slotRangeIndex; - // bindingRangeInfo.subObjectRangeIndex = subObjectRangeIndex; bindingRangeInfo.baseIndex = baseIndex; bindingRangeInfo.descriptorSetIndex = descriptorSetIndex; bindingRangeInfo.rangeIndexInDescriptorSet = rangeIndexInDescriptorSet; @@ -364,13 +323,10 @@ public: *outLayout = layout.detach(); return SLANG_OK; } - - IRenderer* m_renderer = nullptr; - slang::TypeLayoutReflection* m_elementTypeLayout = nullptr; }; static Result createForElementType( - IRenderer* renderer, + RendererBase* renderer, slang::TypeLayoutReflection* elementType, GraphicsCommonShaderObjectLayout** outLayout) { @@ -397,15 +353,19 @@ public: SubObjectRangeInfo const& getSubObjectRange(Index index) { return m_subObjectRanges[index]; } List<SubObjectRangeInfo> const& getSubObjectRanges() { return m_subObjectRanges; } - IRenderer* getRenderer() { return m_renderer; } + RendererBase* getRenderer() { return m_renderer; } + slang::TypeReflection* getType() + { + return m_elementTypeLayout->getType(); + } protected: Result _init(Builder const* builder) { auto renderer = builder->m_renderer; - m_renderer = renderer; - m_elementTypeLayout = builder->m_elementTypeLayout; + initBase(renderer, builder->m_elementTypeLayout); + m_bindingRanges = builder->m_bindingRanges; for (auto descriptorSetBuildInfo : builder->m_descriptorSetBuildInfos) @@ -430,16 +390,12 @@ protected: m_samplerCount = builder->m_samplerCount; m_combinedTextureSamplerCount = builder->m_combinedTextureSamplerCount; m_subObjectCount = builder->m_subObjectCount; - m_subObjectRanges = builder->m_subObjectRanges; - return SLANG_OK; } - IRenderer* m_renderer; List<RefPtr<DescriptorSetInfo>> m_descriptorSets; List<BindingRangeInfo> m_bindingRanges; - slang::TypeLayoutReflection* m_elementTypeLayout; Index m_resourceViewCount = 0; Index m_samplerCount = 0; Index m_combinedTextureSamplerCount = 0; @@ -461,7 +417,7 @@ public: struct Builder : Super::Builder { Builder(IRenderer* renderer) - : Super::Builder(renderer) + : Super::Builder(static_cast<RendererBase*>(renderer)) {} Result build(EntryPointLayout** outLayout) @@ -566,7 +522,7 @@ public: struct Builder : Super::Builder { Builder(IRenderer* renderer) - : Super::Builder(renderer) + : Super::Builder(static_cast<RendererBase*>(renderer)) {} Result build(GraphicsCommonProgramLayout** outLayout) @@ -709,18 +665,9 @@ protected: ComPtr<IPipelineLayout> m_pipelineLayout; }; -class GraphicsCommonShaderObject : public IShaderObject, public RefObject +class GraphicsCommonShaderObject : public ShaderObjectBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IShaderObject* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject) - return static_cast<IShaderObject*>(this); - return nullptr; - } - -public: static Result create( IRenderer* renderer, GraphicsCommonShaderObjectLayout* layout, @@ -733,7 +680,7 @@ public: return SLANG_OK; } - IRenderer* getRenderer() { return m_layout->getRenderer(); } + RendererBase* getRenderer() { return m_layout->getRenderer(); } SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() SLANG_OVERRIDE { return 0; } @@ -746,7 +693,7 @@ public: GraphicsCommonShaderObjectLayout* getLayout() { - return m_layout; + return static_cast<GraphicsCommonShaderObjectLayout*>(m_layout.Ptr()); } SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getElementTypeLayout() SLANG_OVERRIDE @@ -772,29 +719,20 @@ public: { if (offset.bindingRangeIndex < 0) return SLANG_E_INVALID_ARG; - if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount()) + auto layout = getLayout(); + if (offset.bindingRangeIndex >= layout->getBindingRangeCount()) return SLANG_E_INVALID_ARG; - auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex); - - // TODO: Is this reasonable to store the base index directly in the binding range? - m_objects[bindingRange.baseIndex + offset.bindingArrayIndex] = - reinterpret_cast<GraphicsCommonShaderObject*>(object); - // auto& subObjectRange = - // m_layout->getSubObjectRange(bindingRange.subObjectRangeIndex); - // m_objects[subObjectRange.baseIndex + offset.bindingArrayIndex] = object; + auto subObject = static_cast<GraphicsCommonShaderObject*>(object); + if (!subObject->m_bindingFinalized) + return SLANG_E_INVALID_ARG; -#if 0 + auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); - SLANG_ASSERT(bindingRange.descriptorSetIndex >= 0); - SLANG_ASSERT(bindingRange.descriptorSetIndex < m_descriptorSets.getCount()); - auto& descriptorSet = m_descriptorSets[bindingRange.descriptorSetIndex]; + // TODO: Is this reasonable to store the base index directly in the binding range? + m_objects[bindingRange.baseIndex + offset.bindingArrayIndex] = subObject; - descriptorSet->setConstantBuffer(bindingRange.rangeIndexInDescriptorSet, offset.bindingArrayIndex, buffer); - return SLANG_OK; -#else return SLANG_E_NOT_IMPLEMENTED; -#endif } virtual SLANG_NO_THROW Result SLANG_MCALL @@ -804,9 +742,10 @@ public: SLANG_ASSERT(outObject); if (offset.bindingRangeIndex < 0) return SLANG_E_INVALID_ARG; - if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount()) + auto layout = getLayout(); + if (offset.bindingRangeIndex >= layout->getBindingRangeCount()) return SLANG_E_INVALID_ARG; - auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex); + auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); auto object = m_objects[bindingRange.baseIndex + offset.bindingArrayIndex].Ptr(); object->addRef(); @@ -833,9 +772,10 @@ public: { if (offset.bindingRangeIndex < 0) return SLANG_E_INVALID_ARG; - if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount()) + auto layout = getLayout(); + if (offset.bindingRangeIndex >= layout->getBindingRangeCount()) return SLANG_E_INVALID_ARG; - auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex); + auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); m_resourceViews[bindingRange.baseIndex + offset.bindingArrayIndex] = resourceView; return SLANG_OK; @@ -846,9 +786,10 @@ public: { if (offset.bindingRangeIndex < 0) return SLANG_E_INVALID_ARG; - if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount()) + auto layout = getLayout(); + if (offset.bindingRangeIndex >= layout->getBindingRangeCount()) return SLANG_E_INVALID_ARG; - auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex); + auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); m_samplers[bindingRange.baseIndex + offset.bindingArrayIndex] = sampler; return SLANG_OK; @@ -859,9 +800,10 @@ public: { if (offset.bindingRangeIndex < 0) return SLANG_E_INVALID_ARG; - if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount()) + auto layout = getLayout(); + if (offset.bindingRangeIndex >= layout->getBindingRangeCount()) return SLANG_E_INVALID_ARG; - auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex); + auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); auto& slot = m_combinedTextureSamplers[bindingRange.baseIndex + offset.bindingArrayIndex]; slot.textureView = textureView; @@ -869,6 +811,55 @@ public: return SLANG_OK; } +public: + // Appends all types that are used to specialize the element type of this shader object in `args` list. + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override + { + if (!m_bindingFinalized) + return SLANG_FAIL; + auto& subObjectRanges = getLayout()->getSubObjectRanges(); + // The following logic is built on the assumption that all fields that involve existential types (and + // therefore require specialization) will results in a sub-object range in the type layout. + // This allows us to simply scan the sub-object ranges to find out all specialization arguments. + for (Index subObjIndex = 0; subObjIndex < subObjectRanges.getCount(); subObjIndex++) + { + // Retrieve the corresponding binding range of the sub object. + auto bindingRange = getLayout()->getBindingRange(subObjectRanges[subObjIndex].bindingRangeIndex); + switch (bindingRange.bindingType) + { + case slang::BindingType::ExistentialValue: + { + // A binding type of `ExistentialValue` means the sub-object represents a interface-typed field. + // In this case the specialization argument for this field is the actual specialized type of the bound + // shader object. If the shader object's type is an ordinary type without existential fields, then the + // type argument will simply be the ordinary type. But if the sub object's type is itself a specialized + // type, we need to make sure to use that type as the specialization argument. + + // TODO: need to implement the case where the field is an array of existential values. + SLANG_ASSERT(bindingRange.count == 1); + ExtendedShaderObjectType specializedSubObjType; + SLANG_RETURN_ON_FAIL(m_objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType)); + args.add(specializedSubObjType); + break; + } + case slang::BindingType::ParameterBlock: + case slang::BindingType::ConstantBuffer: + // Currently we only handle the case where the field's type is + // `ParameterBlock<SomeStruct>` or `ConstantBuffer<SomeStruct>`, where `SomeStruct` is a struct type + // (not directly an interface type). In this case, we just recursively collect the specialization arguments + // from the bound sub object. + SLANG_RETURN_ON_FAIL(m_objects[subObjIndex]->collectSpecializationArgs(args)); + // TODO: we need to handle the case where the field is of the form `ParameterBlock<IFoo>`. We should treat + // this case the same way as the `ExistentialValue` case here, but currently we lack a mechanism to distinguish + // the two scenarios. + break; + } + // TODO: need to handle another case where specialization happens on resources fields e.g. `StructuredBuffer<IFoo>`. + } + return SLANG_OK; + } + + protected: friend class ProgramVars; @@ -953,7 +944,7 @@ protected: IPipelineLayout* pipelineLayout, Index& ioRootIndex) { - GraphicsCommonShaderObjectLayout* layout = m_layout; + GraphicsCommonShaderObjectLayout* layout = getLayout(); // Create the descritpor sets required by the layout... // @@ -979,7 +970,7 @@ protected: Result _bindIntoDescriptorSet( IDescriptorSet* descriptorSet, Index baseRangeIndex, Index subObjectRangeArrayIndex) { - GraphicsCommonShaderObjectLayout* layout = m_layout; + GraphicsCommonShaderObjectLayout* layout = getLayout(); if (m_buffer) { @@ -1063,7 +1054,7 @@ protected: public: virtual Result _bindIntoDescriptorSets(ComPtr<IDescriptorSet>* descriptorSets) { - GraphicsCommonShaderObjectLayout* layout = m_layout; + GraphicsCommonShaderObjectLayout* layout = getLayout(); if (m_buffer) { @@ -1134,7 +1125,6 @@ public: return SLANG_OK; } - RefPtr<GraphicsCommonShaderObjectLayout> m_layout = nullptr; ComPtr<IBufferResource> m_buffer; List<ComPtr<IResourceView>> m_resourceViews; @@ -1333,8 +1323,7 @@ Result GraphicsAPIRenderer::initProgramCommon( SLANG_RETURN_ON_FAIL(builder.build(programLayout.writeRef())); } - - program->m_slangProgram = slangProgram; + program->slangProgram = slangProgram; program->m_layout = programLayout; return SLANG_OK; @@ -1347,68 +1336,16 @@ Result SLANG_MCALL if (!programVars) return SLANG_E_INVALID_HANDLE; - programVars->apply(this, pipelineType); - return SLANG_OK; -} + SLANG_RETURN_ON_FAIL(maybeSpecializePipeline(programVars)); -SLANG_NO_THROW Result SLANG_MCALL - gfx::GraphicsAPIRenderer::getFeatures( - const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) -{ - if (bufferSize >= (UInt)m_features.getCount()) - { - for (Index i = 0; i < m_features.getCount(); i++) - { - outFeatures[i] = m_features[i].getUnownedSlice().begin(); - } - } - if (outFeatureCount) - *outFeatureCount = (UInt)m_features.getCount(); - return SLANG_OK; -} - -SLANG_NO_THROW bool SLANG_MCALL gfx::GraphicsAPIRenderer::hasFeature(const char* featureName) -{ - return m_features.findFirstIndex([&](Slang::String x) { return x == featureName; }) != -1; -} - -SLANG_NO_THROW Result SLANG_MCALL gfx::GraphicsAPIRenderer::getSlangSession(slang::ISession** outSlangSession) -{ - *outSlangSession = slangContext.session.get(); - slangContext.session->addRef(); + // Apply shader parameter bindings. + programVars->apply(this, pipelineType); return SLANG_OK; } -GraphicsCommonShaderProgram::~GraphicsCommonShaderProgram() -{ - // Note: It might not seem like this destructor is needed at all, since - // it is empty. - // - // In pratice, though, it seems to be required because the `m_layout` - // field is declared in the coresponding header before the `GraphicsCommonProgramLayout` - // is declared (we only have a forward declaration). - // - // `m_layout` is a `RefPtr`, and it seems that the compiler (or at least - // the Visual Studio compiler) either cannot synthesize a destructor for - // the type that properly destructs the field, or it simply synthesizes - // an incorect destructor. - // - // I suspect that part of the problem stems from the way that `GraphicsCommonProgramLayout` - // inherits from `RefObject` via multiple inheritance. - // - // No matter what, defining the destructor here in a file where - // the declaration of `GraphicsCommonProgramLayout` is visible - // seems to result in a correct destructor being emitted. - // - // TODO: Ther simpler and more robust fix would be to move the declaration - // of `GraphicsCommonProgramLayout` and related types to the header. -} - -IShaderProgram* GraphicsCommonShaderProgram::getInterface(const Guid& guid) +GraphicsCommonProgramLayout* gfx::GraphicsCommonShaderProgram::getLayout() const { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram) - return static_cast<IShaderProgram*>(this); - return nullptr; + return static_cast<GraphicsCommonProgramLayout*>(m_layout.Ptr()); } void GraphicsAPIRenderer::preparePipelineDesc(GraphicsPipelineStateDesc& desc) @@ -1432,11 +1369,4 @@ void GraphicsAPIRenderer::preparePipelineDesc(ComputePipelineStateDesc& desc) } } -IRenderer* gfx::GraphicsAPIRenderer::getInterface(const Guid& guid) -{ - return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer) - ? static_cast<IRenderer*>(this) - : nullptr; -} - } diff --git a/tools/gfx/render-graphics-common.h b/tools/gfx/render-graphics-common.h index f4d9567e3..59dfb888a 100644 --- a/tools/gfx/render-graphics-common.h +++ b/tools/gfx/render-graphics-common.h @@ -1,41 +1,25 @@ #pragma once -#include "tools/gfx/render.h" +#include "tools/gfx/renderer-shared.h" #include "core/slang-basic.h" #include "tools/gfx/slang-context.h" namespace gfx { - class GraphicsCommonProgramLayout; -class GraphicsCommonShaderProgram : public IShaderProgram, public Slang::RefObject +class GraphicsCommonShaderProgram : public ShaderProgramBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - - IShaderProgram* getInterface(const Slang::Guid& guid); - - GraphicsCommonProgramLayout* getLayout() const { return m_layout; } - -protected: - ~GraphicsCommonShaderProgram(); - + GraphicsCommonProgramLayout* getLayout() const; private: friend class GraphicsAPIRenderer; - ComPtr<slang::IComponentType> m_slangProgram; - Slang::RefPtr<GraphicsCommonProgramLayout> m_layout; + Slang::RefPtr<ShaderObjectLayoutBase> m_layout; }; -class GraphicsAPIRenderer : public IRenderer, public Slang::RefObject +class GraphicsAPIRenderer : public RendererBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - virtual SLANG_NO_THROW Result SLANG_MCALL getFeatures( - const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) SLANG_OVERRIDE; - virtual SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* featureName) SLANG_OVERRIDE; - virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) SLANG_OVERRIDE; - virtual SLANG_NO_THROW Result SLANG_MCALL createShaderObjectLayout( slang::TypeLayoutReflection* typeLayout, IShaderObjectLayout** outLayout) SLANG_OVERRIDE; virtual SLANG_NO_THROW Result SLANG_MCALL @@ -47,34 +31,9 @@ public: bindRootShaderObject(PipelineType pipelineType, IShaderObject* object) SLANG_OVERRIDE; void preparePipelineDesc(GraphicsPipelineStateDesc& desc); void preparePipelineDesc(ComputePipelineStateDesc& desc); - IRenderer* getInterface(const Slang::Guid& guid); Result initProgramCommon( GraphicsCommonShaderProgram* program, IShaderProgram::Desc const& desc); - -protected: - Slang::List<Slang::String> m_features; - SlangContext slangContext; }; - -struct GfxGUID -{ - static const Slang::Guid IID_ISlangUnknown; - static const Slang::Guid IID_IDescriptorSetLayout; - static const Slang::Guid IID_IDescriptorSet; - static const Slang::Guid IID_IShaderProgram; - static const Slang::Guid IID_IPipelineLayout; - static const Slang::Guid IID_IPipelineState; - static const Slang::Guid IID_IResourceView; - static const Slang::Guid IID_ISamplerState; - static const Slang::Guid IID_IResource; - static const Slang::Guid IID_IBufferResource; - static const Slang::Guid IID_ITextureResource; - static const Slang::Guid IID_IInputLayout; - static const Slang::Guid IID_IRenderer; - static const Slang::Guid IID_IShaderObjectLayout; - static const Slang::Guid IID_IShaderObject; -}; - } diff --git a/tools/gfx/render.cpp b/tools/gfx/render.cpp index a25393714..042f4253b 100644 --- a/tools/gfx/render.cpp +++ b/tools/gfx/render.cpp @@ -85,35 +85,35 @@ extern "C" } } - SGRendererCreateFunc SLANG_MCALL gfxGetCreateFunc(RendererType type) + SLANG_GFX_API SlangResult SLANG_MCALL gfxCreateRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { - switch (type) + switch (desc->rendererType) { #if SLANG_WINDOWS_FAMILY case RendererType::DirectX11: { - return &createD3D11Renderer; + return createD3D11Renderer(desc, windowHandle, outRenderer); } case RendererType::DirectX12: { - return &createD3D12Renderer; + return createD3D12Renderer(desc, windowHandle, outRenderer); } case RendererType::OpenGl: { - return &createGLRenderer; + return createGLRenderer(desc, windowHandle, outRenderer); } case RendererType::Vulkan: { - return &createVKRenderer; + return createVKRenderer(desc, windowHandle, outRenderer); } case RendererType::CUDA: { - return &createCUDARenderer; + return createCUDARenderer(desc, windowHandle, outRenderer); } #endif default: - return nullptr; + return SLANG_FAIL; } } diff --git a/tools/gfx/render.h b/tools/gfx/render.h index 325bab4ee..13af56550 100644 --- a/tools/gfx/render.h +++ b/tools/gfx/render.h @@ -880,6 +880,7 @@ public: setSampler(ShaderOffset const& offset, ISamplerState* sampler) = 0; virtual SLANG_NO_THROW Result SLANG_MCALL setCombinedTextureSampler( ShaderOffset const& offset, IResourceView* textureView, ISamplerState* sampler) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL finalizeBindings() = 0; }; #define SLANG_UUID_IShaderObject \ { \ @@ -1104,18 +1105,17 @@ public: struct Desc { + RendererType rendererType; // The underlying API/Platform of the renderer. int width = 0; // Width in pixels int height = 0; // height in pixels const char* adapter = nullptr; // Name to identify the adapter to use int requiredFeatureCount = 0; // Number of required features. const char** requiredFeatures = nullptr; // Array of required feature names, whose size is `requiredFeatureCount`. int nvapiExtnSlot = -1; // The slot (typically UAV) used to identify NVAPI intrinsics. If >=0 NVAPI is required. + ISlangFileSystem* shaderCacheFileSystem = nullptr; // The file system for loading cached shader kernels. SlangDesc slang = {}; // Configurations for Slang. }; - // Will return with SLANG_E_NOT_AVAILABLE if NVAPI can't be initialized and nvapiExtnSlot >= 0 - virtual SLANG_NO_THROW Result SLANG_MCALL initialize(const Desc& desc, void* inWindowHandle) = 0; - virtual SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* feature) = 0; /// Returns a list of features supported by the renderer. @@ -1342,7 +1342,7 @@ public: setScissorRects(1, &rect); } - virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(PipelineType pipelineType, IPipelineState* state) = 0; + virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) = 0; virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex = 0) = 0; virtual SLANG_NO_THROW void SLANG_MCALL drawIndexed(UInt indexCount, UInt startIndex = 0, UInt baseVertex = 0) = 0; @@ -1374,8 +1374,6 @@ inline void IRenderer::setVertexBuffer(UInt slot, IBufferResource* buffer, UInt extern "C" { - typedef SlangResult(SLANG_MCALL * SGRendererCreateFunc)(IRenderer** outRenderer); - /// Gets the size in bytes of a Format type. Returns 0 if a size is not defined/invalid SLANG_GFX_API size_t SLANG_MCALL gfxGetFormatSize(Format format); @@ -1394,7 +1392,7 @@ extern "C" SLANG_GFX_API const char* SLANG_MCALL gfxGetRendererName(RendererType type); /// Given a type returns a function that can construct it, or nullptr if there isn't one - SLANG_GFX_API SGRendererCreateFunc SLANG_MCALL gfxGetCreateFunc(RendererType type); + SLANG_GFX_API SlangResult SLANG_MCALL gfxCreateRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer); } }// renderer_test diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 2072d52ef..a2e1751fd 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -1,11 +1,60 @@ #include "renderer-shared.h" #include "render-graphics-common.h" +#include "core/slang-io.h" +#include "core/slang-token-reader.h" using namespace Slang; namespace gfx { +const Slang::Guid GfxGUID::IID_ISlangUnknown = SLANG_UUID_ISlangUnknown; +const Slang::Guid GfxGUID::IID_IDescriptorSetLayout = SLANG_UUID_IDescriptorSetLayout; +const Slang::Guid GfxGUID::IID_IDescriptorSet = SLANG_UUID_IDescriptorSet; +const Slang::Guid GfxGUID::IID_IShaderProgram = SLANG_UUID_IShaderProgram; +const Slang::Guid GfxGUID::IID_IPipelineLayout = SLANG_UUID_IPipelineLayout; +const Slang::Guid GfxGUID::IID_IInputLayout = SLANG_UUID_IInputLayout; +const Slang::Guid GfxGUID::IID_IPipelineState = SLANG_UUID_IPipelineState; +const Slang::Guid GfxGUID::IID_IResourceView = SLANG_UUID_IResourceView; +const Slang::Guid GfxGUID::IID_ISamplerState = SLANG_UUID_ISamplerState; +const Slang::Guid GfxGUID::IID_IResource = SLANG_UUID_IResource; +const Slang::Guid GfxGUID::IID_IBufferResource = SLANG_UUID_IBufferResource; +const Slang::Guid GfxGUID::IID_ITextureResource = SLANG_UUID_ITextureResource; +const Slang::Guid GfxGUID::IID_IRenderer = SLANG_UUID_IRenderer; +const Slang::Guid GfxGUID::IID_IShaderObjectLayout = SLANG_UUID_IShaderObjectLayout; +const Slang::Guid GfxGUID::IID_IShaderObject = SLANG_UUID_IShaderObject; + +gfx::StageType translateStage(SlangStage slangStage) +{ + switch (slangStage) + { + default: + SLANG_ASSERT(!"unhandled case"); + return gfx::StageType::Unknown; + +#define CASE(FROM, TO) \ +case SLANG_STAGE_##FROM: \ +return gfx::StageType::TO + + CASE(VERTEX, Vertex); + CASE(HULL, Hull); + CASE(DOMAIN, Domain); + CASE(GEOMETRY, Geometry); + CASE(FRAGMENT, Fragment); + + CASE(COMPUTE, Compute); + + CASE(RAY_GENERATION, RayGeneration); + CASE(INTERSECTION, Intersection); + CASE(ANY_HIT, AnyHit); + CASE(CLOSEST_HIT, ClosestHit); + CASE(MISS, Miss); + CASE(CALLABLE, Callable); + +#undef CASE + } +} + IResource* BufferResource::getInterface(const Slang::Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResource || @@ -96,4 +145,396 @@ Result createProgramFromSlang(IRenderer* renderer, IShaderProgram::Desc const& o return renderer->createProgram(programDesc, outProgram); } +IShaderObject* gfx::ShaderObjectBase::getInterface(const Guid& guid) +{ + if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject) + return static_cast<IShaderObject*>(this); + return nullptr; +} + +IShaderProgram* gfx::ShaderProgramBase::getInterface(const Guid& guid) +{ + if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram) + return static_cast<IShaderProgram*>(this); + return nullptr; +} + +IPipelineState* gfx::PipelineStateBase::getInterface(const Guid& guid) +{ + if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState) + return static_cast<IPipelineState*>(this); + return nullptr; +} + +void PipelineStateBase::initializeBase(const PipelineStateDesc& inDesc) +{ + desc = inDesc; + + auto program = desc.getProgram(); + isSpecializable = (program->slangProgram && program->slangProgram->getSpecializationParamCount() != 0); +} + +IRenderer* gfx::RendererBase::getInterface(const Guid& guid) +{ + return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer) + ? static_cast<IRenderer*>(this) + : nullptr; +} + +SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc, void* inWindowHandle) +{ + SLANG_UNUSED(inWindowHandle); + + shaderCache.init(desc.shaderCacheFileSystem); + return SLANG_OK; +} + +SLANG_NO_THROW Result SLANG_MCALL RendererBase::getFeatures( + const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) +{ + if (bufferSize >= (UInt)m_features.getCount()) + { + for (Index i = 0; i < m_features.getCount(); i++) + { + outFeatures[i] = m_features[i].getUnownedSlice().begin(); + } + } + if (outFeatureCount) + *outFeatureCount = (UInt)m_features.getCount(); + return SLANG_OK; +} + +SLANG_NO_THROW bool SLANG_MCALL RendererBase::hasFeature(const char* featureName) +{ + return m_features.findFirstIndex([&](Slang::String x) { return x == featureName; }) != -1; +} + +SLANG_NO_THROW Result SLANG_MCALL RendererBase::getSlangSession(slang::ISession** outSlangSession) +{ + *outSlangSession = slangContext.session.get(); + slangContext.session->addRef(); + return SLANG_OK; +} + +ShaderComponentID ShaderCache::getComponentId(slang::TypeReflection* type) +{ + ComponentKey key; + key.typeName = UnownedStringSlice(type->getName()); + switch (type->getKind()) + { + case slang::TypeReflection::Kind::Specialized: + // TODO: collect specialization arguments and append them to `key`. + SLANG_UNIMPLEMENTED_X("specialized type"); + default: + break; + } + key.updateHash(); + return getComponentId(key); +} + +ShaderComponentID ShaderCache::getComponentId(UnownedStringSlice name) +{ + ComponentKey key; + key.typeName = name; + key.updateHash(); + return getComponentId(key); +} + +ShaderComponentID ShaderCache::getComponentId(ComponentKey key) +{ + ShaderComponentID componentId = 0; + if (componentIds.TryGetValue(key, componentId)) + return componentId; + OwningComponentKey owningTypeKey; + owningTypeKey.hash = key.hash; + owningTypeKey.typeName = key.typeName; + owningTypeKey.specializationArgs.addRange(key.specializationArgs); + ShaderComponentID resultId = static_cast<ShaderComponentID>(componentIds.Count()); + componentIds[owningTypeKey] = resultId; + return resultId; +} + +void ShaderCache::init(ISlangFileSystem* cacheFileSystem) +{ + fileSystem = cacheFileSystem; + + ComPtr<ISlangBlob> indexFileBlob; + if (fileSystem && fileSystem->loadFile("index", indexFileBlob.writeRef()) == SLANG_OK) + { + UnownedStringSlice indexText = UnownedStringSlice(static_cast<const char*>(indexFileBlob->getBufferPointer())); + TokenReader reader = TokenReader(indexText); + auto componentCountInFileSystem = reader.ReadUInt(); + for (uint32_t i = 0; i < componentCountInFileSystem; i++) + { + OwningComponentKey key; + auto componentId = reader.ReadUInt(); + key.typeName = reader.ReadWord(); + key.specializationArgs.setCount(reader.ReadUInt()); + for (auto& arg : key.specializationArgs) + arg = reader.ReadUInt(); + componentIds[key] = componentId; + } + } +} + +void ShaderCache::writeToFileSystem(ISlangMutableFileSystem* outputFileSystem) +{ + StringBuilder indexBuilder; + indexBuilder << componentIds.Count() << Slang::EndLine; + for (auto id : componentIds) + { + indexBuilder << id.Value << " "; + indexBuilder << id.Key.typeName << " " << id.Key.specializationArgs.getCount(); + for (auto arg : id.Key.specializationArgs) + indexBuilder << " " << arg; + indexBuilder << Slang::EndLine; + } + outputFileSystem->saveFile("index", indexBuilder.getBuffer(), indexBuilder.getLength()); + for (auto& binary : shaderBinaries) + { + ComPtr<ISlangBlob> blob; + binary.Value->writeToBlob(blob.writeRef()); + outputFileSystem->saveFile(String(binary.Key).getBuffer(), blob->getBufferPointer(), blob->getBufferSize()); + } +} + +Slang::RefPtr<ShaderBinary> ShaderCache::tryLoadShaderBinary(ShaderComponentID componentId) +{ + Slang::ComPtr<ISlangBlob> entryBlob; + Slang::RefPtr<ShaderBinary> binary; + if (shaderBinaries.TryGetValue(componentId, binary)) + return binary; + + if (fileSystem && fileSystem->loadFile(String(componentId).getBuffer(), entryBlob.writeRef()) == SLANG_OK) + { + binary = new ShaderBinary(); + binary->loadFromBlob(entryBlob.get()); + return binary; + } + return nullptr; +} + +void ShaderCache::addShaderBinary(ShaderComponentID componentId, ShaderBinary* binary) +{ + shaderBinaries[componentId] = binary; +} + +void ShaderCache::addSpecializedPipeline(PipelineKey key, Slang::RefPtr<PipelineStateBase> specializedPipeline) +{ + specializedPipelines[key] = specializedPipeline; +} + +struct ShaderBinaryEntryHeader +{ + StageType stage; + uint32_t nameLength; + uint32_t codeLength; +}; + +Result ShaderBinary::loadFromBlob(ISlangBlob* blob) +{ + MemoryStreamBase memStream(Slang::FileAccess::Read, blob->getBufferPointer(), blob->getBufferSize()); + uint32_t nameLength = 0; + ShaderBinaryEntryHeader header; + if (memStream.read(&header, sizeof(header)) != sizeof(header)) + return SLANG_FAIL; + const uint8_t* name = memStream.getContents().getBuffer() + memStream.getPosition(); + const uint8_t* code = name + header.nameLength; + entryPointName = reinterpret_cast<const char*>(name); + stage = header.stage; + source.addRange(code, header.codeLength); + return SLANG_OK; +} + +Result ShaderBinary::writeToBlob(ISlangBlob** outBlob) +{ + OwnedMemoryStream outStream(FileAccess::Write); + ShaderBinaryEntryHeader header; + header.stage = stage; + header.nameLength = static_cast<uint32_t>(entryPointName.getLength() + 1); + header.codeLength = static_cast<uint32_t>(source.getCount()); + outStream.write(&header, sizeof(header)); + outStream.write(entryPointName.getBuffer(), header.nameLength - 1); + uint8_t zeroTerminator = 0; + outStream.write(&zeroTerminator, 1); + outStream.write(source.getBuffer(), header.codeLength); + RefPtr<RawBlob> blob = new RawBlob(outStream.getContents().getBuffer(), outStream.getContents().getCount()); + *outBlob = blob.detach(); + return SLANG_OK; +} + +IShaderObjectLayout* ShaderObjectLayoutBase::getInterface(const Slang::Guid& guid) +{ + if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout) + return static_cast<IShaderObjectLayout*>(this); + return nullptr; +} + +void ShaderObjectLayoutBase::initBase(RendererBase* renderer, slang::TypeLayoutReflection* elementTypeLayout) +{ + m_renderer = renderer; + m_elementTypeLayout = elementTypeLayout; + m_componentID = m_renderer->shaderCache.getComponentId(m_elementTypeLayout->getType()); +} + +SLANG_NO_THROW Result SLANG_MCALL ShaderObjectBase::finalizeBindings() +{ + m_bindingFinalized = true; + + // With all binding fixed, the shader object's type can be determined by specializing the + // shader object's type with the type of bound sub objects. + // Now obtain a componentID for the specialized shader object type from the shader cache. + SLANG_RETURN_ON_FAIL(getSpecializedShaderObjectType(&shaderObjectType)); + return SLANG_OK; +} + + +// Get the final type this shader object represents. If the shader object's type has existential fields, +// this function will return a specialized type using the bound sub-objects' type as specialization argument. + +Result ShaderObjectBase::getSpecializedShaderObjectType(ExtendedShaderObjectType* outType) +{ + if (shaderObjectType.slangType) + *outType = shaderObjectType; + ExtendedShaderObjectTypeList specializationArgs; + SLANG_RETURN_ON_FAIL(collectSpecializationArgs(specializationArgs)); + if (specializationArgs.getCount() == 0) + { + shaderObjectType.componentID = getLayout()->getComponentID(); + shaderObjectType.slangType = getLayout()->getElementTypeLayout()->getType(); + } + else + { + shaderObjectType.slangType = getRenderer()->slangContext.session->specializeType( + getElementTypeLayout()->getType(), + specializationArgs.components.getArrayView().getBuffer(), specializationArgs.getCount()); + shaderObjectType.componentID = getRenderer()->shaderCache.getComponentId(shaderObjectType.slangType); + } + *outType = shaderObjectType; + return SLANG_OK; +} + +Result RendererBase::maybeSpecializePipeline(ShaderObjectBase* rootObject) +{ + auto currentPipeline = getCurrentPipeline(); + auto pipelineType = currentPipeline->desc.type; + if (currentPipeline->unspecializedPipelineState) + currentPipeline = currentPipeline->unspecializedPipelineState; + // If the currently bound pipeline is specializable, we need to specialize it based on bound shader objects. + if (currentPipeline->isSpecializable) + { + specializationArgs.clear(); + SLANG_RETURN_ON_FAIL(rootObject->collectSpecializationArgs(specializationArgs)); + + // Construct a shader cache key that represents the specialized shader kernels. + PipelineKey pipelineKey; + pipelineKey.pipeline = currentPipeline; + pipelineKey.specializationArgs.addRange(specializationArgs.componentIDs); + pipelineKey.updateHash(); + + auto specializedPipelineState = shaderCache.getSpecializedPipelineState(pipelineKey); + // Try to find specialized pipeline from shader cache. + if (!specializedPipelineState) + { + auto unspecializedProgram = static_cast<ShaderProgramBase*>(pipelineType == PipelineType::Compute + ? currentPipeline->desc.compute.program + : currentPipeline->desc.graphics.program); + List<RefPtr<ShaderBinary>> entryPointBinaries; + auto unspecializedProgramLayout = unspecializedProgram->slangProgram->getLayout(); + for (SlangUInt i = 0; i < unspecializedProgramLayout->getEntryPointCount(); i++) + { + auto unspecializedEntryPoint = unspecializedProgramLayout->getEntryPointByIndex(i); + UnownedStringSlice entryPointName = UnownedStringSlice(unspecializedEntryPoint->getName()); + ComponentKey specializedKernelKey; + specializedKernelKey.typeName = entryPointName; + specializedKernelKey.specializationArgs.addRange(specializationArgs.componentIDs); + specializedKernelKey.updateHash(); + // If the pipeline is not created, check if the kernel binaries has been compiled. + auto specializedKernelComponentID = shaderCache.getComponentId(specializedKernelKey); + RefPtr<ShaderBinary> binary = shaderCache.tryLoadShaderBinary(specializedKernelComponentID); + if (!binary) + { + // If the specialized shader binary does not exist in cache, use slang to generate it. + entryPointBinaries.clear(); + ComPtr<slang::IComponentType> specializedComponentType; + ComPtr<slang::IBlob> diagnosticBlob; + auto result = unspecializedProgram->slangProgram->specialize(specializationArgs.components.getArrayView().getBuffer(), + specializationArgs.getCount(), specializedComponentType.writeRef(), diagnosticBlob.writeRef()); + + // TODO: print diagnostic message via debug output interface. + + if (result != SLANG_OK) + return result; + + // Cache specialized binaries. + auto programLayout = specializedComponentType->getLayout(); + for (SlangUInt j = 0; j < programLayout->getEntryPointCount(); j++) + { + auto entryPointLayout = programLayout->getEntryPointByIndex(j); + ComPtr<slang::IBlob> entryPointCode; + SLANG_RETURN_ON_FAIL(specializedComponentType->getEntryPointCode(j, 0, entryPointCode.writeRef(), diagnosticBlob.writeRef())); + binary = new ShaderBinary(); + binary->stage = gfx::translateStage(entryPointLayout->getStage()); + binary->entryPointName = entryPointLayout->getName(); + binary->source.addRange((uint8_t*)entryPointCode->getBufferPointer(), entryPointCode->getBufferSize()); + entryPointBinaries.add(binary); + shaderCache.addShaderBinary(specializedKernelComponentID, binary); + } + + // We have already obtained all kernel binaries from this program, so break out of the outer loop since we no longer + // need to examine the rest of the kernels. + break; + } + entryPointBinaries.add(binary); + } + + // Now create specialized shader program using compiled binaries. + ComPtr<IShaderProgram> specializedProgram; + IShaderProgram::Desc specializedProgramDesc = {}; + specializedProgramDesc.kernelCount = unspecializedProgramLayout->getEntryPointCount(); + ShortList<IShaderProgram::KernelDesc> kernelDescs; + for (Slang::Index i = 0; i < entryPointBinaries.getCount(); i++) + { + auto entryPoint = unspecializedProgramLayout->getEntryPointByIndex(i);; + auto& kernelDesc = kernelDescs[i]; + kernelDesc.stage = entryPointBinaries[i]->stage; + kernelDesc.entryPointName = entryPointBinaries[i]->entryPointName.getBuffer(); + kernelDesc.codeBegin = entryPointBinaries[i]->source.begin(); + kernelDesc.codeEnd = entryPointBinaries[i]->source.end(); + } + specializedProgramDesc.kernels = kernelDescs.getArrayView().getBuffer(); + specializedProgramDesc.pipelineType = pipelineType; + SLANG_RETURN_ON_FAIL(createProgram(specializedProgramDesc, specializedProgram.writeRef())); + + // Create specialized pipeline state. + ComPtr<IPipelineState> specializedPipelineState; + switch (pipelineType) + { + case PipelineType::Compute: + { + auto pipelineDesc = currentPipeline->desc.compute; + pipelineDesc.program = specializedProgram; + SLANG_RETURN_ON_FAIL(createComputePipelineState(pipelineDesc, specializedPipelineState.writeRef())); + break; + } + case PipelineType::Graphics: + { + auto pipelineDesc = currentPipeline->desc.graphics; + pipelineDesc.program = specializedProgram; + SLANG_RETURN_ON_FAIL(createGraphicsPipelineState(pipelineDesc, specializedPipelineState.writeRef())); + break; + } + default: + break; + } + auto specializedPipelineStateBase = static_cast<PipelineStateBase*>(specializedPipelineState.get()); + specializedPipelineStateBase->unspecializedPipelineState = currentPipeline; + shaderCache.addSpecializedPipeline(pipelineKey, specializedPipelineStateBase); + } + setPipelineState(specializedPipelineState); + } + return SLANG_OK; +} + + } // namespace gfx diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index c2924a7fd..c4a000e87 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -1,12 +1,33 @@ #pragma once #include "tools/gfx/render.h" - -#include "core/slang-smart-pointer.h" +#include "slang-context.h" +#include "core/slang-basic.h" namespace gfx { +struct GfxGUID +{ + static const Slang::Guid IID_ISlangUnknown; + static const Slang::Guid IID_IDescriptorSetLayout; + static const Slang::Guid IID_IDescriptorSet; + static const Slang::Guid IID_IShaderProgram; + static const Slang::Guid IID_IPipelineLayout; + static const Slang::Guid IID_IPipelineState; + static const Slang::Guid IID_IResourceView; + static const Slang::Guid IID_ISamplerState; + static const Slang::Guid IID_IResource; + static const Slang::Guid IID_IBufferResource; + static const Slang::Guid IID_ITextureResource; + static const Slang::Guid IID_IInputLayout; + static const Slang::Guid IID_IRenderer; + static const Slang::Guid IID_IShaderObjectLayout; + static const Slang::Guid IID_IShaderObject; +}; + +gfx::StageType translateStage(SlangStage slangStage); + class Resource : public Slang::RefObject { public: @@ -70,4 +91,297 @@ protected: Result createProgramFromSlang(IRenderer* renderer, IShaderProgram::Desc const& desc, IShaderProgram** outProgram); +class RendererBase; + +typedef uint32_t ShaderComponentID; +const ShaderComponentID kInvalidComponentID = 0xFFFFFFFF; + +struct ExtendedShaderObjectType +{ + slang::TypeReflection* slangType; + ShaderComponentID componentID; +}; + +struct ExtendedShaderObjectTypeList +{ + Slang::ShortList<ShaderComponentID, 16> componentIDs; + Slang::ShortList<slang::SpecializationArg, 16> components; + void add(const ExtendedShaderObjectType& component) + { + componentIDs.add(component.componentID); + components.add(slang::SpecializationArg{ slang::SpecializationArg::Kind::Type, component.slangType }); + } + ExtendedShaderObjectType operator[](Slang::Index index) const + { + ExtendedShaderObjectType result; + result.componentID = componentIDs[index]; + result.slangType = components[index].type; + return result; + } + void clear() + { + componentIDs.clear(); + components.clear(); + } + Slang::Index getCount() + { + return componentIDs.getCount(); + } +}; + +class ShaderObjectLayoutBase : public IShaderObjectLayout, public Slang::RefObject +{ +protected: + RendererBase* m_renderer; + slang::TypeLayoutReflection* m_elementTypeLayout = nullptr; + ShaderComponentID m_componentID = 0; + +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + IShaderObjectLayout* getInterface(const Slang::Guid& guid); + + RendererBase* getRenderer() { return m_renderer; } + + slang::TypeLayoutReflection* getElementTypeLayout() + { + return m_elementTypeLayout; + } + + ShaderComponentID getComponentID() + { + return m_componentID; + } + + void initBase(RendererBase* renderer, slang::TypeLayoutReflection* elementTypeLayout); +}; + +class ShaderObjectBase : public IShaderObject, public Slang::RefObject +{ +protected: + // The shader object layout used to create this shader object. + Slang::RefPtr<ShaderObjectLayoutBase> m_layout = nullptr; + + // Indicates whether all bindings have been finalized. + bool m_bindingFinalized = false; + + // The specialized shader object type. + ExtendedShaderObjectType shaderObjectType = { nullptr, kInvalidComponentID }; +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + IShaderObject* getInterface(const Slang::Guid& guid); + +public: + ShaderComponentID getComponentID() + { + return shaderObjectType.componentID; + } + + // Get the final type this shader object represents. If the shader object's type has existential fields, + // this function will return a specialized type using the bound sub-objects' type as specialization argument. + Result getSpecializedShaderObjectType(ExtendedShaderObjectType* outType); + + RendererBase* getRenderer() { return m_layout->getRenderer(); } + + SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() SLANG_OVERRIDE { return 0; } + + SLANG_NO_THROW Result SLANG_MCALL getEntryPoint(UInt index, IShaderObject** outEntryPoint) + SLANG_OVERRIDE + { + *outEntryPoint = nullptr; + return SLANG_OK; + } + + ShaderObjectLayoutBase* getLayout() + { + return m_layout; + } + + SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getElementTypeLayout() SLANG_OVERRIDE + { + return m_layout->getElementTypeLayout(); + } + + SLANG_NO_THROW Result SLANG_MCALL finalizeBindings() SLANG_OVERRIDE; + + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) = 0; +}; + +class ShaderProgramBase : public IShaderProgram, public Slang::RefObject +{ +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + IShaderProgram* getInterface(const Slang::Guid& guid); + + ComPtr<slang::IComponentType> slangProgram; +}; + +class PipelineStateBase : public IPipelineState, public Slang::RefObject +{ +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + IPipelineState* getInterface(const Slang::Guid& guid); + + struct PipelineStateDesc + { + PipelineType type; + GraphicsPipelineStateDesc graphics; + ComputePipelineStateDesc compute; + ShaderProgramBase* getProgram() + { + return static_cast<ShaderProgramBase*>(type == PipelineType::Compute ? compute.program : graphics.program); + } + } desc; + + // The pipeline state from which this pipeline state is specialized. + // If null, this pipeline is either an unspecialized pipeline. + Slang::RefPtr<PipelineStateBase> unspecializedPipelineState = nullptr; + + // Indicates whether this is a specializable pipeline. A specializable + // pipeline cannot be used directly and must be specialized first. + bool isSpecializable = false; + +protected: + void initializeBase(const PipelineStateDesc& inDesc); +}; + +class ShaderBinary : public Slang::RefObject +{ +public: + Slang::List<uint8_t> source; + StageType stage; + Slang::String entryPointName; + Result loadFromBlob(ISlangBlob* blob); + Result writeToBlob(ISlangBlob** outBlob); +}; + +struct ComponentKey +{ + Slang::UnownedStringSlice typeName; + Slang::ShortList<ShaderComponentID> specializationArgs; + Slang::HashCode hash; + Slang::HashCode getHashCode() + { + return hash; + } + void updateHash() + { + hash = typeName.getHashCode(); + for (auto& arg : specializationArgs) + hash = Slang::combineHash(hash, arg); + } +}; + +struct PipelineKey +{ + PipelineStateBase* pipeline; + Slang::ShortList<ShaderComponentID> specializationArgs; + Slang::HashCode hash; + Slang::HashCode getHashCode() + { + return hash; + } + void updateHash() + { + hash = Slang::getHashCode(pipeline); + for (auto& arg : specializationArgs) + hash = Slang::combineHash(hash, arg); + } + bool operator==(const PipelineKey& other) + { + if (pipeline != other.pipeline) + return false; + if (specializationArgs.getCount() != other.specializationArgs.getCount()) + return false; + for (Slang::Index i = 0; i < other.specializationArgs.getCount(); i++) + { + if (specializationArgs[i] != other.specializationArgs[i]) + return false; + } + return true; + } +}; + +struct OwningComponentKey +{ + Slang::String typeName; + Slang::ShortList<ShaderComponentID> specializationArgs; + Slang::HashCode hash; + Slang::HashCode getHashCode() + { + return hash; + } + template<typename KeyType> + bool operator==(const KeyType& other) + { + if (typeName != other.typeName) + return false; + if (specializationArgs.getCount() != other.specializationArgs.getCount()) + return false; + for (Slang::Index i = 0; i < other.specializationArgs.getCount(); i++) + { + if (specializationArgs[i] != other.specializationArgs[i]) + return false; + } + return true; + } +}; + +// A cache from specialization keys to a specialized `ShaderKernel`. +class ShaderCache : public Slang::RefObject +{ +public: + ShaderComponentID getComponentId(slang::TypeReflection* type); + ShaderComponentID getComponentId(Slang::UnownedStringSlice name); + ShaderComponentID getComponentId(ComponentKey key); + + void init(ISlangFileSystem* cacheFileSystem); + void writeToFileSystem(ISlangMutableFileSystem* outputFileSystem); + Slang::RefPtr<PipelineStateBase> getSpecializedPipelineState(PipelineKey programKey) + { + Slang::RefPtr<PipelineStateBase> result; + if (specializedPipelines.TryGetValue(programKey, result)) + return result; + return nullptr; + } + Slang::RefPtr<ShaderBinary> tryLoadShaderBinary(ShaderComponentID componentId); + void addShaderBinary(ShaderComponentID componentId, ShaderBinary* binary); + void addSpecializedPipeline(PipelineKey key, Slang::RefPtr<PipelineStateBase> specializedPipeline); +protected: + Slang::ComPtr<ISlangFileSystem> fileSystem; + Slang::OrderedDictionary<OwningComponentKey, ShaderComponentID> componentIds; + Slang::OrderedDictionary<PipelineKey, Slang::RefPtr<PipelineStateBase>> specializedPipelines; + Slang::OrderedDictionary<ShaderComponentID, Slang::RefPtr<ShaderBinary>> shaderBinaries; +}; + +// Renderer implementation shared by all platforms. +// Responsible for shader compilation, specialization and caching. +class RendererBase : public Slang::RefObject, public IRenderer +{ +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + + virtual SLANG_NO_THROW Result SLANG_MCALL getFeatures( + const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) SLANG_OVERRIDE; + virtual SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* featureName) SLANG_OVERRIDE; + virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) SLANG_OVERRIDE; + IRenderer* getInterface(const Slang::Guid& guid); + +protected: + // Retrieves the currently bound unspecialized pipeline. + // If the bound pipeline is not created from a Slang component, an implementation should return null. + virtual PipelineStateBase* getCurrentPipeline() = 0; + ExtendedShaderObjectTypeList specializationArgs; + // Given current pipeline and root shader object binding, generate and bind a specialized pipeline if necessary. + Result maybeSpecializePipeline(ShaderObjectBase* inRootShaderObject); +protected: + virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc, void* inWindowHandle); +protected: + Slang::List<Slang::String> m_features; +public: + SlangContext slangContext; + ShaderCache shaderCache; +}; + } diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 21034d167..565a8e96f 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -114,8 +114,7 @@ public: setViewports(UInt count, Viewport const* viewports) override; virtual SLANG_NO_THROW void SLANG_MCALL setScissorRects(UInt count, ScissorRect const* rects) override; - virtual SLANG_NO_THROW void SLANG_MCALL - setPipelineState(PipelineType pipelineType, IPipelineState* state) override; + virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override; virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex) override; virtual SLANG_NO_THROW void SLANG_MCALL drawIndexed(UInt indexCount, UInt startIndex, UInt baseVertex) override; @@ -126,7 +125,10 @@ public: { return RendererType::Vulkan; } - + virtual PipelineStateBase* getCurrentPipeline() override + { + return m_currentPipeline.Ptr(); + } /// Dtor ~VKRenderer(); @@ -515,17 +517,9 @@ public: int m_offset; }; - class PipelineStateImpl : public IPipelineState, public RefObject + class PipelineStateImpl : public PipelineStateBase { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL - IPipelineState* getInterface(const Guid& guid) - { - if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState) - return static_cast<IPipelineState*>(this); - return nullptr; - } - public: PipelineStateImpl(const VulkanApi& api): m_api(&api) { @@ -538,6 +532,21 @@ public: } } + void init(const GraphicsPipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Graphics; + pipelineDesc.graphics = inDesc; + initializeBase(pipelineDesc); + } + void init(const ComputePipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Compute; + pipelineDesc.compute = inDesc; + initializeBase(pipelineDesc); + } + const VulkanApi* m_api; RefPtr<PipelineLayoutImpl> m_pipelineLayout; @@ -913,10 +922,11 @@ void VKRenderer::_endRender() m_deviceQueue.flush(); } -Result SLANG_MCALL createVKRenderer(IRenderer** outRenderer) +Result SLANG_MCALL createVKRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer) { - *outRenderer = new VKRenderer(); - (*outRenderer)->addRef(); + RefPtr<VKRenderer> result = new VKRenderer(); + SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle)); + *outRenderer = result.detach(); return SLANG_OK; } @@ -1035,6 +1045,8 @@ SlangResult VKRenderer::initialize(const Desc& desc, void* inWindowHandle) { SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_SPIRV, "sm_5_1")); + SLANG_RETURN_ON_FAIL(GraphicsAPIRenderer::initialize(desc, inWindowHandle)); + SLANG_RETURN_ON_FAIL(m_module.init()); SLANG_RETURN_ON_FAIL(m_api.initGlobalProcs(m_module)); @@ -2274,9 +2286,9 @@ void VKRenderer::setScissorRects(UInt count, ScissorRect const* rects) m_api.vkCmdSetScissor(commandBuffer, 0, uint32_t(count), vkRects); } -void VKRenderer::setPipelineState(PipelineType pipelineType, IPipelineState* state) +void VKRenderer::setPipelineState(IPipelineState* state) { - m_currentPipeline = (PipelineStateImpl*)state; + m_currentPipeline = static_cast<PipelineStateImpl*>(state); } void VKRenderer::_flushBindingState(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipelineBindPoint) @@ -2980,6 +2992,7 @@ Result VKRenderer::createGraphicsPipelineState(const GraphicsPipelineStateDesc& pipelineStateImpl->m_pipeline = pipeline; pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl; pipelineStateImpl->m_shaderProgram = programImpl; + pipelineStateImpl->init(desc); *outState = pipelineStateImpl.detach(); return SLANG_OK; } @@ -3005,6 +3018,7 @@ Result VKRenderer::createComputePipelineState(const ComputePipelineStateDesc& in pipelineStateImpl->m_pipeline = pipeline; pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl; pipelineStateImpl->m_shaderProgram = programImpl; + pipelineStateImpl->init(desc); *outState = pipelineStateImpl.detach(); return SLANG_OK; } diff --git a/tools/gfx/vulkan/render-vk.h b/tools/gfx/vulkan/render-vk.h index 7e086f6c0..f259ab44c 100644 --- a/tools/gfx/vulkan/render-vk.h +++ b/tools/gfx/vulkan/render-vk.h @@ -2,12 +2,10 @@ #pragma once #include <cstdint> -#include "slang.h" +#include "../renderer-shared.h" namespace gfx { -class IRenderer; - -SlangResult SLANG_MCALL createVKRenderer(IRenderer** outRenderer); +SlangResult SLANG_MCALL createVKRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer); } // gfx diff --git a/tools/graphics-app-framework/gui.cpp b/tools/graphics-app-framework/gui.cpp index fded0d76a..3bc365701 100644 --- a/tools/graphics-app-framework/gui.cpp +++ b/tools/graphics-app-framework/gui.cpp @@ -335,8 +335,7 @@ void GUI::endFrame() renderer->setViewport(viewport); - auto pipelineType = PipelineType::Graphics; - renderer->setPipelineState(pipelineType, pipelineState); + renderer->setPipelineState(pipelineState); renderer->setVertexBuffer(0, vertexBuffer, sizeof(ImDrawVert)); renderer->setIndexBuffer(indexBuffer, sizeof(ImDrawIdx) == 2 ? Format::R_UInt16 : Format::R_UInt32); @@ -376,7 +375,7 @@ void GUI::endFrame() samplerState); renderer->setDescriptorSet( - pipelineType, + PipelineType::Graphics, pipelineLayout, 0, descriptorSet); diff --git a/tools/graphics-app-framework/windows/win-window.cpp b/tools/graphics-app-framework/windows/win-window.cpp index 7433603cb..3bbf2575a 100644 --- a/tools/graphics-app-framework/windows/win-window.cpp +++ b/tools/graphics-app-framework/windows/win-window.cpp @@ -1,6 +1,8 @@ // win-window.cpp #include "../window.h" +#include "core/slang-smart-pointer.h" + #include <stdio.h> #ifdef _MSC_VER @@ -64,7 +66,7 @@ private: } }; -struct ApplicationContext +struct ApplicationContext : public Slang::RefObject { HINSTANCE instance; int showCommand = SW_SHOWDEFAULT; diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 27070154e..b9fd5c725 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -711,7 +711,7 @@ void RenderTestApp::renderFrame() auto pipelineType = PipelineType::Graphics; - m_renderer->setPipelineState(pipelineType, m_pipelineState); + m_renderer->setPipelineState(m_pipelineState); m_renderer->setPrimitiveTopology(PrimitiveTopology::TriangleList); m_renderer->setVertexBuffer(0, m_vertexBuffer, sizeof(Vertex)); @@ -724,7 +724,7 @@ void RenderTestApp::renderFrame() void RenderTestApp::runCompute() { auto pipelineType = PipelineType::Compute; - m_renderer->setPipelineState(pipelineType, m_pipelineState); + m_renderer->setPipelineState(m_pipelineState); applyBinding(pipelineType); m_startTicks = ProcessUtil::getClockTick(); @@ -1279,22 +1279,8 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi Slang::ComPtr<IRenderer> renderer; { - SGRendererCreateFunc createFunc = gfxGetCreateFunc(options.rendererType); - if (createFunc) - { - createFunc(renderer.writeRef()); - } - - if (!renderer) - { - if (!options.onlyStartup) - { - fprintf(stderr, "Unable to create renderer %s\n", rendererName.getBuffer()); - } - return SLANG_FAIL; - } - - IRenderer::Desc desc; + IRenderer::Desc desc = {}; + desc.rendererType = options.rendererType; desc.width = gWindowWidth; desc.height = gWindowHeight; desc.adapter = options.adapter.getBuffer(); @@ -1306,18 +1292,21 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi desc.nvapiExtnSlot = int(nvapiExtnSlot); desc.slang.slangGlobalSession = session; window = renderer_test::Window::create(); - SLANG_RETURN_ON_FAIL(window->initialize(gWindowWidth, gWindowHeight)); + void* windowHandle = nullptr; + if (window) + { + SLANG_RETURN_ON_FAIL(window->initialize(gWindowWidth, gWindowHeight)); + windowHandle = window->getHandle(); + } + gfxCreateRenderer(&desc, windowHandle, renderer.writeRef()); - SlangResult res = renderer->initialize(desc, window->getHandle()); - if (SLANG_FAILED(res)) + if (!renderer) { - // Returns E_NOT_AVAILABLE only when specified features are not available. - // Will cause to be ignored. - if (!options.onlyStartup && res != SLANG_E_NOT_AVAILABLE) + if (!options.onlyStartup) { - fprintf(stderr, "Unable to initialize renderer %s\n", rendererName.getBuffer()); + fprintf(stderr, "Unable to create renderer %s\n", rendererName.getBuffer()); } - return res; + return SLANG_FAIL; } for (const auto& feature : requiredFeatureList) |
