diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-10-23 20:28:49 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-23 17:28:49 -0700 |
| commit | a0bea07503c68160ad2e88986ba98cfc2161bdff (patch) | |
| tree | 4afbd4009607a5b44e2bc72d13a27627a3501acb /tools | |
| parent | 5a161dd799cfc62dcfee281bfaff9819a8be43ad (diff) | |
Fix several bugs with `specializeWithArgTypes()` (#5365)
* Fix several bugs with `specializeWithArgTypes()`
* Make all types L-values for the purposes of reflection API resolution
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/slang-unit-test/unit-test-function-reflection.cpp | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp index 2b52a8691..f893da69d 100644 --- a/tools/slang-unit-test/unit-test-function-reflection.cpp +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -39,6 +39,10 @@ SLANG_UNIT_TEST(functionReflection) float foo(float x) { return x; } float foo(float x, uint i) { return x + i; } + + int bar1(IFloat a, IFloat b) { return 0; } + int bar2<T>(T a, float3 b) { return 0; } + int bar3(float3 b) { return 0; } )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -122,5 +126,66 @@ SLANG_UNIT_TEST(functionReflection) }; auto resolvedFunctionReflection = overloadReflection->specializeWithArgTypes(2, argTypes); SLANG_CHECK(resolvedFunctionReflection == firstOverload); + + // + // More testing for specializeWithArgTypes + // + + // bar1 (IFloat, IFloat) -> int + // + auto bar1Reflection = module->getLayout()->findFunctionByName("bar1"); + SLANG_CHECK(bar1Reflection != nullptr); + SLANG_CHECK(bar1Reflection->isOverloaded() == false); + SLANG_CHECK(bar1Reflection->getParameterCount() == 2); + + auto float3Type = module->getLayout()->findTypeByName("float3"); + SLANG_CHECK(float3Type != nullptr); + argTypes[0] = float3Type; + argTypes[1] = float3Type; + + resolvedFunctionReflection = bar1Reflection->specializeWithArgTypes(2, argTypes); + + SLANG_CHECK(resolvedFunctionReflection != nullptr); + SLANG_CHECK(resolvedFunctionReflection->getParameterCount() == 2); + SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(0)->getType()) == "IFloat"); + SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(1)->getType()) == "IFloat"); + + // bar2 (T : IFloat, float3) -> int + // + auto bar2Reflection = module->getLayout()->findFunctionByName("bar2"); + SLANG_CHECK(bar2Reflection != nullptr); + SLANG_CHECK(bar2Reflection->isOverloaded() == false); + SLANG_CHECK(bar2Reflection->getParameterCount() == 2); + + auto floatType = module->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + argTypes[0] = floatType; + argTypes[1] = float3Type; + + resolvedFunctionReflection = bar2Reflection->specializeWithArgTypes(2, argTypes); + + SLANG_CHECK(resolvedFunctionReflection != nullptr); + SLANG_CHECK(resolvedFunctionReflection->getParameterCount() == 2); + SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(0)->getType()) == "float"); + SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(1)->getType()) == "vector<float,3>"); + + + // failure case + argTypes[0] = floatType; + argTypes[1] = module->getLayout()->findTypeByName("float2"); + resolvedFunctionReflection = bar2Reflection->specializeWithArgTypes(2, argTypes); + SLANG_CHECK(resolvedFunctionReflection == nullptr); // any errors should result in a nullptr. + + // bar3 (float3) -> int + // (trivial case) + auto bar3Reflection = module->getLayout()->findFunctionByName("bar3"); + SLANG_CHECK(bar3Reflection != nullptr); + SLANG_CHECK(bar3Reflection->isOverloaded() == false); + SLANG_CHECK(bar3Reflection->getParameterCount() == 1); + + argTypes[0] = float3Type; + resolvedFunctionReflection = bar3Reflection->specializeWithArgTypes(1, argTypes); + SLANG_CHECK(resolvedFunctionReflection != nullptr); + SLANG_CHECK(resolvedFunctionReflection == bar3Reflection); } |
