diff options
| author | Yong He <yonghe@outlook.com> | 2024-03-11 14:42:14 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-11 14:42:14 -0700 |
| commit | 1bbcf25af514a9ae24f7006747177f2d1b3b7c0d (patch) | |
| tree | f42c17d32040d033742e741548e7b73ff24a5e92 /source | |
| parent | 25a7d51445e64253beca5c4f70ddd52f40226b1d (diff) | |
Link-time specialization fixes. (#3734)
* Fix method synthesis logic for static differentiable methods.
* Support link-time constants in thread group size reflection.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-decl.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 32 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 50 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 8 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-type-layout.h | 2 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 72 |
10 files changed, 169 insertions, 27 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 61e1b751f..f7f537ed6 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -95,6 +95,9 @@ class VarDeclBase : public Decl // Initializer expression (optional) Expr* initExpr = nullptr; + + // Folded IntVal if the initializer is a constant integer. + IntVal* val = nullptr; }; // Ordinary potentially-mutable variables (locals, globals, and member variables) diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index b2b874fad..0dbe65ee0 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -7,6 +7,7 @@ #include "slang-diagnostics.h" #include "slang-syntax.h" #include "slang-ast-val.h" +#include "slang-mangle.h" namespace Slang { @@ -234,6 +235,15 @@ bool GenericParamIntVal::_isLinkTimeValOverride() return getDeclRef().getDecl()->hasModifier<ExternModifier>(); } +Val* GenericParamIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map) +{ + auto name = getMangledName(getCurrentASTBuilder(), getDeclRef().declRefBase); + IntVal* v; + if (map.tryGetValue(name, v)) + return v; + return this; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ErrorIntVal::_toTextOverride(StringBuilder& out) @@ -1088,6 +1098,15 @@ Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* return nullptr; } +Val* TypeCastIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map) +{ + auto intValBase = as<IntVal>(getBase()); + if (!intValBase) + return this; + auto resolvedBase = intValBase->linkTimeResolve(map); + return tryFoldImpl(getCurrentASTBuilder(), getType(), resolvedBase, nullptr); +} + Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; @@ -1310,6 +1329,14 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR return nullptr; } +Val* FuncCallIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map) +{ + List<IntVal*> newArgs; + for (auto arg : getArgs()) + newArgs.add(as<IntVal>(arg->linkTimeResolve(map))); + return tryFoldImpl(getCurrentASTBuilder(), getType(), getFuncDeclRef(), newArgs, nullptr); +} + Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; @@ -1506,4 +1533,9 @@ bool IntVal::isLinkTimeVal() SLANG_AST_NODE_VIRTUAL_CALL(IntVal, isLinkTimeVal, ()); } +Val* IntVal::linkTimeResolve(Dictionary<String, IntVal*>& mapMangledNameToVal) +{ + SLANG_AST_NODE_VIRTUAL_CALL(IntVal, linkTimeResolve, (mapMangledNameToVal)); +} + } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index ce494b9da..f94cafbda 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -144,6 +144,8 @@ class IntVal : public Val bool isLinkTimeVal(); bool _isLinkTimeValOverride() { return false; } + Val* linkTimeResolve(Dictionary<String, IntVal*>& mapMangledNameToVal); + Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>&) { return this; } }; // Trivial case of a value that is just a constant integer @@ -180,6 +182,7 @@ class GenericParamIntVal : public IntVal } bool _isLinkTimeValOverride(); + Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); }; class TypeCastIntVal : public IntVal @@ -204,6 +207,9 @@ class TypeCastIntVal : public IntVal return intBase->isLinkTimeVal(); return false; } + + Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); + }; // An compile time int val as result of some general computation. @@ -238,6 +244,8 @@ class FuncCallIntVal : public IntVal } return false; } + + Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); }; class WitnessLookupIntVal : public IntVal diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 8dee7b0c5..39f7de89a 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1538,7 +1538,6 @@ namespace Slang varDecl->initExpr = initExpr; varDecl->type.type = initExpr->type; - _validateCircularVarDefinition(varDecl); } @@ -1602,6 +1601,19 @@ namespace Slang } } + if (varDecl->initExpr) + { + if (as<BasicExpressionType>(varDecl->type.type)) + { + auto parentDecl = getParentDecl(varDecl); + if (varDecl->findModifier<ConstModifier>() && + (as<NamespaceDeclBase>(parentDecl) || as<FileDecl>(parentDecl) || varDecl->findModifier<HLSLStaticModifier>())) + { + varDecl->val = tryConstantFoldExpr(varDecl->initExpr, ConstantFoldingKind::LinkTime, nullptr); + } + } + } + checkMeshOutputDecl(varDecl); // The NVAPI library allows user code to express extended operations @@ -3559,24 +3571,24 @@ namespace Slang auto noDiffThisAttr = m_astBuilder->create<NoDiffThisAttribute>(); addModifier(synFuncDecl, noDiffThisAttr); } - if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()) - { - auto attr = m_astBuilder->create<ForwardDifferentiableAttribute>(); - addModifier(synFuncDecl, attr); - } - if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()) - { - auto attr = m_astBuilder->create<BackwardDifferentiableAttribute>(); - addModifier(synFuncDecl, attr); - } - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requiredMemberDeclRef.getDecl()->findModifier<VisibilityModifier>()) - { - auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synFuncDecl, visibility); - } + } + if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()) + { + auto attr = m_astBuilder->create<ForwardDifferentiableAttribute>(); + addModifier(synFuncDecl, attr); + } + if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()) + { + auto attr = m_astBuilder->create<BackwardDifferentiableAttribute>(); + addModifier(synFuncDecl, attr); + } + // The visibility of synthesized decl should be the min of the parent decl and the requirement. + if (requiredMemberDeclRef.getDecl()->findModifier<VisibilityModifier>()) + { + auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); + auto thisVisibility = getDeclVisibility(context->parentDecl); + auto visibility = Math::Min(thisVisibility, requirementVisibility); + addVisibilityModifier(m_astBuilder, synFuncDecl, visibility); } return synFuncDecl; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 1aa93d019..c588a9018 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -1305,7 +1305,7 @@ namespace Slang sink); } - Scope* ComponentType::_createScopeForLegacyLookup(ASTBuilder* astBuilder) + Scope* ComponentType::_getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder) { // The shape of this logic is dictated by the legacy // behavior for name-based lookup/parsing of types @@ -1316,6 +1316,8 @@ namespace Slang // definitions (that scope is necessary because // it defines keywords like `true` and `false`). // + if (m_lookupScope) + return m_lookupScope; Scope* scope = astBuilder->create<Scope>(); scope->parent = getLinkage()->getSessionImpl()->slangLanguageScope; @@ -1338,7 +1340,7 @@ namespace Slang scope->nextSibling = moduleScope; } } - + m_lookupScope = scope; return scope; } @@ -1359,7 +1361,7 @@ namespace Slang // We create the scopes on the linkages ASTBuilder. We might want to create a temporary ASTBuilder, // and let that memory get freed, but is like this because it's not clear if the scopes in ASTNode members // will dangle if we do. - Scope* scope = unspecialiedProgram->_createScopeForLegacyLookup(endToEndReq->getLinkage()->getASTBuilder()); + Scope* scope = unspecialiedProgram->_getOrCreateScopeForLegacyLookup(endToEndReq->getLinkage()->getASTBuilder()); // We are going to do some semantic checking, so we need to // set up a `SemanticsVistitor` that we can use. diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b0350d618..c23eddfde 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -404,6 +404,9 @@ namespace Slang String const& typeStr, DiagnosticSink* sink); + Dictionary<String, IntVal*>& getMangledNameToIntValMap(); + ConstantIntVal* tryFoldIntVal(IntVal* intVal); + /// Get a list of modules that this component type depends on. /// virtual List<Module*> const& getModuleDependencies() = 0; @@ -526,7 +529,7 @@ namespace Slang /// This facility is only needed to support legacy APIs for string-based lookup /// and parsing via Slang reflection, and is not recommended for future APIs to use. /// - Scope* _createScopeForLegacyLookup(ASTBuilder* astBuilder); + Scope* _getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder); protected: ComponentType(Linkage* linkage); @@ -544,6 +547,9 @@ namespace Slang // TODO: Remove this. Type lookup should only be supported on `Module`s. // Dictionary<String, Type*> m_types; + + Scope* m_lookupScope = nullptr; + std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal; }; /// A component type built up from other component types. diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index c2f8b2d4d..267f23e6c 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2718,6 +2718,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters( auto entryPointType = DeclRefType::create(astBuilder, entryPointFuncDeclRef); entryPointLayout->entryPoint = entryPointFuncDeclRef; + entryPointLayout->program = context->getTargetProgram()->getProgram(); // For the duration of our parameter collection work we will // establish this entry point as the current one in the context. diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 7af3ce0a3..d91dd5858 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -2811,12 +2811,18 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier<NumThreadsAttribute>(); if (numThreadsAttribute) { - if (auto cint = as<ConstantIntVal>(numThreadsAttribute->x)) + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x)) sizeAlongAxis[0] = (SlangUInt)cint->getValue(); - if (auto cint = as<ConstantIntVal>(numThreadsAttribute->y)) + else if (numThreadsAttribute->x) + sizeAlongAxis[0] = 0; + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y)) sizeAlongAxis[1] = (SlangUInt)cint->getValue(); - if (auto cint = as<ConstantIntVal>(numThreadsAttribute->z)) + else if (numThreadsAttribute->y) + sizeAlongAxis[1] = 0; + if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z)) sizeAlongAxis[2] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->z) + sizeAlongAxis[2] = 0; } // diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index f11ee342e..c17f2ebb4 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -793,6 +793,8 @@ public: // The corresponding function declaration DeclRef<FuncDecl> entryPoint; + ComponentType* program = nullptr; + DeclRef<FuncDecl> getFuncDeclRef() { return entryPoint; } FuncDecl* getFuncDecl() { return entryPoint.getDecl(); } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index fe3f8dfa5..69c0f0e14 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -2137,7 +2137,7 @@ Type* ComponentType::getTypeFromString( // the modules that were directly or // indirectly referenced. // - Scope* scope = _createScopeForLegacyLookup(astBuilder); + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); auto linkage = getLinkage(); @@ -2154,6 +2154,76 @@ Type* ComponentType::getTypeFromString( return type; } +static void collectExportedConstantInContainer( + Dictionary<String, IntVal*>& dict, + ASTBuilder* builder, + ContainerDecl* containerDecl) +{ + for (auto m : containerDecl->members) + { + auto varMember = as<VarDeclBase>(m); + if (!varMember) + continue; + if (!varMember->val) + continue; + bool isExported = false; + bool isConst = true; + bool isExtern = false; + for (auto modifier : m->modifiers) + { + if (as<HLSLExportModifier>(modifier)) + isExported = true; + if (as<ExternAttribute>(modifier) || as<ExternModifier>(modifier)) + { + isExtern = true; + isExported = true; + } + if (as<ConstModifier>(modifier)) + isConst = true; + if (isExported && isConst) + break; + } + if (isExported && isConst) + { + auto mangledName = getMangledName(builder, m); + if (isExtern && dict.containsKey(mangledName)) + continue; + dict[mangledName] = varMember->val; + } + } + + for (auto member : containerDecl->members) + { + if (as<NamespaceDecl>(member) || as<FileDecl>(member)) + { + collectExportedConstantInContainer(dict, builder, (ContainerDecl*)member); + } + } +} + +Dictionary<String, IntVal*>& ComponentType::getMangledNameToIntValMap() +{ + if (m_mapMangledNameToIntVal) + { + return *m_mapMangledNameToIntVal; + } + m_mapMangledNameToIntVal = std::make_unique<Dictionary<String, IntVal*>>(); + auto astBuilder = getLinkage()->getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + for (; scope; scope = scope->nextSibling) + { + if (scope->containerDecl) + collectExportedConstantInContainer(*m_mapMangledNameToIntVal, astBuilder, scope->containerDecl); + } + return *m_mapMangledNameToIntVal; +} + +ConstantIntVal* ComponentType::tryFoldIntVal(IntVal* intVal) +{ + return as<ConstantIntVal>(intVal->linkTimeResolve(getMangledNameToIntValMap())); +} + CompileRequestBase::CompileRequestBase( Linkage* linkage, DiagnosticSink* sink) |
