diff options
| author | Copilot <198982749+Copilot@users.noreply.github.com> | 2025-07-22 09:03:05 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-22 09:03:05 -0700 |
| commit | 0d26dbaad90f5eac604e148971d14e552bf9d5b8 (patch) | |
| tree | c45813fdb1d613817fca5673b000097ac83c097e /tools | |
| parent | f25e5a89f00bcecacee4f09901d5cfdc1be341c6 (diff) | |
Fix findFunctionByNameInType to preserve functions with different signatures (#7827)
findFunctionByNameInType was only returning one function when multiple functions existed with the same name but different signatures. This broke reflection functionality for extension methods.
Fix the issue by changing findDeclFromStringInType by not calling maybeResolveOverloadedExpr if checkedTerm is overloaded functions. We still call maybeResolveOverloadedExpr when any candidates in the overloaded list is not DeclRefExpr referencing a function.
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/slang-unit-test/unit-test-function-reflection.cpp | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp index 2f57d4151..7e1d87981 100644 --- a/tools/slang-unit-test/unit-test-function-reflection.cpp +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -217,3 +217,112 @@ SLANG_UNIT_TEST(functionReflection) auto ctor = module->getLayout()->findFunctionByNameInType(fooType, "$init"); SLANG_CHECK(ctor != nullptr); } + +// Test that findFunctionByNameInType finds all functions with the same name but different +// signatures +SLANG_UNIT_TEST(findFunctionByNameInType) +{ + // Test shader with extensions that have functions with same name but different signatures + const char* userSourceBody = R"( + public interface IModel<float:IDifferentiable> + { + public float forward(float x); + } + + interface IScalarActivation<float:IDifferentiable> {} + + public extension<float:IDifferentiable, Act:IScalarActivation<float>> Act: IModel<float> + { + public float forward(float x) { return x;} + } + + public struct MyStruct<float:IDifferentiable>: IScalarActivation<float> {} + + public extension<float:IDifferentiable> MyStruct<float>: IModel<float[2]> + { + public float[2] forward(float[2] x) { return x;} + } + + [shader("compute")] + void computeMain(uint3 tid: SV_DispatchThreadID) + { + } + )"; + + auto moduleName = "moduleH" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "test_module", + "test_module.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + auto myStructType = module->getLayout()->findTypeByName("MyStruct<float>"); + SLANG_CHECK_ABORT(myStructType != nullptr); + + // Try to find the "forward" function in MyStruct<float> + // This should find functions with different signatures from both extensions: + // 1. float forward(float x) from the generic extension Act: IModel<float> + // 2. float[2] forward(float[2] x) from the MyStruct-specific extension MyStruct<float>: + // IModel<float[2]> + auto forwardFunc = module->getLayout()->findFunctionByNameInType(myStructType, "forward"); + + // With the fix, this should find functions with different signatures + SLANG_CHECK(forwardFunc != nullptr); + + // The function should be overloaded since there are multiple functions with different + // signatures + if (forwardFunc->isOverloaded()) + { + // If it's overloaded, verify we can access both variants + SLANG_CHECK(forwardFunc->getOverloadCount() >= 2); + + // We should be able to find both: + // - One with float parameter type (from generic extension) + // - One with float[2] parameter type (from MyStruct-specific extension) + bool foundFloatParam = false; + bool foundFloatArrayParam = false; + + for (int i = 0; i < forwardFunc->getOverloadCount(); i++) + { + auto overload = forwardFunc->getOverload(i); + // Check that each overload has the correct name + SLANG_CHECK(UnownedStringSlice(overload->getName()) == "forward"); + if (overload->getParameterCount() > 0) + { + auto paramTypeName = overload->getParameterByIndex(0)->getType()->getName(); + if (strstr(paramTypeName, "Array")) + { + foundFloatArrayParam = true; + } + else if (strstr(paramTypeName, "float")) + { + foundFloatParam = true; + } + } + } + + // Both variants should be found + SLANG_CHECK(foundFloatParam); + SLANG_CHECK(foundFloatArrayParam); + } + else + { + // The function should be overloaded since there are multiple functions with different + // signatures. If it's not overloaded, the fix didn't work properly. + SLANG_CHECK_ABORT(false && "Expected function to be overloaded with multiple signatures"); + } +} |
