summaryrefslogtreecommitdiff
path: root/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/slang-unit-test/unit-test-link-time-type-reflection.cpp')
-rw-r--r--tools/slang-unit-test/unit-test-link-time-type-reflection.cpp90
1 files changed, 87 insertions, 3 deletions
diff --git a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp
index c42fd2f16..0bd580c84 100644
--- a/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-link-time-type-reflection.cpp
@@ -27,16 +27,21 @@ SLANG_UNIT_TEST(linkTimeTypeReflection)
interface IMaterial { float4 load(); }
extern struct Material : IMaterial;
ConstantBuffer<Material> gMaterial;
-
+
+ interface IFoo { float getVal(); }
+ struct DefaultFoo : IFoo { float getVal() { return 0.0f; } }
+ extern struct Foo<T, int x> : IFoo = DefaultFoo;
+
RWTexture2D tex;
extern static const int count;
uniform uint4 buffers[count];
+ uniform Foo<int4, 1> gFoo;
[numthreads(1,1,1)]
[shader("compute")]
void computeMain() {
- tex[uint2(0, 0)] = gMaterial.load();
+ tex[uint2(0, 0)] = gMaterial.load() + gFoo.getVal();
}
)";
@@ -65,7 +70,8 @@ SLANG_UNIT_TEST(linkTimeTypeReflection)
String configModuleSource = "import " + moduleName + ";\n" + R"(
export struct Material : IMaterial = MyMaterial;
export static const int count = 11;
-
+ struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } }
+ export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>;
struct MyMaterial : IMaterial {
int data;
Texture2D diffuse;
@@ -110,6 +116,9 @@ SLANG_UNIT_TEST(linkTimeTypeReflection)
auto var2 = programLayout->getParameterByIndex(2);
SLANG_CHECK(var2->getTypeLayout()->getSize() == 11 * 16);
+ auto var3 = programLayout->getParameterByIndex(3);
+ SLANG_CHECK(var3->getTypeLayout()->getSize() == 32);
+
ComPtr<slang::IBlob> codeBlob;
linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef());
@@ -226,3 +235,78 @@ SLANG_UNIT_TEST(linkTimeConditionalReflection)
SLANG_CHECK(spirvStr.indexOf(toSlice("Location 1")) != -1);
SLANG_CHECK(spirvStr.indexOf(toSlice("Location 2")) == -1);
}
+
+// Test that loading a module that defines an `export` type, but not linking with the module should
+// not affect the type layout.
+
+SLANG_UNIT_TEST(linkTimeTypeReflectionWithLoadedButNotLinkedModule)
+{
+ // Source for a module that contains can be specialized with a link-time type.
+ const char* userSourceBody = R"(
+ interface IFoo { float getVal(); }
+ struct DefaultFoo : IFoo { float getVal() { return 0.0f; } }
+ extern struct Foo<T, int x> : IFoo = DefaultFoo;
+
+ uniform Foo<int4, 1> gFoo;
+ RWTexture2D tex;
+
+ [numthreads(1,1,1)]
+ [shader("compute")]
+ void computeMain() {
+ tex[uint2(0, 0)] = gFoo.getVal();
+ }
+ )";
+
+ String moduleName = "linkTimeTypeReflection_Compute";
+
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_SPIRV_ASM;
+ targetDesc.profile = globalSession->findProfile("spirv_1_5");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ moduleName.getBuffer(),
+ (moduleName + ".slang").getBuffer(),
+ userSourceBody,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(module != nullptr);
+
+ // Source for a module that defines the link-time type, but we won't link with it.
+ String configModuleSource = "import " + moduleName + ";\n" + R"(
+ struct FooImpl<T, int x> : IFoo { T vals[x]; float getVal() { return x; } }
+ export struct Foo<T, int x> : IFoo = FooImpl<T, x + 1>;
+ )";
+ auto configModule = session->loadModuleFromSourceString(
+ "config",
+ "config.slang",
+ configModuleSource.getBuffer(),
+ diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(configModule != nullptr);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ module->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(linkedProgram != nullptr);
+
+ auto programLayout = linkedProgram->getLayout();
+ auto var0 = programLayout->getParameterByIndex(0);
+
+ // Size of `gFoo` is 0, because the module that defines `Foo = FooImpl` is not linked.
+ // Therefore `gFoo`'s type is defaulted to `DefaultFoo`, which has no fields.
+ SLANG_CHECK(var0->getTypeLayout()->getSize() == 0);
+
+ ComPtr<slang::IBlob> codeBlob;
+ linkedProgram->getTargetCode(0, codeBlob.writeRef(), diagnosticBlob.writeRef());
+
+ SLANG_CHECK_ABORT(codeBlob.get());
+
+ auto spirvStr = UnownedStringSlice((const char*)codeBlob->getBufferPointer());
+
+ SLANG_CHECK(spirvStr.indexOf(toSlice("OpDecorate %tex Binding 0")) != -1);
+}