summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-16 16:04:45 -0400
committerGitHub <noreply@github.com>2024-09-16 16:04:45 -0400
commitd866c0b9dfc0fdc8ad8cede4d7a8593f7ddf4716 (patch)
tree77cd8713987e575aaf8c7436cd9d2fda8ddc9e63 /tools
parentc46ca4cfeff2c78078aa3c4014cd6b0341ee01fc (diff)
Add API method to specialize function reference with argument types (#4966)
* Add `FunctionReflection::specializeWithArgTypes()` * Update slang.cpp * Use a shared semantics context on linkage Improve performance on reflection queries * Try to fix linux/mac compile errors
Diffstat (limited to 'tools')
-rw-r--r--tools/slang-unit-test/unit-test-decl-tree-reflection.cpp67
1 files changed, 63 insertions, 4 deletions
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
index d98ea0423..89579c585 100644
--- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
@@ -64,12 +64,16 @@ SLANG_UNIT_TEST(declTreeReflection)
struct MyGenericType<T : IArithmetic & IFloat>
{
T z;
+
+ __init(T _z) { z = _z; }
+
T g() { return z; }
U h<U>(U x, out T y) { y = z; return x; }
T j<let N : int>(T x, out int o) { o = N; return x; }
- }
+ U q<U>(U x, T y) { return x; }
+ }
namespace MyNamespace
{
@@ -79,6 +83,8 @@ SLANG_UNIT_TEST(declTreeReflection)
}
}
+ T foo<T, U>(T t, U u) { return t; }
+
)";
auto moduleName = "moduleG" + String(Process::getId());
@@ -110,7 +116,7 @@ SLANG_UNIT_TEST(declTreeReflection)
auto moduleDeclReflection = module->getModuleReflection();
SLANG_CHECK(moduleDeclReflection != nullptr);
SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module);
- SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 8);
+ SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 9);
// First declaration should be a struct with 1 variable
auto firstDecl = moduleDeclReflection->getChild(0);
@@ -379,6 +385,59 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(compositeProgram->getLayout()->isSubType(uintType, diffType) == false);
}
+ // Check specializeWithArgTypes()
+ {
+ auto unspecializedFoo = compositeProgram->getLayout()->findFunctionByName("foo");
+ SLANG_CHECK(unspecializedFoo != nullptr);
+
+ auto floatType = compositeProgram->getLayout()->findTypeByName("float");
+ SLANG_CHECK(floatType != nullptr);
+ auto uintType = compositeProgram->getLayout()->findTypeByName("uint");
+ SLANG_CHECK(uintType != nullptr);
+
+ List<slang::TypeReflection*> argTypes;
+ argTypes.add(floatType);
+ argTypes.add(uintType);
+
+ slang::FunctionReflection* specializedFoo = unspecializedFoo->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer());
+ SLANG_CHECK(specializedFoo != nullptr);
+
+ SLANG_CHECK(getTypeFullName(specializedFoo->getReturnType()) == "float");
+ SLANG_CHECK(specializedFoo->getParameterCount() == 2);
+
+ SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(0)->getName()) == "t");
+ SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(0)->getType()) == "float");
+
+ SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(1)->getName()) == "u");
+ SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(1)->getType()) == "uint");
+ }
+
+ // Check specializeArgTypes on member method looked up through a specialized type
+ {
+ auto specializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType<half>");
+ SLANG_CHECK(specializedType != nullptr);
+
+ auto unspecializedMethod = compositeProgram->getLayout()->findFunctionByNameInType(specializedType, "h");
+ SLANG_CHECK(unspecializedMethod != nullptr);
+
+ // Specialize the method with float
+ auto floatType = compositeProgram->getLayout()->findTypeByName("float");
+ SLANG_CHECK(floatType != nullptr);
+
+ auto halfType = compositeProgram->getLayout()->findTypeByName("half");
+ SLANG_CHECK(halfType != nullptr);
+
+ List<slang::TypeReflection*> argTypes;
+ argTypes.add(floatType);
+ argTypes.add(halfType);
+
+ auto specializedMethodWithFloat = unspecializedMethod->specializeWithArgTypes(
+ argTypes.getCount(),
+ argTypes.getBuffer());
+ SLANG_CHECK(specializedMethodWithFloat != nullptr);
+ SLANG_CHECK(getTypeFullName(specializedMethodWithFloat->getReturnType()) == "float");
+ }
+
// Check iterators
{
unsigned int count = 0;
@@ -386,7 +445,7 @@ SLANG_UNIT_TEST(declTreeReflection)
{
count++;
}
- SLANG_CHECK(count == 8);
+ SLANG_CHECK(count == 9);
count = 0;
for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>())
@@ -407,7 +466,7 @@ SLANG_UNIT_TEST(declTreeReflection)
{
count++;
}
- SLANG_CHECK(count == 1);
+ SLANG_CHECK(count == 2);
count = 0;
for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>())