diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /tools/slang-unit-test/unit-test-decl-tree-reflection.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'tools/slang-unit-test/unit-test-decl-tree-reflection.cpp')
| -rw-r--r-- | tools/slang-unit-test/unit-test-decl-tree-reflection.cpp | 162 |
1 files changed, 106 insertions, 56 deletions
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index 89579c585..cbe0eb80b 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -1,15 +1,14 @@ // unit-test-translation-unit-import.cpp +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" #include "slang.h" +#include "tools/unit-test/slang-unit-test.h" #include <stdio.h> #include <stdlib.h> -#include "tools/unit-test/slang-unit-test.h" -#include "slang-com-ptr.h" -#include "../../source/core/slang-io.h" -#include "../../source/core/slang-process.h" - using namespace Slang; static String getTypeFullName(slang::TypeReflection* type) @@ -27,7 +26,8 @@ static void printRefl(slang::DeclReflection* refl, unsigned int level = 0) { std::cout << " "; } - std::cout<< "[" << names[(unsigned int)refl->getKind()] << "] (" << refl->getChildrenCount() << ")" << std::endl; + std::cout << "[" << names[(unsigned int)refl->getKind()] << "] (" << refl->getChildrenCount() + << ")" << std::endl; for (auto* child : refl->getChildren()) { @@ -101,16 +101,28 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); ComPtr<slang::IBlob> diagnosticBlob; - auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef()); + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); SLANG_CHECK(module != nullptr); ComPtr<slang::IEntryPoint> entryPoint; - module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef()); + module->findAndCheckEntryPoint( + "fragMain", + SLANG_STAGE_FRAGMENT, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); SLANG_CHECK(entryPoint != nullptr); ComPtr<slang::IComponentType> compositeProgram; - slang::IComponentType* components[] = { module, entryPoint.get() }; - session->createCompositeComponentType(components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef()); + slang::IComponentType* components[] = {module, entryPoint.get()}; + session->createCompositeComponentType( + components, + 2, + compositeProgram.writeRef(), + diagnosticBlob.writeRef()); SLANG_CHECK(compositeProgram != nullptr); auto moduleDeclReflection = module->getModuleReflection(); @@ -137,7 +149,9 @@ SLANG_UNIT_TEST(declTreeReflection) // Second declaration should be a function auto secondDecl = moduleDeclReflection->getChild(1); SLANG_CHECK(secondDecl->getKind() == slang::DeclReflection::Kind::Func); - SLANG_CHECK(secondDecl->getChildrenCount() == 2); // Parameter declarations are children (return type is not) + SLANG_CHECK( + secondDecl->getChildrenCount() == + 2); // Parameter declarations are children (return type is not) { auto funcReflection = secondDecl->asFunction(); @@ -147,7 +161,9 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(funcReflection->getParameterCount() == 2); SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x"); SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float"); - SLANG_CHECK(funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) != nullptr); + SLANG_CHECK( + funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) != + nullptr); SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y"); SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "int"); @@ -161,13 +177,15 @@ SLANG_UNIT_TEST(declTreeReflection) auto result = userAttribute->getArgumentValueInt(0, &val); SLANG_CHECK(result == SLANG_OK); SLANG_CHECK(val == 1024); - SLANG_CHECK(funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); + SLANG_CHECK( + funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == + userAttribute); } // Third declaration should also be a function auto thirdDecl = moduleDeclReflection->getChild(2); SLANG_CHECK(thirdDecl->getKind() == slang::DeclReflection::Kind::Func); - SLANG_CHECK(thirdDecl->getChildrenCount() == 1); + SLANG_CHECK(thirdDecl->getChildrenCount() == 1); { auto funcReflection = thirdDecl->asFunction(); @@ -175,7 +193,9 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "fragMain"); SLANG_CHECK(funcReflection->getParameterCount() == 1); SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "pos"); - SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "vector<float,4>"); + SLANG_CHECK( + getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == + "vector<float,4>"); } // Sixth declaration should be a generic struct @@ -187,9 +207,11 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(UnownedStringSlice(typeParamT->getName()) == "T"); auto typeParamTConstraintCount = genericReflection->getTypeParameterConstraintCount(typeParamT); SLANG_CHECK(typeParamTConstraintCount == 2); - auto typeParamTConstraintType1 = genericReflection->getTypeParameterConstraintType(typeParamT, 0); + auto typeParamTConstraintType1 = + genericReflection->getTypeParameterConstraintType(typeParamT, 0); SLANG_CHECK(getTypeFullName(typeParamTConstraintType1) == "IArithmetic"); - auto typeParamTConstraintType2 = genericReflection->getTypeParameterConstraintType(typeParamT, 1); + auto typeParamTConstraintType2 = + genericReflection->getTypeParameterConstraintType(typeParamT, 1); SLANG_CHECK(getTypeFullName(typeParamTConstraintType2) == "IFloat"); auto innerStruct = genericReflection->getInnerDecl(); @@ -205,7 +227,7 @@ SLANG_UNIT_TEST(declTreeReflection) { auto type = compositeProgram->getLayout()->findTypeByName("MyType"); SLANG_CHECK(type != nullptr); - //SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct); + // SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct); SLANG_CHECK(UnownedStringSlice(type->getName()) == "MyType"); auto funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "f"); SLANG_CHECK(funcReflection != nullptr); @@ -220,7 +242,7 @@ SLANG_UNIT_TEST(declTreeReflection) { auto type = compositeProgram->getLayout()->findTypeByName("MyGenericType<half>"); SLANG_CHECK(type != nullptr); - //SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct); + // SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct); SLANG_CHECK(getTypeFullName(type) == "MyGenericType<half>"); auto funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "g"); SLANG_CHECK(funcReflection != nullptr); @@ -242,85 +264,98 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float"); SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y"); SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "half"); - + // Access parent generic container from a specialized method. auto specializationInfo = funcReflection->getGenericContainer(); SLANG_CHECK(specializationInfo != nullptr); SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "h"); - SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + SLANG_CHECK( + specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); // Check type parameters SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1); auto typeParam = specializationInfo->getTypeParameter(0); SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "U"); // generic name - SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "float"); // specialized type name under the context in which the generic is obtained + SLANG_CHECK( + getTypeFullName(specializationInfo->getConcreteType(typeParam)) == + "float"); // specialized type name under the context in which the generic is obtained SLANG_CHECK(specializationInfo->getTypeParameterConstraintCount(typeParam) == 0); // Go up another level to the generic struct specializationInfo = specializationInfo->getOuterGenericContainer(); SLANG_CHECK(specializationInfo != nullptr); SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType"); - SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + SLANG_CHECK( + specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); // Check type parameters SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1); typeParam = specializationInfo->getTypeParameter(0); SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "T"); // generic name - SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "half"); // specialized type name under the context in which the generic is obtained + SLANG_CHECK( + getTypeFullName(specializationInfo->getConcreteType(typeParam)) == + "half"); // specialized type name under the context in which the generic is obtained SLANG_CHECK(specializationInfo->getTypeParameterConstraintCount(typeParam) == 2); // Query 'j' on the type 'half' funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "j<10>"); SLANG_CHECK(funcReflection != nullptr); SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "j"); - + // Check the generic parameters specializationInfo = funcReflection->getGenericContainer(); SLANG_CHECK(specializationInfo != nullptr); SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j"); - SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + SLANG_CHECK( + specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); SLANG_CHECK(specializationInfo->getValueParameterCount() == 1); auto valueParam = specializationInfo->getValueParameter(0); SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10); } - + // Check specializeGeneric() and applySpecializations() { auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType"); SLANG_CHECK(unspecializedType != nullptr); auto halfType = compositeProgram->getLayout()->findTypeByName("half"); SLANG_CHECK(halfType != nullptr); - + slang::GenericReflection* genericContainer = unspecializedType->getGenericContainer(); SLANG_CHECK(genericContainer != nullptr); - //auto typeParamT = genericContainer->getTypeParameter(0); - + // auto typeParamT = genericContainer->getTypeParameter(0); + List<slang::GenericArgType> argTypes; List<slang::GenericArgReflection> args; argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE); args.add({halfType}); auto specializedContainer = compositeProgram->getLayout()->specializeGeneric( - genericContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); - + genericContainer, + argTypes.getCount(), + argTypes.getBuffer(), + args.getBuffer(), + nullptr); + SLANG_CHECK(specializedContainer != nullptr); auto specializedType = unspecializedType->applySpecializations(specializedContainer); SLANG_CHECK(specializedType != nullptr); SLANG_CHECK(getTypeFullName(specializedType) == "MyGenericType<half>"); - } - // Check specializeGeneric() and applySpecializations() on multiple levels (generic function nested in a generic struct) + // Check specializeGeneric() and applySpecializations() on multiple levels (generic function + // nested in a generic struct) { auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType"); - auto unspecializedFunc = compositeProgram->getLayout()->findFunctionByNameInType(unspecializedType, "j"); + auto unspecializedFunc = + compositeProgram->getLayout()->findFunctionByNameInType(unspecializedType, "j"); SLANG_CHECK(unspecializedFunc != nullptr); auto halfType = compositeProgram->getLayout()->findTypeByName("half"); SLANG_CHECK(halfType != nullptr); - + slang::GenericReflection* genericFuncContainer = unspecializedFunc->getGenericContainer(); SLANG_CHECK(genericFuncContainer != nullptr); - slang::GenericReflection* genericStructContainer = genericFuncContainer->getOuterGenericContainer(); + slang::GenericReflection* genericStructContainer = + genericFuncContainer->getOuterGenericContainer(); SLANG_CHECK(genericStructContainer != nullptr); // Specialize the outer container with half @@ -329,11 +364,16 @@ SLANG_UNIT_TEST(declTreeReflection) argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE); args.add({halfType}); auto specializedStructContainer = compositeProgram->getLayout()->specializeGeneric( - genericStructContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); + genericStructContainer, + argTypes.getCount(), + argTypes.getBuffer(), + args.getBuffer(), + nullptr); SLANG_CHECK(specializedStructContainer != nullptr); // apply T=half. N is still left unspecialized. - genericFuncContainer = genericFuncContainer->applySpecializations(specializedStructContainer); + genericFuncContainer = + genericFuncContainer->applySpecializations(specializedStructContainer); // Specialize the inner container with 10 separately.. argTypes.clear(); @@ -345,16 +385,21 @@ SLANG_UNIT_TEST(declTreeReflection) args.add(argN); auto specializedFuncContainer = compositeProgram->getLayout()->specializeGeneric( - genericFuncContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); + genericFuncContainer, + argTypes.getCount(), + argTypes.getBuffer(), + args.getBuffer(), + nullptr); auto specializedFunc = unspecializedFunc->applySpecializations(specializedFuncContainer); SLANG_CHECK(specializedFunc != nullptr); - + // ------ check the specialized function auto specializationInfo = specializedFunc->getGenericContainer(); SLANG_CHECK(specializationInfo != nullptr); SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j"); - SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + SLANG_CHECK( + specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); SLANG_CHECK(specializationInfo->getValueParameterCount() == 1); auto valueParam = specializationInfo->getValueParameter(0); SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name @@ -364,7 +409,8 @@ SLANG_UNIT_TEST(declTreeReflection) specializationInfo = specializationInfo->getOuterGenericContainer(); SLANG_CHECK(specializationInfo != nullptr); SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType"); - SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + SLANG_CHECK( + specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); // Check type parameters SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1); auto typeParam = specializationInfo->getTypeParameter(0); @@ -399,15 +445,16 @@ SLANG_UNIT_TEST(declTreeReflection) argTypes.add(floatType); argTypes.add(uintType); - slang::FunctionReflection* specializedFoo = unspecializedFoo->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer()); + slang::FunctionReflection* specializedFoo = + unspecializedFoo->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer()); SLANG_CHECK(specializedFoo != nullptr); - + SLANG_CHECK(getTypeFullName(specializedFoo->getReturnType()) == "float"); SLANG_CHECK(specializedFoo->getParameterCount() == 2); SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(0)->getName()) == "t"); SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(0)->getType()) == "float"); - + SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(1)->getName()) == "u"); SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(1)->getType()) == "uint"); } @@ -417,7 +464,8 @@ SLANG_UNIT_TEST(declTreeReflection) auto specializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType<half>"); SLANG_CHECK(specializedType != nullptr); - auto unspecializedMethod = compositeProgram->getLayout()->findFunctionByNameInType(specializedType, "h"); + auto unspecializedMethod = + compositeProgram->getLayout()->findFunctionByNameInType(specializedType, "h"); SLANG_CHECK(unspecializedMethod != nullptr); // Specialize the method with float @@ -431,13 +479,12 @@ SLANG_UNIT_TEST(declTreeReflection) argTypes.add(floatType); argTypes.add(halfType); - auto specializedMethodWithFloat = unspecializedMethod->specializeWithArgTypes( - argTypes.getCount(), - argTypes.getBuffer()); + auto specializedMethodWithFloat = + unspecializedMethod->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer()); SLANG_CHECK(specializedMethodWithFloat != nullptr); SLANG_CHECK(getTypeFullName(specializedMethodWithFloat->getReturnType()) == "float"); } - + // Check iterators { unsigned int count = 0; @@ -448,32 +495,35 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(count == 9); count = 0; - for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>()) + for (auto* child : + moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>()) { count++; } SLANG_CHECK(count == 3); count = 0; - for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Struct>()) + for (auto* child : + moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Struct>()) { count++; } SLANG_CHECK(count == 2); count = 0; - for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Generic>()) + for (auto* child : + moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Generic>()) { count++; } SLANG_CHECK(count == 2); count = 0; - for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>()) + for (auto* child : + moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>()) { count++; } SLANG_CHECK(count == 1); } } - |
