summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-08-27 16:47:05 -0400
committerGitHub <noreply@github.com>2024-08-27 16:47:05 -0400
commit4aac22da6ae902eca1e7750f4e5b83ba238b5874 (patch)
treef266e3c7c3a646473ac4af80ddbcd72702ced917 /tools
parentd40c143eb4f19f1dfd0d0dcf9b718be6e495ca27 (diff)
Add ability to specialize generic references to functions, types and more (#4909)
* More reflection API features. + Lookup methods and members (by string) on types + Fix issue with looking up non-static members through the scope operator '::' + `GenericReflection`: Cast a decl to generic to access unspecialized generic parameter names and constraints + `GenericReflection`: Use `getGenericContainer()` from function, variable or type to access the 'nearest' generic parent along with specialization info + `GenericReflection::getConcreteType` and `GenericReflection::getConcreteIntVal`: to get the concrete type of a param in the context of the reflection object + `GenericReflection::getOuterGenericContainer` to go up one level and get the outer generic declarations (if there are more than one enclosing generic scopes) + `DeclReflection::getParent`: go to parent declaration. + Change `VariableReflection` to be a `DeclRef` rather than a decl (allows us to return properly substituted types for methods, members, and more) * Fix Falcor issue * Initial namespace reflection support * FIx issue with specializing witness tables * Add API method for specializing parameters of a generic decl * Add ability to specialize generic references to functions, types and more This PR adds the following end-points: - `specializeGeneric()` method that can be called on a generic reflection to substitute arguments for generic type and value parameters. It returns another generic reflection, but this time with the appropriate substitution. - `applySpecializations()` method to then copy these specializations onto an existing type or function reflection. - `isSubType()` to check if a type is a subtype of another type (useful to check if a type is differentiable by checking `IDifferentiable`) This PR also: - Adds `DeclReflection::Kind::Namespace` so that namespace containers are correctly reflected when walking the decl-tree. the name can be obtained through `getName()` but there's no need to cast to a namespace (since there's nothing else we can do with a namespace decl) - Fixes an issue with name-based lookups that fail if a type or function is referenced without specializations. Its helpful to be able to form a reference to a function with default substitutions, so that we can we can specialize it later (either directly, or via argument types). * Update slang.h * Fix up naming * Update slang-compiler.h * Update slang-reflection-api.cpp * Update slang.cpp * Update slang.cpp * Update slang.cpp * Use `checkGenericAppWithCheckedArgs` to do specialization * Update slang-reflection-api.cpp * Update slang-check-decl.cpp
Diffstat (limited to 'tools')
-rw-r--r--tools/slang-unit-test/unit-test-decl-tree-reflection.cpp126
1 files changed, 124 insertions, 2 deletions
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
index fb35f323c..d98ea0423 100644
--- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
@@ -70,6 +70,15 @@ SLANG_UNIT_TEST(declTreeReflection)
T j<let N : int>(T x, out int o) { o = N; return x; }
}
+
+ namespace MyNamespace
+ {
+ struct MyStruct
+ {
+ int x;
+ }
+ }
+
)";
auto moduleName = "moduleG" + String(Process::getId());
@@ -101,7 +110,7 @@ SLANG_UNIT_TEST(declTreeReflection)
auto moduleDeclReflection = module->getModuleReflection();
SLANG_CHECK(moduleDeclReflection != nullptr);
SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module);
- SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 7);
+ SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 8);
// First declaration should be a struct with 1 variable
auto firstDecl = moduleDeclReflection->getChild(0);
@@ -180,6 +189,11 @@ SLANG_UNIT_TEST(declTreeReflection)
auto innerStruct = genericReflection->getInnerDecl();
SLANG_CHECK(innerStruct->getKind() == slang::DeclReflection::Kind::Struct);
+ // Check that the seventh declaration is a namespace
+ auto seventhDecl = moduleDeclReflection->getChild(6);
+ SLANG_CHECK(seventhDecl->getKind() == slang::DeclReflection::Kind::Namespace);
+ SLANG_CHECK(UnownedStringSlice(seventhDecl->getName()) == "MyNamespace");
+
// Check type-lookup-by-name
{
@@ -262,7 +276,108 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name
SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10);
}
+
+ // Check specializeGeneric() and applySpecializations()
+ {
+ auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType");
+ SLANG_CHECK(unspecializedType != nullptr);
+ auto halfType = compositeProgram->getLayout()->findTypeByName("half");
+ SLANG_CHECK(halfType != nullptr);
+
+ slang::GenericReflection* genericContainer = unspecializedType->getGenericContainer();
+ SLANG_CHECK(genericContainer != nullptr);
+ //auto typeParamT = genericContainer->getTypeParameter(0);
+
+ List<slang::GenericArgType> argTypes;
+ List<slang::GenericArgReflection> args;
+ argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE);
+ args.add({halfType});
+ auto specializedContainer = compositeProgram->getLayout()->specializeGeneric(
+ genericContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
+
+ SLANG_CHECK(specializedContainer != nullptr);
+
+ auto specializedType = unspecializedType->applySpecializations(specializedContainer);
+ SLANG_CHECK(specializedType != nullptr);
+ SLANG_CHECK(getTypeFullName(specializedType) == "MyGenericType<half>");
+
+ }
+
+ // Check specializeGeneric() and applySpecializations() on multiple levels (generic function nested in a generic struct)
+ {
+ auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType");
+ auto unspecializedFunc = compositeProgram->getLayout()->findFunctionByNameInType(unspecializedType, "j");
+
+ SLANG_CHECK(unspecializedFunc != nullptr);
+ auto halfType = compositeProgram->getLayout()->findTypeByName("half");
+ SLANG_CHECK(halfType != nullptr);
+
+ slang::GenericReflection* genericFuncContainer = unspecializedFunc->getGenericContainer();
+ SLANG_CHECK(genericFuncContainer != nullptr);
+ slang::GenericReflection* genericStructContainer = genericFuncContainer->getOuterGenericContainer();
+ SLANG_CHECK(genericStructContainer != nullptr);
+
+ // Specialize the outer container with half
+ List<slang::GenericArgType> argTypes;
+ List<slang::GenericArgReflection> args;
+ argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE);
+ args.add({halfType});
+ auto specializedStructContainer = compositeProgram->getLayout()->specializeGeneric(
+ genericStructContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
+ SLANG_CHECK(specializedStructContainer != nullptr);
+
+ // apply T=half. N is still left unspecialized.
+ genericFuncContainer = genericFuncContainer->applySpecializations(specializedStructContainer);
+
+ // Specialize the inner container with 10 separately..
+ argTypes.clear();
+ args.clear();
+
+ slang::GenericArgReflection argN;
+ argN.intVal = 10;
+ argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_INT);
+ args.add(argN);
+
+ auto specializedFuncContainer = compositeProgram->getLayout()->specializeGeneric(
+ genericFuncContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
+
+ auto specializedFunc = unspecializedFunc->applySpecializations(specializedFuncContainer);
+ SLANG_CHECK(specializedFunc != nullptr);
+
+ // ------ check the specialized function
+ auto specializationInfo = specializedFunc->getGenericContainer();
+ SLANG_CHECK(specializationInfo != nullptr);
+ SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j");
+ SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ SLANG_CHECK(specializationInfo->getValueParameterCount() == 1);
+ auto valueParam = specializationInfo->getValueParameter(0);
+ SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name
+ SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10);
+
+ // check outer container
+ specializationInfo = specializationInfo->getOuterGenericContainer();
+ SLANG_CHECK(specializationInfo != nullptr);
+ SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType");
+ SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ // Check type parameters
+ SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1);
+ auto typeParam = specializationInfo->getTypeParameter(0);
+ SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "T"); // generic name
+ SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "half");
+ }
+
+ // Check sub-type relations
+ {
+ auto floatType = compositeProgram->getLayout()->findTypeByName("float");
+ SLANG_CHECK(floatType != nullptr);
+ auto diffType = compositeProgram->getLayout()->findTypeByName("IDifferentiable");
+ SLANG_CHECK(diffType != nullptr);
+ SLANG_CHECK(compositeProgram->getLayout()->isSubType(floatType, diffType) == true);
+
+ auto uintType = compositeProgram->getLayout()->findTypeByName("uint");
+ SLANG_CHECK(compositeProgram->getLayout()->isSubType(uintType, diffType) == false);
+ }
// Check iterators
{
@@ -271,7 +386,7 @@ SLANG_UNIT_TEST(declTreeReflection)
{
count++;
}
- SLANG_CHECK(count == 7);
+ SLANG_CHECK(count == 8);
count = 0;
for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>())
@@ -293,6 +408,13 @@ SLANG_UNIT_TEST(declTreeReflection)
count++;
}
SLANG_CHECK(count == 1);
+
+ count = 0;
+ for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>())
+ {
+ count++;
+ }
+ SLANG_CHECK(count == 1);
}
}