diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-08-27 16:47:05 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-27 16:47:05 -0400 |
| commit | 4aac22da6ae902eca1e7750f4e5b83ba238b5874 (patch) | |
| tree | f266e3c7c3a646473ac4af80ddbcd72702ced917 /tools | |
| parent | d40c143eb4f19f1dfd0d0dcf9b718be6e495ca27 (diff) | |
Add ability to specialize generic references to functions, types and more (#4909)
* More reflection API features.
+ Lookup methods and members (by string) on types
+ Fix issue with looking up non-static members through the scope operator '::'
+ `GenericReflection`: Cast a decl to generic to access unspecialized generic parameter names and constraints
+ `GenericReflection`: Use `getGenericContainer()` from function, variable or type to access the 'nearest' generic parent along with specialization info
+ `GenericReflection::getConcreteType` and `GenericReflection::getConcreteIntVal`: to get the concrete type of a param in the context of the reflection object
+ `GenericReflection::getOuterGenericContainer` to go up one level and get the outer generic declarations (if there are more than one enclosing generic scopes)
+ `DeclReflection::getParent`: go to parent declaration.
+ Change `VariableReflection` to be a `DeclRef` rather than a decl (allows us to return properly substituted types for methods, members, and more)
* Fix Falcor issue
* Initial namespace reflection support
* FIx issue with specializing witness tables
* Add API method for specializing parameters of a generic decl
* Add ability to specialize generic references to functions, types and more
This PR adds the following end-points:
- `specializeGeneric()` method that can be called on a generic reflection to substitute arguments for generic type and value parameters. It returns another generic reflection, but this time with the appropriate substitution.
- `applySpecializations()` method to then copy these specializations onto an existing type or function reflection.
- `isSubType()` to check if a type is a subtype of another type (useful to check if a type is differentiable by checking `IDifferentiable`)
This PR also:
- Adds `DeclReflection::Kind::Namespace` so that namespace containers are correctly reflected when walking the decl-tree. the name can be obtained through `getName()` but there's no need to cast to a namespace (since there's nothing else we can do with a namespace decl)
- Fixes an issue with name-based lookups that fail if a type or function is referenced without specializations. Its helpful to be able to form a reference to a function with default substitutions, so that we can we can specialize it later (either directly, or via argument types).
* Update slang.h
* Fix up naming
* Update slang-compiler.h
* Update slang-reflection-api.cpp
* Update slang.cpp
* Update slang.cpp
* Update slang.cpp
* Use `checkGenericAppWithCheckedArgs` to do specialization
* Update slang-reflection-api.cpp
* Update slang-check-decl.cpp
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/slang-unit-test/unit-test-decl-tree-reflection.cpp | 126 |
1 files changed, 124 insertions, 2 deletions
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index fb35f323c..d98ea0423 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -70,6 +70,15 @@ SLANG_UNIT_TEST(declTreeReflection) T j<let N : int>(T x, out int o) { o = N; return x; } } + + namespace MyNamespace + { + struct MyStruct + { + int x; + } + } + )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -101,7 +110,7 @@ SLANG_UNIT_TEST(declTreeReflection) auto moduleDeclReflection = module->getModuleReflection(); SLANG_CHECK(moduleDeclReflection != nullptr); SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module); - SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 7); + SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 8); // First declaration should be a struct with 1 variable auto firstDecl = moduleDeclReflection->getChild(0); @@ -180,6 +189,11 @@ SLANG_UNIT_TEST(declTreeReflection) auto innerStruct = genericReflection->getInnerDecl(); SLANG_CHECK(innerStruct->getKind() == slang::DeclReflection::Kind::Struct); + // Check that the seventh declaration is a namespace + auto seventhDecl = moduleDeclReflection->getChild(6); + SLANG_CHECK(seventhDecl->getKind() == slang::DeclReflection::Kind::Namespace); + SLANG_CHECK(UnownedStringSlice(seventhDecl->getName()) == "MyNamespace"); + // Check type-lookup-by-name { @@ -262,7 +276,108 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10); } + + // Check specializeGeneric() and applySpecializations() + { + auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType"); + SLANG_CHECK(unspecializedType != nullptr); + auto halfType = compositeProgram->getLayout()->findTypeByName("half"); + SLANG_CHECK(halfType != nullptr); + + slang::GenericReflection* genericContainer = unspecializedType->getGenericContainer(); + SLANG_CHECK(genericContainer != nullptr); + //auto typeParamT = genericContainer->getTypeParameter(0); + + List<slang::GenericArgType> argTypes; + List<slang::GenericArgReflection> args; + argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE); + args.add({halfType}); + auto specializedContainer = compositeProgram->getLayout()->specializeGeneric( + genericContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); + + SLANG_CHECK(specializedContainer != nullptr); + + auto specializedType = unspecializedType->applySpecializations(specializedContainer); + SLANG_CHECK(specializedType != nullptr); + SLANG_CHECK(getTypeFullName(specializedType) == "MyGenericType<half>"); + + } + + // Check specializeGeneric() and applySpecializations() on multiple levels (generic function nested in a generic struct) + { + auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType"); + auto unspecializedFunc = compositeProgram->getLayout()->findFunctionByNameInType(unspecializedType, "j"); + + SLANG_CHECK(unspecializedFunc != nullptr); + auto halfType = compositeProgram->getLayout()->findTypeByName("half"); + SLANG_CHECK(halfType != nullptr); + + slang::GenericReflection* genericFuncContainer = unspecializedFunc->getGenericContainer(); + SLANG_CHECK(genericFuncContainer != nullptr); + slang::GenericReflection* genericStructContainer = genericFuncContainer->getOuterGenericContainer(); + SLANG_CHECK(genericStructContainer != nullptr); + + // Specialize the outer container with half + List<slang::GenericArgType> argTypes; + List<slang::GenericArgReflection> args; + argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE); + args.add({halfType}); + auto specializedStructContainer = compositeProgram->getLayout()->specializeGeneric( + genericStructContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); + SLANG_CHECK(specializedStructContainer != nullptr); + + // apply T=half. N is still left unspecialized. + genericFuncContainer = genericFuncContainer->applySpecializations(specializedStructContainer); + + // Specialize the inner container with 10 separately.. + argTypes.clear(); + args.clear(); + + slang::GenericArgReflection argN; + argN.intVal = 10; + argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_INT); + args.add(argN); + + auto specializedFuncContainer = compositeProgram->getLayout()->specializeGeneric( + genericFuncContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); + + auto specializedFunc = unspecializedFunc->applySpecializations(specializedFuncContainer); + SLANG_CHECK(specializedFunc != nullptr); + + // ------ check the specialized function + auto specializationInfo = specializedFunc->getGenericContainer(); + SLANG_CHECK(specializationInfo != nullptr); + SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j"); + SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + SLANG_CHECK(specializationInfo->getValueParameterCount() == 1); + auto valueParam = specializationInfo->getValueParameter(0); + SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name + SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10); + + // check outer container + specializationInfo = specializationInfo->getOuterGenericContainer(); + SLANG_CHECK(specializationInfo != nullptr); + SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType"); + SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + // Check type parameters + SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1); + auto typeParam = specializationInfo->getTypeParameter(0); + SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "T"); // generic name + SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "half"); + } + + // Check sub-type relations + { + auto floatType = compositeProgram->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + auto diffType = compositeProgram->getLayout()->findTypeByName("IDifferentiable"); + SLANG_CHECK(diffType != nullptr); + SLANG_CHECK(compositeProgram->getLayout()->isSubType(floatType, diffType) == true); + + auto uintType = compositeProgram->getLayout()->findTypeByName("uint"); + SLANG_CHECK(compositeProgram->getLayout()->isSubType(uintType, diffType) == false); + } // Check iterators { @@ -271,7 +386,7 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 7); + SLANG_CHECK(count == 8); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>()) @@ -293,6 +408,13 @@ SLANG_UNIT_TEST(declTreeReflection) count++; } SLANG_CHECK(count == 1); + + count = 0; + for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>()) + { + count++; + } + SLANG_CHECK(count == 1); } } |
