diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-09-16 16:04:45 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-16 16:04:45 -0400 |
| commit | d866c0b9dfc0fdc8ad8cede4d7a8593f7ddf4716 (patch) | |
| tree | 77cd8713987e575aaf8c7436cd9d2fda8ddc9e63 /tools | |
| parent | c46ca4cfeff2c78078aa3c4014cd6b0341ee01fc (diff) | |
Add API method to specialize function reference with argument types (#4966)
* Add `FunctionReflection::specializeWithArgTypes()`
* Update slang.cpp
* Use a shared semantics context on linkage
Improve performance on reflection queries
* Try to fix linux/mac compile errors
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/slang-unit-test/unit-test-decl-tree-reflection.cpp | 67 |
1 files changed, 63 insertions, 4 deletions
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index d98ea0423..89579c585 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -64,12 +64,16 @@ SLANG_UNIT_TEST(declTreeReflection) struct MyGenericType<T : IArithmetic & IFloat> { T z; + + __init(T _z) { z = _z; } + T g() { return z; } U h<U>(U x, out T y) { y = z; return x; } T j<let N : int>(T x, out int o) { o = N; return x; } - } + U q<U>(U x, T y) { return x; } + } namespace MyNamespace { @@ -79,6 +83,8 @@ SLANG_UNIT_TEST(declTreeReflection) } } + T foo<T, U>(T t, U u) { return t; } + )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -110,7 +116,7 @@ SLANG_UNIT_TEST(declTreeReflection) auto moduleDeclReflection = module->getModuleReflection(); SLANG_CHECK(moduleDeclReflection != nullptr); SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module); - SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 8); + SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 9); // First declaration should be a struct with 1 variable auto firstDecl = moduleDeclReflection->getChild(0); @@ -379,6 +385,59 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(compositeProgram->getLayout()->isSubType(uintType, diffType) == false); } + // Check specializeWithArgTypes() + { + auto unspecializedFoo = compositeProgram->getLayout()->findFunctionByName("foo"); + SLANG_CHECK(unspecializedFoo != nullptr); + + auto floatType = compositeProgram->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + auto uintType = compositeProgram->getLayout()->findTypeByName("uint"); + SLANG_CHECK(uintType != nullptr); + + List<slang::TypeReflection*> argTypes; + argTypes.add(floatType); + argTypes.add(uintType); + + slang::FunctionReflection* specializedFoo = unspecializedFoo->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer()); + SLANG_CHECK(specializedFoo != nullptr); + + SLANG_CHECK(getTypeFullName(specializedFoo->getReturnType()) == "float"); + SLANG_CHECK(specializedFoo->getParameterCount() == 2); + + SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(0)->getName()) == "t"); + SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(0)->getType()) == "float"); + + SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(1)->getName()) == "u"); + SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(1)->getType()) == "uint"); + } + + // Check specializeArgTypes on member method looked up through a specialized type + { + auto specializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType<half>"); + SLANG_CHECK(specializedType != nullptr); + + auto unspecializedMethod = compositeProgram->getLayout()->findFunctionByNameInType(specializedType, "h"); + SLANG_CHECK(unspecializedMethod != nullptr); + + // Specialize the method with float + auto floatType = compositeProgram->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + + auto halfType = compositeProgram->getLayout()->findTypeByName("half"); + SLANG_CHECK(halfType != nullptr); + + List<slang::TypeReflection*> argTypes; + argTypes.add(floatType); + argTypes.add(halfType); + + auto specializedMethodWithFloat = unspecializedMethod->specializeWithArgTypes( + argTypes.getCount(), + argTypes.getBuffer()); + SLANG_CHECK(specializedMethodWithFloat != nullptr); + SLANG_CHECK(getTypeFullName(specializedMethodWithFloat->getReturnType()) == "float"); + } + // Check iterators { unsigned int count = 0; @@ -386,7 +445,7 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 8); + SLANG_CHECK(count == 9); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>()) @@ -407,7 +466,7 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 1); + SLANG_CHECK(count == 2); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>()) |
