summaryrefslogtreecommitdiffstats
path: root/tools/slang-unit-test/unit-test-function-reflection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/slang-unit-test/unit-test-function-reflection.cpp')
-rw-r--r--tools/slang-unit-test/unit-test-function-reflection.cpp109
1 files changed, 109 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp
index 2f57d4151..7e1d87981 100644
--- a/tools/slang-unit-test/unit-test-function-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-function-reflection.cpp
@@ -217,3 +217,112 @@ SLANG_UNIT_TEST(functionReflection)
auto ctor = module->getLayout()->findFunctionByNameInType(fooType, "$init");
SLANG_CHECK(ctor != nullptr);
}
+
+// Test that findFunctionByNameInType finds all functions with the same name but different
+// signatures
+SLANG_UNIT_TEST(findFunctionByNameInType)
+{
+ // Test shader with extensions that have functions with same name but different signatures
+ const char* userSourceBody = R"(
+ public interface IModel<float:IDifferentiable>
+ {
+ public float forward(float x);
+ }
+
+ interface IScalarActivation<float:IDifferentiable> {}
+
+ public extension<float:IDifferentiable, Act:IScalarActivation<float>> Act: IModel<float>
+ {
+ public float forward(float x) { return x;}
+ }
+
+ public struct MyStruct<float:IDifferentiable>: IScalarActivation<float> {}
+
+ public extension<float:IDifferentiable> MyStruct<float>: IModel<float[2]>
+ {
+ public float[2] forward(float[2] x) { return x;}
+ }
+
+ [shader("compute")]
+ void computeMain(uint3 tid: SV_DispatchThreadID)
+ {
+ }
+ )";
+
+ auto moduleName = "moduleH" + String(Process::getId());
+ String userSource = "import " + moduleName + ";\n" + userSourceBody;
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_HLSL;
+ targetDesc.profile = globalSession->findProfile("sm_5_0");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ "test_module",
+ "test_module.slang",
+ userSourceBody,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ auto myStructType = module->getLayout()->findTypeByName("MyStruct<float>");
+ SLANG_CHECK_ABORT(myStructType != nullptr);
+
+ // Try to find the "forward" function in MyStruct<float>
+ // This should find functions with different signatures from both extensions:
+ // 1. float forward(float x) from the generic extension Act: IModel<float>
+ // 2. float[2] forward(float[2] x) from the MyStruct-specific extension MyStruct<float>:
+ // IModel<float[2]>
+ auto forwardFunc = module->getLayout()->findFunctionByNameInType(myStructType, "forward");
+
+ // With the fix, this should find functions with different signatures
+ SLANG_CHECK(forwardFunc != nullptr);
+
+ // The function should be overloaded since there are multiple functions with different
+ // signatures
+ if (forwardFunc->isOverloaded())
+ {
+ // If it's overloaded, verify we can access both variants
+ SLANG_CHECK(forwardFunc->getOverloadCount() >= 2);
+
+ // We should be able to find both:
+ // - One with float parameter type (from generic extension)
+ // - One with float[2] parameter type (from MyStruct-specific extension)
+ bool foundFloatParam = false;
+ bool foundFloatArrayParam = false;
+
+ for (int i = 0; i < forwardFunc->getOverloadCount(); i++)
+ {
+ auto overload = forwardFunc->getOverload(i);
+ // Check that each overload has the correct name
+ SLANG_CHECK(UnownedStringSlice(overload->getName()) == "forward");
+ if (overload->getParameterCount() > 0)
+ {
+ auto paramTypeName = overload->getParameterByIndex(0)->getType()->getName();
+ if (strstr(paramTypeName, "Array"))
+ {
+ foundFloatArrayParam = true;
+ }
+ else if (strstr(paramTypeName, "float"))
+ {
+ foundFloatParam = true;
+ }
+ }
+ }
+
+ // Both variants should be found
+ SLANG_CHECK(foundFloatParam);
+ SLANG_CHECK(foundFloatArrayParam);
+ }
+ else
+ {
+ // The function should be overloaded since there are multiple functions with different
+ // signatures. If it's not overloaded, the fix didn't work properly.
+ SLANG_CHECK_ABORT(false && "Expected function to be overloaded with multiple signatures");
+ }
+}