diff options
Diffstat (limited to 'tools/slang-unit-test/unit-test-link-time-type-reflection.cpp')
| -rw-r--r-- | tools/slang-unit-test/unit-test-link-time-type-reflection.cpp | 90 |
1 files changed, 87 insertions, 3 deletions
diff --git a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp index c42fd2f16..0bd580c84 100644 --- a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp +++ b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp @@ -27,16 +27,21 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) interface IMaterial { float4 load(); } extern struct Material : IMaterial; ConstantBuffer<Material> gMaterial; - + + interface IFoo { float getVal(); } + struct DefaultFoo : IFoo { float getVal() { return 0.0f; } } + extern struct Foo<T, int x> : IFoo = DefaultFoo; + RWTexture2D tex; extern static const int count; uniform uint4 buffers[count]; + uniform Foo<int4, 1> gFoo; [numthreads(1,1,1)] [shader("compute")] void computeMain() { - tex[uint2(0, 0)] = gMaterial.load(); + tex[uint2(0, 0)] = gMaterial.load() + gFoo.getVal(); } )"; @@ -65,7 +70,8 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) String configModuleSource = "import " + moduleName + ";\n" + R"( export struct Material : IMaterial = MyMaterial; export static const int count = 11; - + struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } } + export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>; struct MyMaterial : IMaterial { int data; Texture2D diffuse; @@ -110,6 +116,9 @@ SLANG_UNIT_TEST(linkTimeTypeReflection) auto var2 = programLayout->getParameterByIndex(2); SLANG_CHECK(var2->getTypeLayout()->getSize() == 11 * 16); + auto var3 = programLayout->getParameterByIndex(3); + SLANG_CHECK(var3->getTypeLayout()->getSize() == 32); + ComPtr<slang::IBlob> codeBlob; linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef()); @@ -226,3 +235,78 @@ SLANG_UNIT_TEST(linkTimeConditionalReflection) SLANG_CHECK(spirvStr.indexOf(toSlice("Location 1")) != -1); SLANG_CHECK(spirvStr.indexOf(toSlice("Location 2")) == -1); } + +// Test that loading a module that defines an `export` type, but not linking with the module should +// not affect the type layout. + +SLANG_UNIT_TEST(linkTimeTypeReflectionWithLoadedButNotLinkedModule) +{ + // Source for a module that contains can be specialized with a link-time type. + const char* userSourceBody = R"( + interface IFoo { float getVal(); } + struct DefaultFoo : IFoo { float getVal() { return 0.0f; } } + extern struct Foo<T, int x> : IFoo = DefaultFoo; + + uniform Foo<int4, 1> gFoo; + RWTexture2D tex; + + [numthreads(1,1,1)] + [shader("compute")] + void computeMain() { + tex[uint2(0, 0)] = gFoo.getVal(); + } + )"; + + String moduleName = "linkTimeTypeReflection_Compute"; + + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_SPIRV_ASM; + targetDesc.profile = globalSession->findProfile("spirv_1_5"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + moduleName.getBuffer(), + (moduleName + ".slang").getBuffer(), + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(module != nullptr); + + // Source for a module that defines the link-time type, but we won't link with it. + String configModuleSource = "import " + moduleName + ";\n" + R"( + struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } } + export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>; + )"; + auto configModule = session->loadModuleFromSourceString( + "config", + "config.slang", + configModuleSource.getBuffer(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(configModule != nullptr); + + ComPtr<slang::IComponentType> linkedProgram; + module->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(linkedProgram != nullptr); + + auto programLayout = linkedProgram->getLayout(); + auto var0 = programLayout->getParameterByIndex(0); + + // Size of `gFoo` is 0, because the module that defines `Foo = FooImpl` is not linked. + // Therefore `gFoo`'s type is defaulted to `DefaultFoo`, which has no fields. + SLANG_CHECK(var0->getTypeLayout()->getSize() == 0); + + ComPtr<slang::IBlob> codeBlob; + linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK_ABORT(codeBlob.get()); + + auto spirvStr = UnownedStringSlice((const char*)codeBlob->getBufferPointer()); + + SLANG_CHECK(spirvStr.indexOf(toSlice("OpDecorate %tex Binding 0")) != -1); +} |
