summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-11 14:42:14 -0700
committerGitHub <noreply@github.com>2024-03-11 14:42:14 -0700
commit1bbcf25af514a9ae24f7006747177f2d1b3b7c0d (patch)
treef42c17d32040d033742e741548e7b73ff24a5e92 /source
parent25a7d51445e64253beca5c4f70ddd52f40226b1d (diff)
Link-time specialization fixes. (#3734)
* Fix method synthesis logic for static differentiable methods. * Support link-time constants in thread group size reflection.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-decl.h3
-rw-r--r--source/slang/slang-ast-val.cpp32
-rw-r--r--source/slang/slang-ast-val.h8
-rw-r--r--source/slang/slang-check-decl.cpp50
-rw-r--r--source/slang/slang-check-shader.cpp8
-rwxr-xr-xsource/slang/slang-compiler.h8
-rw-r--r--source/slang/slang-parameter-binding.cpp1
-rw-r--r--source/slang/slang-reflection-api.cpp12
-rw-r--r--source/slang/slang-type-layout.h2
-rw-r--r--source/slang/slang.cpp72
10 files changed, 169 insertions, 27 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 61e1b751f..f7f537ed6 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -95,6 +95,9 @@ class VarDeclBase : public Decl
// Initializer expression (optional)
Expr* initExpr = nullptr;
+
+ // Folded IntVal if the initializer is a constant integer.
+ IntVal* val = nullptr;
};
// Ordinary potentially-mutable variables (locals, globals, and member variables)
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index b2b874fad..0dbe65ee0 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -7,6 +7,7 @@
#include "slang-diagnostics.h"
#include "slang-syntax.h"
#include "slang-ast-val.h"
+#include "slang-mangle.h"
namespace Slang {
@@ -234,6 +235,15 @@ bool GenericParamIntVal::_isLinkTimeValOverride()
return getDeclRef().getDecl()->hasModifier<ExternModifier>();
}
+Val* GenericParamIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map)
+{
+ auto name = getMangledName(getCurrentASTBuilder(), getDeclRef().declRefBase);
+ IntVal* v;
+ if (map.tryGetValue(name, v))
+ return v;
+ return this;
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void ErrorIntVal::_toTextOverride(StringBuilder& out)
@@ -1088,6 +1098,15 @@ Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val*
return nullptr;
}
+Val* TypeCastIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map)
+{
+ auto intValBase = as<IntVal>(getBase());
+ if (!intValBase)
+ return this;
+ auto resolvedBase = intValBase->linkTimeResolve(map);
+ return tryFoldImpl(getCurrentASTBuilder(), getType(), resolvedBase, nullptr);
+}
+
Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
@@ -1310,6 +1329,14 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
return nullptr;
}
+Val* FuncCallIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map)
+{
+ List<IntVal*> newArgs;
+ for (auto arg : getArgs())
+ newArgs.add(as<IntVal>(arg->linkTimeResolve(map)));
+ return tryFoldImpl(getCurrentASTBuilder(), getType(), getFuncDeclRef(), newArgs, nullptr);
+}
+
Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
@@ -1506,4 +1533,9 @@ bool IntVal::isLinkTimeVal()
SLANG_AST_NODE_VIRTUAL_CALL(IntVal, isLinkTimeVal, ());
}
+Val* IntVal::linkTimeResolve(Dictionary<String, IntVal*>& mapMangledNameToVal)
+{
+ SLANG_AST_NODE_VIRTUAL_CALL(IntVal, linkTimeResolve, (mapMangledNameToVal));
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index ce494b9da..f94cafbda 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -144,6 +144,8 @@ class IntVal : public Val
bool isLinkTimeVal();
bool _isLinkTimeValOverride() { return false; }
+ Val* linkTimeResolve(Dictionary<String, IntVal*>& mapMangledNameToVal);
+ Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>&) { return this; }
};
// Trivial case of a value that is just a constant integer
@@ -180,6 +182,7 @@ class GenericParamIntVal : public IntVal
}
bool _isLinkTimeValOverride();
+ Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map);
};
class TypeCastIntVal : public IntVal
@@ -204,6 +207,9 @@ class TypeCastIntVal : public IntVal
return intBase->isLinkTimeVal();
return false;
}
+
+ Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map);
+
};
// An compile time int val as result of some general computation.
@@ -238,6 +244,8 @@ class FuncCallIntVal : public IntVal
}
return false;
}
+
+ Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map);
};
class WitnessLookupIntVal : public IntVal
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 8dee7b0c5..39f7de89a 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1538,7 +1538,6 @@ namespace Slang
varDecl->initExpr = initExpr;
varDecl->type.type = initExpr->type;
-
_validateCircularVarDefinition(varDecl);
}
@@ -1602,6 +1601,19 @@ namespace Slang
}
}
+ if (varDecl->initExpr)
+ {
+ if (as<BasicExpressionType>(varDecl->type.type))
+ {
+ auto parentDecl = getParentDecl(varDecl);
+ if (varDecl->findModifier<ConstModifier>() &&
+ (as<NamespaceDeclBase>(parentDecl) || as<FileDecl>(parentDecl) || varDecl->findModifier<HLSLStaticModifier>()))
+ {
+ varDecl->val = tryConstantFoldExpr(varDecl->initExpr, ConstantFoldingKind::LinkTime, nullptr);
+ }
+ }
+ }
+
checkMeshOutputDecl(varDecl);
// The NVAPI library allows user code to express extended operations
@@ -3559,24 +3571,24 @@ namespace Slang
auto noDiffThisAttr = m_astBuilder->create<NoDiffThisAttribute>();
addModifier(synFuncDecl, noDiffThisAttr);
}
- if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>())
- {
- auto attr = m_astBuilder->create<ForwardDifferentiableAttribute>();
- addModifier(synFuncDecl, attr);
- }
- if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>())
- {
- auto attr = m_astBuilder->create<BackwardDifferentiableAttribute>();
- addModifier(synFuncDecl, attr);
- }
- // The visibility of synthesized decl should be the min of the parent decl and the requirement.
- if (requiredMemberDeclRef.getDecl()->findModifier<VisibilityModifier>())
- {
- auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl());
- auto thisVisibility = getDeclVisibility(context->parentDecl);
- auto visibility = Math::Min(thisVisibility, requirementVisibility);
- addVisibilityModifier(m_astBuilder, synFuncDecl, visibility);
- }
+ }
+ if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>())
+ {
+ auto attr = m_astBuilder->create<ForwardDifferentiableAttribute>();
+ addModifier(synFuncDecl, attr);
+ }
+ if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>())
+ {
+ auto attr = m_astBuilder->create<BackwardDifferentiableAttribute>();
+ addModifier(synFuncDecl, attr);
+ }
+ // The visibility of synthesized decl should be the min of the parent decl and the requirement.
+ if (requiredMemberDeclRef.getDecl()->findModifier<VisibilityModifier>())
+ {
+ auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl());
+ auto thisVisibility = getDeclVisibility(context->parentDecl);
+ auto visibility = Math::Min(thisVisibility, requirementVisibility);
+ addVisibilityModifier(m_astBuilder, synFuncDecl, visibility);
}
return synFuncDecl;
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 1aa93d019..c588a9018 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -1305,7 +1305,7 @@ namespace Slang
sink);
}
- Scope* ComponentType::_createScopeForLegacyLookup(ASTBuilder* astBuilder)
+ Scope* ComponentType::_getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder)
{
// The shape of this logic is dictated by the legacy
// behavior for name-based lookup/parsing of types
@@ -1316,6 +1316,8 @@ namespace Slang
// definitions (that scope is necessary because
// it defines keywords like `true` and `false`).
//
+ if (m_lookupScope)
+ return m_lookupScope;
Scope* scope = astBuilder->create<Scope>();
scope->parent = getLinkage()->getSessionImpl()->slangLanguageScope;
@@ -1338,7 +1340,7 @@ namespace Slang
scope->nextSibling = moduleScope;
}
}
-
+ m_lookupScope = scope;
return scope;
}
@@ -1359,7 +1361,7 @@ namespace Slang
// We create the scopes on the linkages ASTBuilder. We might want to create a temporary ASTBuilder,
// and let that memory get freed, but is like this because it's not clear if the scopes in ASTNode members
// will dangle if we do.
- Scope* scope = unspecialiedProgram->_createScopeForLegacyLookup(endToEndReq->getLinkage()->getASTBuilder());
+ Scope* scope = unspecialiedProgram->_getOrCreateScopeForLegacyLookup(endToEndReq->getLinkage()->getASTBuilder());
// We are going to do some semantic checking, so we need to
// set up a `SemanticsVistitor` that we can use.
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index b0350d618..c23eddfde 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -404,6 +404,9 @@ namespace Slang
String const& typeStr,
DiagnosticSink* sink);
+ Dictionary<String, IntVal*>& getMangledNameToIntValMap();
+ ConstantIntVal* tryFoldIntVal(IntVal* intVal);
+
/// Get a list of modules that this component type depends on.
///
virtual List<Module*> const& getModuleDependencies() = 0;
@@ -526,7 +529,7 @@ namespace Slang
/// This facility is only needed to support legacy APIs for string-based lookup
/// and parsing via Slang reflection, and is not recommended for future APIs to use.
///
- Scope* _createScopeForLegacyLookup(ASTBuilder* astBuilder);
+ Scope* _getOrCreateScopeForLegacyLookup(ASTBuilder* astBuilder);
protected:
ComponentType(Linkage* linkage);
@@ -544,6 +547,9 @@ namespace Slang
// TODO: Remove this. Type lookup should only be supported on `Module`s.
//
Dictionary<String, Type*> m_types;
+
+ Scope* m_lookupScope = nullptr;
+ std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal;
};
/// A component type built up from other component types.
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index c2f8b2d4d..267f23e6c 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -2718,6 +2718,7 @@ static RefPtr<EntryPointLayout> collectEntryPointParameters(
auto entryPointType = DeclRefType::create(astBuilder, entryPointFuncDeclRef);
entryPointLayout->entryPoint = entryPointFuncDeclRef;
+ entryPointLayout->program = context->getTargetProgram()->getProgram();
// For the duration of our parameter collection work we will
// establish this entry point as the current one in the context.
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 7af3ce0a3..d91dd5858 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -2811,12 +2811,18 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier<NumThreadsAttribute>();
if (numThreadsAttribute)
{
- if (auto cint = as<ConstantIntVal>(numThreadsAttribute->x))
+ if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x))
sizeAlongAxis[0] = (SlangUInt)cint->getValue();
- if (auto cint = as<ConstantIntVal>(numThreadsAttribute->y))
+ else if (numThreadsAttribute->x)
+ sizeAlongAxis[0] = 0;
+ if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y))
sizeAlongAxis[1] = (SlangUInt)cint->getValue();
- if (auto cint = as<ConstantIntVal>(numThreadsAttribute->z))
+ else if (numThreadsAttribute->y)
+ sizeAlongAxis[1] = 0;
+ if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z))
sizeAlongAxis[2] = (SlangUInt)cint->getValue();
+ else if (numThreadsAttribute->z)
+ sizeAlongAxis[2] = 0;
}
//
diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h
index f11ee342e..c17f2ebb4 100644
--- a/source/slang/slang-type-layout.h
+++ b/source/slang/slang-type-layout.h
@@ -793,6 +793,8 @@ public:
// The corresponding function declaration
DeclRef<FuncDecl> entryPoint;
+ ComponentType* program = nullptr;
+
DeclRef<FuncDecl> getFuncDeclRef() { return entryPoint; }
FuncDecl* getFuncDecl() { return entryPoint.getDecl(); }
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index fe3f8dfa5..69c0f0e14 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -2137,7 +2137,7 @@ Type* ComponentType::getTypeFromString(
// the modules that were directly or
// indirectly referenced.
//
- Scope* scope = _createScopeForLegacyLookup(astBuilder);
+ Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder);
auto linkage = getLinkage();
@@ -2154,6 +2154,76 @@ Type* ComponentType::getTypeFromString(
return type;
}
+static void collectExportedConstantInContainer(
+ Dictionary<String, IntVal*>& dict,
+ ASTBuilder* builder,
+ ContainerDecl* containerDecl)
+{
+ for (auto m : containerDecl->members)
+ {
+ auto varMember = as<VarDeclBase>(m);
+ if (!varMember)
+ continue;
+ if (!varMember->val)
+ continue;
+ bool isExported = false;
+ bool isConst = true;
+ bool isExtern = false;
+ for (auto modifier : m->modifiers)
+ {
+ if (as<HLSLExportModifier>(modifier))
+ isExported = true;
+ if (as<ExternAttribute>(modifier) || as<ExternModifier>(modifier))
+ {
+ isExtern = true;
+ isExported = true;
+ }
+ if (as<ConstModifier>(modifier))
+ isConst = true;
+ if (isExported && isConst)
+ break;
+ }
+ if (isExported && isConst)
+ {
+ auto mangledName = getMangledName(builder, m);
+ if (isExtern && dict.containsKey(mangledName))
+ continue;
+ dict[mangledName] = varMember->val;
+ }
+ }
+
+ for (auto member : containerDecl->members)
+ {
+ if (as<NamespaceDecl>(member) || as<FileDecl>(member))
+ {
+ collectExportedConstantInContainer(dict, builder, (ContainerDecl*)member);
+ }
+ }
+}
+
+Dictionary<String, IntVal*>& ComponentType::getMangledNameToIntValMap()
+{
+ if (m_mapMangledNameToIntVal)
+ {
+ return *m_mapMangledNameToIntVal;
+ }
+ m_mapMangledNameToIntVal = std::make_unique<Dictionary<String, IntVal*>>();
+ auto astBuilder = getLinkage()->getASTBuilder();
+ SLANG_AST_BUILDER_RAII(astBuilder);
+ Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder);
+ for (; scope; scope = scope->nextSibling)
+ {
+ if (scope->containerDecl)
+ collectExportedConstantInContainer(*m_mapMangledNameToIntVal, astBuilder, scope->containerDecl);
+ }
+ return *m_mapMangledNameToIntVal;
+}
+
+ConstantIntVal* ComponentType::tryFoldIntVal(IntVal* intVal)
+{
+ return as<ConstantIntVal>(intVal->linkTimeResolve(getMangledNameToIntValMap()));
+}
+
CompileRequestBase::CompileRequestBase(
Linkage* linkage,
DiagnosticSink* sink)