summaryrefslogtreecommitdiffstats
path: root/tools/gfx-unit-test/link-time-type-multi-use.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx-unit-test/link-time-type-multi-use.cpp')
-rw-r--r--tools/gfx-unit-test/link-time-type-multi-use.cpp156
1 files changed, 156 insertions, 0 deletions
diff --git a/tools/gfx-unit-test/link-time-type-multi-use.cpp b/tools/gfx-unit-test/link-time-type-multi-use.cpp
new file mode 100644
index 000000000..4dc6d085e
--- /dev/null
+++ b/tools/gfx-unit-test/link-time-type-multi-use.cpp
@@ -0,0 +1,156 @@
+#include "core/slang-basic.h"
+#include "core/slang-blob.h"
+#include "gfx-test-util.h"
+#include "slang-rhi.h"
+#include "slang-rhi/shader-cursor.h"
+#include "unit-test/slang-unit-test.h"
+
+using namespace rhi;
+
+// Test that a type can be used to serve multiple link-time type requirements.
+
+namespace gfx_test
+{
+static Slang::Result loadProgram(
+ rhi::IDevice* device,
+ Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram,
+ slang::ProgramLayout*& slangReflection,
+ bool linkSpecialization = false)
+{
+ const char* moduleInterfaceSrc = R"(
+ interface IFoo { int getFoo(); }
+ interface IBar { int getBar(); }
+ struct SimpleImpl : IFoo, IBar
+ {
+ int getFoo() { return 10; }
+ int getBar() { return 20; }
+ }
+ )";
+ const char* module0Src = R"(
+ import ifoo;
+ extern struct Foo : IFoo;
+ extern struct Bar : IBar;
+ uniform Foo gFoo;
+ uniform Bar gBar;
+ [numthreads(1,1,1)]
+ void computeMain(uniform RWStructuredBuffer<int> buffer)
+ {
+ buffer[0] = gFoo.getFoo() + gBar.getBar();
+ }
+ )";
+ const char* module1Src = R"(
+ import ifoo;
+ export struct Foo : IFoo = SimpleImpl;
+ export struct Bar : IBar = SimpleImpl;
+ )";
+ Slang::ComPtr<slang::ISession> slangSession;
+ SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
+ Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ auto moduleInterfaceBlob =
+ Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc));
+ auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src));
+ auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src));
+ slang::IModule* moduleInterface =
+ slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob);
+ slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob);
+ slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob);
+ ComPtr<slang::IEntryPoint> computeEntryPoint;
+ SLANG_RETURN_ON_FAIL(
+ module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef()));
+
+ Slang::List<slang::IComponentType*> componentTypes;
+ componentTypes.add(moduleInterface);
+ componentTypes.add(module0);
+ if (linkSpecialization)
+ componentTypes.add(module1);
+ componentTypes.add(computeEntryPoint);
+
+ Slang::ComPtr<slang::IComponentType> composedProgram;
+ SlangResult result = slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ composedProgram.writeRef(),
+ diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ composedProgram = linkedProgram;
+ slangReflection = composedProgram->getLayout();
+
+ ShaderProgramDesc programDesc = {};
+ programDesc.slangGlobalScope = composedProgram.get();
+
+ auto shaderProgram = device->createShaderProgram(programDesc);
+
+ outShaderProgram = shaderProgram;
+ return SLANG_OK;
+}
+
+void linkTimeTypeMultiUseTestImpl(IDevice* device, UnitTestContext* context)
+{
+ // Create pipeline without both modules linked, specifying both Foo and Bar to be SimpleImpl.
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, true));
+
+ ComputePipelineDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ ComPtr<IComputePipeline> pipelineState;
+ GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef()));
+
+ const int numberCount = 4;
+ float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f};
+ BufferDesc bufferDesc = {};
+ bufferDesc.size = numberCount * sizeof(float);
+ bufferDesc.format = rhi::Format::Undefined;
+ bufferDesc.elementSize = sizeof(float);
+ bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess |
+ BufferUsage::CopyDestination | BufferUsage::CopySource;
+ bufferDesc.defaultState = ResourceState::UnorderedAccess;
+ bufferDesc.memoryType = MemoryType::DeviceLocal;
+
+ ComPtr<IBuffer> numbersBuffer;
+ GFX_CHECK_CALL_ABORT(
+ device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef()));
+
+ auto queue = device->getQueue(QueueType::Graphics);
+
+ // We have done all the set up work, now it is time to start recording a command buffer for
+ // GPU execution.
+ {
+ auto commandEncoder = queue->createCommandEncoder();
+ auto computePassEncoder = commandEncoder->beginComputePass();
+
+ auto rootObject = computePassEncoder->bindPipeline(pipelineState);
+
+ ShaderCursor entryPointCursor(
+ rootObject->getEntryPoint(0)); // get a cursor the the first entry-point.
+ // Bind buffer to the entry point.
+ entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer));
+
+ computePassEncoder->dispatchCompute(1, 1, 1);
+ computePassEncoder->end();
+ auto commandBuffer = commandEncoder->finish();
+ queue->submit(commandBuffer);
+ queue->waitOnHost();
+ }
+
+ compareComputeResult(device, numbersBuffer, std::array{30});
+}
+
+SLANG_UNIT_TEST(linkTimeTypeMultiUseD3D12)
+{
+ runTestImpl(linkTimeTypeMultiUseTestImpl, unitTestContext, DeviceType::D3D12);
+}
+
+SLANG_UNIT_TEST(linkTimeTypeMultiUseVulkan)
+{
+ runTestImpl(linkTimeTypeMultiUseTestImpl, unitTestContext, DeviceType::Vulkan);
+}
+
+} // namespace gfx_test