diff options
| author | Yong He <yonghe@outlook.com> | 2024-07-10 14:09:18 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-10 14:09:18 -0700 |
| commit | b89421cb3b803165455020f5b70d582b6aec6e76 (patch) | |
| tree | 461327ffbb55e7cea0ca73ae11cfa18425c904a8 | |
| parent | 45ef0ce906c93c16495755fec2e597573e8631c4 (diff) | |
Add reflection API for functions. (#4587)
* Add reflection API for functions.
This change adds `SlangFunctionReflection` type in the reflection API that provides methods for querying function result type, parameters and user-defined attributes.
`ProgramLayout::findFunctionByName` can now find a function with the given name and returns a `FunctionReflection`.
`IEntryPoint` now has a `getFunctionReflection` method that returns an `FunctionReflection` for the entrypoint.
* More modifiers; make reflection API consistent.
| -rw-r--r-- | docs/user-guide/09-reflection.md | 32 | ||||
| -rw-r--r-- | docs/user-guide/toc.html | 1 | ||||
| -rw-r--r-- | slang.h | 97 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-entrypoint.cpp | 6 | ||||
| -rw-r--r-- | source/slang-capture-replay/slang-entrypoint.h | 2 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 127 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 50 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-function-reflection.cpp | 94 |
9 files changed, 408 insertions, 12 deletions
diff --git a/docs/user-guide/09-reflection.md b/docs/user-guide/09-reflection.md index 8572dbe08..0c9311212 100644 --- a/docs/user-guide/09-reflection.md +++ b/docs/user-guide/09-reflection.md @@ -19,11 +19,6 @@ Note that just as with output code, the reflection object (and all other objects Unlike the other data, there is no easy way to save the reflection data for later user (we do not currently implement serialization for reflection data). Applications are encouraged to extract whatever information they need before destroying the compilation request. -For convenience (since the reflection API surface area is large), the Slang API provides a C++ wrapper interface around the reflection API, and this document will show code examples using those wrappers: - -```c++ -slang::ShaderReflection* shaderReflection = slang::ShaderReflection::get(request); -``` ## Program Reflection @@ -203,3 +198,30 @@ In the case of a compute shader entry point, you can also query the user-specifi SlangUInt threadGroupSize[3]; entryPoint->getComputeThreadGruopSize(3, &threadGroupSize[0]); ``` + +## Function Reflection + +The `slang::FunctionReflection` type provides methods to query information about a function, such as the return type, parameters and user-defined attributes. You can obtain a `FunctionReflection` object from an `IEntryPoint` with `IEntryPoint::getFunctionReflection`, which will provide more details on the entry point function. + +In addition to entry points, you can also query for ordinary functions with the `ShaderReflection::findFunctionByName` method: + +```c++ +auto funcReflection = program->getLayout()->findFunctionByName("ordinaryFunc"); + +// Get return type. +slang::TypeReflection* returnType = funcReflection->getReturnType(); + +// Get parameter count. +unsigned int paramCount = funcReflection->getParameterCount(); + +// Get Parameter. +slang::VariableReflection* param0 = funcReflection->getParameter(0); +const char* param0Name = param0->getName(); +slang::TypeReflection* param0Type = param0->getType(); + +// Get user defined attributes on the function. +unsigned int attribCount = funcReflection->getUserAttributeCount(); +slang::UserAttribute* attrib = funcReflection->getUserAttributeByIndex(0); +const char* attribName = attrib->getName(); + +```
\ No newline at end of file diff --git a/docs/user-guide/toc.html b/docs/user-guide/toc.html index 3443a95e8..2adb7655b 100644 --- a/docs/user-guide/toc.html +++ b/docs/user-guide/toc.html @@ -115,6 +115,7 @@ <li data-link="reflection#arrays"><span>Arrays</span></li> <li data-link="reflection#structures"><span>Structures</span></li> <li data-link="reflection#entry-points"><span>Entry Points</span></li> +<li data-link="reflection#function-reflection"><span>Function Reflection</span></li> </ul> </li> <li data-link="targets"><span>Supported Compilation Targets</span> @@ -2111,6 +2111,7 @@ extern "C" typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout; typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter; typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute; + typedef struct SlangReflectionFunction SlangReflectionFunction; /* Type aliases to maintain backward compatibility. @@ -2384,6 +2385,12 @@ extern "C" enum SlangModifierID : SlangModifierIDIntegral { SLANG_MODIFIER_SHARED, + SLANG_MODIFIER_NO_DIFF, + SLANG_MODIFIER_STATIC, + SLANG_MODIFIER_CONST, + SLANG_MODIFIER_EXPORT, + SLANG_MODIFIER_EXTERN, + SLANG_MODIFIER_DIFFERENTIABLE, }; // User Attribute @@ -2434,6 +2441,7 @@ extern "C" SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangReflectionType* type); SLANG_API char const* spReflectionType_GetName(SlangReflectionType* type); + SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* type, ISlangBlob** outNameBlob); // Type Layout Reflection @@ -2516,7 +2524,7 @@ extern "C" SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflectionVariable* var, SlangModifierID modifierID); SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* var); SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* var, unsigned int index); - SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * session, char const* name); + SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * globalSession, char const* name); // Variable Layout Reflection @@ -2530,6 +2538,17 @@ extern "C" SLANG_API char const* spReflectionVariableLayout_GetSemanticName(SlangReflectionVariableLayout* var); SLANG_API size_t spReflectionVariableLayout_GetSemanticIndex(SlangReflectionVariableLayout* var); + + // Function Reflection + + SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* func); + SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* func); + SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* func, unsigned int index); + SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* func, SlangSession* globalSession, char const* name); + SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* func); + SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* func, unsigned index); + SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func); + /** Get the stage that a variable belongs to (if any). A variable "belongs" to a specific stage when it is a varying input/output @@ -2567,6 +2586,9 @@ extern "C" SLANG_API char const* spReflectionEntryPoint_getNameOverride( SlangReflectionEntryPoint* entryPoint); + SLANG_API SlangReflectionFunction* spReflectionEntryPoint_getFunction( + SlangReflectionEntryPoint* entryPoint); + SLANG_API unsigned spReflectionEntryPoint_getParameterCount( SlangReflectionEntryPoint* entryPoint); @@ -2615,6 +2637,8 @@ extern "C" SLANG_API SlangReflectionType* spReflection_FindTypeByName(SlangReflection* reflection, char const* name); SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout(SlangReflection* reflection, SlangReflectionType* reflectionType, SlangLayoutRules rules); + SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name); + SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* reflection); SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* reflection, SlangUInt index); SLANG_API SlangReflectionEntryPoint* spReflection_findEntryPointByName(SlangReflection* reflection, char const* name); @@ -2830,6 +2854,11 @@ namespace slang return spReflectionType_GetName((SlangReflectionType*) this); } + SlangResult getFullName(ISlangBlob** outNameBlob) + { + return spReflectionType_GetFullName((SlangReflectionType*)this, outNameBlob); + } + unsigned int getUserAttributeCount() { return spReflectionType_GetUserAttributeCount((SlangReflectionType*)this); @@ -3255,6 +3284,12 @@ namespace slang enum ID : SlangModifierIDIntegral { Shared = SLANG_MODIFIER_SHARED, + NoDiff = SLANG_MODIFIER_NO_DIFF, + Static = SLANG_MODIFIER_STATIC, + Const = SLANG_MODIFIER_CONST, + Export = SLANG_MODIFIER_EXPORT, + Extern = SLANG_MODIFIER_EXTERN, + Differentiable = SLANG_MODIFIER_DIFFERENTIABLE, }; }; @@ -3283,9 +3318,9 @@ namespace slang { return (UserAttribute*)spReflectionVariable_GetUserAttribute((SlangReflectionVariable*)this, index); } - UserAttribute* findUserAttributeByName(SlangSession* session, char const* name) + UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name) { - return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, session, name); + return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, globalSession, name); } }; @@ -3373,6 +3408,47 @@ namespace slang } }; + struct FunctionReflection + { + char const* getName() + { + return spReflectionFunction_GetName((SlangReflectionFunction*)this); + } + + TypeReflection* getReturnType() + { + return (TypeReflection*)spReflectionFunction_GetResultType((SlangReflectionFunction*)this); + } + + unsigned int getParameterCount() + { + return spReflectionFunction_GetParameterCount((SlangReflectionFunction*)this); + } + + VariableReflection* getParameterByIndex(unsigned int index) + { + return (VariableReflection*)spReflectionFunction_GetParameter((SlangReflectionFunction*)this, index); + } + + unsigned int getUserAttributeCount() + { + return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this); + } + UserAttribute* getUserAttributeByIndex(unsigned int index) + { + return (UserAttribute*)spReflectionVariable_GetUserAttribute((SlangReflectionVariable*)this, index); + } + UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name) + { + return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, globalSession, name); + } + + Modifier* findModifier(Modifier::ID id) + { + return (Modifier*)spReflectionVariable_FindModifier((SlangReflectionVariable*)this, (SlangModifierID)id); + } + }; + struct EntryPointReflection { char const* getName() @@ -3390,6 +3466,11 @@ namespace slang return spReflectionEntryPoint_getParameterCount((SlangReflectionEntryPoint*) this); } + FunctionReflection* getFunction() + { + return (FunctionReflection*)spReflectionEntryPoint_getFunction((SlangReflectionEntryPoint*) this); + } + VariableLayoutReflection* getParameterByIndex(unsigned index) { return (VariableLayoutReflection*) spReflectionEntryPoint_getParameterByIndex((SlangReflectionEntryPoint*) this, index); @@ -3438,6 +3519,7 @@ namespace slang return spReflectionEntryPoint_hasDefaultConstantBuffer((SlangReflectionEntryPoint*) this) != 0; } }; + typedef EntryPointReflection EntryPointLayout; struct TypeParameterReflection @@ -3531,6 +3613,13 @@ namespace slang name); } + FunctionReflection* findFunctionByName(const char* name) + { + return (FunctionReflection*)spReflection_FindFunctionByName( + (SlangReflection*) this, + name); + } + TypeLayoutReflection* getTypeLayout( TypeReflection* type, LayoutRules rules = LayoutRules::Default) @@ -4966,6 +5055,8 @@ namespace slang struct IEntryPoint : public IComponentType { SLANG_COM_INTERFACE(0x8f241361, 0xf5bd, 0x4ca0, { 0xa3, 0xac, 0x2, 0xf7, 0xfa, 0x24, 0x2, 0xb8 }) + + virtual SLANG_NO_THROW FunctionReflection* SLANG_MCALL getFunctionReflection() = 0; }; #define SLANG_UUID_IEntryPoint IEntryPoint::getTypeGuid() diff --git a/source/slang-capture-replay/slang-entrypoint.cpp b/source/slang-capture-replay/slang-entrypoint.cpp index c8ce7a663..46d701f98 100644 --- a/source/slang-capture-replay/slang-entrypoint.cpp +++ b/source/slang-capture-replay/slang-entrypoint.cpp @@ -304,4 +304,10 @@ namespace SlangCapture return res; } + + SLANG_NO_THROW slang::FunctionReflection* EntryPointCapture::getFunctionReflection() + { + return m_actualEntryPoint->getFunctionReflection(); + } + } diff --git a/source/slang-capture-replay/slang-entrypoint.h b/source/slang-capture-replay/slang-entrypoint.h index 87b3cc430..06a577069 100644 --- a/source/slang-capture-replay/slang-entrypoint.h +++ b/source/slang-capture-replay/slang-entrypoint.h @@ -66,7 +66,7 @@ namespace SlangCapture uint32_t compilerOptionEntryCount, slang::CompilerOptionEntry* compilerOptionEntries, ISlangBlob** outDiagnostics = nullptr) override; - + virtual SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() override; slang::IEntryPoint* getActualEntryPoint() const { return m_actualEntryPoint; } private: Slang::ComPtr<slang::IEntryPoint> m_actualEntryPoint; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 46724119b..e9660f7ed 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -416,6 +416,10 @@ namespace Slang String const& typeStr, DiagnosticSink* sink); + DeclRef<Decl> findDeclFromString( + String const& name, + DiagnosticSink* sink); + Dictionary<String, IntVal*>& getMangledNameToIntValMap(); ConstantIntVal* tryFoldIntVal(IntVal* intVal); @@ -560,6 +564,9 @@ namespace Slang // Dictionary<String, Type*> m_types; + // Any decls looked up dynamically using `findDeclFromString`. + Dictionary<String, DeclRef<Decl>> m_decls; + Scope* m_lookupScope = nullptr; std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal; }; @@ -1042,6 +1049,10 @@ namespace Slang List<ExpandedSpecializationArg> existentialSpecializationArgs; }; + SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() SLANG_OVERRIDE + { + return (slang::FunctionReflection*)m_funcDeclRef.getDecl(); + } protected: void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index e7f4b9bf3..0642caeb6 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -65,6 +65,16 @@ static inline SlangReflectionVariable* convert(Decl* var) return (SlangReflectionVariable*) var; } +static inline FunctionDeclBase* convert(SlangReflectionFunction* func) +{ + return (FunctionDeclBase*)func; +} + +static inline SlangReflectionFunction* convert(FunctionDeclBase* func) +{ + return (SlangReflectionFunction*)func; +} + static inline VarLayout* convert(SlangReflectionVariableLayout* var) { return (VarLayout*) var; @@ -729,13 +739,48 @@ SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType) auto decl = declRef.getDecl(); if(decl->hasModifier<ImplicitParameterGroupElementTypeModifier>()) return nullptr; - return getText(declRef.getName()).begin(); } return nullptr; } +SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* inType, ISlangBlob** outNameBlob) +{ + auto type = convert(inType); + + if (!type) return SLANG_FAIL; + + StringBuilder sb; + type->toText(sb); + *outNameBlob = StringUtil::createStringBlob(sb.produceString()).detach(); + return SLANG_OK; +} + +SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name) +{ + auto programLayout = convert(reflection); + auto program = programLayout->getProgram(); + + // TODO: We should extend this API to support getting error messages + // when type lookup fails. + // + Slang::DiagnosticSink sink( + programLayout->getTargetReq()->getLinkage()->getSourceManager(), + Lexer::sourceLocationLexer); + + try + { + auto result = program->findDeclFromString(name, &sink); + if (auto funcDeclRef = result.as<FunctionDeclBase>()) + return (SlangReflectionFunction*)funcDeclRef.getDecl(); + } + catch (...) + { + } + return nullptr; +} + SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * reflection, char const * name) { auto programLayout = convert(reflection); @@ -2558,7 +2603,24 @@ SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflec case SLANG_MODIFIER_SHARED: modifier = var->findModifier<HLSLEffectSharedModifier>(); break; - + case SLANG_MODIFIER_CONST: + modifier = var->findModifier<ConstModifier>(); + break; + case SLANG_MODIFIER_NO_DIFF: + modifier = var->findModifier<NoDiffModifier>(); + break; + case SLANG_MODIFIER_STATIC: + modifier = var->findModifier<HLSLStaticModifier>(); + break; + case SLANG_MODIFIER_EXPORT: + modifier = var->findModifier<HLSLExportModifier>(); + break; + case SLANG_MODIFIER_EXTERN: + modifier = var->findModifier<ExternModifier>(); + break; + case SLANG_MODIFIER_DIFFERENTIABLE: + modifier = var->findModifier<DifferentiableAttribute>(); + break; default: return nullptr; } @@ -2729,6 +2791,57 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage( return (SlangStage) varLayout->stage; } +// Function Reflection + +SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc) +{ + auto func = convert(inFunc); + if (!func) return nullptr; + return getText(func->getName()).getBuffer(); +} + +SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc) +{ + auto func = convert(inFunc); + if (!func) return nullptr; + + return convert(func->returnType.type); +} + +SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc) +{ + auto func = convert(inFunc); + if (!func) return 0; + return getUserAttributeCount(func); +} + +SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index) +{ + auto func = convert(inFunc); + if (!func) return nullptr; + return getUserAttributeByIndex(func, index); +} + +SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name) +{ + auto func = convert(inFunc); + if (!func) return nullptr; + return findUserAttributeByName(asInternal(session), func, name); +} + +SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc) +{ + auto func = convert(inFunc); + if (!func) return 0; + return (unsigned int)func->getParameters().getCount(); +} + +SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index) +{ + auto func = convert(inFunc); + if (!func) return nullptr; + return convert(as<Decl>(func->getParameters()[index])); +} // Shader Parameter Reflection @@ -2788,6 +2901,16 @@ SLANG_API char const* spReflectionEntryPoint_getNameOverride(SlangReflectionEntr return nullptr; } +SLANG_API SlangReflectionFunction* spReflectionEntryPoint_GetFunction(SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + if (entryPointLayout) + { + return (SlangReflectionFunction*)entryPointLayout->entryPoint.getDecl(); + } + return nullptr; +} + SLANG_API unsigned spReflectionEntryPoint_getParameterCount( SlangReflectionEntryPoint* inEntryPoint) { diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index adfa031f5..d601dcd0e 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -2229,6 +2229,55 @@ Type* ComponentType::getTypeFromString( return type; } +DeclRef<Decl> ComponentType::findDeclFromString( + String const& name, + DiagnosticSink* sink) +{ + // If we've looked up this type name before, + // then we can re-use it. + // + DeclRef<Decl> result; + if (m_decls.tryGetValue(name, result)) + return result; + + + // TODO(JS): For now just used the linkages ASTBuilder to keep on scope + // + // The parseTermString uses the linkage ASTBuilder for it's parsing. + // + // It might be possible to just create a temporary ASTBuilder - the worry though is + // that the parsing sets a member variable in AST node to one of these scopes, and then + // it become a dangling pointer. So for now we go with the linkages. + auto astBuilder = getLinkage()->getASTBuilder(); + + // Otherwise, we need to start looking in + // the modules that were directly or + // indirectly referenced. + // + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + + auto linkage = getLinkage(); + + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + + Expr* expr = linkage->parseTermString(name, scope); + + SharedSemanticsContext sharedSemanticsContext( + linkage, + nullptr, + sink); + SemanticsVisitor visitor(&sharedSemanticsContext); + + auto checkedExpr = visitor.CheckExpr(expr); + if (auto declRefExpr = as<DeclRefExpr>(checkedExpr)) + { + result = declRefExpr->declRef; + } + + m_decls[name] = result; + return result; +} + static void collectExportedConstantInContainer( Dictionary<String, IntVal*>& dict, ASTBuilder* builder, @@ -4075,7 +4124,6 @@ RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name) return nullptr; } - RefPtr<EntryPoint> Module::findAndCheckEntryPoint( UnownedStringSlice const& name, SlangStage stage, diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp new file mode 100644 index 000000000..ddbfa439b --- /dev/null +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -0,0 +1,94 @@ +// unit-test-translation-unit-import.cpp + +#include "../../slang.h" + +#include <stdio.h> +#include <stdlib.h> + +#include "tools/unit-test/slang-unit-test.h" +#include "../../slang-com-ptr.h" +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" + +using namespace Slang; + +static String getTypeFullName(slang::TypeReflection* type) +{ + ComPtr<ISlangBlob> blob; + type->getFullName(blob.writeRef()); + return String((const char*)blob->getBufferPointer()); +} + +// Test that the reflection API provides correct info about entry point and ordinary functions. + +SLANG_UNIT_TEST(functionReflection) +{ + // Source for a module that contains an undecorated entrypoint. + const char* userSourceBody = R"( + [__AttributeUsage(_AttributeTargets.Function)] + struct MyFuncPropertyAttribute {int v;} + + [MyFuncProperty(1024)] + [Differentiable] + float ordinaryFunc(no_diff float x, int y) { return x + y; } + + float4 fragMain(float4 pos:SV_Position) : SV_Position + { + return pos; + } + )"; + + auto moduleName = "moduleG" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(entryPoint != nullptr); + + auto entryPointFuncReflection = entryPoint->getFunctionReflection(); + SLANG_CHECK(entryPointFuncReflection != nullptr); + SLANG_CHECK(UnownedStringSlice(entryPointFuncReflection->getName()) == "fragMain"); + SLANG_CHECK(entryPointFuncReflection->getParameterCount() == 1); + SLANG_CHECK(UnownedStringSlice(entryPointFuncReflection->getParameterByIndex(0)->getName()) == "pos"); + SLANG_CHECK(getTypeFullName(entryPointFuncReflection->getParameterByIndex(0)->getType()) == "vector<float,4>"); + + auto funcReflection = module->getLayout()->findFunctionByName("ordinaryFunc"); + SLANG_CHECK(funcReflection != nullptr); + + SLANG_CHECK(funcReflection->findModifier(slang::Modifier::Differentiable) != nullptr); + SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "float"); + SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "ordinaryFunc"); + SLANG_CHECK(funcReflection->getParameterCount() == 2); + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float"); + SLANG_CHECK(funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) != nullptr); + + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "int"); + + SLANG_CHECK(funcReflection->getUserAttributeCount() == 1); + auto userAttribute = funcReflection->getUserAttributeByIndex(0); + SLANG_CHECK(UnownedStringSlice(userAttribute->getName()) == "MyFuncProperty"); + SLANG_CHECK(userAttribute->getArgumentCount() == 1); + SLANG_CHECK(getTypeFullName(userAttribute->getArgumentType(0)) == "int"); + int val = 0; + auto result = userAttribute->getArgumentValueInt(0, &val); + SLANG_CHECK(result == SLANG_OK); + SLANG_CHECK(val == 1024); + SLANG_CHECK(funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); +} + |
