summaryrefslogtreecommitdiffstats
path: root/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'tools/slang-unit-test/unit-test-decl-tree-reflection.cpp')
-rw-r--r--tools/slang-unit-test/unit-test-decl-tree-reflection.cpp162
1 files changed, 106 insertions, 56 deletions
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
index 89579c585..cbe0eb80b 100644
--- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
@@ -1,15 +1,14 @@
// unit-test-translation-unit-import.cpp
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-process.h"
+#include "slang-com-ptr.h"
#include "slang.h"
+#include "tools/unit-test/slang-unit-test.h"
#include <stdio.h>
#include <stdlib.h>
-#include "tools/unit-test/slang-unit-test.h"
-#include "slang-com-ptr.h"
-#include "../../source/core/slang-io.h"
-#include "../../source/core/slang-process.h"
-
using namespace Slang;
static String getTypeFullName(slang::TypeReflection* type)
@@ -27,7 +26,8 @@ static void printRefl(slang::DeclReflection* refl, unsigned int level = 0)
{
std::cout << " ";
}
- std::cout<< "[" << names[(unsigned int)refl->getKind()] << "] (" << refl->getChildrenCount() << ")" << std::endl;
+ std::cout << "[" << names[(unsigned int)refl->getKind()] << "] (" << refl->getChildrenCount()
+ << ")" << std::endl;
for (auto* child : refl->getChildren())
{
@@ -101,16 +101,28 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
ComPtr<slang::IBlob> diagnosticBlob;
- auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef());
+ auto module = session->loadModuleFromSourceString(
+ "m",
+ "m.slang",
+ userSourceBody,
+ diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);
ComPtr<slang::IEntryPoint> entryPoint;
- module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef());
+ module->findAndCheckEntryPoint(
+ "fragMain",
+ SLANG_STAGE_FRAGMENT,
+ entryPoint.writeRef(),
+ diagnosticBlob.writeRef());
SLANG_CHECK(entryPoint != nullptr);
ComPtr<slang::IComponentType> compositeProgram;
- slang::IComponentType* components[] = { module, entryPoint.get() };
- session->createCompositeComponentType(components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef());
+ slang::IComponentType* components[] = {module, entryPoint.get()};
+ session->createCompositeComponentType(
+ components,
+ 2,
+ compositeProgram.writeRef(),
+ diagnosticBlob.writeRef());
SLANG_CHECK(compositeProgram != nullptr);
auto moduleDeclReflection = module->getModuleReflection();
@@ -137,7 +149,9 @@ SLANG_UNIT_TEST(declTreeReflection)
// Second declaration should be a function
auto secondDecl = moduleDeclReflection->getChild(1);
SLANG_CHECK(secondDecl->getKind() == slang::DeclReflection::Kind::Func);
- SLANG_CHECK(secondDecl->getChildrenCount() == 2); // Parameter declarations are children (return type is not)
+ SLANG_CHECK(
+ secondDecl->getChildrenCount() ==
+ 2); // Parameter declarations are children (return type is not)
{
auto funcReflection = secondDecl->asFunction();
@@ -147,7 +161,9 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(funcReflection->getParameterCount() == 2);
SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x");
SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float");
- SLANG_CHECK(funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) != nullptr);
+ SLANG_CHECK(
+ funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) !=
+ nullptr);
SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y");
SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "int");
@@ -161,13 +177,15 @@ SLANG_UNIT_TEST(declTreeReflection)
auto result = userAttribute->getArgumentValueInt(0, &val);
SLANG_CHECK(result == SLANG_OK);
SLANG_CHECK(val == 1024);
- SLANG_CHECK(funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute);
+ SLANG_CHECK(
+ funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") ==
+ userAttribute);
}
// Third declaration should also be a function
auto thirdDecl = moduleDeclReflection->getChild(2);
SLANG_CHECK(thirdDecl->getKind() == slang::DeclReflection::Kind::Func);
- SLANG_CHECK(thirdDecl->getChildrenCount() == 1);
+ SLANG_CHECK(thirdDecl->getChildrenCount() == 1);
{
auto funcReflection = thirdDecl->asFunction();
@@ -175,7 +193,9 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "fragMain");
SLANG_CHECK(funcReflection->getParameterCount() == 1);
SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "pos");
- SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "vector<float,4>");
+ SLANG_CHECK(
+ getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) ==
+ "vector<float,4>");
}
// Sixth declaration should be a generic struct
@@ -187,9 +207,11 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(UnownedStringSlice(typeParamT->getName()) == "T");
auto typeParamTConstraintCount = genericReflection->getTypeParameterConstraintCount(typeParamT);
SLANG_CHECK(typeParamTConstraintCount == 2);
- auto typeParamTConstraintType1 = genericReflection->getTypeParameterConstraintType(typeParamT, 0);
+ auto typeParamTConstraintType1 =
+ genericReflection->getTypeParameterConstraintType(typeParamT, 0);
SLANG_CHECK(getTypeFullName(typeParamTConstraintType1) == "IArithmetic");
- auto typeParamTConstraintType2 = genericReflection->getTypeParameterConstraintType(typeParamT, 1);
+ auto typeParamTConstraintType2 =
+ genericReflection->getTypeParameterConstraintType(typeParamT, 1);
SLANG_CHECK(getTypeFullName(typeParamTConstraintType2) == "IFloat");
auto innerStruct = genericReflection->getInnerDecl();
@@ -205,7 +227,7 @@ SLANG_UNIT_TEST(declTreeReflection)
{
auto type = compositeProgram->getLayout()->findTypeByName("MyType");
SLANG_CHECK(type != nullptr);
- //SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct);
+ // SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct);
SLANG_CHECK(UnownedStringSlice(type->getName()) == "MyType");
auto funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "f");
SLANG_CHECK(funcReflection != nullptr);
@@ -220,7 +242,7 @@ SLANG_UNIT_TEST(declTreeReflection)
{
auto type = compositeProgram->getLayout()->findTypeByName("MyGenericType<half>");
SLANG_CHECK(type != nullptr);
- //SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct);
+ // SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct);
SLANG_CHECK(getTypeFullName(type) == "MyGenericType<half>");
auto funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "g");
SLANG_CHECK(funcReflection != nullptr);
@@ -242,85 +264,98 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float");
SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y");
SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "half");
-
+
// Access parent generic container from a specialized method.
auto specializationInfo = funcReflection->getGenericContainer();
SLANG_CHECK(specializationInfo != nullptr);
SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "h");
- SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ SLANG_CHECK(
+ specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
// Check type parameters
SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1);
auto typeParam = specializationInfo->getTypeParameter(0);
SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "U"); // generic name
- SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "float"); // specialized type name under the context in which the generic is obtained
+ SLANG_CHECK(
+ getTypeFullName(specializationInfo->getConcreteType(typeParam)) ==
+ "float"); // specialized type name under the context in which the generic is obtained
SLANG_CHECK(specializationInfo->getTypeParameterConstraintCount(typeParam) == 0);
// Go up another level to the generic struct
specializationInfo = specializationInfo->getOuterGenericContainer();
SLANG_CHECK(specializationInfo != nullptr);
SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType");
- SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ SLANG_CHECK(
+ specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
// Check type parameters
SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1);
typeParam = specializationInfo->getTypeParameter(0);
SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "T"); // generic name
- SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "half"); // specialized type name under the context in which the generic is obtained
+ SLANG_CHECK(
+ getTypeFullName(specializationInfo->getConcreteType(typeParam)) ==
+ "half"); // specialized type name under the context in which the generic is obtained
SLANG_CHECK(specializationInfo->getTypeParameterConstraintCount(typeParam) == 2);
// Query 'j' on the type 'half'
funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "j<10>");
SLANG_CHECK(funcReflection != nullptr);
SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "j");
-
+
// Check the generic parameters
specializationInfo = funcReflection->getGenericContainer();
SLANG_CHECK(specializationInfo != nullptr);
SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j");
- SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ SLANG_CHECK(
+ specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
SLANG_CHECK(specializationInfo->getValueParameterCount() == 1);
auto valueParam = specializationInfo->getValueParameter(0);
SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name
SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10);
}
-
+
// Check specializeGeneric() and applySpecializations()
{
auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType");
SLANG_CHECK(unspecializedType != nullptr);
auto halfType = compositeProgram->getLayout()->findTypeByName("half");
SLANG_CHECK(halfType != nullptr);
-
+
slang::GenericReflection* genericContainer = unspecializedType->getGenericContainer();
SLANG_CHECK(genericContainer != nullptr);
- //auto typeParamT = genericContainer->getTypeParameter(0);
-
+ // auto typeParamT = genericContainer->getTypeParameter(0);
+
List<slang::GenericArgType> argTypes;
List<slang::GenericArgReflection> args;
argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE);
args.add({halfType});
auto specializedContainer = compositeProgram->getLayout()->specializeGeneric(
- genericContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
-
+ genericContainer,
+ argTypes.getCount(),
+ argTypes.getBuffer(),
+ args.getBuffer(),
+ nullptr);
+
SLANG_CHECK(specializedContainer != nullptr);
auto specializedType = unspecializedType->applySpecializations(specializedContainer);
SLANG_CHECK(specializedType != nullptr);
SLANG_CHECK(getTypeFullName(specializedType) == "MyGenericType<half>");
-
}
- // Check specializeGeneric() and applySpecializations() on multiple levels (generic function nested in a generic struct)
+ // Check specializeGeneric() and applySpecializations() on multiple levels (generic function
+ // nested in a generic struct)
{
auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType");
- auto unspecializedFunc = compositeProgram->getLayout()->findFunctionByNameInType(unspecializedType, "j");
+ auto unspecializedFunc =
+ compositeProgram->getLayout()->findFunctionByNameInType(unspecializedType, "j");
SLANG_CHECK(unspecializedFunc != nullptr);
auto halfType = compositeProgram->getLayout()->findTypeByName("half");
SLANG_CHECK(halfType != nullptr);
-
+
slang::GenericReflection* genericFuncContainer = unspecializedFunc->getGenericContainer();
SLANG_CHECK(genericFuncContainer != nullptr);
- slang::GenericReflection* genericStructContainer = genericFuncContainer->getOuterGenericContainer();
+ slang::GenericReflection* genericStructContainer =
+ genericFuncContainer->getOuterGenericContainer();
SLANG_CHECK(genericStructContainer != nullptr);
// Specialize the outer container with half
@@ -329,11 +364,16 @@ SLANG_UNIT_TEST(declTreeReflection)
argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE);
args.add({halfType});
auto specializedStructContainer = compositeProgram->getLayout()->specializeGeneric(
- genericStructContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
+ genericStructContainer,
+ argTypes.getCount(),
+ argTypes.getBuffer(),
+ args.getBuffer(),
+ nullptr);
SLANG_CHECK(specializedStructContainer != nullptr);
// apply T=half. N is still left unspecialized.
- genericFuncContainer = genericFuncContainer->applySpecializations(specializedStructContainer);
+ genericFuncContainer =
+ genericFuncContainer->applySpecializations(specializedStructContainer);
// Specialize the inner container with 10 separately..
argTypes.clear();
@@ -345,16 +385,21 @@ SLANG_UNIT_TEST(declTreeReflection)
args.add(argN);
auto specializedFuncContainer = compositeProgram->getLayout()->specializeGeneric(
- genericFuncContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
+ genericFuncContainer,
+ argTypes.getCount(),
+ argTypes.getBuffer(),
+ args.getBuffer(),
+ nullptr);
auto specializedFunc = unspecializedFunc->applySpecializations(specializedFuncContainer);
SLANG_CHECK(specializedFunc != nullptr);
-
+
// ------ check the specialized function
auto specializationInfo = specializedFunc->getGenericContainer();
SLANG_CHECK(specializationInfo != nullptr);
SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j");
- SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ SLANG_CHECK(
+ specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
SLANG_CHECK(specializationInfo->getValueParameterCount() == 1);
auto valueParam = specializationInfo->getValueParameter(0);
SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name
@@ -364,7 +409,8 @@ SLANG_UNIT_TEST(declTreeReflection)
specializationInfo = specializationInfo->getOuterGenericContainer();
SLANG_CHECK(specializationInfo != nullptr);
SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType");
- SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ SLANG_CHECK(
+ specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
// Check type parameters
SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1);
auto typeParam = specializationInfo->getTypeParameter(0);
@@ -399,15 +445,16 @@ SLANG_UNIT_TEST(declTreeReflection)
argTypes.add(floatType);
argTypes.add(uintType);
- slang::FunctionReflection* specializedFoo = unspecializedFoo->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer());
+ slang::FunctionReflection* specializedFoo =
+ unspecializedFoo->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer());
SLANG_CHECK(specializedFoo != nullptr);
-
+
SLANG_CHECK(getTypeFullName(specializedFoo->getReturnType()) == "float");
SLANG_CHECK(specializedFoo->getParameterCount() == 2);
SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(0)->getName()) == "t");
SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(0)->getType()) == "float");
-
+
SLANG_CHECK(UnownedStringSlice(specializedFoo->getParameterByIndex(1)->getName()) == "u");
SLANG_CHECK(getTypeFullName(specializedFoo->getParameterByIndex(1)->getType()) == "uint");
}
@@ -417,7 +464,8 @@ SLANG_UNIT_TEST(declTreeReflection)
auto specializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType<half>");
SLANG_CHECK(specializedType != nullptr);
- auto unspecializedMethod = compositeProgram->getLayout()->findFunctionByNameInType(specializedType, "h");
+ auto unspecializedMethod =
+ compositeProgram->getLayout()->findFunctionByNameInType(specializedType, "h");
SLANG_CHECK(unspecializedMethod != nullptr);
// Specialize the method with float
@@ -431,13 +479,12 @@ SLANG_UNIT_TEST(declTreeReflection)
argTypes.add(floatType);
argTypes.add(halfType);
- auto specializedMethodWithFloat = unspecializedMethod->specializeWithArgTypes(
- argTypes.getCount(),
- argTypes.getBuffer());
+ auto specializedMethodWithFloat =
+ unspecializedMethod->specializeWithArgTypes(argTypes.getCount(), argTypes.getBuffer());
SLANG_CHECK(specializedMethodWithFloat != nullptr);
SLANG_CHECK(getTypeFullName(specializedMethodWithFloat->getReturnType()) == "float");
}
-
+
// Check iterators
{
unsigned int count = 0;
@@ -448,32 +495,35 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(count == 9);
count = 0;
- for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>())
+ for (auto* child :
+ moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>())
{
count++;
}
SLANG_CHECK(count == 3);
count = 0;
- for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Struct>())
+ for (auto* child :
+ moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Struct>())
{
count++;
}
SLANG_CHECK(count == 2);
count = 0;
- for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Generic>())
+ for (auto* child :
+ moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Generic>())
{
count++;
}
SLANG_CHECK(count == 2);
count = 0;
- for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>())
+ for (auto* child :
+ moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>())
{
count++;
}
SLANG_CHECK(count == 1);
}
}
-