summaryrefslogtreecommitdiffstats
path: root/tools/slang-unit-test/unit-test-function-reflection.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-10-23 20:28:49 -0400
committerGitHub <noreply@github.com>2024-10-23 17:28:49 -0700
commita0bea07503c68160ad2e88986ba98cfc2161bdff (patch)
tree4afbd4009607a5b44e2bc72d13a27627a3501acb /tools/slang-unit-test/unit-test-function-reflection.cpp
parent5a161dd799cfc62dcfee281bfaff9819a8be43ad (diff)
Fix several bugs with `specializeWithArgTypes()` (#5365)
* Fix several bugs with `specializeWithArgTypes()` * Make all types L-values for the purposes of reflection API resolution
Diffstat (limited to 'tools/slang-unit-test/unit-test-function-reflection.cpp')
-rw-r--r--tools/slang-unit-test/unit-test-function-reflection.cpp65
1 files changed, 65 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp
index 2b52a8691..f893da69d 100644
--- a/tools/slang-unit-test/unit-test-function-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-function-reflection.cpp
@@ -39,6 +39,10 @@ SLANG_UNIT_TEST(functionReflection)
float foo(float x) { return x; }
float foo(float x, uint i) { return x + i; }
+
+ int bar1(IFloat a, IFloat b) { return 0; }
+ int bar2<T>(T a, float3 b) { return 0; }
+ int bar3(float3 b) { return 0; }
)";
auto moduleName = "moduleG" + String(Process::getId());
@@ -122,5 +126,66 @@ SLANG_UNIT_TEST(functionReflection)
};
auto resolvedFunctionReflection = overloadReflection->specializeWithArgTypes(2, argTypes);
SLANG_CHECK(resolvedFunctionReflection == firstOverload);
+
+ //
+ // More testing for specializeWithArgTypes
+ //
+
+ // bar1 (IFloat, IFloat) -> int
+ //
+ auto bar1Reflection = module->getLayout()->findFunctionByName("bar1");
+ SLANG_CHECK(bar1Reflection != nullptr);
+ SLANG_CHECK(bar1Reflection->isOverloaded() == false);
+ SLANG_CHECK(bar1Reflection->getParameterCount() == 2);
+
+ auto float3Type = module->getLayout()->findTypeByName("float3");
+ SLANG_CHECK(float3Type != nullptr);
+ argTypes[0] = float3Type;
+ argTypes[1] = float3Type;
+
+ resolvedFunctionReflection = bar1Reflection->specializeWithArgTypes(2, argTypes);
+
+ SLANG_CHECK(resolvedFunctionReflection != nullptr);
+ SLANG_CHECK(resolvedFunctionReflection->getParameterCount() == 2);
+ SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(0)->getType()) == "IFloat");
+ SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(1)->getType()) == "IFloat");
+
+ // bar2 (T : IFloat, float3) -> int
+ //
+ auto bar2Reflection = module->getLayout()->findFunctionByName("bar2");
+ SLANG_CHECK(bar2Reflection != nullptr);
+ SLANG_CHECK(bar2Reflection->isOverloaded() == false);
+ SLANG_CHECK(bar2Reflection->getParameterCount() == 2);
+
+ auto floatType = module->getLayout()->findTypeByName("float");
+ SLANG_CHECK(floatType != nullptr);
+ argTypes[0] = floatType;
+ argTypes[1] = float3Type;
+
+ resolvedFunctionReflection = bar2Reflection->specializeWithArgTypes(2, argTypes);
+
+ SLANG_CHECK(resolvedFunctionReflection != nullptr);
+ SLANG_CHECK(resolvedFunctionReflection->getParameterCount() == 2);
+ SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(0)->getType()) == "float");
+ SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(1)->getType()) == "vector<float,3>");
+
+
+ // failure case
+ argTypes[0] = floatType;
+ argTypes[1] = module->getLayout()->findTypeByName("float2");
+ resolvedFunctionReflection = bar2Reflection->specializeWithArgTypes(2, argTypes);
+ SLANG_CHECK(resolvedFunctionReflection == nullptr); // any errors should result in a nullptr.
+
+ // bar3 (float3) -> int
+ // (trivial case)
+ auto bar3Reflection = module->getLayout()->findFunctionByName("bar3");
+ SLANG_CHECK(bar3Reflection != nullptr);
+ SLANG_CHECK(bar3Reflection->isOverloaded() == false);
+ SLANG_CHECK(bar3Reflection->getParameterCount() == 1);
+
+ argTypes[0] = float3Type;
+ resolvedFunctionReflection = bar3Reflection->specializeWithArgTypes(1, argTypes);
+ SLANG_CHECK(resolvedFunctionReflection != nullptr);
+ SLANG_CHECK(resolvedFunctionReflection == bar3Reflection);
}