summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
Diffstat (limited to 'tools')
-rw-r--r--tools/gfx-unit-test/link-time-type-generic.cpp229
-rw-r--r--tools/gfx-unit-test/link-time-type-multi-use-generic.cpp156
-rw-r--r--tools/gfx-unit-test/link-time-type-multi-use.cpp156
-rw-r--r--tools/slang-unit-test/unit-test-link-time-type-reflection.cpp90
4 files changed, 628 insertions, 3 deletions
diff --git a/tools/gfx-unit-test/link-time-type-generic.cpp b/tools/gfx-unit-test/link-time-type-generic.cpp
new file mode 100644
index 000000000..ac1750518
--- /dev/null
+++ b/tools/gfx-unit-test/link-time-type-generic.cpp
@@ -0,0 +1,229 @@
+#include "core/slang-basic.h"
+#include "core/slang-blob.h"
+#include "gfx-test-util.h"
+#include "slang-rhi.h"
+#include "slang-rhi/shader-cursor.h"
+#include "unit-test/slang-unit-test.h"
+
+using namespace rhi;
+
+// Test that generic link time types conforming to a generic interface with generic
+// methods/subscript members work correctly.
+// Also test that global generic link-time functions works correctly.
+
+namespace gfx_test
+{
+static Slang::Result loadProgram(
+ rhi::IDevice* device,
+ Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram,
+ slang::ProgramLayout*& slangReflection,
+ bool linkSpecialization = false)
+{
+ const char* moduleInterfaceSrc = R"(
+ interface ISimple { float getVal(); }
+ interface IHasProperty { property float val2{get;set;} }
+ interface IFoo<T:__BuiltinFloatingPointType> : IHasProperty
+ {
+ static const int offset;
+ [mutating] void setValue(float v);
+
+ T getValue<U:ISimple>(U u);
+
+ __subscript<U:__BuiltinIntegerType>(U index) -> T { get; }
+ }
+ struct FooImpl<T:__BuiltinFloatingPointType, int x> : IFoo<T>
+ {
+ T val;
+ static const int offset = x;
+ [mutating] void setValue(float v) { val = T(v); }
+ T getValue<U:ISimple>(U u){ return val + T(u.getVal()); }
+ property float val2 {
+ get { return __real_cast<float>(val) + 2.0; }
+ set { val = T(newValue); }
+ }
+ __subscript<U:__BuiltinIntegerType>(U index) -> T { get {return T(1.0); } }
+ };
+ struct BarImpl<T:__BuiltinFloatingPointType, int x> : IFoo<T>
+ {
+ T val;
+ static const int offset = -x;
+ [mutating] void setValue(float v) { val = T(v); }
+ T getValue<U:ISimple>(U u){ return val - T(1.0); }
+ property float val2 {
+ get { return __real_cast<float>(val) + 2.0; }
+ set { val = T(newValue); }
+ }
+ __subscript<U:__BuiltinIntegerType>(U index) -> T { get {return T(2.0); } }
+ };
+ )";
+ const char* module0Src = R"(
+ import ifoo;
+ extern struct Foo<T:__BuiltinFloatingPointType, int i> : IFoo<T> = FooImpl<T, i+1>;
+ extern static const float c = 0.0;
+ extern int linkTimeFunc<int x>() { return x; }
+ struct SimpleImpl : ISimple
+ {
+ float getVal() { return 100.0; }
+ };
+
+ // Use an indirect generic function to retrieve val2, to make sure intermediate witness tables
+ // can be obtained correctly from link-time witnesses.
+ float getVal2<T:IHasProperty>(T t) { return t.val2; }
+
+ [numthreads(1,1,1)]
+ void computeMain(uniform RWStructuredBuffer<float> buffer)
+ {
+ Foo<float, 0> foo;
+ foo.setValue(3.0);
+ buffer[0] = foo.getValue(SimpleImpl()) + getVal2(foo) + Foo<float, 0>.offset + c + foo[0] + linkTimeFunc<0>();
+ }
+ )";
+ const char* module1Src = R"(
+ import ifoo;
+ export struct Foo<T1:__BuiltinFloatingPointType, int i> : IFoo<T1> = BarImpl<T1, i+1>;
+ export static const float c = 1.0;
+ export int linkTimeFunc<int x>() { return x + 1; }
+ )";
+ Slang::ComPtr<slang::ISession> slangSession;
+ SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
+ Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ auto moduleInterfaceBlob =
+ Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc));
+ auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src));
+ auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src));
+ slang::IModule* moduleInterface =
+ slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob);
+ slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob);
+ slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob);
+ ComPtr<slang::IEntryPoint> computeEntryPoint;
+ SLANG_RETURN_ON_FAIL(
+ module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef()));
+
+ Slang::List<slang::IComponentType*> componentTypes;
+ componentTypes.add(moduleInterface);
+ componentTypes.add(module0);
+ if (linkSpecialization)
+ componentTypes.add(module1);
+ componentTypes.add(computeEntryPoint);
+
+ Slang::ComPtr<slang::IComponentType> composedProgram;
+ SlangResult result = slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ composedProgram.writeRef(),
+ diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ composedProgram = linkedProgram;
+ slangReflection = composedProgram->getLayout();
+
+ ShaderProgramDesc programDesc = {};
+ programDesc.slangGlobalScope = composedProgram.get();
+
+ auto shaderProgram = device->createShaderProgram(programDesc);
+
+ outShaderProgram = shaderProgram;
+ return SLANG_OK;
+}
+
+void linkTimeTypeGenericTestImpl(IDevice* device, UnitTestContext* context)
+{
+ // Create pipeline without linking a specialization override module, so we should
+ // see the default value of `extern Foo`.
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, false));
+
+ ComputePipelineDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ ComPtr<IComputePipeline> pipelineState;
+ GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef()));
+
+ // Create pipeline with a specialization override module linked in, so we should
+ // see the result of using `BarImpl<T>` for `extern Foo<T>`.
+ ComPtr<IShaderProgram> shaderProgram1;
+ GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram1, slangReflection, true));
+
+ ComputePipelineDesc pipelineDesc1 = {};
+ pipelineDesc1.program = shaderProgram1.get();
+ ComPtr<IComputePipeline> pipelineState1;
+ GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc1, pipelineState1.writeRef()));
+
+ const int numberCount = 4;
+ float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f};
+ BufferDesc bufferDesc = {};
+ bufferDesc.size = numberCount * sizeof(float);
+ bufferDesc.format = rhi::Format::Undefined;
+ bufferDesc.elementSize = sizeof(float);
+ bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess |
+ BufferUsage::CopyDestination | BufferUsage::CopySource;
+ bufferDesc.defaultState = ResourceState::UnorderedAccess;
+ bufferDesc.memoryType = MemoryType::DeviceLocal;
+
+ ComPtr<IBuffer> numbersBuffer;
+ GFX_CHECK_CALL_ABORT(
+ device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef()));
+
+ auto queue = device->getQueue(QueueType::Graphics);
+
+ // We have done all the set up work, now it is time to start recording a command buffer for
+ // GPU execution.
+ {
+ auto commandEncoder = queue->createCommandEncoder();
+ auto computePassEncoder = commandEncoder->beginComputePass();
+
+ auto rootObject = computePassEncoder->bindPipeline(pipelineState);
+
+ ShaderCursor entryPointCursor(
+ rootObject->getEntryPoint(0)); // get a cursor the the first entry-point.
+ // Bind buffer to the entry point.
+ entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer));
+
+ computePassEncoder->dispatchCompute(1, 1, 1);
+ computePassEncoder->end();
+ auto commandBuffer = commandEncoder->finish();
+ queue->submit(commandBuffer);
+ queue->waitOnHost();
+ }
+
+ compareComputeResult(device, numbersBuffer, std::array{110.0f});
+
+ // Now run again with the overrided program.
+ {
+ auto commandEncoder = queue->createCommandEncoder();
+ auto computePassEncoder = commandEncoder->beginComputePass();
+
+ auto rootObject = computePassEncoder->bindPipeline(pipelineState1);
+
+ ShaderCursor entryPointCursor(
+ rootObject->getEntryPoint(0)); // get a cursor the the first entry-point.
+ // Bind buffer to the entry point.
+ entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer));
+
+ computePassEncoder->dispatchCompute(1, 1, 1);
+ computePassEncoder->end();
+ auto commandBuffer = commandEncoder->finish();
+ queue->submit(commandBuffer);
+ queue->waitOnHost();
+ }
+
+ compareComputeResult(device, numbersBuffer, std::array{10.0f});
+}
+
+SLANG_UNIT_TEST(linkTimeTypeGenericD3D12)
+{
+ runTestImpl(linkTimeTypeGenericTestImpl, unitTestContext, DeviceType::D3D12);
+}
+
+SLANG_UNIT_TEST(linkTimeTypeGenerictVulkan)
+{
+ runTestImpl(linkTimeTypeGenericTestImpl, unitTestContext, DeviceType::Vulkan);
+}
+
+} // namespace gfx_test
diff --git a/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp b/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp
new file mode 100644
index 000000000..c640389e5
--- /dev/null
+++ b/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp
@@ -0,0 +1,156 @@
+#include "core/slang-basic.h"
+#include "core/slang-blob.h"
+#include "gfx-test-util.h"
+#include "slang-rhi.h"
+#include "slang-rhi/shader-cursor.h"
+#include "unit-test/slang-unit-test.h"
+
+using namespace rhi;
+
+// Test that a generic type can be used to serve multiple link-time type requirements.
+
+namespace gfx_test
+{
+static Slang::Result loadProgram(
+ rhi::IDevice* device,
+ Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram,
+ slang::ProgramLayout*& slangReflection,
+ bool linkSpecialization = false)
+{
+ const char* moduleInterfaceSrc = R"(
+ interface IFoo { int getFoo(); }
+ interface IBar { int getBar(); }
+ struct SimpleImpl<int y> : IFoo, IBar
+ {
+ int getFoo() { return y; }
+ int getBar() { return y * 2; }
+ }
+ )";
+ const char* module0Src = R"(
+ import ifoo;
+ extern struct Foo<int x> : IFoo;
+ extern struct Bar<int x> : IBar;
+ uniform Foo<10> gFoo;
+ uniform Bar<20> gBar;
+ [numthreads(1,1,1)]
+ void computeMain(uniform RWStructuredBuffer<int> buffer)
+ {
+ buffer[0] = gFoo.getFoo() + gBar.getBar();
+ }
+ )";
+ const char* module1Src = R"(
+ import ifoo;
+ export struct Foo<int x> : IFoo = SimpleImpl<x>;
+ export struct Bar<int x> : IBar = SimpleImpl<x>;
+ )";
+ Slang::ComPtr<slang::ISession> slangSession;
+ SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
+ Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ auto moduleInterfaceBlob =
+ Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc));
+ auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src));
+ auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src));
+ slang::IModule* moduleInterface =
+ slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob);
+ slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob);
+ slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob);
+ ComPtr<slang::IEntryPoint> computeEntryPoint;
+ SLANG_RETURN_ON_FAIL(
+ module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef()));
+
+ Slang::List<slang::IComponentType*> componentTypes;
+ componentTypes.add(moduleInterface);
+ componentTypes.add(module0);
+ if (linkSpecialization)
+ componentTypes.add(module1);
+ componentTypes.add(computeEntryPoint);
+
+ Slang::ComPtr<slang::IComponentType> composedProgram;
+ SlangResult result = slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ composedProgram.writeRef(),
+ diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ composedProgram = linkedProgram;
+ slangReflection = composedProgram->getLayout();
+
+ ShaderProgramDesc programDesc = {};
+ programDesc.slangGlobalScope = composedProgram.get();
+
+ auto shaderProgram = device->createShaderProgram(programDesc);
+
+ outShaderProgram = shaderProgram;
+ return SLANG_OK;
+}
+
+void linkTimeTypeMultiUseGenericTestImpl(IDevice* device, UnitTestContext* context)
+{
+ // Create pipeline without both modules linked, specifying both Foo and Bar to be SimpleImpl.
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, true));
+
+ ComputePipelineDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ ComPtr<IComputePipeline> pipelineState;
+ GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef()));
+
+ const int numberCount = 4;
+ float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f};
+ BufferDesc bufferDesc = {};
+ bufferDesc.size = numberCount * sizeof(float);
+ bufferDesc.format = rhi::Format::Undefined;
+ bufferDesc.elementSize = sizeof(float);
+ bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess |
+ BufferUsage::CopyDestination | BufferUsage::CopySource;
+ bufferDesc.defaultState = ResourceState::UnorderedAccess;
+ bufferDesc.memoryType = MemoryType::DeviceLocal;
+
+ ComPtr<IBuffer> numbersBuffer;
+ GFX_CHECK_CALL_ABORT(
+ device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef()));
+
+ auto queue = device->getQueue(QueueType::Graphics);
+
+ // We have done all the set up work, now it is time to start recording a command buffer for
+ // GPU execution.
+ {
+ auto commandEncoder = queue->createCommandEncoder();
+ auto computePassEncoder = commandEncoder->beginComputePass();
+
+ auto rootObject = computePassEncoder->bindPipeline(pipelineState);
+
+ ShaderCursor entryPointCursor(
+ rootObject->getEntryPoint(0)); // get a cursor the the first entry-point.
+ // Bind buffer to the entry point.
+ entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer));
+
+ computePassEncoder->dispatchCompute(1, 1, 1);
+ computePassEncoder->end();
+ auto commandBuffer = commandEncoder->finish();
+ queue->submit(commandBuffer);
+ queue->waitOnHost();
+ }
+
+ compareComputeResult(device, numbersBuffer, std::array{50});
+}
+
+SLANG_UNIT_TEST(linkTimeTypeMultiUseGenericD3D12)
+{
+ runTestImpl(linkTimeTypeMultiUseGenericTestImpl, unitTestContext, DeviceType::D3D12);
+}
+
+SLANG_UNIT_TEST(linkTimeTypeMultiUseGenericVulkan)
+{
+ runTestImpl(linkTimeTypeMultiUseGenericTestImpl, unitTestContext, DeviceType::Vulkan);
+}
+
+} // namespace gfx_test
diff --git a/tools/gfx-unit-test/link-time-type-multi-use.cpp b/tools/gfx-unit-test/link-time-type-multi-use.cpp
new file mode 100644
index 000000000..4dc6d085e
--- /dev/null
+++ b/tools/gfx-unit-test/link-time-type-multi-use.cpp
@@ -0,0 +1,156 @@
+#include "core/slang-basic.h"
+#include "core/slang-blob.h"
+#include "gfx-test-util.h"
+#include "slang-rhi.h"
+#include "slang-rhi/shader-cursor.h"
+#include "unit-test/slang-unit-test.h"
+
+using namespace rhi;
+
+// Test that a type can be used to serve multiple link-time type requirements.
+
+namespace gfx_test
+{
+static Slang::Result loadProgram(
+ rhi::IDevice* device,
+ Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram,
+ slang::ProgramLayout*& slangReflection,
+ bool linkSpecialization = false)
+{
+ const char* moduleInterfaceSrc = R"(
+ interface IFoo { int getFoo(); }
+ interface IBar { int getBar(); }
+ struct SimpleImpl : IFoo, IBar
+ {
+ int getFoo() { return 10; }
+ int getBar() { return 20; }
+ }
+ )";
+ const char* module0Src = R"(
+ import ifoo;
+ extern struct Foo : IFoo;
+ extern struct Bar : IBar;
+ uniform Foo gFoo;
+ uniform Bar gBar;
+ [numthreads(1,1,1)]
+ void computeMain(uniform RWStructuredBuffer<int> buffer)
+ {
+ buffer[0] = gFoo.getFoo() + gBar.getBar();
+ }
+ )";
+ const char* module1Src = R"(
+ import ifoo;
+ export struct Foo : IFoo = SimpleImpl;
+ export struct Bar : IBar = SimpleImpl;
+ )";
+ Slang::ComPtr<slang::ISession> slangSession;
+ SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
+ Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ auto moduleInterfaceBlob =
+ Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc));
+ auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src));
+ auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src));
+ slang::IModule* moduleInterface =
+ slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob);
+ slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob);
+ slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob);
+ ComPtr<slang::IEntryPoint> computeEntryPoint;
+ SLANG_RETURN_ON_FAIL(
+ module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef()));
+
+ Slang::List<slang::IComponentType*> componentTypes;
+ componentTypes.add(moduleInterface);
+ componentTypes.add(module0);
+ if (linkSpecialization)
+ componentTypes.add(module1);
+ componentTypes.add(computeEntryPoint);
+
+ Slang::ComPtr<slang::IComponentType> composedProgram;
+ SlangResult result = slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ composedProgram.writeRef(),
+ diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ composedProgram = linkedProgram;
+ slangReflection = composedProgram->getLayout();
+
+ ShaderProgramDesc programDesc = {};
+ programDesc.slangGlobalScope = composedProgram.get();
+
+ auto shaderProgram = device->createShaderProgram(programDesc);
+
+ outShaderProgram = shaderProgram;
+ return SLANG_OK;
+}
+
+void linkTimeTypeMultiUseTestImpl(IDevice* device, UnitTestContext* context)
+{
+ // Create pipeline without both modules linked, specifying both Foo and Bar to be SimpleImpl.
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, true));
+
+ ComputePipelineDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ ComPtr<IComputePipeline> pipelineState;
+ GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef()));
+
+ const int numberCount = 4;
+ float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f};
+ BufferDesc bufferDesc = {};
+ bufferDesc.size = numberCount * sizeof(float);
+ bufferDesc.format = rhi::Format::Undefined;
+ bufferDesc.elementSize = sizeof(float);
+ bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess |
+ BufferUsage::CopyDestination | BufferUsage::CopySource;
+ bufferDesc.defaultState = ResourceState::UnorderedAccess;
+ bufferDesc.memoryType = MemoryType::DeviceLocal;
+
+ ComPtr<IBuffer> numbersBuffer;
+ GFX_CHECK_CALL_ABORT(
+ device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef()));
+
+ auto queue = device->getQueue(QueueType::Graphics);
+
+ // We have done all the set up work, now it is time to start recording a command buffer for
+ // GPU execution.
+ {
+ auto commandEncoder = queue->createCommandEncoder();
+ auto computePassEncoder = commandEncoder->beginComputePass();
+
+ auto rootObject = computePassEncoder->bindPipeline(pipelineState);
+
+ ShaderCursor entryPointCursor(
+ rootObject->getEntryPoint(0)); // get a cursor the the first entry-point.
+ // Bind buffer to the entry point.
+ entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer));
+
+ computePassEncoder->dispatchCompute(1, 1, 1);
+ computePassEncoder->end();
+ auto commandBuffer = commandEncoder->finish();
+ queue->submit(commandBuffer);
+ queue->waitOnHost();
+ }
+
+ compareComputeResult(device, numbersBuffer, std::array{30});
+}
+
+SLANG_UNIT_TEST(linkTimeTypeMultiUseD3D12)
+{
+ runTestImpl(linkTimeTypeMultiUseTestImpl, unitTestContext, DeviceType::D3D12);
+}
+
+SLANG_UNIT_TEST(linkTimeTypeMultiUseVulkan)
+{
+ runTestImpl(linkTimeTypeMultiUseTestImpl, unitTestContext, DeviceType::Vulkan);
+}
+
+} // namespace gfx_test
diff --git a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp
index c42fd2f16..0bd580c84 100644
--- a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp
@@ -27,16 +27,21 @@ SLANG_UNIT_TEST(linkTimeTypeReflection)
interface IMaterial { float4 load(); }
extern struct Material : IMaterial;
ConstantBuffer<Material> gMaterial;
-
+
+ interface IFoo { float getVal(); }
+ struct DefaultFoo : IFoo { float getVal() { return 0.0f; } }
+ extern struct Foo<T, int x> : IFoo = DefaultFoo;
+
RWTexture2D tex;
extern static const int count;
uniform uint4 buffers[count];
+ uniform Foo<int4, 1> gFoo;
[numthreads(1,1,1)]
[shader("compute")]
void computeMain() {
- tex[uint2(0, 0)] = gMaterial.load();
+ tex[uint2(0, 0)] = gMaterial.load() + gFoo.getVal();
}
)";
@@ -65,7 +70,8 @@ SLANG_UNIT_TEST(linkTimeTypeReflection)
String configModuleSource = "import " + moduleName + ";\n" + R"(
export struct Material : IMaterial = MyMaterial;
export static const int count = 11;
-
+ struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } }
+ export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>;
struct MyMaterial : IMaterial {
int data;
Texture2D diffuse;
@@ -110,6 +116,9 @@ SLANG_UNIT_TEST(linkTimeTypeReflection)
auto var2 = programLayout->getParameterByIndex(2);
SLANG_CHECK(var2->getTypeLayout()->getSize() == 11 * 16);
+ auto var3 = programLayout->getParameterByIndex(3);
+ SLANG_CHECK(var3->getTypeLayout()->getSize() == 32);
+
ComPtr<slang::IBlob> codeBlob;
linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef());
@@ -226,3 +235,78 @@ SLANG_UNIT_TEST(linkTimeConditionalReflection)
SLANG_CHECK(spirvStr.indexOf(toSlice("Location 1")) != -1);
SLANG_CHECK(spirvStr.indexOf(toSlice("Location 2")) == -1);
}
+
+// Test that loading a module that defines an `export` type, but not linking with the module should
+// not affect the type layout.
+
+SLANG_UNIT_TEST(linkTimeTypeReflectionWithLoadedButNotLinkedModule)
+{
+ // Source for a module that contains can be specialized with a link-time type.
+ const char* userSourceBody = R"(
+ interface IFoo { float getVal(); }
+ struct DefaultFoo : IFoo { float getVal() { return 0.0f; } }
+ extern struct Foo<T, int x> : IFoo = DefaultFoo;
+
+ uniform Foo<int4, 1> gFoo;
+ RWTexture2D tex;
+
+ [numthreads(1,1,1)]
+ [shader("compute")]
+ void computeMain() {
+ tex[uint2(0, 0)] = gFoo.getVal();
+ }
+ )";
+
+ String moduleName = "linkTimeTypeReflection_Compute";
+
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_SPIRV_ASM;
+ targetDesc.profile = globalSession->findProfile("spirv_1_5");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ moduleName.getBuffer(),
+ (moduleName + ".slang").getBuffer(),
+ userSourceBody,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(module != nullptr);
+
+ // Source for a module that defines the link-time type, but we won't link with it.
+ String configModuleSource = "import " + moduleName + ";\n" + R"(
+ struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } }
+ export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>;
+ )";
+ auto configModule = session->loadModuleFromSourceString(
+ "config",
+ "config.slang",
+ configModuleSource.getBuffer(),
+ diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(configModule != nullptr);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ module->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(linkedProgram != nullptr);
+
+ auto programLayout = linkedProgram->getLayout();
+ auto var0 = programLayout->getParameterByIndex(0);
+
+ // Size of `gFoo` is 0, because the module that defines `Foo = FooImpl` is not linked.
+ // Therefore `gFoo`'s type is defaulted to `DefaultFoo`, which has no fields.
+ SLANG_CHECK(var0->getTypeLayout()->getSize() == 0);
+
+ ComPtr<slang::IBlob> codeBlob;
+ linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef());
+
+ SLANG_CHECK_ABORT(codeBlob.get());
+
+ auto spirvStr = UnownedStringSlice((const char*)codeBlob->getBufferPointer());
+
+ SLANG_CHECK(spirvStr.indexOf(toSlice("OpDecorate %tex Binding 0")) != -1);
+}