diff options
Diffstat (limited to 'tools/slang-unit-test/unit-test-function-reflection.cpp')
| -rw-r--r-- | tools/slang-unit-test/unit-test-function-reflection.cpp | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp index 2f57d4151..7e1d87981 100644 --- a/tools/slang-unit-test/unit-test-function-reflection.cpp +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -217,3 +217,112 @@ SLANG_UNIT_TEST(functionReflection) auto ctor = module->getLayout()->findFunctionByNameInType(fooType, "$init"); SLANG_CHECK(ctor != nullptr); } + +// Test that findFunctionByNameInType finds all functions with the same name but different +// signatures +SLANG_UNIT_TEST(findFunctionByNameInType) +{ + // Test shader with extensions that have functions with same name but different signatures + const char* userSourceBody = R"( + public interface IModel<float:IDifferentiable> + { + public float forward(float x); + } + + interface IScalarActivation<float:IDifferentiable> {} + + public extension<float:IDifferentiable, Act:IScalarActivation<float>> Act: IModel<float> + { + public float forward(float x) { return x;} + } + + public struct MyStruct<float:IDifferentiable>: IScalarActivation<float> {} + + public extension<float:IDifferentiable> MyStruct<float>: IModel<float[2]> + { + public float[2] forward(float[2] x) { return x;} + } + + [shader("compute")] + void computeMain(uint3 tid: SV_DispatchThreadID) + { + } + )"; + + auto moduleName = "moduleH" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "test_module", + "test_module.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + auto myStructType = module->getLayout()->findTypeByName("MyStruct<float>"); + SLANG_CHECK_ABORT(myStructType != nullptr); + + // Try to find the "forward" function in MyStruct<float> + // This should find functions with different signatures from both extensions: + // 1. float forward(float x) from the generic extension Act: IModel<float> + // 2. float[2] forward(float[2] x) from the MyStruct-specific extension MyStruct<float>: + // IModel<float[2]> + auto forwardFunc = module->getLayout()->findFunctionByNameInType(myStructType, "forward"); + + // With the fix, this should find functions with different signatures + SLANG_CHECK(forwardFunc != nullptr); + + // The function should be overloaded since there are multiple functions with different + // signatures + if (forwardFunc->isOverloaded()) + { + // If it's overloaded, verify we can access both variants + SLANG_CHECK(forwardFunc->getOverloadCount() >= 2); + + // We should be able to find both: + // - One with float parameter type (from generic extension) + // - One with float[2] parameter type (from MyStruct-specific extension) + bool foundFloatParam = false; + bool foundFloatArrayParam = false; + + for (int i = 0; i < forwardFunc->getOverloadCount(); i++) + { + auto overload = forwardFunc->getOverload(i); + // Check that each overload has the correct name + SLANG_CHECK(UnownedStringSlice(overload->getName()) == "forward"); + if (overload->getParameterCount() > 0) + { + auto paramTypeName = overload->getParameterByIndex(0)->getType()->getName(); + if (strstr(paramTypeName, "Array")) + { + foundFloatArrayParam = true; + } + else if (strstr(paramTypeName, "float")) + { + foundFloatParam = true; + } + } + } + + // Both variants should be found + SLANG_CHECK(foundFloatParam); + SLANG_CHECK(foundFloatArrayParam); + } + else + { + // The function should be overloaded since there are multiple functions with different + // signatures. If it's not overloaded, the fix didn't work properly. + SLANG_CHECK_ABORT(false && "Expected function to be overloaded with multiple signatures"); + } +} |
