summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-08-24 12:05:19 -0700
committerGitHub <noreply@github.com>2022-08-24 12:05:19 -0700
commitba6f55ed9481960b4f6c7f0a6b8f1cf7d450c752 (patch)
treebd92bf3cca5614585f8be6ad6f57510b18565b47 /source
parent3746a47ce407b14c4afbfc5b625513cf81b5e890 (diff)
Allow `static const` interface requirements. (#2378)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-decl.cpp14
-rw-r--r--source/slang/slang-ast-decl.h3
-rw-r--r--source/slang/slang-ast-val.cpp104
-rw-r--r--source/slang/slang-ast-val.h27
-rw-r--r--source/slang/slang-check-decl.cpp76
-rw-r--r--source/slang/slang-check-expr.cpp18
-rw-r--r--source/slang/slang-check-impl.h4
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp28
9 files changed, 238 insertions, 38 deletions
diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp
index edc79c030..2df9164fb 100644
--- a/source/slang/slang-ast-decl.cpp
+++ b/source/slang/slang-ast-decl.cpp
@@ -18,4 +18,18 @@ const TypeExp& TypeConstraintDecl::_getSupOverride() const
}
+bool isInterfaceRequirement(Decl* decl)
+{
+ auto ancestor = decl->parentDecl;
+ for (; ancestor; ancestor = ancestor->parentDecl)
+ {
+ if (as<InterfaceDecl>(ancestor))
+ return true;
+
+ if (as<ExtensionDecl>(ancestor))
+ return false;
+ }
+ return false;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 9d2c99f14..147bc7d22 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -494,4 +494,7 @@ class AttributeDecl : public ContainerDecl
SyntaxClass<NodeBase> syntaxClass;
};
+
+bool isInterfaceRequirement(Decl* decl);
+
} // namespace Slang
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 4ed69e282..dd5aff238 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -619,15 +619,15 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out)
for (Index i = 0; i < terms.getCount(); i++)
{
auto& term = *(terms[i]);
- if (i > 0)
+ if (term.constFactor > 0)
{
- if (term.constFactor > 0)
+ if (i > 0)
out << "+";
- else
- out << "-";
}
+ else
+ out << "-";
bool isFirstFactor = true;
- if (term.constFactor != 1 || term.paramFactors.getCount() == 0)
+ if (abs(term.constFactor) != 1 || term.paramFactors.getCount() == 0)
{
out << abs(term.constFactor);
isFirstFactor = false;
@@ -1039,19 +1039,19 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
return this;
}
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SomeIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool SomeIntVal::_equalsValOverride(Val* val)
+bool FuncCallIntVal::_equalsValOverride(Val* val)
{
- if (auto someIntVal = as<SomeIntVal>(val))
+ if (auto funcCallIntVal = as<FuncCallIntVal>(val))
{
- if (!funcDeclRef.equals(someIntVal->funcDeclRef))
+ if (!funcDeclRef.equals(funcCallIntVal->funcDeclRef))
return false;
- if (args.getCount() != someIntVal->args.getCount())
+ if (args.getCount() != funcCallIntVal->args.getCount())
return false;
for (Index i = 0; i < args.getCount(); i++)
{
- if (!args[i]->equalsVal(someIntVal->args[i]))
+ if (!args[i]->equalsVal(funcCallIntVal->args[i]))
return false;
}
return true;
@@ -1059,11 +1059,11 @@ bool SomeIntVal::_equalsValOverride(Val* val)
return false;
}
-void SomeIntVal::_toTextOverride(StringBuilder& out)
+void FuncCallIntVal::_toTextOverride(StringBuilder& out)
{
auto argToText = [&](int index)
{
- if (as<PolynomialIntVal>(args[index]) || as<SomeIntVal>(args[index]))
+ if (as<PolynomialIntVal>(args[index]) || as<FuncCallIntVal>(args[index]))
{
out << "(";
args[index]->toText(out);
@@ -1110,7 +1110,7 @@ void SomeIntVal::_toTextOverride(StringBuilder& out)
}
}
-HashCode SomeIntVal::_getHashCodeOverride()
+HashCode FuncCallIntVal::_getHashCodeOverride()
{
HashCode result = funcDeclRef.getHashCode();
for (auto arg : args)
@@ -1127,7 +1127,7 @@ static bool nameIs(Name* name, const char* val)
return false;
}
-Val* SomeIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
+Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
{
// Are all args const now?
List<ConstantIntVal*> constArgs;
@@ -1205,14 +1205,14 @@ Val* SomeIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl,
}
else
{
- SLANG_UNREACHABLE("constant folding of SomeIntVal");
+ SLANG_UNREACHABLE("constant folding of FuncCallIntVal");
}
return astBuilder->create<ConstantIntVal>(resultValue);
}
return nullptr;
}
-Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff);
@@ -1233,7 +1233,7 @@ Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet
return newVal;
else
{
- auto result = astBuilder->create<SomeIntVal>();
+ auto result = astBuilder->create<FuncCallIntVal>();
result->args = _Move(newArgs);
result->funcDeclRef = newFuncDeclRef;
result->funcType = funcType;
@@ -1244,4 +1244,72 @@ Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet
return this;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+bool WitnessLookupIntVal::_equalsValOverride(Val* val)
+{
+ if (auto lookupIntVal = as<WitnessLookupIntVal>(val))
+ {
+ if (!witness->equalsVal(lookupIntVal->witness))
+ return false;
+ if (key != lookupIntVal->key)
+ return false;
+ return true;
+ }
+ return false;
+}
+
+void WitnessLookupIntVal::_toTextOverride(StringBuilder& out)
+{
+ witness->sub->toText(out);
+ out << ".";
+ out << (key->getName() ? key->getName()->text : "??");
+}
+
+HashCode WitnessLookupIntVal::_getHashCodeOverride()
+{
+ HashCode result = witness->getHashCode();
+ result = combineHash(result, Slang::getHashCode(key));
+ return result;
+}
+Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ auto newWitness = witness->substituteImpl(astBuilder, subst, &diff);
+ *ioDiff += diff;
+ if (diff)
+ {
+ auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), key);
+ if (witnessEntry)
+ return witnessEntry;
+ }
+ // Nothing found: don't substitute.
+ return this;
+}
+
+Val* WitnessLookupIntVal::tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key)
+{
+ auto witnessEntry = tryLookUpRequirementWitness(astBuilder, witness, key);
+ switch (witnessEntry.getFlavor())
+ {
+ case RequirementWitness::Flavor::val:
+ return witnessEntry.getVal();
+ break;
+ default:
+ break;
+ }
+ return nullptr;
+}
+
+Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type)
+{
+ if (auto result = tryFoldOrNull(astBuilder, witness, key))
+ return result;
+ auto witnessResult = astBuilder->create<WitnessLookupIntVal>();
+ witnessResult->witness = witness;
+ witnessResult->key = key;
+ witnessResult->type = type;
+ return witnessResult;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 64f04abf9..259be75c3 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -53,9 +53,9 @@ protected:
};
// An compile time int val as result of some general computation.
-class SomeIntVal : public IntVal
+class FuncCallIntVal : public IntVal
{
- SLANG_AST_CLASS(SomeIntVal)
+ SLANG_AST_CLASS(FuncCallIntVal)
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
@@ -63,9 +63,28 @@ class SomeIntVal : public IntVal
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
DeclRef<Decl> funcDeclRef;
- List<IntVal*> args;
Type* funcType;
- static Val* tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*> &newArgs, DiagnosticSink* sink);
+ List<IntVal*> args;
+
+ static Val* tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink);
+};
+
+class WitnessLookupIntVal : public IntVal
+{
+ SLANG_AST_CLASS(WitnessLookupIntVal)
+
+ bool _equalsValOverride(Val* val);
+ void _toTextOverride(StringBuilder& out);
+ HashCode _getHashCodeOverride();
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ SubtypeWitness* witness;
+ Decl* key;
+ Type* type;
+
+ static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key);
+
+ static Val* tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type);
};
// polynomial expression "2*a*b^3 + 1" will be represented as:
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index abe7642fb..9e3fbaa8d 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1053,6 +1053,33 @@ namespace Slang
addModifier(varDecl, m_astBuilder->create<NVAPIMagicModifier>());
}
}
+
+ if (auto interfaceDecl = as<InterfaceDecl>(varDecl->parentDecl))
+ {
+ if (auto basicType = as<BasicExpressionType>(varDecl->getType()))
+ {
+ switch (basicType->baseType)
+ {
+ case BaseType::Bool:
+ case BaseType::Int8:
+ case BaseType::Int16:
+ case BaseType::Int:
+ case BaseType::Int64:
+ case BaseType::UInt8:
+ case BaseType::UInt16:
+ case BaseType::UInt:
+ case BaseType::UInt64:
+ break;
+ default:
+ getSink()->diagnose(varDecl, Diagnostics::staticConstRequirementMustBeIntOrBool);
+ break;
+ }
+ }
+ if (!varDecl->findModifier<HLSLStaticModifier>() || !varDecl->findModifier<ConstModifier>())
+ {
+ getSink()->diagnose(varDecl, Diagnostics::valueRequirementMustBeCompileTimeConst);
+ }
+ }
}
void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl)
@@ -1588,6 +1615,47 @@ namespace Slang
return true;
}
+ bool SemanticsVisitor::doesVarMatchRequirement(
+ DeclRef<VarDeclBase> satisfyingMemberDeclRef,
+ DeclRef<VarDeclBase> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ // The type of the satisfying member must match the type of the required member.
+ auto satisfyingType = getType(getASTBuilder(), satisfyingMemberDeclRef);
+ auto requiredType = getType(getASTBuilder(), requiredMemberDeclRef);
+ if (!satisfyingType->equals(requiredType))
+ return false;
+
+ for (auto modifier : requiredMemberDeclRef.getDecl()->modifiers)
+ {
+ bool found = false;
+ for (auto satisfyingModifier : satisfyingMemberDeclRef.getDecl()->modifiers)
+ {
+ if (satisfyingModifier->astNodeType == modifier->astNodeType)
+ {
+ found = true;
+ break;
+ }
+ }
+ if (!found)
+ return false;
+ }
+
+ auto satisfyingVal = tryConstantFoldDeclRef(satisfyingMemberDeclRef, nullptr);
+ if (satisfyingVal)
+ {
+ witnessTable->add(
+ requiredMemberDeclRef,
+ RequirementWitness(satisfyingVal));
+ }
+ else
+ {
+ witnessTable->add(
+ requiredMemberDeclRef.getDecl(),
+ RequirementWitness(satisfyingMemberDeclRef));
+ }
+ return true;
+ }
bool SemanticsVisitor::doesGenericSignatureMatchRequirement(
DeclRef<GenericDecl> satisfyingGenericDeclRef,
@@ -1975,6 +2043,14 @@ namespace Slang
return doesPropertyMatchRequirement(propertyDeclRef, requiredPropertyDeclRef, witnessTable);
}
}
+ else if (auto varDeclRef = memberDeclRef.as<VarDeclBase>())
+ {
+ if (auto requiredVarDeclRef = requiredMemberDeclRef.as<VarDeclBase>())
+ {
+ ensureDecl(varDeclRef, DeclCheckState::SignatureChecked);
+ return doesVarMatchRequirement(varDeclRef, requiredVarDeclRef, witnessTable);
+ }
+ }
// Default: just assume that thing aren't being satisfied.
return false;
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 8e14af72a..e33d26c0c 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -982,7 +982,7 @@ namespace Slang
|| opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") ||
opName == getName("?:") || opName == getName("<<") || opName == getName(">>"))
{
- auto result = m_astBuilder->create<SomeIntVal>();
+ auto result = m_astBuilder->create<FuncCallIntVal>();
result->args.addRange(argVals, argCount);
result->funcDeclRef = funcDeclRef;
result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute(
@@ -1131,6 +1131,22 @@ namespace Slang
if(!decl->hasModifier<ConstModifier>())
return nullptr;
+ if (isInterfaceRequirement(decl))
+ {
+ for (auto subst = declRef.substitutions.substitutions; subst; subst = subst->outer)
+ {
+ if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst))
+ {
+ auto val = WitnessLookupIntVal::tryFold(
+ m_astBuilder,
+ thisTypeSubst->witness,
+ decl,
+ declRef.substitute(m_astBuilder, decl->type.type));
+ return as<IntVal>(val);
+ }
+ }
+ }
+
auto initExpr = getInitExpr(m_astBuilder, declRef);
if(!initExpr)
return nullptr;
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 9a849384c..091990cf3 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -898,6 +898,10 @@ namespace Slang
DeclRef<PropertyDecl> satisfyingMemberDeclRef,
DeclRef<PropertyDecl> requiredMemberDeclRef,
RefPtr<WitnessTable> witnessTable);
+ bool doesVarMatchRequirement(
+ DeclRef<VarDeclBase> satisfyingMemberDeclRef,
+ DeclRef<VarDeclBase> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable);
bool doesGenericSignatureMatchRequirement(
DeclRef<GenericDecl> genDecl,
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 27076681d..c3062ea4f 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -338,6 +338,8 @@ DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum'
// 303xx: interfaces and associated types
DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.")
DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.")
+DIAGNOSTIC(30302, Error, staticConstRequirementMustBeIntOrBool, "'static const' requirement can only have int or bool type.")
+DIAGNOSTIC(30303, Error, valueRequirementMustBeCompileTimeConst, "requirement in the form of a simple value must be declared as 'static const'.")
// Interop
DIAGNOSTIC(30400, Error, cannotDefinePtrTypeToManagedResource, "pointer to a managed resource is invalid, use `NativeRef<T>` instead")
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b351bfe21..0921d36fb 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1199,6 +1199,9 @@ bool shouldDeclBeTreatedAsInterfaceRequirement(Decl* requirementDecl)
else if (auto typeConstraint = as<TypeConstraintDecl>(requirementDecl))
{
}
+ else if (auto varDecl = as<VarDeclBase>(requirementDecl))
+ {
+ }
else if (auto genericDecl = as<GenericDecl>(requirementDecl))
{
return shouldDeclBeTreatedAsInterfaceRequirement(genericDecl->inner);
@@ -1287,7 +1290,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
lowerType(context, getType(context->astBuilder, val->declRef)));
}
- LoweredValInfo visitSomeIntVal(SomeIntVal* val)
+ LoweredValInfo visitFuncCallIntVal(FuncCallIntVal* val)
{
TryClauseEnvironment tryEnv;
List<IRInst*> args;
@@ -1306,6 +1309,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
tryEnv);
}
+ LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val)
+ {
+ auto witnessVal = lowerVal(context, val->witness);
+ auto key = getInterfaceRequirementKey(context, val->key);
+ auto type = lowerType(context, val->type);
+ return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst(
+ type, witnessVal.val, key));
+ }
+
LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val)
{
auto irBuilder = getBuilder();
@@ -8259,20 +8271,6 @@ bool canDeclLowerToAGeneric(Decl* decl)
return false;
}
-static bool isInterfaceRequirement(Decl* decl)
-{
- auto ancestor = decl->parentDecl;
- for (; ancestor; ancestor = ancestor->parentDecl)
- {
- if (as<InterfaceDecl>(ancestor))
- return true;
-
- if (as<ExtensionDecl>(ancestor))
- return false;
- }
- return false;
-}
-
/// Add flattened "leaf" elements from `val` to the `ioArgs` list
static void _addFlattenedTupleArgs(
List<IRInst*>& ioArgs,