diff options
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/gfx-unit-test/link-time-type-generic.cpp | 229 | ||||
| -rw-r--r-- | tools/gfx-unit-test/link-time-type-multi-use-generic.cpp | 156 | ||||
| -rw-r--r-- | tools/gfx-unit-test/link-time-type-multi-use.cpp | 156 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-link-time-type-reflection.cpp | 90 |
4 files changed, 628 insertions, 3 deletions
diff --git a/tools/gfx-unit-test/link-time-type-generic.cpp b/tools/gfx-unit-test/link-time-type-generic.cpp new file mode 100644 index 000000000..ac1750518 --- /dev/null +++ b/tools/gfx-unit-test/link-time-type-generic.cpp @@ -0,0 +1,229 @@ +#include "core/slang-basic.h" +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "slang-rhi.h" +#include "slang-rhi/shader-cursor.h" +#include "unit-test/slang-unit-test.h" + +using namespace rhi; + +// Test that generic link time types conforming to a generic interface with generic +// methods/subscript members work correctly. +// Also test that global generic link-time functions works correctly. + +namespace gfx_test +{ +static Slang::Result loadProgram( + rhi::IDevice* device, + Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram, + slang::ProgramLayout*& slangReflection, + bool linkSpecialization = false) +{ + const char* moduleInterfaceSrc = R"( + interface ISimple { float getVal(); } + interface IHasProperty { property float val2{get;set;} } + interface IFoo<T:__BuiltinFloatingPointType> : IHasProperty + { + static const int offset; + [mutating] void setValue(float v); + + T getValue<U:ISimple>(U u); + + __subscript<U:__BuiltinIntegerType>(U index) -> T { get; } + } + struct FooImpl<T:__BuiltinFloatingPointType, int x> : IFoo<T> + { + T val; + static const int offset = x; + [mutating] void setValue(float v) { val = T(v); } + T getValue<U:ISimple>(U u){ return val + T(u.getVal()); } + property float val2 { + get { return __real_cast<float>(val) + 2.0; } + set { val = T(newValue); } + } + __subscript<U:__BuiltinIntegerType>(U index) -> T { get {return T(1.0); } } + }; + struct BarImpl<T:__BuiltinFloatingPointType, int x> : IFoo<T> + { + T val; + static const int offset = -x; + [mutating] void setValue(float v) { val = T(v); } + T getValue<U:ISimple>(U u){ return val - T(1.0); } + property float val2 { + get { return __real_cast<float>(val) + 2.0; } + set { val = T(newValue); } + } + __subscript<U:__BuiltinIntegerType>(U index) -> T { get {return T(2.0); } } + }; + )"; + const char* module0Src = R"( + import ifoo; + extern struct Foo<T:__BuiltinFloatingPointType, int i> : IFoo<T> = FooImpl<T, i+1>; + extern static const float c = 0.0; + extern int linkTimeFunc<int x>() { return x; } + struct SimpleImpl : ISimple + { + float getVal() { return 100.0; } + }; + + // Use an indirect generic function to retrieve val2, to make sure intermediate witness tables + // can be obtained correctly from link-time witnesses. + float getVal2<T:IHasProperty>(T t) { return t.val2; } + + [numthreads(1,1,1)] + void computeMain(uniform RWStructuredBuffer<float> buffer) + { + Foo<float, 0> foo; + foo.setValue(3.0); + buffer[0] = foo.getValue(SimpleImpl()) + getVal2(foo) + Foo<float, 0>.offset + c + foo[0] + linkTimeFunc<0>(); + } + )"; + const char* module1Src = R"( + import ifoo; + export struct Foo<T1:__BuiltinFloatingPointType, int i> : IFoo<T1> = BarImpl<T1, i+1>; + export static const float c = 1.0; + export int linkTimeFunc<int x>() { return x + 1; } + )"; + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + auto moduleInterfaceBlob = + Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc)); + auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src)); + auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src)); + slang::IModule* moduleInterface = + slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob); + slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob); + slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob); + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef())); + + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(moduleInterface); + componentTypes.add(module0); + if (linkSpecialization) + componentTypes.add(module1); + componentTypes.add(computeEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + ComPtr<slang::IComponentType> linkedProgram; + result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + ShaderProgramDesc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + + auto shaderProgram = device->createShaderProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +void linkTimeTypeGenericTestImpl(IDevice* device, UnitTestContext* context) +{ + // Create pipeline without linking a specialization override module, so we should + // see the default value of `extern Foo`. + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, false)); + + ComputePipelineDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<IComputePipeline> pipelineState; + GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef())); + + // Create pipeline with a specialization override module linked in, so we should + // see the result of using `BarImpl<T>` for `extern Foo<T>`. + ComPtr<IShaderProgram> shaderProgram1; + GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram1, slangReflection, true)); + + ComputePipelineDesc pipelineDesc1 = {}; + pipelineDesc1.program = shaderProgram1.get(); + ComPtr<IComputePipeline> pipelineState1; + GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc1, pipelineState1.writeRef())); + + const int numberCount = 4; + float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f}; + BufferDesc bufferDesc = {}; + bufferDesc.size = numberCount * sizeof(float); + bufferDesc.format = rhi::Format::Undefined; + bufferDesc.elementSize = sizeof(float); + bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess | + BufferUsage::CopyDestination | BufferUsage::CopySource; + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBuffer> numbersBuffer; + GFX_CHECK_CALL_ABORT( + device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef())); + + auto queue = device->getQueue(QueueType::Graphics); + + // We have done all the set up work, now it is time to start recording a command buffer for + // GPU execution. + { + auto commandEncoder = queue->createCommandEncoder(); + auto computePassEncoder = commandEncoder->beginComputePass(); + + auto rootObject = computePassEncoder->bindPipeline(pipelineState); + + ShaderCursor entryPointCursor( + rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer to the entry point. + entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer)); + + computePassEncoder->dispatchCompute(1, 1, 1); + computePassEncoder->end(); + auto commandBuffer = commandEncoder->finish(); + queue->submit(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult(device, numbersBuffer, std::array{110.0f}); + + // Now run again with the overrided program. + { + auto commandEncoder = queue->createCommandEncoder(); + auto computePassEncoder = commandEncoder->beginComputePass(); + + auto rootObject = computePassEncoder->bindPipeline(pipelineState1); + + ShaderCursor entryPointCursor( + rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer to the entry point. + entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer)); + + computePassEncoder->dispatchCompute(1, 1, 1); + computePassEncoder->end(); + auto commandBuffer = commandEncoder->finish(); + queue->submit(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult(device, numbersBuffer, std::array{10.0f}); +} + +SLANG_UNIT_TEST(linkTimeTypeGenericD3D12) +{ + runTestImpl(linkTimeTypeGenericTestImpl, unitTestContext, DeviceType::D3D12); +} + +SLANG_UNIT_TEST(linkTimeTypeGenerictVulkan) +{ + runTestImpl(linkTimeTypeGenericTestImpl, unitTestContext, DeviceType::Vulkan); +} + +} // namespace gfx_test diff --git a/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp b/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp new file mode 100644 index 000000000..c640389e5 --- /dev/null +++ b/tools/gfx-unit-test/link-time-type-multi-use-generic.cpp @@ -0,0 +1,156 @@ +#include "core/slang-basic.h" +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "slang-rhi.h" +#include "slang-rhi/shader-cursor.h" +#include "unit-test/slang-unit-test.h" + +using namespace rhi; + +// Test that a generic type can be used to serve multiple link-time type requirements. + +namespace gfx_test +{ +static Slang::Result loadProgram( + rhi::IDevice* device, + Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram, + slang::ProgramLayout*& slangReflection, + bool linkSpecialization = false) +{ + const char* moduleInterfaceSrc = R"( + interface IFoo { int getFoo(); } + interface IBar { int getBar(); } + struct SimpleImpl<int y> : IFoo, IBar + { + int getFoo() { return y; } + int getBar() { return y * 2; } + } + )"; + const char* module0Src = R"( + import ifoo; + extern struct Foo<int x> : IFoo; + extern struct Bar<int x> : IBar; + uniform Foo<10> gFoo; + uniform Bar<20> gBar; + [numthreads(1,1,1)] + void computeMain(uniform RWStructuredBuffer<int> buffer) + { + buffer[0] = gFoo.getFoo() + gBar.getBar(); + } + )"; + const char* module1Src = R"( + import ifoo; + export struct Foo<int x> : IFoo = SimpleImpl<x>; + export struct Bar<int x> : IBar = SimpleImpl<x>; + )"; + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + auto moduleInterfaceBlob = + Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc)); + auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src)); + auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src)); + slang::IModule* moduleInterface = + slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob); + slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob); + slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob); + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef())); + + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(moduleInterface); + componentTypes.add(module0); + if (linkSpecialization) + componentTypes.add(module1); + componentTypes.add(computeEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + ComPtr<slang::IComponentType> linkedProgram; + result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + ShaderProgramDesc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + + auto shaderProgram = device->createShaderProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +void linkTimeTypeMultiUseGenericTestImpl(IDevice* device, UnitTestContext* context) +{ + // Create pipeline without both modules linked, specifying both Foo and Bar to be SimpleImpl. + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, true)); + + ComputePipelineDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<IComputePipeline> pipelineState; + GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef())); + + const int numberCount = 4; + float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f}; + BufferDesc bufferDesc = {}; + bufferDesc.size = numberCount * sizeof(float); + bufferDesc.format = rhi::Format::Undefined; + bufferDesc.elementSize = sizeof(float); + bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess | + BufferUsage::CopyDestination | BufferUsage::CopySource; + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBuffer> numbersBuffer; + GFX_CHECK_CALL_ABORT( + device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef())); + + auto queue = device->getQueue(QueueType::Graphics); + + // We have done all the set up work, now it is time to start recording a command buffer for + // GPU execution. + { + auto commandEncoder = queue->createCommandEncoder(); + auto computePassEncoder = commandEncoder->beginComputePass(); + + auto rootObject = computePassEncoder->bindPipeline(pipelineState); + + ShaderCursor entryPointCursor( + rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer to the entry point. + entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer)); + + computePassEncoder->dispatchCompute(1, 1, 1); + computePassEncoder->end(); + auto commandBuffer = commandEncoder->finish(); + queue->submit(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult(device, numbersBuffer, std::array{50}); +} + +SLANG_UNIT_TEST(linkTimeTypeMultiUseGenericD3D12) +{ + runTestImpl(linkTimeTypeMultiUseGenericTestImpl, unitTestContext, DeviceType::D3D12); +} + +SLANG_UNIT_TEST(linkTimeTypeMultiUseGenericVulkan) +{ + runTestImpl(linkTimeTypeMultiUseGenericTestImpl, unitTestContext, DeviceType::Vulkan); +} + +} // namespace gfx_test diff --git a/tools/gfx-unit-test/link-time-type-multi-use.cpp b/tools/gfx-unit-test/link-time-type-multi-use.cpp new file mode 100644 index 000000000..4dc6d085e --- /dev/null +++ b/tools/gfx-unit-test/link-time-type-multi-use.cpp @@ -0,0 +1,156 @@ +#include "core/slang-basic.h" +#include "core/slang-blob.h" +#include "gfx-test-util.h" +#include "slang-rhi.h" +#include "slang-rhi/shader-cursor.h" +#include "unit-test/slang-unit-test.h" + +using namespace rhi; + +// Test that a type can be used to serve multiple link-time type requirements. + +namespace gfx_test +{ +static Slang::Result loadProgram( + rhi::IDevice* device, + Slang::ComPtr<rhi::IShaderProgram>& outShaderProgram, + slang::ProgramLayout*& slangReflection, + bool linkSpecialization = false) +{ + const char* moduleInterfaceSrc = R"( + interface IFoo { int getFoo(); } + interface IBar { int getBar(); } + struct SimpleImpl : IFoo, IBar + { + int getFoo() { return 10; } + int getBar() { return 20; } + } + )"; + const char* module0Src = R"( + import ifoo; + extern struct Foo : IFoo; + extern struct Bar : IBar; + uniform Foo gFoo; + uniform Bar gBar; + [numthreads(1,1,1)] + void computeMain(uniform RWStructuredBuffer<int> buffer) + { + buffer[0] = gFoo.getFoo() + gBar.getBar(); + } + )"; + const char* module1Src = R"( + import ifoo; + export struct Foo : IFoo = SimpleImpl; + export struct Bar : IBar = SimpleImpl; + )"; + Slang::ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + Slang::ComPtr<slang::IBlob> diagnosticsBlob; + auto moduleInterfaceBlob = + Slang::UnownedRawBlob::create(moduleInterfaceSrc, strlen(moduleInterfaceSrc)); + auto module0Blob = Slang::UnownedRawBlob::create(module0Src, strlen(module0Src)); + auto module1Blob = Slang::UnownedRawBlob::create(module1Src, strlen(module1Src)); + slang::IModule* moduleInterface = + slangSession->loadModuleFromSource("ifoo", "ifoo.slang", moduleInterfaceBlob); + slang::IModule* module0 = slangSession->loadModuleFromSource("module0", "path0", module0Blob); + slang::IModule* module1 = slangSession->loadModuleFromSource("module1", "path1", module1Blob); + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + module0->findEntryPointByName("computeMain", computeEntryPoint.writeRef())); + + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(moduleInterface); + componentTypes.add(module0); + if (linkSpecialization) + componentTypes.add(module1); + componentTypes.add(computeEntryPoint); + + Slang::ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + ComPtr<slang::IComponentType> linkedProgram; + result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + composedProgram = linkedProgram; + slangReflection = composedProgram->getLayout(); + + ShaderProgramDesc programDesc = {}; + programDesc.slangGlobalScope = composedProgram.get(); + + auto shaderProgram = device->createShaderProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +void linkTimeTypeMultiUseTestImpl(IDevice* device, UnitTestContext* context) +{ + // Create pipeline without both modules linked, specifying both Foo and Bar to be SimpleImpl. + ComPtr<IShaderProgram> shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadProgram(device, shaderProgram, slangReflection, true)); + + ComputePipelineDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<IComputePipeline> pipelineState; + GFX_CHECK_CALL_ABORT(device->createComputePipeline(pipelineDesc, pipelineState.writeRef())); + + const int numberCount = 4; + float initialData[] = {0.0f, 0.0f, 0.0f, 0.0f}; + BufferDesc bufferDesc = {}; + bufferDesc.size = numberCount * sizeof(float); + bufferDesc.format = rhi::Format::Undefined; + bufferDesc.elementSize = sizeof(float); + bufferDesc.usage = BufferUsage::ShaderResource | BufferUsage::UnorderedAccess | + BufferUsage::CopyDestination | BufferUsage::CopySource; + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr<IBuffer> numbersBuffer; + GFX_CHECK_CALL_ABORT( + device->createBuffer(bufferDesc, (void*)initialData, numbersBuffer.writeRef())); + + auto queue = device->getQueue(QueueType::Graphics); + + // We have done all the set up work, now it is time to start recording a command buffer for + // GPU execution. + { + auto commandEncoder = queue->createCommandEncoder(); + auto computePassEncoder = commandEncoder->beginComputePass(); + + auto rootObject = computePassEncoder->bindPipeline(pipelineState); + + ShaderCursor entryPointCursor( + rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer to the entry point. + entryPointCursor.getPath("buffer").setBinding(Binding(numbersBuffer)); + + computePassEncoder->dispatchCompute(1, 1, 1); + computePassEncoder->end(); + auto commandBuffer = commandEncoder->finish(); + queue->submit(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult(device, numbersBuffer, std::array{30}); +} + +SLANG_UNIT_TEST(linkTimeTypeMultiUseD3D12) +{ + runTestImpl(linkTimeTypeMultiUseTestImpl, unitTestContext, DeviceType::D3D12); +} + +SLANG_UNIT_TEST(linkTimeTypeMultiUseVulkan) +{ + runTestImpl(linkTimeTypeMultiUseTestImpl, unitTestContext, DeviceType::Vulkan); +} + +} // namespace gfx_test diff --git a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp index c42fd2f16..0bd580c84 100644 --- a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp +++ b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp @@ -27,16 +27,21 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) interface IMaterial { float4 load(); } extern struct Material : IMaterial; ConstantBuffer<Material> gMaterial; - + + interface IFoo { float getVal(); } + struct DefaultFoo : IFoo { float getVal() { return 0.0f; } } + extern struct Foo<T, int x> : IFoo = DefaultFoo; + RWTexture2D tex; extern static const int count; uniform uint4 buffers[count]; + uniform Foo<int4, 1> gFoo; [numthreads(1,1,1)] [shader("compute")] void computeMain() { - tex[uint2(0, 0)] = gMaterial.load(); + tex[uint2(0, 0)] = gMaterial.load() + gFoo.getVal(); } )"; @@ -65,7 +70,8 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) String configModuleSource = "import " + moduleName + ";\n" + R"( export struct Material : IMaterial = MyMaterial; export static const int count = 11; - + struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } } + export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>; struct MyMaterial : IMaterial { int data; Texture2D diffuse; @@ -110,6 +116,9 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) auto var2 = programLayout->getParameterByIndex(2); SLANG_CHECK(var2->getTypeLayout()->getSize() == 11 * 16); + auto var3 = programLayout->getParameterByIndex(3); + SLANG_CHECK(var3->getTypeLayout()->getSize() == 32); + ComPtr<slang::IBlob> codeBlob; linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef()); @@ -226,3 +235,78 @@ SLANG_UNIT_TEST(linkTimeConditionalReflection) SLANG_CHECK(spirvStr.indexOf(toSlice("Location 1")) != -1); SLANG_CHECK(spirvStr.indexOf(toSlice("Location 2")) == -1); } + +// Test that loading a module that defines an `export` type, but not linking with the module should +// not affect the type layout. + +SLANG_UNIT_TEST(linkTimeTypeReflectionWithLoadedButNotLinkedModule) +{ + // Source for a module that contains can be specialized with a link-time type. + const char* userSourceBody = R"( + interface IFoo { float getVal(); } + struct DefaultFoo : IFoo { float getVal() { return 0.0f; } } + extern struct Foo<T, int x> : IFoo = DefaultFoo; + + uniform Foo<int4, 1> gFoo; + RWTexture2D tex; + + [numthreads(1,1,1)] + [shader("compute")] + void computeMain() { + tex[uint2(0, 0)] = gFoo.getVal(); + } + )"; + + String moduleName = "linkTimeTypeReflection_Compute"; + + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_SPIRV_ASM; + targetDesc.profile = globalSession->findProfile("spirv_1_5"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + moduleName.getBuffer(), + (moduleName + ".slang").getBuffer(), + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(module != nullptr); + + // Source for a module that defines the link-time type, but we won't link with it. + String configModuleSource = "import " + moduleName + ";\n" + R"( + struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } } + export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>; + )"; + auto configModule = session->loadModuleFromSourceString( + "config", + "config.slang", + configModuleSource.getBuffer(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(configModule != nullptr); + + ComPtr<slang::IComponentType> linkedProgram; + module->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(linkedProgram != nullptr); + + auto programLayout = linkedProgram->getLayout(); + auto var0 = programLayout->getParameterByIndex(0); + + // Size of `gFoo` is 0, because the module that defines `Foo = FooImpl` is not linked. + // Therefore `gFoo`'s type is defaulted to `DefaultFoo`, which has no fields. + SLANG_CHECK(var0->getTypeLayout()->getSize() == 0); + + ComPtr<slang::IBlob> codeBlob; + linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK_ABORT(codeBlob.get()); + + auto spirvStr = UnownedStringSlice((const char*)codeBlob->getBufferPointer()); + + SLANG_CHECK(spirvStr.indexOf(toSlice("OpDecorate %tex Binding 0")) != -1); +} |
