From b89421cb3b803165455020f5b70d582b6aec6e76 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 10 Jul 2024 14:09:18 -0700 Subject: Add reflection API for functions. (#4587) * Add reflection API for functions. This change adds `SlangFunctionReflection` type in the reflection API that provides methods for querying function result type, parameters and user-defined attributes. `ProgramLayout::findFunctionByName` can now find a function with the given name and returns a `FunctionReflection`. `IEntryPoint` now has a `getFunctionReflection` method that returns an `FunctionReflection` for the entrypoint. * More modifiers; make reflection API consistent. --- .../unit-test-function-reflection.cpp | 94 ++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tools/slang-unit-test/unit-test-function-reflection.cpp (limited to 'tools/slang-unit-test/unit-test-function-reflection.cpp') diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp new file mode 100644 index 000000000..ddbfa439b --- /dev/null +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -0,0 +1,94 @@ +// unit-test-translation-unit-import.cpp + +#include "../../slang.h" + +#include +#include + +#include "tools/unit-test/slang-unit-test.h" +#include "../../slang-com-ptr.h" +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" + +using namespace Slang; + +static String getTypeFullName(slang::TypeReflection* type) +{ + ComPtr blob; + type->getFullName(blob.writeRef()); + return String((const char*)blob->getBufferPointer()); +} + +// Test that the reflection API provides correct info about entry point and ordinary functions. + +SLANG_UNIT_TEST(functionReflection) +{ + // Source for a module that contains an undecorated entrypoint. + const char* userSourceBody = R"( + [__AttributeUsage(_AttributeTargets.Function)] + struct MyFuncPropertyAttribute {int v;} + + [MyFuncProperty(1024)] + [Differentiable] + float ordinaryFunc(no_diff float x, int y) { return x + y; } + + float4 fragMain(float4 pos:SV_Position) : SV_Position + { + return pos; + } + )"; + + auto moduleName = "moduleG" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr diagnosticBlob; + auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr entryPoint; + module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(entryPoint != nullptr); + + auto entryPointFuncReflection = entryPoint->getFunctionReflection(); + SLANG_CHECK(entryPointFuncReflection != nullptr); + SLANG_CHECK(UnownedStringSlice(entryPointFuncReflection->getName()) == "fragMain"); + SLANG_CHECK(entryPointFuncReflection->getParameterCount() == 1); + SLANG_CHECK(UnownedStringSlice(entryPointFuncReflection->getParameterByIndex(0)->getName()) == "pos"); + SLANG_CHECK(getTypeFullName(entryPointFuncReflection->getParameterByIndex(0)->getType()) == "vector"); + + auto funcReflection = module->getLayout()->findFunctionByName("ordinaryFunc"); + SLANG_CHECK(funcReflection != nullptr); + + SLANG_CHECK(funcReflection->findModifier(slang::Modifier::Differentiable) != nullptr); + SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "float"); + SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "ordinaryFunc"); + SLANG_CHECK(funcReflection->getParameterCount() == 2); + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float"); + SLANG_CHECK(funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) != nullptr); + + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "int"); + + SLANG_CHECK(funcReflection->getUserAttributeCount() == 1); + auto userAttribute = funcReflection->getUserAttributeByIndex(0); + SLANG_CHECK(UnownedStringSlice(userAttribute->getName()) == "MyFuncProperty"); + SLANG_CHECK(userAttribute->getArgumentCount() == 1); + SLANG_CHECK(getTypeFullName(userAttribute->getArgumentType(0)) == "int"); + int val = 0; + auto result = userAttribute->getArgumentValueInt(0, &val); + SLANG_CHECK(result == SLANG_OK); + SLANG_CHECK(val == 1024); + SLANG_CHECK(funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); +} + -- cgit v1.2.3