summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--examples/gpu-printing/main.cpp10
-rw-r--r--examples/hello-world/main.cpp12
-rw-r--r--examples/heterogeneous-hello-world/main.cpp13
-rw-r--r--examples/model-viewer/main.cpp8
-rw-r--r--examples/shader-toy/main.cpp10
-rw-r--r--slang.h4
-rw-r--r--source/core/slang-dictionary.h17
-rw-r--r--source/core/slang-short-list.h7
-rw-r--r--source/core/slang-smart-pointer.h6
-rwxr-xr-xsource/slang/slang-compiler.h11
-rw-r--r--tests/compute/dynamic-dispatch-11.slang6
-rw-r--r--tools/gfx/cuda/render-cuda.cpp237
-rw-r--r--tools/gfx/cuda/render-cuda.h6
-rw-r--r--tools/gfx/d3d11/render-d3d11.cpp76
-rw-r--r--tools/gfx/d3d11/render-d3d11.h10
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp52
-rw-r--r--tools/gfx/d3d12/render-d3d12.h10
-rw-r--r--tools/gfx/open-gl/render-gl.cpp48
-rw-r--r--tools/gfx/open-gl/render-gl.h7
-rw-r--r--tools/gfx/render-graphics-common.cpp292
-rw-r--r--tools/gfx/render-graphics-common.h51
-rw-r--r--tools/gfx/render.cpp16
-rw-r--r--tools/gfx/render.h12
-rw-r--r--tools/gfx/renderer-shared.cpp441
-rw-r--r--tools/gfx/renderer-shared.h318
-rw-r--r--tools/gfx/vulkan/render-vk.cpp48
-rw-r--r--tools/gfx/vulkan/render-vk.h6
-rw-r--r--tools/graphics-app-framework/gui.cpp5
-rw-r--r--tools/graphics-app-framework/windows/win-window.cpp4
-rw-r--r--tools/render-test/render-test-main.cpp41
31 files changed, 1242 insertions, 544 deletions
diff --git a/.gitignore b/.gitignore
index e941540c1..fad8a1030 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,7 +9,9 @@
*.obj
*.slang-module
*.zip
+*.ini
.clang-format
+
bin/
intermediate/
build.*/
diff --git a/examples/gpu-printing/main.cpp b/examples/gpu-printing/main.cpp
index 6a4aabf3a..8dc0b0f3d 100644
--- a/examples/gpu-printing/main.cpp
+++ b/examples/gpu-printing/main.cpp
@@ -118,14 +118,12 @@ Result execute()
windowDesc.height = gWindowHeight;
gWindow = createWindow(windowDesc);
- gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef());
IRenderer::Desc rendererDesc;
+ rendererDesc.rendererType = gfx::RendererType::DirectX11;
rendererDesc.width = gWindowWidth;
rendererDesc.height = gWindowHeight;
- {
- Result res = gRenderer->initialize(rendererDesc, getPlatformWindowHandle(gWindow));
- if(SLANG_FAILED(res)) return res;
- }
+ Result res = gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(gWindow), gRenderer.writeRef());
+ if(SLANG_FAILED(res)) return res;
gSlangSession = createSlangSession(gRenderer);
gSlangModule = compileShaderModuleFromFile(gSlangSession, "kernels.slang");
@@ -191,7 +189,7 @@ Result execute()
gDescriptorSet->setResource(0, 0, printBufferView);
gRenderer->setDescriptorSet(PipelineType::Compute, gPipelineLayout, 0, gDescriptorSet);
- gRenderer->setPipelineState(PipelineType::Compute, gPipelineState);
+ gRenderer->setPipelineState(gPipelineState);
gRenderer->dispatchCompute(1, 1, 1);
// TODO: need to copy from the print buffer to a staging buffer...
diff --git a/examples/hello-world/main.cpp b/examples/hello-world/main.cpp
index e8e3c45e0..91c9c0627 100644
--- a/examples/hello-world/main.cpp
+++ b/examples/hello-world/main.cpp
@@ -240,14 +240,12 @@ Slang::Result initialize()
// A future version of this example may support multiple
// platforms/APIs.
//
- gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef());
- IRenderer::Desc rendererDesc;
+ IRenderer::Desc rendererDesc = {};
+ rendererDesc.rendererType = gfx::RendererType::DirectX11;
rendererDesc.width = gWindowWidth;
rendererDesc.height = gWindowHeight;
- {
- gfx::Result res = gRenderer->initialize(rendererDesc, getPlatformWindowHandle(gWindow));
- if(SLANG_FAILED(res)) return res;
- }
+ gfx::Result res = gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(gWindow), gRenderer.writeRef());
+ if(SLANG_FAILED(res)) return res;
// Now we will create objects needed to configur the "input assembler"
// (IA) stage of the D3D pipeline.
@@ -414,7 +412,7 @@ void renderFrame()
// PSO, binding our root shader object to it (which references
// the `Uniforms` buffer that will filled in above).
//
- gRenderer->setPipelineState(PipelineType::Graphics, gPipelineState);
+ gRenderer->setPipelineState(gPipelineState);
gRenderer->bindRootShaderObject(PipelineType::Graphics, gRootObject);
// We also need to set up a few pieces of fixed-function pipeline
diff --git a/examples/heterogeneous-hello-world/main.cpp b/examples/heterogeneous-hello-world/main.cpp
index 8610a5fa2..163b17deb 100644
--- a/examples/heterogeneous-hello-world/main.cpp
+++ b/examples/heterogeneous-hello-world/main.cpp
@@ -43,7 +43,6 @@ using namespace gfx;
// We create global ref pointers to avoid dereferencing values
//
ComPtr<gfx::IShaderProgram> gShaderProgram;
-Slang::RefPtr<gfx::ApplicationContext> gAppContext;
Slang::ComPtr<gfx::IRenderer> gRenderer;
ComPtr<gfx::IBufferResource> gStructuredBuffer;
@@ -123,14 +122,12 @@ gfx::IRenderer* createRenderer(
// A future version of this example may support multiple
// platforms/APIs.
//
- gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef());
- IRenderer::Desc rendererDesc;
+ IRenderer::Desc rendererDesc = {};
+ rendererDesc.rendererType = gfx::RendererType::DirectX11;
rendererDesc.width = windowWidth;
rendererDesc.height = windowHeight;
- {
- Result res = gRenderer->initialize(rendererDesc, getPlatformWindowHandle(window));
- if (SLANG_FAILED(res)) return nullptr;
- }
+ Result res = gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(window), gRenderer.writeRef());
+ if (SLANG_FAILED(res)) return nullptr;
return gRenderer;
}
@@ -249,7 +246,7 @@ void dispatchComputation(
unsigned int gridDimsZ)
{
- gRenderer->setPipelineState(PipelineType::Compute, gPipelineState);
+ gRenderer->setPipelineState(gPipelineState);
gRenderer->setDescriptorSet(PipelineType::Compute, gPipelineLayout, 0, gDescriptorSet);
gRenderer->dispatchCompute(gridDimsX, gridDimsY, gridDimsZ);
diff --git a/examples/model-viewer/main.cpp b/examples/model-viewer/main.cpp
index 384cc5eac..f830d4044 100644
--- a/examples/model-viewer/main.cpp
+++ b/examples/model-viewer/main.cpp
@@ -1254,7 +1254,7 @@ public:
// we simply bind its PSO into the GPU state, and
// remember the variant we've selected.
//
- renderer->setPipelineState(PipelineType::Graphics, variant->pipelineState);
+ renderer->setPipelineState(variant->pipelineState);
currentEffectVariant = variant;
}
@@ -2050,11 +2050,11 @@ Result initialize()
windowDesc.userData = this;
gWindow = createWindow(windowDesc);
- gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef());
- IRenderer::Desc rendererDesc;
+ IRenderer::Desc rendererDesc = {};
+ rendererDesc.rendererType = gfx::RendererType::DirectX11;
rendererDesc.width = gWindowWidth;
rendererDesc.height = gWindowHeight;
- gRenderer->initialize(rendererDesc, getPlatformWindowHandle(gWindow));
+ gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(gWindow), gRenderer.writeRef());
InputElementDesc inputElements[] = {
{"POSITION", 0, Format::RGB_Float32, offsetof(Model::Vertex, position) },
diff --git a/examples/shader-toy/main.cpp b/examples/shader-toy/main.cpp
index a1408e38e..2bbc59113 100644
--- a/examples/shader-toy/main.cpp
+++ b/examples/shader-toy/main.cpp
@@ -347,14 +347,12 @@ Result initialize()
windowDesc.userData = this;
gWindow = createWindow(windowDesc);
- gfxGetCreateFunc(gfx::RendererType::DirectX11)(gRenderer.writeRef());
IRenderer::Desc rendererDesc;
+ rendererDesc.rendererType = RendererType::DirectX11;
rendererDesc.width = gWindowWidth;
rendererDesc.height = gWindowHeight;
- {
- Result res = gRenderer->initialize(rendererDesc, getPlatformWindowHandle(gWindow));
- if(SLANG_FAILED(res)) return res;
- }
+ Result res = gfxCreateRenderer(&rendererDesc, getPlatformWindowHandle(gWindow), gRenderer.writeRef());
+ if(SLANG_FAILED(res)) return res;
int constantBufferSize = sizeof(Uniforms);
@@ -477,7 +475,7 @@ void renderFrame()
gRenderer->unmap(gConstantBuffer);
}
- gRenderer->setPipelineState(PipelineType::Graphics, gPipelineState);
+ gRenderer->setPipelineState(gPipelineState);
gRenderer->setDescriptorSet(PipelineType::Graphics, gPipelineLayout, 0, gDescriptorSet);
gRenderer->setVertexBuffer(0, gVertexBuffer, sizeof(FullScreenTriangle::Vertex));
diff --git a/slang.h b/slang.h
index 1e93a4f55..e0975c853 100644
--- a/slang.h
+++ b/slang.h
@@ -3991,6 +3991,10 @@ namespace slang
SlangInt targetIndex = 0,
IBlob** outDiagnostics = nullptr) = 0;
+ /** Get the number of (unspecialized) specialization parameters for the component type.
+ */
+ virtual SLANG_NO_THROW SlangInt SLANG_MCALL getSpecializationParamCount() = 0;
+
/** Get the compiled code for the entry point at `entryPointIndex` for the chosen `targetIndex`
Entry point code can only be computed for a component type that
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h
index 4c352d55b..51b61de60 100644
--- a/source/core/slang-dictionary.h
+++ b/source/core/slang-dictionary.h
@@ -135,15 +135,17 @@ namespace Slang
}
};
- inline int GetHashPos(TKey& key) const
+ template<typename KeyType>
+ inline int GetHashPos(KeyType& key) const
{
SLANG_ASSERT(bucketSizeMinusOne > 0);
const unsigned int hash = (unsigned int)getHashCode(key);
return (hash * 2654435761u) % (unsigned int)(bucketSizeMinusOne);
}
- FindPositionResult FindPosition(const TKey& key) const
+ template<typename KeyType>
+ FindPositionResult FindPosition(const KeyType& key) const
{
- int hashPos = GetHashPos(const_cast<TKey&>(key));
+ int hashPos = GetHashPos(const_cast<KeyType&>(key));
int insertPos = -1;
int numProbes = 0;
while (numProbes <= bucketSizeMinusOne)
@@ -380,14 +382,16 @@ namespace Slang
throw InvalidOperationException("Inconsistent find result returned. This is a bug in Dictionary implementation.");
}
- bool ContainsKey(const TKey& key) const
+ template<typename KeyType>
+ bool ContainsKey(const KeyType& key) const
{
if (bucketSizeMinusOne == -1)
return false;
auto pos = FindPosition(key);
return pos.ObjectPosition != -1;
}
- bool TryGetValue(const TKey& key, TValue& value) const
+ template<typename KeyType>
+ bool TryGetValue(const KeyType& key, TValue& value) const
{
if (bucketSizeMinusOne == -1)
return false;
@@ -399,7 +403,8 @@ namespace Slang
}
return false;
}
- TValue* TryGetValue(const TKey& key) const
+ template<typename KeyType>
+ TValue* TryGetValue(const KeyType& key) const
{
if (bucketSizeMinusOne == -1)
return nullptr;
diff --git a/source/core/slang-short-list.h b/source/core/slang-short-list.h
index 82ad4fe1e..7d51a8abf 100644
--- a/source/core/slang-short-list.h
+++ b/source/core/slang-short-list.h
@@ -47,6 +47,13 @@ namespace Slang
return *this;
}
+ ThisType& operator=(const ThisType& other)
+ {
+ clearAndDeallocate();
+ addRange(other);
+ return *this;
+ }
+
ThisType& operator=(ThisType&& list)
{
// Could just do a swap here, and memory would be freed on rhs dtor
diff --git a/source/core/slang-smart-pointer.h b/source/core/slang-smart-pointer.h
index 53ac010ed..eb9fe7be5 100644
--- a/source/core/slang-smart-pointer.h
+++ b/source/core/slang-smart-pointer.h
@@ -114,9 +114,9 @@ namespace Slang
template <typename U>
RefPtr(RefPtr<U> const& p,
typename EnableIf<IsConvertible<T*, U*>::Value, void>::type * = 0)
- : pointer((U*) p)
+ : pointer(static_cast<U*>(p))
{
- addReference((U*) p);
+ addReference(static_cast<U*>(p));
}
#if 0
@@ -200,7 +200,7 @@ namespace Slang
~RefPtr()
{
- releaseReference((Slang::RefObject*) pointer);
+ releaseReference(static_cast<Slang::RefObject*>(pointer));
}
T& operator*() const
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index ac2d72862..90b1e30f3 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -318,9 +318,6 @@ namespace Slang
/// Get one of the global shader parametesr linked into this component type.
virtual ShaderParamInfo getShaderParam(Index index) = 0;
- /// Get the number of (unspecialized) specialization parameters for the component type.
- virtual Index getSpecializationParamCount() = 0;
-
/// Get the specialization parameter at `index`.
virtual SpecializationParam const& getSpecializationParam(Index index) = 0;
@@ -511,7 +508,7 @@ namespace Slang
Index getShaderParamCount() SLANG_OVERRIDE;
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE;
- Index getSpecializationParamCount() SLANG_OVERRIDE;
+ SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE;
SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE;
Index getRequirementCount() SLANG_OVERRIDE;
@@ -598,7 +595,7 @@ namespace Slang
Index getShaderParamCount() SLANG_OVERRIDE { return m_base->getShaderParamCount(); }
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_base->getShaderParam(index); }
- Index getSpecializationParamCount() SLANG_OVERRIDE { return 0; }
+ SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; }
SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { SLANG_UNUSED(index); static SpecializationParam dummy; return dummy; }
Index getRequirementCount() SLANG_OVERRIDE;
@@ -759,7 +756,7 @@ namespace Slang
String mangledName);
/// Get the number of existential type parameters for the entry point.
- Index getSpecializationParamCount() SLANG_OVERRIDE;
+ SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE;
/// Get the existential type parameter at `index`.
SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE;
@@ -969,7 +966,7 @@ namespace Slang
Index getShaderParamCount() SLANG_OVERRIDE { return m_shaderParams.getCount(); }
ShaderParamInfo getShaderParam(Index index) SLANG_OVERRIDE { return m_shaderParams[index]; }
- Index getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); }
+ SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return m_specializationParams.getCount(); }
SpecializationParam const& getSpecializationParam(Index index) SLANG_OVERRIDE { return m_specializationParams[index]; }
Index getRequirementCount() SLANG_OVERRIDE;
diff --git a/tests/compute/dynamic-dispatch-11.slang b/tests/compute/dynamic-dispatch-11.slang
index daebff39a..964431aaf 100644
--- a/tests/compute/dynamic-dispatch-11.slang
+++ b/tests/compute/dynamic-dispatch-11.slang
@@ -1,9 +1,9 @@
// Test using interface typed shader parameters with dynamic dispatch.
-//TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj
-//TEST(compute):COMPARE_COMPUTE:-vk -shaderobj
+//DISABLE_TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj
+//DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -shaderobj
//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj
-//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj
+//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj
[anyValueSize(8)]
interface IInterface
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp
index 057674550..7d7ee8eb9 100644
--- a/tools/gfx/cuda/render-cuda.cpp
+++ b/tools/gfx/cuda/render-cuda.cpp
@@ -244,21 +244,12 @@ public:
class CUDAProgramLayout;
-class CUDAShaderProgram : public IShaderProgram, public RefObject
+class CUDAShaderProgram : public ShaderProgramBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IShaderProgram* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram)
- return static_cast<IShaderProgram*>(this);
- return nullptr;
- }
-public:
CUmodule cudaModule = nullptr;
CUfunction cudaKernel;
String kernelName;
- ComPtr<slang::IComponentType> slangProgram;
RefPtr<CUDAProgramLayout> layout;
~CUDAShaderProgram()
@@ -268,33 +259,22 @@ public:
}
};
-class CUDAPipelineState : public IPipelineState, public RefObject
+class CUDAPipelineState : public PipelineStateBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IPipelineState* getInterface(const Guid& guid)
+ RefPtr<CUDAShaderProgram> shaderProgram;
+ void init(const ComputePipelineStateDesc& inDesc)
{
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState)
- return static_cast<IPipelineState*>(this);
- return nullptr;
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Compute;
+ pipelineDesc.compute = inDesc;
+ initializeBase(pipelineDesc);
}
-public:
- RefPtr<CUDAShaderProgram> shaderProgram;
};
-class CUDAShaderObjectLayout : public IShaderObjectLayout, public RefObject
+class CUDAShaderObjectLayout : public ShaderObjectLayoutBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IShaderObjectLayout* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout)
- return static_cast<IShaderObjectLayout*>(this);
- return nullptr;
- }
-public:
- slang::TypeLayoutReflection* typeLayout = nullptr;
-
struct BindingRangeInfo
{
slang::BindingType bindingType;
@@ -335,30 +315,32 @@ public:
}
}
- CUDAShaderObjectLayout(slang::TypeLayoutReflection* layout)
+ CUDAShaderObjectLayout(RendererBase* renderer, slang::TypeLayoutReflection* layout)
{
+ initBase(renderer, layout);
+
Index subObjectCount = 0;
- typeLayout = unwrapParameterGroups(layout);
+ m_elementTypeLayout = unwrapParameterGroups(layout);
// Compute the binding ranges that are used to store
// the logical contents of the object in memory. These will relate
// to the descriptor ranges in the various sets, but not always
// in a one-to-one fashion.
- SlangInt bindingRangeCount = typeLayout->getBindingRangeCount();
+ SlangInt bindingRangeCount = m_elementTypeLayout->getBindingRangeCount();
for (SlangInt r = 0; r < bindingRangeCount; ++r)
{
- slang::BindingType slangBindingType = typeLayout->getBindingRangeType(r);
- SlangInt count = typeLayout->getBindingRangeBindingCount(r);
+ slang::BindingType slangBindingType = m_elementTypeLayout->getBindingRangeType(r);
+ SlangInt count = m_elementTypeLayout->getBindingRangeBindingCount(r);
slang::TypeLayoutReflection* slangLeafTypeLayout =
- typeLayout->getBindingRangeLeafTypeLayout(r);
+ m_elementTypeLayout->getBindingRangeLeafTypeLayout(r);
- SlangInt descriptorSetIndex = typeLayout->getBindingRangeDescriptorSetIndex(r);
+ SlangInt descriptorSetIndex = m_elementTypeLayout->getBindingRangeDescriptorSetIndex(r);
SlangInt rangeIndexInDescriptorSet =
- typeLayout->getBindingRangeFirstDescriptorRangeIndex(r);
+ m_elementTypeLayout->getBindingRangeFirstDescriptorRangeIndex(r);
- auto uniformOffset = typeLayout->getDescriptorSetDescriptorRangeIndexOffset(
+ auto uniformOffset = m_elementTypeLayout->getDescriptorSetDescriptorRangeIndexOffset(
descriptorSetIndex, rangeIndexInDescriptorSet);
Index baseIndex = 0;
@@ -383,13 +365,13 @@ public:
m_bindingRanges.add(bindingRangeInfo);
}
- SlangInt subObjectRangeCount = typeLayout->getSubObjectRangeCount();
+ SlangInt subObjectRangeCount = m_elementTypeLayout->getSubObjectRangeCount();
for (SlangInt r = 0; r < subObjectRangeCount; ++r)
{
- SlangInt bindingRangeIndex = typeLayout->getSubObjectRangeBindingRangeIndex(r);
- auto slangBindingType = typeLayout->getBindingRangeType(bindingRangeIndex);
+ SlangInt bindingRangeIndex = m_elementTypeLayout->getSubObjectRangeBindingRangeIndex(r);
+ auto slangBindingType = m_elementTypeLayout->getBindingRangeType(bindingRangeIndex);
slang::TypeLayoutReflection* slangLeafTypeLayout =
- typeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex);
+ m_elementTypeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex);
// A sub-object range can either represent a sub-object of a known
// type, like a `ConstantBuffer<Foo>` or `ParameterBlock<Foo>`
@@ -402,7 +384,7 @@ public:
if (slangBindingType != slang::BindingType::ExistentialValue)
{
subObjectLayout =
- new CUDAShaderObjectLayout(slangLeafTypeLayout->getElementTypeLayout());
+ new CUDAShaderObjectLayout(renderer, slangLeafTypeLayout->getElementTypeLayout());
}
SubObjectRangeInfo subObjectRange;
@@ -418,13 +400,14 @@ class CUDAProgramLayout : public CUDAShaderObjectLayout
public:
slang::ProgramLayout* programLayout = nullptr;
List<RefPtr<CUDAShaderObjectLayout>> entryPointLayouts;
- CUDAProgramLayout(slang::ProgramLayout* inProgramLayout)
- : CUDAShaderObjectLayout(inProgramLayout->getGlobalParamsTypeLayout())
+ CUDAProgramLayout(RendererBase* renderer, slang::ProgramLayout* inProgramLayout)
+ : CUDAShaderObjectLayout(renderer, inProgramLayout->getGlobalParamsTypeLayout())
, programLayout(inProgramLayout)
{
for (UInt i =0; i< programLayout->getEntryPointCount(); i++)
{
entryPointLayouts.add(new CUDAShaderObjectLayout(
+ renderer,
programLayout->getEntryPointByIndex(i)->getTypeLayout()));
}
@@ -450,26 +433,21 @@ public:
}
};
-class CUDAShaderObject : public IShaderObject, public RefObject
+class CUDAShaderObject : public ShaderObjectBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IShaderObject* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject)
- return static_cast<IShaderObject*>(this);
- return nullptr;
- }
-
-public:
RefPtr<MemoryCUDAResource> bufferResource;
- RefPtr<CUDAShaderObjectLayout> layout;
List<RefPtr<CUDAShaderObject>> objects;
List<RefPtr<CUDAResourceView>> resources;
virtual SLANG_NO_THROW Result SLANG_MCALL
init(IRenderer* renderer, CUDAShaderObjectLayout* typeLayout);
+ CUDAShaderObjectLayout* getLayout()
+ {
+ return static_cast<CUDAShaderObjectLayout*>(m_layout.Ptr());
+ }
+
virtual SLANG_NO_THROW Result SLANG_MCALL initBuffer(IRenderer* renderer, size_t bufferSize)
{
BufferResource::Desc bufferDesc;
@@ -494,7 +472,7 @@ public:
virtual SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getElementTypeLayout() override
{
- return layout->typeLayout;
+ return getLayout()->getElementTypeLayout();
}
virtual SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() override { return 0; }
@@ -519,7 +497,7 @@ public:
getObject(ShaderOffset const& offset, IShaderObject** object)
{
auto subObjectIndex =
- layout->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex;
+ getLayout()->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex;
if (subObjectIndex >= objects.getCount())
{
*object = nullptr;
@@ -533,10 +511,10 @@ public:
setObject(ShaderOffset const& offset, IShaderObject* object)
{
auto subObjectIndex =
- layout->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex;
+ getLayout()->m_bindingRanges[offset.bindingRangeIndex].baseIndex + offset.bindingArrayIndex;
SLANG_ASSERT(
offset.uniformOffset ==
- layout->m_bindingRanges[offset.bindingRangeIndex].uniformOffset +
+ getLayout()->m_bindingRanges[offset.bindingRangeIndex].uniformOffset +
offset.bindingArrayIndex * sizeof(void*));
auto cudaObject = dynamic_cast<CUDAShaderObject*>(object);
if (subObjectIndex >= objects.getCount())
@@ -593,6 +571,56 @@ public:
setResource(offset, textureView);
return SLANG_OK;
}
+
+ // Appends all types that are used to specialize the element type of this shader object in `args` list.
+ virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
+ {
+ // TODO: the logic here is a copy-paste of `GraphicsCommonShaderObject::collectSpecializationArgs`,
+ // consider moving the implementation to `ShaderObjectBase` and share the logic among different implementations.
+
+ if (!m_bindingFinalized)
+ return SLANG_FAIL;
+ auto& subObjectRanges = getLayout()->subObjectRanges;
+ // The following logic is built on the assumption that all fields that involve existential types (and
+ // therefore require specialization) will results in a sub-object range in the type layout.
+ // This allows us to simply scan the sub-object ranges to find out all specialization arguments.
+ for (Index subObjIndex = 0; subObjIndex < subObjectRanges.getCount(); subObjIndex++)
+ {
+ // Retrieve the corresponding binding range of the sub object.
+ auto bindingRange = getLayout()->m_bindingRanges[subObjectRanges[subObjIndex].bindingRangeIndex];
+ switch (bindingRange.bindingType)
+ {
+ case slang::BindingType::ExistentialValue:
+ {
+ // A binding type of `ExistentialValue` means the sub-object represents a interface-typed field.
+ // In this case the specialization argument for this field is the actual specialized type of the bound
+ // shader object. If the shader object's type is an ordinary type without existential fields, then the
+ // type argument will simply be the ordinary type. But if the sub object's type is itself a specialized
+ // type, we need to make sure to use that type as the specialization argument.
+
+ // TODO: need to implement the case where the field is an array of existential values.
+ SLANG_ASSERT(bindingRange.count == 1);
+ ExtendedShaderObjectType specializedSubObjType;
+ SLANG_RETURN_ON_FAIL(objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType));
+ args.add(specializedSubObjType);
+ break;
+ }
+ case slang::BindingType::ParameterBlock:
+ case slang::BindingType::ConstantBuffer:
+ // Currently we only handle the case where the field's type is
+ // `ParameterBlock<SomeStruct>` or `ConstantBuffer<SomeStruct>`, where `SomeStruct` is a struct type
+ // (not directly an interface type). In this case, we just recursively collect the specialization arguments
+ // from the bound sub object.
+ SLANG_RETURN_ON_FAIL(objects[subObjIndex]->collectSpecializationArgs(args));
+ // TODO: we need to handle the case where the field is of the form `ParameterBlock<IFoo>`. We should treat
+ // this case the same way as the `ExistentialValue` case here, but currently we lack a mechanism to distinguish
+ // the two scenarios.
+ break;
+ }
+ // TODO: need to handle another case where specialization happens on resources fields e.g. `StructuredBuffer<IFoo>`.
+ }
+ return SLANG_OK;
+ }
};
class CUDAEntryPointShaderObject : public CUDAShaderObject
@@ -652,17 +680,8 @@ public:
};
-class CUDARenderer : public IRenderer, public RefObject
+class CUDARenderer : public RendererBase
{
-public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IRenderer* getInterface(const Guid& guid)
- {
- return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer)
- ? static_cast<IRenderer*>(this)
- : nullptr;
- }
-
private:
static const CUDAReportStyle reportType = CUDAReportStyle::Normal;
static int _calcSMCountPerMultiProcessor(int major, int minor)
@@ -781,12 +800,14 @@ private:
int m_deviceIndex = -1;
CUdevice m_device = 0;
CUcontext m_context = nullptr;
- CUDAPipelineState* currentPipeline = nullptr;
- CUDARootShaderObject* currentRootObject = nullptr;
+ RefPtr<CUDAPipelineState> currentPipeline = nullptr;
+ RefPtr<CUDARootShaderObject> currentRootObject = nullptr;
SlangContext slangContext;
public:
~CUDARenderer()
{
+ currentPipeline = nullptr;
+ currentRootObject = nullptr;
if (m_context)
{
cuCtxDestroy(m_context);
@@ -796,6 +817,8 @@ private:
{
SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_PTX, "sm_5_1"));
+ SLANG_RETURN_ON_FAIL(RendererBase::initialize(desc, inWindowHandle));
+
SLANG_RETURN_ON_FAIL(_initCuda(reportType));
SLANG_RETURN_ON_FAIL(_findMaxFlopsDeviceIndex(&m_deviceIndex));
@@ -813,13 +836,6 @@ private:
return SLANG_OK;
}
- virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) override
- {
- *outSlangSession = slangContext.session.get();
- slangContext.session->addRef();
- return SLANG_OK;
- }
-
virtual SLANG_NO_THROW Result SLANG_MCALL createTextureResource(
IResource::Usage initialUsage,
const ITextureResource::Desc& desc,
@@ -1271,7 +1287,7 @@ private:
slang::TypeLayoutReflection* typeLayout, IShaderObjectLayout** outLayout) override
{
RefPtr<CUDAShaderObjectLayout> cudaLayout;
- cudaLayout = new CUDAShaderObjectLayout(typeLayout);
+ cudaLayout = new CUDAShaderObjectLayout(this, typeLayout);
*outLayout = cudaLayout.detach();
return SLANG_OK;
}
@@ -1309,6 +1325,17 @@ private:
virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override
{
+ // If this is a specializable program, we just keep a reference to the slang program and
+ // don't actually create any kernels. This program will be specialized later when we know
+ // the shader object bindings.
+ if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0)
+ {
+ RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram();
+ cudaProgram->slangProgram = desc.slangProgram;
+ *outProgram = cudaProgram.detach();
+ return SLANG_OK;
+ }
+
if( desc.kernelCount == 0 )
{
return createProgramFromSlang(this, desc, outProgram);
@@ -1332,7 +1359,7 @@ private:
return SLANG_FAIL;
RefPtr<CUDAProgramLayout> cudaLayout;
- cudaLayout = new CUDAProgramLayout(slangProgramLayout);
+ cudaLayout = new CUDAProgramLayout(this, slangProgramLayout);
cudaLayout->programLayout = slangProgramLayout;
cudaProgram->layout = cudaLayout;
}
@@ -1346,6 +1373,7 @@ private:
{
RefPtr<CUDAPipelineState> state = new CUDAPipelineState();
state->shaderProgram = dynamic_cast<CUDAShaderProgram*>(desc.program);
+ state->init(desc);
*outState = state.detach();
return Result();
}
@@ -1360,18 +1388,19 @@ private:
SLANG_UNUSED(buffer);
}
- virtual SLANG_NO_THROW void SLANG_MCALL
- setPipelineState(PipelineType pipelineType, IPipelineState* state) override
+ virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override
{
- SLANG_ASSERT(pipelineType == PipelineType::Compute);
currentPipeline = dynamic_cast<CUDAPipelineState*>(state);
}
virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override
{
+ // Specialize the compute kernel based on the shader object bindings.
+ maybeSpecializePipeline(currentRootObject);
+
// Find out thread group size from program reflection.
auto& kernelName = currentPipeline->shaderProgram->kernelName;
- auto programLayout = dynamic_cast<CUDAProgramLayout*>(currentRootObject->layout.Ptr());
+ auto programLayout = static_cast<CUDAProgramLayout*>(currentRootObject->getLayout());
int kernelId = programLayout->getKernelIndex(kernelName.getUnownedSlice());
SLANG_ASSERT(kernelId != -1);
UInt threadGroupSize[3];
@@ -1451,21 +1480,12 @@ private:
return RendererType::CUDA;
}
-public:
- // Unused public interfaces. These functions are not supported on CUDA.
- SLANG_NO_THROW Result SLANG_MCALL getFeatures(
- const char** outFeatures, UInt bufferSize, UInt* outFeatureCount)
- {
- if (outFeatureCount)
- *outFeatureCount = 0;
- return SLANG_OK;
- }
-
- SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* featureName)
+ virtual PipelineStateBase* getCurrentPipeline() override
{
- return false;
+ return currentPipeline;
}
+public:
virtual SLANG_NO_THROW void SLANG_MCALL setClearColor(const float color[4]) override
{
SLANG_UNUSED(color);
@@ -1602,8 +1622,8 @@ public:
SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout* typeLayout)
{
- this->layout = typeLayout;
-
+ m_layout = typeLayout;
+
// If the layout tells us that there is any uniform data,
// then we need to allocate a constant buffer to hold that data.
//
@@ -1613,8 +1633,8 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout*
// TODO: When/where do we bind this constant buffer into
// a descriptor set for later use?
//
- auto slangLayout = layout->typeLayout;
- size_t uniformSize = layout->typeLayout->getSize();
+ auto slangLayout = getLayout()->getElementTypeLayout();
+ size_t uniformSize = slangLayout->getSize();
if (uniformSize)
{
initBuffer(renderer, uniformSize);
@@ -1626,7 +1646,7 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout*
Index subObjectCount = slangLayout->getSubObjectRangeCount();
objects.setCount(subObjectCount);
- for (auto subObjectRange : layout->subObjectRanges)
+ for (auto subObjectRange : getLayout()->subObjectRanges)
{
RefPtr<CUDAShaderObjectLayout> subObjectLayout = subObjectRange.layout;
@@ -1643,7 +1663,7 @@ SlangResult CUDAShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayout*
// in each entry in this range, based on the layout
// information we already have.
- auto& bindingRangeInfo = layout->m_bindingRanges[subObjectRange.bindingRangeIndex];
+ auto& bindingRangeInfo = getLayout()->m_bindingRanges[subObjectRange.bindingRangeIndex];
for (Index i = 0; i < bindingRangeInfo.count; ++i)
{
RefPtr<CUDAShaderObject> subObject = new CUDAShaderObject();
@@ -1671,15 +1691,18 @@ SlangResult CUDARootShaderObject::init(IRenderer* renderer, CUDAShaderObjectLayo
return SLANG_OK;
}
-SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer)
+SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
- *outRenderer = new CUDARenderer();
- (*outRenderer)->addRef();
+ RefPtr<CUDARenderer> result = new CUDARenderer();
+ SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle));
+ *outRenderer = result.detach();
return SLANG_OK;
}
#else
-SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer)
+SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
+ SLANG_UNUSED(desc);
+ SLANG_UNUSED(windowHandle);
*outRenderer = nullptr;
return SLANG_OK;
}
diff --git a/tools/gfx/cuda/render-cuda.h b/tools/gfx/cuda/render-cuda.h
index e209af02c..39d5b60f8 100644
--- a/tools/gfx/cuda/render-cuda.h
+++ b/tools/gfx/cuda/render-cuda.h
@@ -1,11 +1,9 @@
#pragma once
-#include <cstdint>
-#include "slang.h"
+#include "../renderer-shared.h"
namespace gfx
{
-class IRenderer;
-SlangResult SLANG_MCALL createCUDARenderer(IRenderer** outRenderer);
+SlangResult SLANG_MCALL createCUDARenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer);
}
diff --git a/tools/gfx/d3d11/render-d3d11.cpp b/tools/gfx/d3d11/render-d3d11.cpp
index 3df5e25bc..079d89a59 100644
--- a/tools/gfx/d3d11/render-d3d11.cpp
+++ b/tools/gfx/d3d11/render-d3d11.cpp
@@ -59,7 +59,9 @@ public:
kMaxUAVs = 64,
kMaxRTVs = 8,
};
-
+
+ ~D3D11Renderer() {}
+
// Renderer implementation
virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc, void* inWindowHandle) override;
virtual SLANG_NO_THROW void SLANG_MCALL setClearColor(const float color[4]) override;
@@ -132,8 +134,7 @@ public:
setViewports(UInt count, Viewport const* viewports) override;
virtual SLANG_NO_THROW void SLANG_MCALL
setScissorRects(UInt count, ScissorRect const* rects) override;
- virtual SLANG_NO_THROW void SLANG_MCALL
- setPipelineState(PipelineType pipelineType, IPipelineState* state) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override;
virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex) override;
virtual SLANG_NO_THROW void SLANG_MCALL
drawIndexed(UInt indexCount, UInt startIndex, UInt baseVertex) override;
@@ -144,9 +145,10 @@ public:
{
return RendererType::DirectX11;
}
-
- ~D3D11Renderer() {}
-
+ virtual PipelineStateBase* getCurrentPipeline() override
+ {
+ return m_currentPipelineState;
+ }
protected:
class ScopeNVAPI
@@ -444,17 +446,9 @@ public:
ComPtr<ID3D11InputLayout> m_layout;
};
- class PipelineStateImpl : public IPipelineState, public RefObject
+ class PipelineStateImpl : public PipelineStateBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IPipelineState* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState)
- return static_cast<IPipelineState*>(this);
- return nullptr;
- }
- public:
RefPtr<ShaderProgramImpl> m_program;
RefPtr<PipelineLayoutImpl> m_pipelineLayout;
};
@@ -473,11 +467,26 @@ public:
UINT m_stencilRef;
float m_blendColor[4];
UINT m_sampleMask;
+
+ void init(const GraphicsPipelineStateDesc& inDesc)
+ {
+ PipelineStateBase::PipelineStateDesc pipelineDesc;
+ pipelineDesc.graphics = inDesc;
+ pipelineDesc.type = PipelineType::Graphics;
+ initializeBase(pipelineDesc);
+ }
};
class ComputePipelineStateImpl : public PipelineStateImpl
{
public:
+ void init(const ComputePipelineStateDesc& inDesc)
+ {
+ PipelineStateBase::PipelineStateDesc pipelineDesc;
+ pipelineDesc.compute = inDesc;
+ pipelineDesc.type = PipelineType::Compute;
+ initializeBase(pipelineDesc);
+ }
};
/// Capture a texture to a file
@@ -506,8 +515,7 @@ public:
bool m_renderTargetBindingsDirty = false;
- ComPtr<GraphicsPipelineStateImpl> m_currentGraphicsState;
- ComPtr<ComputePipelineStateImpl> m_currentComputeState;
+ ComPtr<PipelineStateImpl> m_currentPipelineState;
ComPtr<ID3D11RenderTargetView> m_rtvBindings[kMaxRTVs];
ComPtr<ID3D11DepthStencilView> m_dsvBinding;
@@ -521,10 +529,11 @@ public:
bool m_nvapi = false;
};
-SlangResult SLANG_MCALL createD3D11Renderer(IRenderer** outRenderer)
+SlangResult SLANG_MCALL createD3D11Renderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
- *outRenderer = new D3D11Renderer();
- (*outRenderer)->addRef();
+ RefPtr<D3D11Renderer> result = new D3D11Renderer();
+ SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle));
+ *outRenderer = result.detach();
return SLANG_OK;
}
@@ -666,6 +675,8 @@ SlangResult D3D11Renderer::initialize(const Desc& desc, void* inWindowHandle)
{
SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_DXBC, "sm_5_0"));
+ SLANG_RETURN_ON_FAIL(GraphicsAPIRenderer::initialize(desc, inWindowHandle));
+
auto windowHandle = (HWND)inWindowHandle;
m_desc = desc;
@@ -1698,8 +1709,10 @@ void D3D11Renderer::setScissorRects(UInt count, ScissorRect const* rects)
}
-void D3D11Renderer::setPipelineState(PipelineType pipelineType, IPipelineState* state)
+void D3D11Renderer::setPipelineState(IPipelineState* state)
{
+ auto pipelineType = static_cast<PipelineStateBase*>(state)->desc.type;
+
switch(pipelineType)
{
default:
@@ -1722,8 +1735,8 @@ void D3D11Renderer::setPipelineState(PipelineType pipelineType, IPipelineState*
m_immediateContext->IASetInputLayout(stateImpl->m_inputLayout->m_layout);
// VS
-
- m_immediateContext->VSSetShader(programImpl->m_vertexShader, nullptr, 0);
+ if (programImpl->m_vertexShader)
+ m_immediateContext->VSSetShader(programImpl->m_vertexShader, nullptr, 0);
// HS
@@ -1736,15 +1749,15 @@ void D3D11Renderer::setPipelineState(PipelineType pipelineType, IPipelineState*
m_immediateContext->RSSetState(stateImpl->m_rasterizerState);
// PS
-
- m_immediateContext->PSSetShader(programImpl->m_pixelShader, nullptr, 0);
+ if (programImpl->m_pixelShader)
+ m_immediateContext->PSSetShader(programImpl->m_pixelShader, nullptr, 0);
// OM
m_immediateContext->OMSetBlendState(stateImpl->m_blendState, stateImpl->m_blendColor, stateImpl->m_sampleMask);
m_immediateContext->OMSetDepthStencilState(stateImpl->m_depthStencilState, stateImpl->m_stencilRef);
- m_currentGraphicsState = stateImpl;
+ m_currentPipelineState = stateImpl;
}
break;
@@ -1756,8 +1769,7 @@ void D3D11Renderer::setPipelineState(PipelineType pipelineType, IPipelineState*
// CS
m_immediateContext->CSSetShader(programImpl->m_computeShader, nullptr, 0);
-
- m_currentComputeState = stateImpl;
+ m_currentPipelineState = stateImpl;
}
break;
}
@@ -2083,7 +2095,7 @@ Result D3D11Renderer::createGraphicsPipelineState(const GraphicsPipelineStateDes
state->m_blendColor[2] = 0;
state->m_blendColor[3] = 0;
state->m_sampleMask = 0xFFFFFFFF;
-
+ state->init(desc);
*outState = state.detach();
return SLANG_OK;
}
@@ -2099,7 +2111,7 @@ Result D3D11Renderer::createComputePipelineState(const ComputePipelineStateDesc&
RefPtr<ComputePipelineStateImpl> state = new ComputePipelineStateImpl();
state->m_program = programImpl;
state->m_pipelineLayout = pipelineLayoutImpl;
-
+ state->init(desc);
*outState = state.detach();
return SLANG_OK;
}
@@ -2303,7 +2315,7 @@ void D3D11Renderer::_flushGraphicsState()
{
m_targetBindingsDirty[pipelineType] = false;
- auto pipelineState = m_currentGraphicsState.get();
+ auto pipelineState = static_cast<GraphicsPipelineStateImpl*>(m_currentPipelineState.get());
auto rtvCount = pipelineState->m_rtvCount;
auto uavCount = pipelineState->m_pipelineLayout->m_uavCount;
@@ -2326,7 +2338,7 @@ void D3D11Renderer::_flushComputeState()
{
m_targetBindingsDirty[pipelineType] = false;
- auto pipelineState = m_currentComputeState.get();
+ auto pipelineState = static_cast<ComputePipelineStateImpl*>(m_currentPipelineState.get());
auto uavCount = pipelineState->m_pipelineLayout->m_uavCount;
diff --git a/tools/gfx/d3d11/render-d3d11.h b/tools/gfx/d3d11/render-d3d11.h
index 96a6a981c..1c81b688c 100644
--- a/tools/gfx/d3d11/render-d3d11.h
+++ b/tools/gfx/d3d11/render-d3d11.h
@@ -1,13 +1,11 @@
// render-d3d11.h
#pragma once
-#include <cstdint>
-#include "slang.h"
+#include "../renderer-shared.h"
-namespace gfx {
+namespace gfx
+{
-class IRenderer;
-
-SlangResult SLANG_MCALL createD3D11Renderer(IRenderer** outRenderer);
+SlangResult SLANG_MCALL createD3D11Renderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer);
} // gfx
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 407f0dbf2..de7cbd2e2 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -140,8 +140,7 @@ public:
setViewports(UInt count, Viewport const* viewports) override;
virtual SLANG_NO_THROW void SLANG_MCALL
setScissorRects(UInt count, ScissorRect const* rects) override;
- virtual SLANG_NO_THROW void SLANG_MCALL
- setPipelineState(PipelineType pipelineType, IPipelineState* state) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override;
virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex) override;
virtual SLANG_NO_THROW void SLANG_MCALL
drawIndexed(UInt indexCount, UInt startIndex, UInt baseVertex) override;
@@ -152,7 +151,10 @@ public:
{
return RendererType::DirectX12;
}
-
+ virtual PipelineStateBase* getCurrentPipeline() override
+ {
+ return m_currentPipelineState;
+ }
~D3D12Renderer();
protected:
@@ -540,20 +542,25 @@ protected:
D3D12DescriptorHeap m_cpuViewHeap; ///< Cbv, Srv, Uav
D3D12DescriptorHeap m_cpuSamplerHeap; ///< Heap for samplers
- class PipelineStateImpl : public IPipelineState, public RefObject
+ class PipelineStateImpl : public PipelineStateBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IPipelineState* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState)
- return static_cast<IPipelineState*>(this);
- return nullptr;
- }
- public:
- PipelineType m_pipelineType;
RefPtr<PipelineLayoutImpl> m_pipelineLayout;
ComPtr<ID3D12PipelineState> m_pipelineState;
+ void init(const GraphicsPipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Graphics;
+ pipelineDesc.graphics = inDesc;
+ initializeBase(pipelineDesc);
+ }
+ void init(const ComputePipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Compute;
+ pipelineDesc.compute = inDesc;
+ initializeBase(pipelineDesc);
+ }
};
struct BoundVertexBuffer
@@ -760,10 +767,11 @@ protected:
bool m_nvapi = false;
};
-SlangResult SLANG_MCALL createD3D12Renderer(IRenderer** outRenderer)
+SlangResult SLANG_MCALL createD3D12Renderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
- *outRenderer = new D3D12Renderer();
- (*outRenderer)->addRef();
+ RefPtr<D3D12Renderer> result = new D3D12Renderer();
+ SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle));
+ *outRenderer = result.detach();
return SLANG_OK;
}
@@ -1160,7 +1168,7 @@ Result D3D12Renderer::_bindRenderState(PipelineStateImpl* pipelineStateImpl, ID3
{
// TODO: we should only set some of this state as needed...
- auto pipelineTypeIndex = (int) pipelineStateImpl->m_pipelineType;
+ auto pipelineTypeIndex = (int) pipelineStateImpl->desc.type;
auto pipelineLayout = pipelineStateImpl->m_pipelineLayout;
submitter->setRootSignature(pipelineLayout->m_rootSignature);
@@ -1350,6 +1358,8 @@ Result D3D12Renderer::initialize(const Desc& desc, void* inWindowHandle)
{
SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_DXBC, "sm_5_1"));
+ SLANG_RETURN_ON_FAIL(GraphicsAPIRenderer::initialize(desc, inWindowHandle));
+
m_hwnd = (HWND)inWindowHandle;
// Rather than statically link against D3D, we load it dynamically.
@@ -2692,7 +2702,7 @@ void D3D12Renderer::setScissorRects(UInt count, ScissorRect const* rects)
m_commandList->RSSetScissorRects(UINT(count), dxRects);
}
-void D3D12Renderer::setPipelineState(PipelineType pipelineType, IPipelineState* state)
+void D3D12Renderer::setPipelineState(IPipelineState* state)
{
m_currentPipelineState = (PipelineStateImpl*)state;
}
@@ -2702,7 +2712,7 @@ void D3D12Renderer::draw(UInt vertexCount, UInt startVertex)
ID3D12GraphicsCommandList* commandList = m_commandList;
auto pipelineState = m_currentPipelineState.Ptr();
- if (!pipelineState || (pipelineState->m_pipelineType != PipelineType::Graphics))
+ if (!pipelineState || (pipelineState->desc.type != PipelineType::Graphics))
{
assert(!"No graphics pipeline state set");
return;
@@ -3715,9 +3725,9 @@ Result D3D12Renderer::createGraphicsPipelineState(const GraphicsPipelineStateDes
SLANG_RETURN_ON_FAIL(m_device->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(pipelineState.writeRef())));
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
- pipelineStateImpl->m_pipelineType = PipelineType::Graphics;
pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl;
pipelineStateImpl->m_pipelineState = pipelineState;
+ pipelineStateImpl->init(desc);
*outState = pipelineStateImpl.detach();
return SLANG_OK;
}
@@ -3768,9 +3778,9 @@ Result D3D12Renderer::createComputePipelineState(const ComputePipelineStateDesc&
}
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
- pipelineStateImpl->m_pipelineType = PipelineType::Compute;
pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl;
pipelineStateImpl->m_pipelineState = pipelineState;
+ pipelineStateImpl->init(desc);
*outState = pipelineStateImpl.detach();
return SLANG_OK;
}
diff --git a/tools/gfx/d3d12/render-d3d12.h b/tools/gfx/d3d12/render-d3d12.h
index bc28e276b..3304d92dc 100644
--- a/tools/gfx/d3d12/render-d3d12.h
+++ b/tools/gfx/d3d12/render-d3d12.h
@@ -1,13 +1,11 @@
// render-d3d12.h
#pragma once
-#include <cstdint>
-#include "slang.h"
+#include "../renderer-shared.h"
-namespace gfx {
+namespace gfx
+{
-class IRenderer;
-
-SlangResult SLANG_MCALL createD3D12Renderer(IRenderer** outRenderer);
+SlangResult SLANG_MCALL createD3D12Renderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer);
} // gfx
diff --git a/tools/gfx/open-gl/render-gl.cpp b/tools/gfx/open-gl/render-gl.cpp
index 80edbdceb..b84db44b6 100644
--- a/tools/gfx/open-gl/render-gl.cpp
+++ b/tools/gfx/open-gl/render-gl.cpp
@@ -151,8 +151,7 @@ public:
setViewports(UInt count, Viewport const* viewports) override;
virtual SLANG_NO_THROW void SLANG_MCALL
setScissorRects(UInt count, ScissorRect const* rects) override;
- virtual SLANG_NO_THROW void SLANG_MCALL
- setPipelineState(PipelineType pipelineType, IPipelineState* state) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override;
virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex) override;
virtual void SLANG_MCALL
drawIndexed(UInt indexCount, UInt startIndex, UInt baseVertex) override;
@@ -163,7 +162,10 @@ public:
{
return RendererType::OpenGl;
}
-
+ virtual PipelineStateBase* getCurrentPipeline() override
+ {
+ return m_currentPipelineState.Ptr();
+ }
GLRenderer();
~GLRenderer();
@@ -401,20 +403,26 @@ public:
RefPtr<WeakSink<GLRenderer> > m_renderer;
};
- class PipelineStateImpl : public IPipelineState, public RefObject
+ class PipelineStateImpl : public PipelineStateBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IPipelineState* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState)
- return static_cast<IPipelineState*>(this);
- return nullptr;
- }
- public:
RefPtr<ShaderProgramImpl> m_program;
RefPtr<PipelineLayoutImpl> m_pipelineLayout;
RefPtr<InputLayoutImpl> m_inputLayout;
+ void init(const GraphicsPipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Graphics;
+ pipelineDesc.graphics = inDesc;
+ initializeBase(pipelineDesc);
+ }
+ void init(const ComputePipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Compute;
+ pipelineDesc.compute = inDesc;
+ initializeBase(pipelineDesc);
+ }
};
enum class GlPixelFormat
@@ -494,10 +502,11 @@ public:
SLANG_COMPILE_TIME_ASSERT(SLANG_COUNT_OF(s_pixelFormatInfos) == int(GlPixelFormat::CountOf));
}
-SlangResult SLANG_MCALL createGLRenderer(IRenderer** outRenderer)
+SlangResult SLANG_MCALL createGLRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
- *outRenderer = new GLRenderer();
- (*outRenderer)->addRef();
+ RefPtr<GLRenderer> result = new GLRenderer();
+ SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle));
+ *outRenderer = result.detach();
return SLANG_OK;
}
@@ -785,7 +794,9 @@ void GLRenderer::destroyBindingEntries(const BindingState::Desc& desc, const Bin
SLANG_NO_THROW Result SLANG_MCALL GLRenderer::initialize(const Desc& desc, void* inWindowHandle)
{
- SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_GLSL, "sm_5_0"));
+ SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_GLSL, "glsl_440"));
+
+ SLANG_RETURN_ON_FAIL(GraphicsAPIRenderer::initialize(desc, inWindowHandle));
auto windowHandle = (HWND)inWindowHandle;
m_desc = desc;
@@ -1270,8 +1281,7 @@ SLANG_NO_THROW void SLANG_MCALL GLRenderer::setScissorRects(UInt count, ScissorR
}
}
-SLANG_NO_THROW void SLANG_MCALL
- GLRenderer::setPipelineState(PipelineType pipelineType, IPipelineState* state)
+SLANG_NO_THROW void SLANG_MCALL GLRenderer::setPipelineState(IPipelineState* state)
{
auto pipelineStateImpl = (PipelineStateImpl*) state;
@@ -1552,6 +1562,7 @@ Result GLRenderer::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
pipelineStateImpl->m_program = programImpl;
pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl;
pipelineStateImpl->m_inputLayout = inputLayoutImpl;
+ pipelineStateImpl->init(desc);
*outState = pipelineStateImpl.detach();
return SLANG_OK;
}
@@ -1567,6 +1578,7 @@ Result GLRenderer::createComputePipelineState(const ComputePipelineStateDesc& in
RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl();
pipelineStateImpl->m_program = programImpl;
pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl;
+ pipelineStateImpl->init(desc);
*outState = pipelineStateImpl.detach();
return SLANG_OK;
}
diff --git a/tools/gfx/open-gl/render-gl.h b/tools/gfx/open-gl/render-gl.h
index 79ff2d124..ef8bb318c 100644
--- a/tools/gfx/open-gl/render-gl.h
+++ b/tools/gfx/open-gl/render-gl.h
@@ -1,13 +1,10 @@
// render-d3d11.h
#pragma once
-#include <cstdint>
-#include "slang.h"
+#include "../renderer-shared.h"
namespace gfx {
-class IRenderer;
-
-SlangResult SLANG_MCALL createGLRenderer(IRenderer** outRenderer);
+SlangResult SLANG_MCALL createGLRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer);
} // gfx
diff --git a/tools/gfx/render-graphics-common.cpp b/tools/gfx/render-graphics-common.cpp
index adc943d1d..fb01867d8 100644
--- a/tools/gfx/render-graphics-common.cpp
+++ b/tools/gfx/render-graphics-common.cpp
@@ -6,64 +6,9 @@ using namespace Slang;
namespace gfx
{
-const Slang::Guid GfxGUID::IID_ISlangUnknown = SLANG_UUID_ISlangUnknown;
-const Slang::Guid GfxGUID::IID_IDescriptorSetLayout = SLANG_UUID_IDescriptorSetLayout;
-const Slang::Guid GfxGUID::IID_IDescriptorSet = SLANG_UUID_IDescriptorSet;
-const Slang::Guid GfxGUID::IID_IShaderProgram = SLANG_UUID_IShaderProgram;
-const Slang::Guid GfxGUID::IID_IPipelineLayout = SLANG_UUID_IPipelineLayout;
-const Slang::Guid GfxGUID::IID_IInputLayout = SLANG_UUID_IInputLayout;
-const Slang::Guid GfxGUID::IID_IPipelineState = SLANG_UUID_IPipelineState;
-const Slang::Guid GfxGUID::IID_IResourceView = SLANG_UUID_IResourceView;
-const Slang::Guid GfxGUID::IID_ISamplerState = SLANG_UUID_ISamplerState;
-const Slang::Guid GfxGUID::IID_IResource = SLANG_UUID_IResource;
-const Slang::Guid GfxGUID::IID_IBufferResource = SLANG_UUID_IBufferResource;
-const Slang::Guid GfxGUID::IID_ITextureResource = SLANG_UUID_ITextureResource;
-const Slang::Guid GfxGUID::IID_IRenderer = SLANG_UUID_IRenderer;
-const Slang::Guid GfxGUID::IID_IShaderObjectLayout = SLANG_UUID_IShaderObjectLayout;
-const Slang::Guid GfxGUID::IID_IShaderObject = SLANG_UUID_IShaderObject;
-
-gfx::StageType translateStage(SlangStage slangStage)
-{
- switch (slangStage)
- {
- default:
- SLANG_ASSERT(!"unhandled case");
- return gfx::StageType::Unknown;
-
-#define CASE(FROM, TO) \
- case SLANG_STAGE_##FROM: \
- return gfx::StageType::TO
-
- CASE(VERTEX, Vertex);
- CASE(HULL, Hull);
- CASE(DOMAIN, Domain);
- CASE(GEOMETRY, Geometry);
- CASE(FRAGMENT, Fragment);
-
- CASE(COMPUTE, Compute);
-
- CASE(RAY_GENERATION, RayGeneration);
- CASE(INTERSECTION, Intersection);
- CASE(ANY_HIT, AnyHit);
- CASE(CLOSEST_HIT, ClosestHit);
- CASE(MISS, Miss);
- CASE(CALLABLE, Callable);
-
-#undef CASE
- }
-}
-
-class GraphicsCommonShaderObjectLayout : public IShaderObjectLayout, public RefObject
+class GraphicsCommonShaderObjectLayout : public ShaderObjectLayoutBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IShaderObjectLayout* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout)
- return static_cast<IShaderObjectLayout*>(this);
- return nullptr;
- }
-public:
struct BindingRangeInfo
{
slang::BindingType bindingType;
@@ -71,7 +16,12 @@ public:
Index baseIndex;
Index descriptorSetIndex;
Index rangeIndexInDescriptorSet;
- // Index subObjectRangeIndex = -1;
+
+ // Returns true if this binding range consumes a specialization argument slot.
+ bool isSpecializationArg() const
+ {
+ return bindingType == slang::BindingType::ExistentialValue;
+ }
};
struct SubObjectRangeInfo
@@ -91,10 +41,13 @@ public:
struct Builder
{
public:
- Builder(IRenderer* renderer)
+ Builder(RendererBase* renderer)
: m_renderer(renderer)
{}
+ RendererBase* m_renderer;
+ slang::TypeLayoutReflection* m_elementTypeLayout;
+
List<BindingRangeInfo> m_bindingRanges;
List<SubObjectRangeInfo> m_subObjectRanges;
@@ -196,6 +149,15 @@ public:
for (SlangInt r = 0; r < descriptorRangeCount; ++r)
{
auto slangBindingType = typeLayout->getDescriptorSetDescriptorRangeType(s, r);
+
+ switch (slangBindingType)
+ {
+ case slang::BindingType::ExistentialValue:
+ continue;
+ default:
+ break;
+ }
+
auto gfxDescriptorType = _mapDescriptorType(slangBindingType);
IDescriptorSetLayout::SlotRangeDesc descriptorRangeDesc;
@@ -284,9 +246,6 @@ public:
BindingRangeInfo bindingRangeInfo;
bindingRangeInfo.bindingType = slangBindingType;
bindingRangeInfo.count = count;
- // bindingRangeInfo.descriptorSetIndex = descriptorSetIndex;
- // bindingRangeInfo.rangeIndexInDescriptorSet = slotRangeIndex;
- // bindingRangeInfo.subObjectRangeIndex = subObjectRangeIndex;
bindingRangeInfo.baseIndex = baseIndex;
bindingRangeInfo.descriptorSetIndex = descriptorSetIndex;
bindingRangeInfo.rangeIndexInDescriptorSet = rangeIndexInDescriptorSet;
@@ -364,13 +323,10 @@ public:
*outLayout = layout.detach();
return SLANG_OK;
}
-
- IRenderer* m_renderer = nullptr;
- slang::TypeLayoutReflection* m_elementTypeLayout = nullptr;
};
static Result createForElementType(
- IRenderer* renderer,
+ RendererBase* renderer,
slang::TypeLayoutReflection* elementType,
GraphicsCommonShaderObjectLayout** outLayout)
{
@@ -397,15 +353,19 @@ public:
SubObjectRangeInfo const& getSubObjectRange(Index index) { return m_subObjectRanges[index]; }
List<SubObjectRangeInfo> const& getSubObjectRanges() { return m_subObjectRanges; }
- IRenderer* getRenderer() { return m_renderer; }
+ RendererBase* getRenderer() { return m_renderer; }
+ slang::TypeReflection* getType()
+ {
+ return m_elementTypeLayout->getType();
+ }
protected:
Result _init(Builder const* builder)
{
auto renderer = builder->m_renderer;
- m_renderer = renderer;
- m_elementTypeLayout = builder->m_elementTypeLayout;
+ initBase(renderer, builder->m_elementTypeLayout);
+
m_bindingRanges = builder->m_bindingRanges;
for (auto descriptorSetBuildInfo : builder->m_descriptorSetBuildInfos)
@@ -430,16 +390,12 @@ protected:
m_samplerCount = builder->m_samplerCount;
m_combinedTextureSamplerCount = builder->m_combinedTextureSamplerCount;
m_subObjectCount = builder->m_subObjectCount;
-
m_subObjectRanges = builder->m_subObjectRanges;
-
return SLANG_OK;
}
- IRenderer* m_renderer;
List<RefPtr<DescriptorSetInfo>> m_descriptorSets;
List<BindingRangeInfo> m_bindingRanges;
- slang::TypeLayoutReflection* m_elementTypeLayout;
Index m_resourceViewCount = 0;
Index m_samplerCount = 0;
Index m_combinedTextureSamplerCount = 0;
@@ -461,7 +417,7 @@ public:
struct Builder : Super::Builder
{
Builder(IRenderer* renderer)
- : Super::Builder(renderer)
+ : Super::Builder(static_cast<RendererBase*>(renderer))
{}
Result build(EntryPointLayout** outLayout)
@@ -566,7 +522,7 @@ public:
struct Builder : Super::Builder
{
Builder(IRenderer* renderer)
- : Super::Builder(renderer)
+ : Super::Builder(static_cast<RendererBase*>(renderer))
{}
Result build(GraphicsCommonProgramLayout** outLayout)
@@ -709,18 +665,9 @@ protected:
ComPtr<IPipelineLayout> m_pipelineLayout;
};
-class GraphicsCommonShaderObject : public IShaderObject, public RefObject
+class GraphicsCommonShaderObject : public ShaderObjectBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IShaderObject* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject)
- return static_cast<IShaderObject*>(this);
- return nullptr;
- }
-
-public:
static Result create(
IRenderer* renderer,
GraphicsCommonShaderObjectLayout* layout,
@@ -733,7 +680,7 @@ public:
return SLANG_OK;
}
- IRenderer* getRenderer() { return m_layout->getRenderer(); }
+ RendererBase* getRenderer() { return m_layout->getRenderer(); }
SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() SLANG_OVERRIDE { return 0; }
@@ -746,7 +693,7 @@ public:
GraphicsCommonShaderObjectLayout* getLayout()
{
- return m_layout;
+ return static_cast<GraphicsCommonShaderObjectLayout*>(m_layout.Ptr());
}
SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getElementTypeLayout() SLANG_OVERRIDE
@@ -772,29 +719,20 @@ public:
{
if (offset.bindingRangeIndex < 0)
return SLANG_E_INVALID_ARG;
- if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount())
+ auto layout = getLayout();
+ if (offset.bindingRangeIndex >= layout->getBindingRangeCount())
return SLANG_E_INVALID_ARG;
- auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex);
-
- // TODO: Is this reasonable to store the base index directly in the binding range?
- m_objects[bindingRange.baseIndex + offset.bindingArrayIndex] =
- reinterpret_cast<GraphicsCommonShaderObject*>(object);
- // auto& subObjectRange =
- // m_layout->getSubObjectRange(bindingRange.subObjectRangeIndex);
- // m_objects[subObjectRange.baseIndex + offset.bindingArrayIndex] = object;
+ auto subObject = static_cast<GraphicsCommonShaderObject*>(object);
+ if (!subObject->m_bindingFinalized)
+ return SLANG_E_INVALID_ARG;
-#if 0
+ auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex);
- SLANG_ASSERT(bindingRange.descriptorSetIndex >= 0);
- SLANG_ASSERT(bindingRange.descriptorSetIndex < m_descriptorSets.getCount());
- auto& descriptorSet = m_descriptorSets[bindingRange.descriptorSetIndex];
+ // TODO: Is this reasonable to store the base index directly in the binding range?
+ m_objects[bindingRange.baseIndex + offset.bindingArrayIndex] = subObject;
- descriptorSet->setConstantBuffer(bindingRange.rangeIndexInDescriptorSet, offset.bindingArrayIndex, buffer);
- return SLANG_OK;
-#else
return SLANG_E_NOT_IMPLEMENTED;
-#endif
}
virtual SLANG_NO_THROW Result SLANG_MCALL
@@ -804,9 +742,10 @@ public:
SLANG_ASSERT(outObject);
if (offset.bindingRangeIndex < 0)
return SLANG_E_INVALID_ARG;
- if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount())
+ auto layout = getLayout();
+ if (offset.bindingRangeIndex >= layout->getBindingRangeCount())
return SLANG_E_INVALID_ARG;
- auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex);
+ auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex);
auto object = m_objects[bindingRange.baseIndex + offset.bindingArrayIndex].Ptr();
object->addRef();
@@ -833,9 +772,10 @@ public:
{
if (offset.bindingRangeIndex < 0)
return SLANG_E_INVALID_ARG;
- if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount())
+ auto layout = getLayout();
+ if (offset.bindingRangeIndex >= layout->getBindingRangeCount())
return SLANG_E_INVALID_ARG;
- auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex);
+ auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex);
m_resourceViews[bindingRange.baseIndex + offset.bindingArrayIndex] = resourceView;
return SLANG_OK;
@@ -846,9 +786,10 @@ public:
{
if (offset.bindingRangeIndex < 0)
return SLANG_E_INVALID_ARG;
- if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount())
+ auto layout = getLayout();
+ if (offset.bindingRangeIndex >= layout->getBindingRangeCount())
return SLANG_E_INVALID_ARG;
- auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex);
+ auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex);
m_samplers[bindingRange.baseIndex + offset.bindingArrayIndex] = sampler;
return SLANG_OK;
@@ -859,9 +800,10 @@ public:
{
if (offset.bindingRangeIndex < 0)
return SLANG_E_INVALID_ARG;
- if (offset.bindingRangeIndex >= m_layout->getBindingRangeCount())
+ auto layout = getLayout();
+ if (offset.bindingRangeIndex >= layout->getBindingRangeCount())
return SLANG_E_INVALID_ARG;
- auto& bindingRange = m_layout->getBindingRange(offset.bindingRangeIndex);
+ auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex);
auto& slot = m_combinedTextureSamplers[bindingRange.baseIndex + offset.bindingArrayIndex];
slot.textureView = textureView;
@@ -869,6 +811,55 @@ public:
return SLANG_OK;
}
+public:
+ // Appends all types that are used to specialize the element type of this shader object in `args` list.
+ virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
+ {
+ if (!m_bindingFinalized)
+ return SLANG_FAIL;
+ auto& subObjectRanges = getLayout()->getSubObjectRanges();
+ // The following logic is built on the assumption that all fields that involve existential types (and
+ // therefore require specialization) will results in a sub-object range in the type layout.
+ // This allows us to simply scan the sub-object ranges to find out all specialization arguments.
+ for (Index subObjIndex = 0; subObjIndex < subObjectRanges.getCount(); subObjIndex++)
+ {
+ // Retrieve the corresponding binding range of the sub object.
+ auto bindingRange = getLayout()->getBindingRange(subObjectRanges[subObjIndex].bindingRangeIndex);
+ switch (bindingRange.bindingType)
+ {
+ case slang::BindingType::ExistentialValue:
+ {
+ // A binding type of `ExistentialValue` means the sub-object represents a interface-typed field.
+ // In this case the specialization argument for this field is the actual specialized type of the bound
+ // shader object. If the shader object's type is an ordinary type without existential fields, then the
+ // type argument will simply be the ordinary type. But if the sub object's type is itself a specialized
+ // type, we need to make sure to use that type as the specialization argument.
+
+ // TODO: need to implement the case where the field is an array of existential values.
+ SLANG_ASSERT(bindingRange.count == 1);
+ ExtendedShaderObjectType specializedSubObjType;
+ SLANG_RETURN_ON_FAIL(m_objects[subObjIndex]->getSpecializedShaderObjectType(&specializedSubObjType));
+ args.add(specializedSubObjType);
+ break;
+ }
+ case slang::BindingType::ParameterBlock:
+ case slang::BindingType::ConstantBuffer:
+ // Currently we only handle the case where the field's type is
+ // `ParameterBlock<SomeStruct>` or `ConstantBuffer<SomeStruct>`, where `SomeStruct` is a struct type
+ // (not directly an interface type). In this case, we just recursively collect the specialization arguments
+ // from the bound sub object.
+ SLANG_RETURN_ON_FAIL(m_objects[subObjIndex]->collectSpecializationArgs(args));
+ // TODO: we need to handle the case where the field is of the form `ParameterBlock<IFoo>`. We should treat
+ // this case the same way as the `ExistentialValue` case here, but currently we lack a mechanism to distinguish
+ // the two scenarios.
+ break;
+ }
+ // TODO: need to handle another case where specialization happens on resources fields e.g. `StructuredBuffer<IFoo>`.
+ }
+ return SLANG_OK;
+ }
+
+
protected:
friend class ProgramVars;
@@ -953,7 +944,7 @@ protected:
IPipelineLayout* pipelineLayout,
Index& ioRootIndex)
{
- GraphicsCommonShaderObjectLayout* layout = m_layout;
+ GraphicsCommonShaderObjectLayout* layout = getLayout();
// Create the descritpor sets required by the layout...
//
@@ -979,7 +970,7 @@ protected:
Result _bindIntoDescriptorSet(
IDescriptorSet* descriptorSet, Index baseRangeIndex, Index subObjectRangeArrayIndex)
{
- GraphicsCommonShaderObjectLayout* layout = m_layout;
+ GraphicsCommonShaderObjectLayout* layout = getLayout();
if (m_buffer)
{
@@ -1063,7 +1054,7 @@ protected:
public:
virtual Result _bindIntoDescriptorSets(ComPtr<IDescriptorSet>* descriptorSets)
{
- GraphicsCommonShaderObjectLayout* layout = m_layout;
+ GraphicsCommonShaderObjectLayout* layout = getLayout();
if (m_buffer)
{
@@ -1134,7 +1125,6 @@ public:
return SLANG_OK;
}
- RefPtr<GraphicsCommonShaderObjectLayout> m_layout = nullptr;
ComPtr<IBufferResource> m_buffer;
List<ComPtr<IResourceView>> m_resourceViews;
@@ -1333,8 +1323,7 @@ Result GraphicsAPIRenderer::initProgramCommon(
SLANG_RETURN_ON_FAIL(builder.build(programLayout.writeRef()));
}
-
- program->m_slangProgram = slangProgram;
+ program->slangProgram = slangProgram;
program->m_layout = programLayout;
return SLANG_OK;
@@ -1347,68 +1336,16 @@ Result SLANG_MCALL
if (!programVars)
return SLANG_E_INVALID_HANDLE;
- programVars->apply(this, pipelineType);
- return SLANG_OK;
-}
+ SLANG_RETURN_ON_FAIL(maybeSpecializePipeline(programVars));
-SLANG_NO_THROW Result SLANG_MCALL
- gfx::GraphicsAPIRenderer::getFeatures(
- const char** outFeatures, UInt bufferSize, UInt* outFeatureCount)
-{
- if (bufferSize >= (UInt)m_features.getCount())
- {
- for (Index i = 0; i < m_features.getCount(); i++)
- {
- outFeatures[i] = m_features[i].getUnownedSlice().begin();
- }
- }
- if (outFeatureCount)
- *outFeatureCount = (UInt)m_features.getCount();
- return SLANG_OK;
-}
-
-SLANG_NO_THROW bool SLANG_MCALL gfx::GraphicsAPIRenderer::hasFeature(const char* featureName)
-{
- return m_features.findFirstIndex([&](Slang::String x) { return x == featureName; }) != -1;
-}
-
-SLANG_NO_THROW Result SLANG_MCALL gfx::GraphicsAPIRenderer::getSlangSession(slang::ISession** outSlangSession)
-{
- *outSlangSession = slangContext.session.get();
- slangContext.session->addRef();
+ // Apply shader parameter bindings.
+ programVars->apply(this, pipelineType);
return SLANG_OK;
}
-GraphicsCommonShaderProgram::~GraphicsCommonShaderProgram()
-{
- // Note: It might not seem like this destructor is needed at all, since
- // it is empty.
- //
- // In pratice, though, it seems to be required because the `m_layout`
- // field is declared in the coresponding header before the `GraphicsCommonProgramLayout`
- // is declared (we only have a forward declaration).
- //
- // `m_layout` is a `RefPtr`, and it seems that the compiler (or at least
- // the Visual Studio compiler) either cannot synthesize a destructor for
- // the type that properly destructs the field, or it simply synthesizes
- // an incorect destructor.
- //
- // I suspect that part of the problem stems from the way that `GraphicsCommonProgramLayout`
- // inherits from `RefObject` via multiple inheritance.
- //
- // No matter what, defining the destructor here in a file where
- // the declaration of `GraphicsCommonProgramLayout` is visible
- // seems to result in a correct destructor being emitted.
- //
- // TODO: Ther simpler and more robust fix would be to move the declaration
- // of `GraphicsCommonProgramLayout` and related types to the header.
-}
-
-IShaderProgram* GraphicsCommonShaderProgram::getInterface(const Guid& guid)
+GraphicsCommonProgramLayout* gfx::GraphicsCommonShaderProgram::getLayout() const
{
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram)
- return static_cast<IShaderProgram*>(this);
- return nullptr;
+ return static_cast<GraphicsCommonProgramLayout*>(m_layout.Ptr());
}
void GraphicsAPIRenderer::preparePipelineDesc(GraphicsPipelineStateDesc& desc)
@@ -1432,11 +1369,4 @@ void GraphicsAPIRenderer::preparePipelineDesc(ComputePipelineStateDesc& desc)
}
}
-IRenderer* gfx::GraphicsAPIRenderer::getInterface(const Guid& guid)
-{
- return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer)
- ? static_cast<IRenderer*>(this)
- : nullptr;
-}
-
}
diff --git a/tools/gfx/render-graphics-common.h b/tools/gfx/render-graphics-common.h
index f4d9567e3..59dfb888a 100644
--- a/tools/gfx/render-graphics-common.h
+++ b/tools/gfx/render-graphics-common.h
@@ -1,41 +1,25 @@
#pragma once
-#include "tools/gfx/render.h"
+#include "tools/gfx/renderer-shared.h"
#include "core/slang-basic.h"
#include "tools/gfx/slang-context.h"
namespace gfx
{
-
class GraphicsCommonProgramLayout;
-class GraphicsCommonShaderProgram : public IShaderProgram, public Slang::RefObject
+class GraphicsCommonShaderProgram : public ShaderProgramBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
-
- IShaderProgram* getInterface(const Slang::Guid& guid);
-
- GraphicsCommonProgramLayout* getLayout() const { return m_layout; }
-
-protected:
- ~GraphicsCommonShaderProgram();
-
+ GraphicsCommonProgramLayout* getLayout() const;
private:
friend class GraphicsAPIRenderer;
- ComPtr<slang::IComponentType> m_slangProgram;
- Slang::RefPtr<GraphicsCommonProgramLayout> m_layout;
+ Slang::RefPtr<ShaderObjectLayoutBase> m_layout;
};
-class GraphicsAPIRenderer : public IRenderer, public Slang::RefObject
+class GraphicsAPIRenderer : public RendererBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- virtual SLANG_NO_THROW Result SLANG_MCALL getFeatures(
- const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) SLANG_OVERRIDE;
- virtual SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* featureName) SLANG_OVERRIDE;
- virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) SLANG_OVERRIDE;
-
virtual SLANG_NO_THROW Result SLANG_MCALL createShaderObjectLayout(
slang::TypeLayoutReflection* typeLayout, IShaderObjectLayout** outLayout) SLANG_OVERRIDE;
virtual SLANG_NO_THROW Result SLANG_MCALL
@@ -47,34 +31,9 @@ public:
bindRootShaderObject(PipelineType pipelineType, IShaderObject* object) SLANG_OVERRIDE;
void preparePipelineDesc(GraphicsPipelineStateDesc& desc);
void preparePipelineDesc(ComputePipelineStateDesc& desc);
- IRenderer* getInterface(const Slang::Guid& guid);
Result initProgramCommon(
GraphicsCommonShaderProgram* program,
IShaderProgram::Desc const& desc);
-
-protected:
- Slang::List<Slang::String> m_features;
- SlangContext slangContext;
};
-
-struct GfxGUID
-{
- static const Slang::Guid IID_ISlangUnknown;
- static const Slang::Guid IID_IDescriptorSetLayout;
- static const Slang::Guid IID_IDescriptorSet;
- static const Slang::Guid IID_IShaderProgram;
- static const Slang::Guid IID_IPipelineLayout;
- static const Slang::Guid IID_IPipelineState;
- static const Slang::Guid IID_IResourceView;
- static const Slang::Guid IID_ISamplerState;
- static const Slang::Guid IID_IResource;
- static const Slang::Guid IID_IBufferResource;
- static const Slang::Guid IID_ITextureResource;
- static const Slang::Guid IID_IInputLayout;
- static const Slang::Guid IID_IRenderer;
- static const Slang::Guid IID_IShaderObjectLayout;
- static const Slang::Guid IID_IShaderObject;
-};
-
}
diff --git a/tools/gfx/render.cpp b/tools/gfx/render.cpp
index a25393714..042f4253b 100644
--- a/tools/gfx/render.cpp
+++ b/tools/gfx/render.cpp
@@ -85,35 +85,35 @@ extern "C"
}
}
- SGRendererCreateFunc SLANG_MCALL gfxGetCreateFunc(RendererType type)
+ SLANG_GFX_API SlangResult SLANG_MCALL gfxCreateRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
- switch (type)
+ switch (desc->rendererType)
{
#if SLANG_WINDOWS_FAMILY
case RendererType::DirectX11:
{
- return &createD3D11Renderer;
+ return createD3D11Renderer(desc, windowHandle, outRenderer);
}
case RendererType::DirectX12:
{
- return &createD3D12Renderer;
+ return createD3D12Renderer(desc, windowHandle, outRenderer);
}
case RendererType::OpenGl:
{
- return &createGLRenderer;
+ return createGLRenderer(desc, windowHandle, outRenderer);
}
case RendererType::Vulkan:
{
- return &createVKRenderer;
+ return createVKRenderer(desc, windowHandle, outRenderer);
}
case RendererType::CUDA:
{
- return &createCUDARenderer;
+ return createCUDARenderer(desc, windowHandle, outRenderer);
}
#endif
default:
- return nullptr;
+ return SLANG_FAIL;
}
}
diff --git a/tools/gfx/render.h b/tools/gfx/render.h
index 325bab4ee..13af56550 100644
--- a/tools/gfx/render.h
+++ b/tools/gfx/render.h
@@ -880,6 +880,7 @@ public:
setSampler(ShaderOffset const& offset, ISamplerState* sampler) = 0;
virtual SLANG_NO_THROW Result SLANG_MCALL setCombinedTextureSampler(
ShaderOffset const& offset, IResourceView* textureView, ISamplerState* sampler) = 0;
+ virtual SLANG_NO_THROW Result SLANG_MCALL finalizeBindings() = 0;
};
#define SLANG_UUID_IShaderObject \
{ \
@@ -1104,18 +1105,17 @@ public:
struct Desc
{
+ RendererType rendererType; // The underlying API/Platform of the renderer.
int width = 0; // Width in pixels
int height = 0; // height in pixels
const char* adapter = nullptr; // Name to identify the adapter to use
int requiredFeatureCount = 0; // Number of required features.
const char** requiredFeatures = nullptr; // Array of required feature names, whose size is `requiredFeatureCount`.
int nvapiExtnSlot = -1; // The slot (typically UAV) used to identify NVAPI intrinsics. If >=0 NVAPI is required.
+ ISlangFileSystem* shaderCacheFileSystem = nullptr; // The file system for loading cached shader kernels.
SlangDesc slang = {}; // Configurations for Slang.
};
- // Will return with SLANG_E_NOT_AVAILABLE if NVAPI can't be initialized and nvapiExtnSlot >= 0
- virtual SLANG_NO_THROW Result SLANG_MCALL initialize(const Desc& desc, void* inWindowHandle) = 0;
-
virtual SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* feature) = 0;
/// Returns a list of features supported by the renderer.
@@ -1342,7 +1342,7 @@ public:
setScissorRects(1, &rect);
}
- virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(PipelineType pipelineType, IPipelineState* state) = 0;
+ virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) = 0;
virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex = 0) = 0;
virtual SLANG_NO_THROW void SLANG_MCALL drawIndexed(UInt indexCount, UInt startIndex = 0, UInt baseVertex = 0) = 0;
@@ -1374,8 +1374,6 @@ inline void IRenderer::setVertexBuffer(UInt slot, IBufferResource* buffer, UInt
extern "C"
{
- typedef SlangResult(SLANG_MCALL * SGRendererCreateFunc)(IRenderer** outRenderer);
-
/// Gets the size in bytes of a Format type. Returns 0 if a size is not defined/invalid
SLANG_GFX_API size_t SLANG_MCALL gfxGetFormatSize(Format format);
@@ -1394,7 +1392,7 @@ extern "C"
SLANG_GFX_API const char* SLANG_MCALL gfxGetRendererName(RendererType type);
/// Given a type returns a function that can construct it, or nullptr if there isn't one
- SLANG_GFX_API SGRendererCreateFunc SLANG_MCALL gfxGetCreateFunc(RendererType type);
+ SLANG_GFX_API SlangResult SLANG_MCALL gfxCreateRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer);
}
}// renderer_test
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 2072d52ef..a2e1751fd 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -1,11 +1,60 @@
#include "renderer-shared.h"
#include "render-graphics-common.h"
+#include "core/slang-io.h"
+#include "core/slang-token-reader.h"
using namespace Slang;
namespace gfx
{
+const Slang::Guid GfxGUID::IID_ISlangUnknown = SLANG_UUID_ISlangUnknown;
+const Slang::Guid GfxGUID::IID_IDescriptorSetLayout = SLANG_UUID_IDescriptorSetLayout;
+const Slang::Guid GfxGUID::IID_IDescriptorSet = SLANG_UUID_IDescriptorSet;
+const Slang::Guid GfxGUID::IID_IShaderProgram = SLANG_UUID_IShaderProgram;
+const Slang::Guid GfxGUID::IID_IPipelineLayout = SLANG_UUID_IPipelineLayout;
+const Slang::Guid GfxGUID::IID_IInputLayout = SLANG_UUID_IInputLayout;
+const Slang::Guid GfxGUID::IID_IPipelineState = SLANG_UUID_IPipelineState;
+const Slang::Guid GfxGUID::IID_IResourceView = SLANG_UUID_IResourceView;
+const Slang::Guid GfxGUID::IID_ISamplerState = SLANG_UUID_ISamplerState;
+const Slang::Guid GfxGUID::IID_IResource = SLANG_UUID_IResource;
+const Slang::Guid GfxGUID::IID_IBufferResource = SLANG_UUID_IBufferResource;
+const Slang::Guid GfxGUID::IID_ITextureResource = SLANG_UUID_ITextureResource;
+const Slang::Guid GfxGUID::IID_IRenderer = SLANG_UUID_IRenderer;
+const Slang::Guid GfxGUID::IID_IShaderObjectLayout = SLANG_UUID_IShaderObjectLayout;
+const Slang::Guid GfxGUID::IID_IShaderObject = SLANG_UUID_IShaderObject;
+
+gfx::StageType translateStage(SlangStage slangStage)
+{
+ switch (slangStage)
+ {
+ default:
+ SLANG_ASSERT(!"unhandled case");
+ return gfx::StageType::Unknown;
+
+#define CASE(FROM, TO) \
+case SLANG_STAGE_##FROM: \
+return gfx::StageType::TO
+
+ CASE(VERTEX, Vertex);
+ CASE(HULL, Hull);
+ CASE(DOMAIN, Domain);
+ CASE(GEOMETRY, Geometry);
+ CASE(FRAGMENT, Fragment);
+
+ CASE(COMPUTE, Compute);
+
+ CASE(RAY_GENERATION, RayGeneration);
+ CASE(INTERSECTION, Intersection);
+ CASE(ANY_HIT, AnyHit);
+ CASE(CLOSEST_HIT, ClosestHit);
+ CASE(MISS, Miss);
+ CASE(CALLABLE, Callable);
+
+#undef CASE
+ }
+}
+
IResource* BufferResource::getInterface(const Slang::Guid& guid)
{
if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IResource ||
@@ -96,4 +145,396 @@ Result createProgramFromSlang(IRenderer* renderer, IShaderProgram::Desc const& o
return renderer->createProgram(programDesc, outProgram);
}
+IShaderObject* gfx::ShaderObjectBase::getInterface(const Guid& guid)
+{
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObject)
+ return static_cast<IShaderObject*>(this);
+ return nullptr;
+}
+
+IShaderProgram* gfx::ShaderProgramBase::getInterface(const Guid& guid)
+{
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderProgram)
+ return static_cast<IShaderProgram*>(this);
+ return nullptr;
+}
+
+IPipelineState* gfx::PipelineStateBase::getInterface(const Guid& guid)
+{
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState)
+ return static_cast<IPipelineState*>(this);
+ return nullptr;
+}
+
+void PipelineStateBase::initializeBase(const PipelineStateDesc& inDesc)
+{
+ desc = inDesc;
+
+ auto program = desc.getProgram();
+ isSpecializable = (program->slangProgram && program->slangProgram->getSpecializationParamCount() != 0);
+}
+
+IRenderer* gfx::RendererBase::getInterface(const Guid& guid)
+{
+ return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IRenderer)
+ ? static_cast<IRenderer*>(this)
+ : nullptr;
+}
+
+SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc, void* inWindowHandle)
+{
+ SLANG_UNUSED(inWindowHandle);
+
+ shaderCache.init(desc.shaderCacheFileSystem);
+ return SLANG_OK;
+}
+
+SLANG_NO_THROW Result SLANG_MCALL RendererBase::getFeatures(
+ const char** outFeatures, UInt bufferSize, UInt* outFeatureCount)
+{
+ if (bufferSize >= (UInt)m_features.getCount())
+ {
+ for (Index i = 0; i < m_features.getCount(); i++)
+ {
+ outFeatures[i] = m_features[i].getUnownedSlice().begin();
+ }
+ }
+ if (outFeatureCount)
+ *outFeatureCount = (UInt)m_features.getCount();
+ return SLANG_OK;
+}
+
+SLANG_NO_THROW bool SLANG_MCALL RendererBase::hasFeature(const char* featureName)
+{
+ return m_features.findFirstIndex([&](Slang::String x) { return x == featureName; }) != -1;
+}
+
+SLANG_NO_THROW Result SLANG_MCALL RendererBase::getSlangSession(slang::ISession** outSlangSession)
+{
+ *outSlangSession = slangContext.session.get();
+ slangContext.session->addRef();
+ return SLANG_OK;
+}
+
+ShaderComponentID ShaderCache::getComponentId(slang::TypeReflection* type)
+{
+ ComponentKey key;
+ key.typeName = UnownedStringSlice(type->getName());
+ switch (type->getKind())
+ {
+ case slang::TypeReflection::Kind::Specialized:
+ // TODO: collect specialization arguments and append them to `key`.
+ SLANG_UNIMPLEMENTED_X("specialized type");
+ default:
+ break;
+ }
+ key.updateHash();
+ return getComponentId(key);
+}
+
+ShaderComponentID ShaderCache::getComponentId(UnownedStringSlice name)
+{
+ ComponentKey key;
+ key.typeName = name;
+ key.updateHash();
+ return getComponentId(key);
+}
+
+ShaderComponentID ShaderCache::getComponentId(ComponentKey key)
+{
+ ShaderComponentID componentId = 0;
+ if (componentIds.TryGetValue(key, componentId))
+ return componentId;
+ OwningComponentKey owningTypeKey;
+ owningTypeKey.hash = key.hash;
+ owningTypeKey.typeName = key.typeName;
+ owningTypeKey.specializationArgs.addRange(key.specializationArgs);
+ ShaderComponentID resultId = static_cast<ShaderComponentID>(componentIds.Count());
+ componentIds[owningTypeKey] = resultId;
+ return resultId;
+}
+
+void ShaderCache::init(ISlangFileSystem* cacheFileSystem)
+{
+ fileSystem = cacheFileSystem;
+
+ ComPtr<ISlangBlob> indexFileBlob;
+ if (fileSystem && fileSystem->loadFile("index", indexFileBlob.writeRef()) == SLANG_OK)
+ {
+ UnownedStringSlice indexText = UnownedStringSlice(static_cast<const char*>(indexFileBlob->getBufferPointer()));
+ TokenReader reader = TokenReader(indexText);
+ auto componentCountInFileSystem = reader.ReadUInt();
+ for (uint32_t i = 0; i < componentCountInFileSystem; i++)
+ {
+ OwningComponentKey key;
+ auto componentId = reader.ReadUInt();
+ key.typeName = reader.ReadWord();
+ key.specializationArgs.setCount(reader.ReadUInt());
+ for (auto& arg : key.specializationArgs)
+ arg = reader.ReadUInt();
+ componentIds[key] = componentId;
+ }
+ }
+}
+
+void ShaderCache::writeToFileSystem(ISlangMutableFileSystem* outputFileSystem)
+{
+ StringBuilder indexBuilder;
+ indexBuilder << componentIds.Count() << Slang::EndLine;
+ for (auto id : componentIds)
+ {
+ indexBuilder << id.Value << " ";
+ indexBuilder << id.Key.typeName << " " << id.Key.specializationArgs.getCount();
+ for (auto arg : id.Key.specializationArgs)
+ indexBuilder << " " << arg;
+ indexBuilder << Slang::EndLine;
+ }
+ outputFileSystem->saveFile("index", indexBuilder.getBuffer(), indexBuilder.getLength());
+ for (auto& binary : shaderBinaries)
+ {
+ ComPtr<ISlangBlob> blob;
+ binary.Value->writeToBlob(blob.writeRef());
+ outputFileSystem->saveFile(String(binary.Key).getBuffer(), blob->getBufferPointer(), blob->getBufferSize());
+ }
+}
+
+Slang::RefPtr<ShaderBinary> ShaderCache::tryLoadShaderBinary(ShaderComponentID componentId)
+{
+ Slang::ComPtr<ISlangBlob> entryBlob;
+ Slang::RefPtr<ShaderBinary> binary;
+ if (shaderBinaries.TryGetValue(componentId, binary))
+ return binary;
+
+ if (fileSystem && fileSystem->loadFile(String(componentId).getBuffer(), entryBlob.writeRef()) == SLANG_OK)
+ {
+ binary = new ShaderBinary();
+ binary->loadFromBlob(entryBlob.get());
+ return binary;
+ }
+ return nullptr;
+}
+
+void ShaderCache::addShaderBinary(ShaderComponentID componentId, ShaderBinary* binary)
+{
+ shaderBinaries[componentId] = binary;
+}
+
+void ShaderCache::addSpecializedPipeline(PipelineKey key, Slang::RefPtr<PipelineStateBase> specializedPipeline)
+{
+ specializedPipelines[key] = specializedPipeline;
+}
+
+struct ShaderBinaryEntryHeader
+{
+ StageType stage;
+ uint32_t nameLength;
+ uint32_t codeLength;
+};
+
+Result ShaderBinary::loadFromBlob(ISlangBlob* blob)
+{
+ MemoryStreamBase memStream(Slang::FileAccess::Read, blob->getBufferPointer(), blob->getBufferSize());
+ uint32_t nameLength = 0;
+ ShaderBinaryEntryHeader header;
+ if (memStream.read(&header, sizeof(header)) != sizeof(header))
+ return SLANG_FAIL;
+ const uint8_t* name = memStream.getContents().getBuffer() + memStream.getPosition();
+ const uint8_t* code = name + header.nameLength;
+ entryPointName = reinterpret_cast<const char*>(name);
+ stage = header.stage;
+ source.addRange(code, header.codeLength);
+ return SLANG_OK;
+}
+
+Result ShaderBinary::writeToBlob(ISlangBlob** outBlob)
+{
+ OwnedMemoryStream outStream(FileAccess::Write);
+ ShaderBinaryEntryHeader header;
+ header.stage = stage;
+ header.nameLength = static_cast<uint32_t>(entryPointName.getLength() + 1);
+ header.codeLength = static_cast<uint32_t>(source.getCount());
+ outStream.write(&header, sizeof(header));
+ outStream.write(entryPointName.getBuffer(), header.nameLength - 1);
+ uint8_t zeroTerminator = 0;
+ outStream.write(&zeroTerminator, 1);
+ outStream.write(source.getBuffer(), header.codeLength);
+ RefPtr<RawBlob> blob = new RawBlob(outStream.getContents().getBuffer(), outStream.getContents().getCount());
+ *outBlob = blob.detach();
+ return SLANG_OK;
+}
+
+IShaderObjectLayout* ShaderObjectLayoutBase::getInterface(const Slang::Guid& guid)
+{
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderObjectLayout)
+ return static_cast<IShaderObjectLayout*>(this);
+ return nullptr;
+}
+
+void ShaderObjectLayoutBase::initBase(RendererBase* renderer, slang::TypeLayoutReflection* elementTypeLayout)
+{
+ m_renderer = renderer;
+ m_elementTypeLayout = elementTypeLayout;
+ m_componentID = m_renderer->shaderCache.getComponentId(m_elementTypeLayout->getType());
+}
+
+SLANG_NO_THROW Result SLANG_MCALL ShaderObjectBase::finalizeBindings()
+{
+ m_bindingFinalized = true;
+
+ // With all binding fixed, the shader object's type can be determined by specializing the
+ // shader object's type with the type of bound sub objects.
+ // Now obtain a componentID for the specialized shader object type from the shader cache.
+ SLANG_RETURN_ON_FAIL(getSpecializedShaderObjectType(&shaderObjectType));
+ return SLANG_OK;
+}
+
+
+// Get the final type this shader object represents. If the shader object's type has existential fields,
+// this function will return a specialized type using the bound sub-objects' type as specialization argument.
+
+Result ShaderObjectBase::getSpecializedShaderObjectType(ExtendedShaderObjectType* outType)
+{
+ if (shaderObjectType.slangType)
+ *outType = shaderObjectType;
+ ExtendedShaderObjectTypeList specializationArgs;
+ SLANG_RETURN_ON_FAIL(collectSpecializationArgs(specializationArgs));
+ if (specializationArgs.getCount() == 0)
+ {
+ shaderObjectType.componentID = getLayout()->getComponentID();
+ shaderObjectType.slangType = getLayout()->getElementTypeLayout()->getType();
+ }
+ else
+ {
+ shaderObjectType.slangType = getRenderer()->slangContext.session->specializeType(
+ getElementTypeLayout()->getType(),
+ specializationArgs.components.getArrayView().getBuffer(), specializationArgs.getCount());
+ shaderObjectType.componentID = getRenderer()->shaderCache.getComponentId(shaderObjectType.slangType);
+ }
+ *outType = shaderObjectType;
+ return SLANG_OK;
+}
+
+Result RendererBase::maybeSpecializePipeline(ShaderObjectBase* rootObject)
+{
+ auto currentPipeline = getCurrentPipeline();
+ auto pipelineType = currentPipeline->desc.type;
+ if (currentPipeline->unspecializedPipelineState)
+ currentPipeline = currentPipeline->unspecializedPipelineState;
+ // If the currently bound pipeline is specializable, we need to specialize it based on bound shader objects.
+ if (currentPipeline->isSpecializable)
+ {
+ specializationArgs.clear();
+ SLANG_RETURN_ON_FAIL(rootObject->collectSpecializationArgs(specializationArgs));
+
+ // Construct a shader cache key that represents the specialized shader kernels.
+ PipelineKey pipelineKey;
+ pipelineKey.pipeline = currentPipeline;
+ pipelineKey.specializationArgs.addRange(specializationArgs.componentIDs);
+ pipelineKey.updateHash();
+
+ auto specializedPipelineState = shaderCache.getSpecializedPipelineState(pipelineKey);
+ // Try to find specialized pipeline from shader cache.
+ if (!specializedPipelineState)
+ {
+ auto unspecializedProgram = static_cast<ShaderProgramBase*>(pipelineType == PipelineType::Compute
+ ? currentPipeline->desc.compute.program
+ : currentPipeline->desc.graphics.program);
+ List<RefPtr<ShaderBinary>> entryPointBinaries;
+ auto unspecializedProgramLayout = unspecializedProgram->slangProgram->getLayout();
+ for (SlangUInt i = 0; i < unspecializedProgramLayout->getEntryPointCount(); i++)
+ {
+ auto unspecializedEntryPoint = unspecializedProgramLayout->getEntryPointByIndex(i);
+ UnownedStringSlice entryPointName = UnownedStringSlice(unspecializedEntryPoint->getName());
+ ComponentKey specializedKernelKey;
+ specializedKernelKey.typeName = entryPointName;
+ specializedKernelKey.specializationArgs.addRange(specializationArgs.componentIDs);
+ specializedKernelKey.updateHash();
+ // If the pipeline is not created, check if the kernel binaries has been compiled.
+ auto specializedKernelComponentID = shaderCache.getComponentId(specializedKernelKey);
+ RefPtr<ShaderBinary> binary = shaderCache.tryLoadShaderBinary(specializedKernelComponentID);
+ if (!binary)
+ {
+ // If the specialized shader binary does not exist in cache, use slang to generate it.
+ entryPointBinaries.clear();
+ ComPtr<slang::IComponentType> specializedComponentType;
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto result = unspecializedProgram->slangProgram->specialize(specializationArgs.components.getArrayView().getBuffer(),
+ specializationArgs.getCount(), specializedComponentType.writeRef(), diagnosticBlob.writeRef());
+
+ // TODO: print diagnostic message via debug output interface.
+
+ if (result != SLANG_OK)
+ return result;
+
+ // Cache specialized binaries.
+ auto programLayout = specializedComponentType->getLayout();
+ for (SlangUInt j = 0; j < programLayout->getEntryPointCount(); j++)
+ {
+ auto entryPointLayout = programLayout->getEntryPointByIndex(j);
+ ComPtr<slang::IBlob> entryPointCode;
+ SLANG_RETURN_ON_FAIL(specializedComponentType->getEntryPointCode(j, 0, entryPointCode.writeRef(), diagnosticBlob.writeRef()));
+ binary = new ShaderBinary();
+ binary->stage = gfx::translateStage(entryPointLayout->getStage());
+ binary->entryPointName = entryPointLayout->getName();
+ binary->source.addRange((uint8_t*)entryPointCode->getBufferPointer(), entryPointCode->getBufferSize());
+ entryPointBinaries.add(binary);
+ shaderCache.addShaderBinary(specializedKernelComponentID, binary);
+ }
+
+ // We have already obtained all kernel binaries from this program, so break out of the outer loop since we no longer
+ // need to examine the rest of the kernels.
+ break;
+ }
+ entryPointBinaries.add(binary);
+ }
+
+ // Now create specialized shader program using compiled binaries.
+ ComPtr<IShaderProgram> specializedProgram;
+ IShaderProgram::Desc specializedProgramDesc = {};
+ specializedProgramDesc.kernelCount = unspecializedProgramLayout->getEntryPointCount();
+ ShortList<IShaderProgram::KernelDesc> kernelDescs;
+ for (Slang::Index i = 0; i < entryPointBinaries.getCount(); i++)
+ {
+ auto entryPoint = unspecializedProgramLayout->getEntryPointByIndex(i);;
+ auto& kernelDesc = kernelDescs[i];
+ kernelDesc.stage = entryPointBinaries[i]->stage;
+ kernelDesc.entryPointName = entryPointBinaries[i]->entryPointName.getBuffer();
+ kernelDesc.codeBegin = entryPointBinaries[i]->source.begin();
+ kernelDesc.codeEnd = entryPointBinaries[i]->source.end();
+ }
+ specializedProgramDesc.kernels = kernelDescs.getArrayView().getBuffer();
+ specializedProgramDesc.pipelineType = pipelineType;
+ SLANG_RETURN_ON_FAIL(createProgram(specializedProgramDesc, specializedProgram.writeRef()));
+
+ // Create specialized pipeline state.
+ ComPtr<IPipelineState> specializedPipelineState;
+ switch (pipelineType)
+ {
+ case PipelineType::Compute:
+ {
+ auto pipelineDesc = currentPipeline->desc.compute;
+ pipelineDesc.program = specializedProgram;
+ SLANG_RETURN_ON_FAIL(createComputePipelineState(pipelineDesc, specializedPipelineState.writeRef()));
+ break;
+ }
+ case PipelineType::Graphics:
+ {
+ auto pipelineDesc = currentPipeline->desc.graphics;
+ pipelineDesc.program = specializedProgram;
+ SLANG_RETURN_ON_FAIL(createGraphicsPipelineState(pipelineDesc, specializedPipelineState.writeRef()));
+ break;
+ }
+ default:
+ break;
+ }
+ auto specializedPipelineStateBase = static_cast<PipelineStateBase*>(specializedPipelineState.get());
+ specializedPipelineStateBase->unspecializedPipelineState = currentPipeline;
+ shaderCache.addSpecializedPipeline(pipelineKey, specializedPipelineStateBase);
+ }
+ setPipelineState(specializedPipelineState);
+ }
+ return SLANG_OK;
+}
+
+
} // namespace gfx
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index c2924a7fd..c4a000e87 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -1,12 +1,33 @@
#pragma once
#include "tools/gfx/render.h"
-
-#include "core/slang-smart-pointer.h"
+#include "slang-context.h"
+#include "core/slang-basic.h"
namespace gfx
{
+struct GfxGUID
+{
+ static const Slang::Guid IID_ISlangUnknown;
+ static const Slang::Guid IID_IDescriptorSetLayout;
+ static const Slang::Guid IID_IDescriptorSet;
+ static const Slang::Guid IID_IShaderProgram;
+ static const Slang::Guid IID_IPipelineLayout;
+ static const Slang::Guid IID_IPipelineState;
+ static const Slang::Guid IID_IResourceView;
+ static const Slang::Guid IID_ISamplerState;
+ static const Slang::Guid IID_IResource;
+ static const Slang::Guid IID_IBufferResource;
+ static const Slang::Guid IID_ITextureResource;
+ static const Slang::Guid IID_IInputLayout;
+ static const Slang::Guid IID_IRenderer;
+ static const Slang::Guid IID_IShaderObjectLayout;
+ static const Slang::Guid IID_IShaderObject;
+};
+
+gfx::StageType translateStage(SlangStage slangStage);
+
class Resource : public Slang::RefObject
{
public:
@@ -70,4 +91,297 @@ protected:
Result createProgramFromSlang(IRenderer* renderer, IShaderProgram::Desc const& desc, IShaderProgram** outProgram);
+class RendererBase;
+
+typedef uint32_t ShaderComponentID;
+const ShaderComponentID kInvalidComponentID = 0xFFFFFFFF;
+
+struct ExtendedShaderObjectType
+{
+ slang::TypeReflection* slangType;
+ ShaderComponentID componentID;
+};
+
+struct ExtendedShaderObjectTypeList
+{
+ Slang::ShortList<ShaderComponentID, 16> componentIDs;
+ Slang::ShortList<slang::SpecializationArg, 16> components;
+ void add(const ExtendedShaderObjectType& component)
+ {
+ componentIDs.add(component.componentID);
+ components.add(slang::SpecializationArg{ slang::SpecializationArg::Kind::Type, component.slangType });
+ }
+ ExtendedShaderObjectType operator[](Slang::Index index) const
+ {
+ ExtendedShaderObjectType result;
+ result.componentID = componentIDs[index];
+ result.slangType = components[index].type;
+ return result;
+ }
+ void clear()
+ {
+ componentIDs.clear();
+ components.clear();
+ }
+ Slang::Index getCount()
+ {
+ return componentIDs.getCount();
+ }
+};
+
+class ShaderObjectLayoutBase : public IShaderObjectLayout, public Slang::RefObject
+{
+protected:
+ RendererBase* m_renderer;
+ slang::TypeLayoutReflection* m_elementTypeLayout = nullptr;
+ ShaderComponentID m_componentID = 0;
+
+public:
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+ IShaderObjectLayout* getInterface(const Slang::Guid& guid);
+
+ RendererBase* getRenderer() { return m_renderer; }
+
+ slang::TypeLayoutReflection* getElementTypeLayout()
+ {
+ return m_elementTypeLayout;
+ }
+
+ ShaderComponentID getComponentID()
+ {
+ return m_componentID;
+ }
+
+ void initBase(RendererBase* renderer, slang::TypeLayoutReflection* elementTypeLayout);
+};
+
+class ShaderObjectBase : public IShaderObject, public Slang::RefObject
+{
+protected:
+ // The shader object layout used to create this shader object.
+ Slang::RefPtr<ShaderObjectLayoutBase> m_layout = nullptr;
+
+ // Indicates whether all bindings have been finalized.
+ bool m_bindingFinalized = false;
+
+ // The specialized shader object type.
+ ExtendedShaderObjectType shaderObjectType = { nullptr, kInvalidComponentID };
+public:
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+ IShaderObject* getInterface(const Slang::Guid& guid);
+
+public:
+ ShaderComponentID getComponentID()
+ {
+ return shaderObjectType.componentID;
+ }
+
+ // Get the final type this shader object represents. If the shader object's type has existential fields,
+ // this function will return a specialized type using the bound sub-objects' type as specialization argument.
+ Result getSpecializedShaderObjectType(ExtendedShaderObjectType* outType);
+
+ RendererBase* getRenderer() { return m_layout->getRenderer(); }
+
+ SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() SLANG_OVERRIDE { return 0; }
+
+ SLANG_NO_THROW Result SLANG_MCALL getEntryPoint(UInt index, IShaderObject** outEntryPoint)
+ SLANG_OVERRIDE
+ {
+ *outEntryPoint = nullptr;
+ return SLANG_OK;
+ }
+
+ ShaderObjectLayoutBase* getLayout()
+ {
+ return m_layout;
+ }
+
+ SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL getElementTypeLayout() SLANG_OVERRIDE
+ {
+ return m_layout->getElementTypeLayout();
+ }
+
+ SLANG_NO_THROW Result SLANG_MCALL finalizeBindings() SLANG_OVERRIDE;
+
+ virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) = 0;
+};
+
+class ShaderProgramBase : public IShaderProgram, public Slang::RefObject
+{
+public:
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+
+ IShaderProgram* getInterface(const Slang::Guid& guid);
+
+ ComPtr<slang::IComponentType> slangProgram;
+};
+
+class PipelineStateBase : public IPipelineState, public Slang::RefObject
+{
+public:
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+
+ IPipelineState* getInterface(const Slang::Guid& guid);
+
+ struct PipelineStateDesc
+ {
+ PipelineType type;
+ GraphicsPipelineStateDesc graphics;
+ ComputePipelineStateDesc compute;
+ ShaderProgramBase* getProgram()
+ {
+ return static_cast<ShaderProgramBase*>(type == PipelineType::Compute ? compute.program : graphics.program);
+ }
+ } desc;
+
+ // The pipeline state from which this pipeline state is specialized.
+ // If null, this pipeline is either an unspecialized pipeline.
+ Slang::RefPtr<PipelineStateBase> unspecializedPipelineState = nullptr;
+
+ // Indicates whether this is a specializable pipeline. A specializable
+ // pipeline cannot be used directly and must be specialized first.
+ bool isSpecializable = false;
+
+protected:
+ void initializeBase(const PipelineStateDesc& inDesc);
+};
+
+class ShaderBinary : public Slang::RefObject
+{
+public:
+ Slang::List<uint8_t> source;
+ StageType stage;
+ Slang::String entryPointName;
+ Result loadFromBlob(ISlangBlob* blob);
+ Result writeToBlob(ISlangBlob** outBlob);
+};
+
+struct ComponentKey
+{
+ Slang::UnownedStringSlice typeName;
+ Slang::ShortList<ShaderComponentID> specializationArgs;
+ Slang::HashCode hash;
+ Slang::HashCode getHashCode()
+ {
+ return hash;
+ }
+ void updateHash()
+ {
+ hash = typeName.getHashCode();
+ for (auto& arg : specializationArgs)
+ hash = Slang::combineHash(hash, arg);
+ }
+};
+
+struct PipelineKey
+{
+ PipelineStateBase* pipeline;
+ Slang::ShortList<ShaderComponentID> specializationArgs;
+ Slang::HashCode hash;
+ Slang::HashCode getHashCode()
+ {
+ return hash;
+ }
+ void updateHash()
+ {
+ hash = Slang::getHashCode(pipeline);
+ for (auto& arg : specializationArgs)
+ hash = Slang::combineHash(hash, arg);
+ }
+ bool operator==(const PipelineKey& other)
+ {
+ if (pipeline != other.pipeline)
+ return false;
+ if (specializationArgs.getCount() != other.specializationArgs.getCount())
+ return false;
+ for (Slang::Index i = 0; i < other.specializationArgs.getCount(); i++)
+ {
+ if (specializationArgs[i] != other.specializationArgs[i])
+ return false;
+ }
+ return true;
+ }
+};
+
+struct OwningComponentKey
+{
+ Slang::String typeName;
+ Slang::ShortList<ShaderComponentID> specializationArgs;
+ Slang::HashCode hash;
+ Slang::HashCode getHashCode()
+ {
+ return hash;
+ }
+ template<typename KeyType>
+ bool operator==(const KeyType& other)
+ {
+ if (typeName != other.typeName)
+ return false;
+ if (specializationArgs.getCount() != other.specializationArgs.getCount())
+ return false;
+ for (Slang::Index i = 0; i < other.specializationArgs.getCount(); i++)
+ {
+ if (specializationArgs[i] != other.specializationArgs[i])
+ return false;
+ }
+ return true;
+ }
+};
+
+// A cache from specialization keys to a specialized `ShaderKernel`.
+class ShaderCache : public Slang::RefObject
+{
+public:
+ ShaderComponentID getComponentId(slang::TypeReflection* type);
+ ShaderComponentID getComponentId(Slang::UnownedStringSlice name);
+ ShaderComponentID getComponentId(ComponentKey key);
+
+ void init(ISlangFileSystem* cacheFileSystem);
+ void writeToFileSystem(ISlangMutableFileSystem* outputFileSystem);
+ Slang::RefPtr<PipelineStateBase> getSpecializedPipelineState(PipelineKey programKey)
+ {
+ Slang::RefPtr<PipelineStateBase> result;
+ if (specializedPipelines.TryGetValue(programKey, result))
+ return result;
+ return nullptr;
+ }
+ Slang::RefPtr<ShaderBinary> tryLoadShaderBinary(ShaderComponentID componentId);
+ void addShaderBinary(ShaderComponentID componentId, ShaderBinary* binary);
+ void addSpecializedPipeline(PipelineKey key, Slang::RefPtr<PipelineStateBase> specializedPipeline);
+protected:
+ Slang::ComPtr<ISlangFileSystem> fileSystem;
+ Slang::OrderedDictionary<OwningComponentKey, ShaderComponentID> componentIds;
+ Slang::OrderedDictionary<PipelineKey, Slang::RefPtr<PipelineStateBase>> specializedPipelines;
+ Slang::OrderedDictionary<ShaderComponentID, Slang::RefPtr<ShaderBinary>> shaderBinaries;
+};
+
+// Renderer implementation shared by all platforms.
+// Responsible for shader compilation, specialization and caching.
+class RendererBase : public Slang::RefObject, public IRenderer
+{
+public:
+ SLANG_REF_OBJECT_IUNKNOWN_ALL
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL getFeatures(
+ const char** outFeatures, UInt bufferSize, UInt* outFeatureCount) SLANG_OVERRIDE;
+ virtual SLANG_NO_THROW bool SLANG_MCALL hasFeature(const char* featureName) SLANG_OVERRIDE;
+ virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) SLANG_OVERRIDE;
+ IRenderer* getInterface(const Slang::Guid& guid);
+
+protected:
+ // Retrieves the currently bound unspecialized pipeline.
+ // If the bound pipeline is not created from a Slang component, an implementation should return null.
+ virtual PipelineStateBase* getCurrentPipeline() = 0;
+ ExtendedShaderObjectTypeList specializationArgs;
+ // Given current pipeline and root shader object binding, generate and bind a specialized pipeline if necessary.
+ Result maybeSpecializePipeline(ShaderObjectBase* inRootShaderObject);
+protected:
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc, void* inWindowHandle);
+protected:
+ Slang::List<Slang::String> m_features;
+public:
+ SlangContext slangContext;
+ ShaderCache shaderCache;
+};
+
}
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index 21034d167..565a8e96f 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -114,8 +114,7 @@ public:
setViewports(UInt count, Viewport const* viewports) override;
virtual SLANG_NO_THROW void SLANG_MCALL
setScissorRects(UInt count, ScissorRect const* rects) override;
- virtual SLANG_NO_THROW void SLANG_MCALL
- setPipelineState(PipelineType pipelineType, IPipelineState* state) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL setPipelineState(IPipelineState* state) override;
virtual SLANG_NO_THROW void SLANG_MCALL draw(UInt vertexCount, UInt startVertex) override;
virtual SLANG_NO_THROW void SLANG_MCALL
drawIndexed(UInt indexCount, UInt startIndex, UInt baseVertex) override;
@@ -126,7 +125,10 @@ public:
{
return RendererType::Vulkan;
}
-
+ virtual PipelineStateBase* getCurrentPipeline() override
+ {
+ return m_currentPipeline.Ptr();
+ }
/// Dtor
~VKRenderer();
@@ -515,17 +517,9 @@ public:
int m_offset;
};
- class PipelineStateImpl : public IPipelineState, public RefObject
+ class PipelineStateImpl : public PipelineStateBase
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
- IPipelineState* getInterface(const Guid& guid)
- {
- if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IPipelineState)
- return static_cast<IPipelineState*>(this);
- return nullptr;
- }
- public:
PipelineStateImpl(const VulkanApi& api):
m_api(&api)
{
@@ -538,6 +532,21 @@ public:
}
}
+ void init(const GraphicsPipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Graphics;
+ pipelineDesc.graphics = inDesc;
+ initializeBase(pipelineDesc);
+ }
+ void init(const ComputePipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::Compute;
+ pipelineDesc.compute = inDesc;
+ initializeBase(pipelineDesc);
+ }
+
const VulkanApi* m_api;
RefPtr<PipelineLayoutImpl> m_pipelineLayout;
@@ -913,10 +922,11 @@ void VKRenderer::_endRender()
m_deviceQueue.flush();
}
-Result SLANG_MCALL createVKRenderer(IRenderer** outRenderer)
+Result SLANG_MCALL createVKRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer)
{
- *outRenderer = new VKRenderer();
- (*outRenderer)->addRef();
+ RefPtr<VKRenderer> result = new VKRenderer();
+ SLANG_RETURN_ON_FAIL(result->initialize(*desc, windowHandle));
+ *outRenderer = result.detach();
return SLANG_OK;
}
@@ -1035,6 +1045,8 @@ SlangResult VKRenderer::initialize(const Desc& desc, void* inWindowHandle)
{
SLANG_RETURN_ON_FAIL(slangContext.initialize(desc.slang, SLANG_SPIRV, "sm_5_1"));
+ SLANG_RETURN_ON_FAIL(GraphicsAPIRenderer::initialize(desc, inWindowHandle));
+
SLANG_RETURN_ON_FAIL(m_module.init());
SLANG_RETURN_ON_FAIL(m_api.initGlobalProcs(m_module));
@@ -2274,9 +2286,9 @@ void VKRenderer::setScissorRects(UInt count, ScissorRect const* rects)
m_api.vkCmdSetScissor(commandBuffer, 0, uint32_t(count), vkRects);
}
-void VKRenderer::setPipelineState(PipelineType pipelineType, IPipelineState* state)
+void VKRenderer::setPipelineState(IPipelineState* state)
{
- m_currentPipeline = (PipelineStateImpl*)state;
+ m_currentPipeline = static_cast<PipelineStateImpl*>(state);
}
void VKRenderer::_flushBindingState(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipelineBindPoint)
@@ -2980,6 +2992,7 @@ Result VKRenderer::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
pipelineStateImpl->m_pipeline = pipeline;
pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl;
pipelineStateImpl->m_shaderProgram = programImpl;
+ pipelineStateImpl->init(desc);
*outState = pipelineStateImpl.detach();
return SLANG_OK;
}
@@ -3005,6 +3018,7 @@ Result VKRenderer::createComputePipelineState(const ComputePipelineStateDesc& in
pipelineStateImpl->m_pipeline = pipeline;
pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl;
pipelineStateImpl->m_shaderProgram = programImpl;
+ pipelineStateImpl->init(desc);
*outState = pipelineStateImpl.detach();
return SLANG_OK;
}
diff --git a/tools/gfx/vulkan/render-vk.h b/tools/gfx/vulkan/render-vk.h
index 7e086f6c0..f259ab44c 100644
--- a/tools/gfx/vulkan/render-vk.h
+++ b/tools/gfx/vulkan/render-vk.h
@@ -2,12 +2,10 @@
#pragma once
#include <cstdint>
-#include "slang.h"
+#include "../renderer-shared.h"
namespace gfx {
-class IRenderer;
-
-SlangResult SLANG_MCALL createVKRenderer(IRenderer** outRenderer);
+SlangResult SLANG_MCALL createVKRenderer(const IRenderer::Desc* desc, void* windowHandle, IRenderer** outRenderer);
} // gfx
diff --git a/tools/graphics-app-framework/gui.cpp b/tools/graphics-app-framework/gui.cpp
index fded0d76a..3bc365701 100644
--- a/tools/graphics-app-framework/gui.cpp
+++ b/tools/graphics-app-framework/gui.cpp
@@ -335,8 +335,7 @@ void GUI::endFrame()
renderer->setViewport(viewport);
- auto pipelineType = PipelineType::Graphics;
- renderer->setPipelineState(pipelineType, pipelineState);
+ renderer->setPipelineState(pipelineState);
renderer->setVertexBuffer(0, vertexBuffer, sizeof(ImDrawVert));
renderer->setIndexBuffer(indexBuffer, sizeof(ImDrawIdx) == 2 ? Format::R_UInt16 : Format::R_UInt32);
@@ -376,7 +375,7 @@ void GUI::endFrame()
samplerState);
renderer->setDescriptorSet(
- pipelineType,
+ PipelineType::Graphics,
pipelineLayout,
0,
descriptorSet);
diff --git a/tools/graphics-app-framework/windows/win-window.cpp b/tools/graphics-app-framework/windows/win-window.cpp
index 7433603cb..3bbf2575a 100644
--- a/tools/graphics-app-framework/windows/win-window.cpp
+++ b/tools/graphics-app-framework/windows/win-window.cpp
@@ -1,6 +1,8 @@
// win-window.cpp
#include "../window.h"
+#include "core/slang-smart-pointer.h"
+
#include <stdio.h>
#ifdef _MSC_VER
@@ -64,7 +66,7 @@ private:
}
};
-struct ApplicationContext
+struct ApplicationContext : public Slang::RefObject
{
HINSTANCE instance;
int showCommand = SW_SHOWDEFAULT;
diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp
index 27070154e..b9fd5c725 100644
--- a/tools/render-test/render-test-main.cpp
+++ b/tools/render-test/render-test-main.cpp
@@ -711,7 +711,7 @@ void RenderTestApp::renderFrame()
auto pipelineType = PipelineType::Graphics;
- m_renderer->setPipelineState(pipelineType, m_pipelineState);
+ m_renderer->setPipelineState(m_pipelineState);
m_renderer->setPrimitiveTopology(PrimitiveTopology::TriangleList);
m_renderer->setVertexBuffer(0, m_vertexBuffer, sizeof(Vertex));
@@ -724,7 +724,7 @@ void RenderTestApp::renderFrame()
void RenderTestApp::runCompute()
{
auto pipelineType = PipelineType::Compute;
- m_renderer->setPipelineState(pipelineType, m_pipelineState);
+ m_renderer->setPipelineState(m_pipelineState);
applyBinding(pipelineType);
m_startTicks = ProcessUtil::getClockTick();
@@ -1279,22 +1279,8 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi
Slang::ComPtr<IRenderer> renderer;
{
- SGRendererCreateFunc createFunc = gfxGetCreateFunc(options.rendererType);
- if (createFunc)
- {
- createFunc(renderer.writeRef());
- }
-
- if (!renderer)
- {
- if (!options.onlyStartup)
- {
- fprintf(stderr, "Unable to create renderer %s\n", rendererName.getBuffer());
- }
- return SLANG_FAIL;
- }
-
- IRenderer::Desc desc;
+ IRenderer::Desc desc = {};
+ desc.rendererType = options.rendererType;
desc.width = gWindowWidth;
desc.height = gWindowHeight;
desc.adapter = options.adapter.getBuffer();
@@ -1306,18 +1292,21 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi
desc.nvapiExtnSlot = int(nvapiExtnSlot);
desc.slang.slangGlobalSession = session;
window = renderer_test::Window::create();
- SLANG_RETURN_ON_FAIL(window->initialize(gWindowWidth, gWindowHeight));
+ void* windowHandle = nullptr;
+ if (window)
+ {
+ SLANG_RETURN_ON_FAIL(window->initialize(gWindowWidth, gWindowHeight));
+ windowHandle = window->getHandle();
+ }
+ gfxCreateRenderer(&desc, windowHandle, renderer.writeRef());
- SlangResult res = renderer->initialize(desc, window->getHandle());
- if (SLANG_FAILED(res))
+ if (!renderer)
{
- // Returns E_NOT_AVAILABLE only when specified features are not available.
- // Will cause to be ignored.
- if (!options.onlyStartup && res != SLANG_E_NOT_AVAILABLE)
+ if (!options.onlyStartup)
{
- fprintf(stderr, "Unable to initialize renderer %s\n", rendererName.getBuffer());
+ fprintf(stderr, "Unable to create renderer %s\n", rendererName.getBuffer());
}
- return res;
+ return SLANG_FAIL;
}
for (const auto& feature : requiredFeatureList)