summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-07-18 13:11:19 -0400
committerGitHub <noreply@github.com>2024-07-18 13:11:19 -0400
commit0d06ebcefb36a19710d87832fc1ea027e21281af (patch)
tree05ac5af6fa26a658348335eac3d63199b6949507 /tools
parent89e836d42822e69dcaa4eb0a366d8c66e5aaa7e4 (diff)
Initial implementation for decl-tree reflection API (#4666)
* Initial implementation for decl-tree reflection API This patch adds Slang API methods for walking all the declarations in the AST. We expose this functionality through an abstract `DeclReflection` class that can be a type, function or a variable declaration. We also provide ways to cast the decl to a `FunctionReflection`, `TypeReflection` or `VariableReflection` and traverse through the child nodes (for instance, a struct type will have component variable declarations) This patch also adds `ISlangInternal` as an internal COM interface to allow us to cast IGlobalSession to the internal Session pointer while bypassing any wrappers (such as the capture interface) * Update slang.h * Remove `ISlangInternal` (its causing a diamond pattern w.r.t `ISlangUnknown`) and use `ComPtr` for proper ref management. * Update unit-test-decl-tree-reflection.cpp * Change `FunctionDeclBase` to use `DeclRef` instead of directly using the decl. * Update slang-reflection-api.cpp
Diffstat (limited to 'tools')
-rw-r--r--tools/slang-unit-test/unit-test-decl-tree-reflection.cpp167
1 files changed, 167 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
new file mode 100644
index 000000000..e443358fb
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
@@ -0,0 +1,167 @@
+// unit-test-translation-unit-import.cpp
+
+#include "slang.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "tools/unit-test/slang-unit-test.h"
+#include "slang-com-ptr.h"
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-process.h"
+
+using namespace Slang;
+
+static String getTypeFullName(slang::TypeReflection* type)
+{
+ ComPtr<ISlangBlob> blob;
+ type->getFullName(blob.writeRef());
+ return String((const char*)blob->getBufferPointer());
+}
+
+static void printRefl(slang::DeclReflection* refl, unsigned int level = 0)
+{
+ // Mapping of kind ids to names
+ std::string names[] = {"Unsupported", "Struct", "Function", "Module", "Generic", "Variable"};
+ for (unsigned int i = 0; i < level; i++)
+ {
+ std::cout << " ";
+ }
+ std::cout<< "[" << names[(unsigned int)refl->getKind()] << "] (" << refl->getChildrenCount() << ")" << std::endl;
+
+ for (auto* child : refl->getChildren())
+ {
+ printRefl(child, level + 1);
+ }
+}
+
+// Test that the reflection API provides correct info about entry point and ordinary functions.
+
+SLANG_UNIT_TEST(declTreeReflection)
+{
+ // Source for a module that contains an undecorated entrypoint.
+ const char* userSourceBody = R"(
+ [__AttributeUsage(_AttributeTargets.Function)]
+ struct MyFuncPropertyAttribute {int v;}
+
+ [MyFuncProperty(1024)]
+ [Differentiable]
+ float ordinaryFunc(no_diff float x, int y) { return x + y; }
+
+ float4 fragMain(float4 pos:SV_Position) : SV_Position
+ {
+ return pos;
+ }
+ )";
+
+ auto moduleName = "moduleG" + String(Process::getId());
+ String userSource = "import " + moduleName + ";\n" + userSourceBody;
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_HLSL;
+ targetDesc.profile = globalSession->findProfile("sm_5_0");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ ComPtr<slang::IEntryPoint> entryPoint;
+ module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK(entryPoint != nullptr);
+
+ auto moduleDeclReflection = module->getModuleReflection();
+ SLANG_CHECK(moduleDeclReflection != nullptr);
+ SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module);
+ SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 4);
+
+ // First declaration should be a struct with 1 variable
+ auto firstDecl = moduleDeclReflection->getChild(0);
+ SLANG_CHECK(firstDecl->getKind() == slang::DeclReflection::Kind::Struct);
+ SLANG_CHECK(firstDecl->getChildrenCount() == 1);
+
+ {
+ slang::TypeReflection* type = firstDecl->getType(globalSession);
+ SLANG_CHECK(getTypeFullName(type) == "MyFuncPropertyAttribute");
+
+ // Check the field of the struct.
+ SLANG_CHECK(type->getFieldCount() == 1);
+ auto field = type->getFieldByIndex(0);
+ SLANG_CHECK(UnownedStringSlice(field->getName()) == "v");
+ SLANG_CHECK(getTypeFullName(field->getType()) == "int");
+ }
+
+ // Second declaration should be a function
+ auto secondDecl = moduleDeclReflection->getChild(1);
+ SLANG_CHECK(secondDecl->getKind() == slang::DeclReflection::Kind::Func);
+ SLANG_CHECK(secondDecl->getChildrenCount() == 2); // Parameter declarations are children (return type is not)
+
+ {
+ auto funcReflection = secondDecl->asFunction();
+ SLANG_CHECK(funcReflection->findModifier(slang::Modifier::Differentiable) != nullptr);
+ SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "float");
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "ordinaryFunc");
+ SLANG_CHECK(funcReflection->getParameterCount() == 2);
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x");
+ SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float");
+ SLANG_CHECK(funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) != nullptr);
+
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y");
+ SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "int");
+
+ SLANG_CHECK(funcReflection->getUserAttributeCount() == 1);
+ auto userAttribute = funcReflection->getUserAttributeByIndex(0);
+ SLANG_CHECK(UnownedStringSlice(userAttribute->getName()) == "MyFuncProperty");
+ SLANG_CHECK(userAttribute->getArgumentCount() == 1);
+ SLANG_CHECK(getTypeFullName(userAttribute->getArgumentType(0)) == "int");
+ int val = 0;
+ auto result = userAttribute->getArgumentValueInt(0, &val);
+ SLANG_CHECK(result == SLANG_OK);
+ SLANG_CHECK(val == 1024);
+ SLANG_CHECK(funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute);
+ }
+
+ // Third declaration should also be a function
+ auto thirdDecl = moduleDeclReflection->getChild(2);
+ SLANG_CHECK(thirdDecl->getKind() == slang::DeclReflection::Kind::Func);
+ SLANG_CHECK(thirdDecl->getChildrenCount() == 1);
+
+ {
+ auto funcReflection = thirdDecl->asFunction();
+ SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "vector<float,4>");
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "fragMain");
+ SLANG_CHECK(funcReflection->getParameterCount() == 1);
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "pos");
+ SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "vector<float,4>");
+ }
+
+ // Check iterators
+ {
+ unsigned int count = 0;
+ for (auto* child : moduleDeclReflection->getChildren())
+ {
+ count++;
+ }
+ SLANG_CHECK(count == 4);
+
+ count = 0;
+ for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>())
+ {
+ count++;
+ }
+ SLANG_CHECK(count == 2);
+
+ count = 0;
+ for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Struct>())
+ {
+ count++;
+ }
+ SLANG_CHECK(count == 1);
+ }
+}
+