diff options
| author | Yong He <yonghe@outlook.com> | 2022-08-24 12:05:19 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-08-24 12:05:19 -0700 |
| commit | ba6f55ed9481960b4f6c7f0a6b8f1cf7d450c752 (patch) | |
| tree | bd92bf3cca5614585f8be6ad6f57510b18565b47 /source | |
| parent | 3746a47ce407b14c4afbfc5b625513cf81b5e890 (diff) | |
Allow `static const` interface requirements. (#2378)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-decl.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 104 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 27 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 76 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 28 |
9 files changed, 238 insertions, 38 deletions
diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index edc79c030..2df9164fb 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -18,4 +18,18 @@ const TypeExp& TypeConstraintDecl::_getSupOverride() const } +bool isInterfaceRequirement(Decl* decl) +{ + auto ancestor = decl->parentDecl; + for (; ancestor; ancestor = ancestor->parentDecl) + { + if (as<InterfaceDecl>(ancestor)) + return true; + + if (as<ExtensionDecl>(ancestor)) + return false; + } + return false; +} + } // namespace Slang diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 9d2c99f14..147bc7d22 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -494,4 +494,7 @@ class AttributeDecl : public ContainerDecl SyntaxClass<NodeBase> syntaxClass; }; + +bool isInterfaceRequirement(Decl* decl); + } // namespace Slang diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 4ed69e282..dd5aff238 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -619,15 +619,15 @@ void PolynomialIntVal::_toTextOverride(StringBuilder& out) for (Index i = 0; i < terms.getCount(); i++) { auto& term = *(terms[i]); - if (i > 0) + if (term.constFactor > 0) { - if (term.constFactor > 0) + if (i > 0) out << "+"; - else - out << "-"; } + else + out << "-"; bool isFirstFactor = true; - if (term.constFactor != 1 || term.paramFactors.getCount() == 0) + if (abs(term.constFactor) != 1 || term.paramFactors.getCount() == 0) { out << abs(term.constFactor); isFirstFactor = false; @@ -1039,19 +1039,19 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) return this; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SomeIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool SomeIntVal::_equalsValOverride(Val* val) +bool FuncCallIntVal::_equalsValOverride(Val* val) { - if (auto someIntVal = as<SomeIntVal>(val)) + if (auto funcCallIntVal = as<FuncCallIntVal>(val)) { - if (!funcDeclRef.equals(someIntVal->funcDeclRef)) + if (!funcDeclRef.equals(funcCallIntVal->funcDeclRef)) return false; - if (args.getCount() != someIntVal->args.getCount()) + if (args.getCount() != funcCallIntVal->args.getCount()) return false; for (Index i = 0; i < args.getCount(); i++) { - if (!args[i]->equalsVal(someIntVal->args[i])) + if (!args[i]->equalsVal(funcCallIntVal->args[i])) return false; } return true; @@ -1059,11 +1059,11 @@ bool SomeIntVal::_equalsValOverride(Val* val) return false; } -void SomeIntVal::_toTextOverride(StringBuilder& out) +void FuncCallIntVal::_toTextOverride(StringBuilder& out) { auto argToText = [&](int index) { - if (as<PolynomialIntVal>(args[index]) || as<SomeIntVal>(args[index])) + if (as<PolynomialIntVal>(args[index]) || as<FuncCallIntVal>(args[index])) { out << "("; args[index]->toText(out); @@ -1110,7 +1110,7 @@ void SomeIntVal::_toTextOverride(StringBuilder& out) } } -HashCode SomeIntVal::_getHashCodeOverride() +HashCode FuncCallIntVal::_getHashCodeOverride() { HashCode result = funcDeclRef.getHashCode(); for (auto arg : args) @@ -1127,7 +1127,7 @@ static bool nameIs(Name* name, const char* val) return false; } -Val* SomeIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) +Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) { // Are all args const now? List<ConstantIntVal*> constArgs; @@ -1205,14 +1205,14 @@ Val* SomeIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, } else { - SLANG_UNREACHABLE("constant folding of SomeIntVal"); + SLANG_UNREACHABLE("constant folding of FuncCallIntVal"); } return astBuilder->create<ConstantIntVal>(resultValue); } return nullptr; } -Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) { int diff = 0; auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff); @@ -1233,7 +1233,7 @@ Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet return newVal; else { - auto result = astBuilder->create<SomeIntVal>(); + auto result = astBuilder->create<FuncCallIntVal>(); result->args = _Move(newArgs); result->funcDeclRef = newFuncDeclRef; result->funcType = funcType; @@ -1244,4 +1244,72 @@ Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool WitnessLookupIntVal::_equalsValOverride(Val* val) +{ + if (auto lookupIntVal = as<WitnessLookupIntVal>(val)) + { + if (!witness->equalsVal(lookupIntVal->witness)) + return false; + if (key != lookupIntVal->key) + return false; + return true; + } + return false; +} + +void WitnessLookupIntVal::_toTextOverride(StringBuilder& out) +{ + witness->sub->toText(out); + out << "."; + out << (key->getName() ? key->getName()->text : "??"); +} + +HashCode WitnessLookupIntVal::_getHashCodeOverride() +{ + HashCode result = witness->getHashCode(); + result = combineHash(result, Slang::getHashCode(key)); + return result; +} +Val* WitnessLookupIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto newWitness = witness->substituteImpl(astBuilder, subst, &diff); + *ioDiff += diff; + if (diff) + { + auto witnessEntry = tryFoldOrNull(astBuilder, as<SubtypeWitness>(newWitness), key); + if (witnessEntry) + return witnessEntry; + } + // Nothing found: don't substitute. + return this; +} + +Val* WitnessLookupIntVal::tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key) +{ + auto witnessEntry = tryLookUpRequirementWitness(astBuilder, witness, key); + switch (witnessEntry.getFlavor()) + { + case RequirementWitness::Flavor::val: + return witnessEntry.getVal(); + break; + default: + break; + } + return nullptr; +} + +Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type) +{ + if (auto result = tryFoldOrNull(astBuilder, witness, key)) + return result; + auto witnessResult = astBuilder->create<WitnessLookupIntVal>(); + witnessResult->witness = witness; + witnessResult->key = key; + witnessResult->type = type; + return witnessResult; +} + } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 64f04abf9..259be75c3 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -53,9 +53,9 @@ protected: }; // An compile time int val as result of some general computation. -class SomeIntVal : public IntVal +class FuncCallIntVal : public IntVal { - SLANG_AST_CLASS(SomeIntVal) + SLANG_AST_CLASS(FuncCallIntVal) bool _equalsValOverride(Val* val); void _toTextOverride(StringBuilder& out); @@ -63,9 +63,28 @@ class SomeIntVal : public IntVal Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); DeclRef<Decl> funcDeclRef; - List<IntVal*> args; Type* funcType; - static Val* tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*> &newArgs, DiagnosticSink* sink); + List<IntVal*> args; + + static Val* tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink); +}; + +class WitnessLookupIntVal : public IntVal +{ + SLANG_AST_CLASS(WitnessLookupIntVal) + + bool _equalsValOverride(Val* val); + void _toTextOverride(StringBuilder& out); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + SubtypeWitness* witness; + Decl* key; + Type* type; + + static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key); + + static Val* tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type); }; // polynomial expression "2*a*b^3 + 1" will be represented as: diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index abe7642fb..9e3fbaa8d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1053,6 +1053,33 @@ namespace Slang addModifier(varDecl, m_astBuilder->create<NVAPIMagicModifier>()); } } + + if (auto interfaceDecl = as<InterfaceDecl>(varDecl->parentDecl)) + { + if (auto basicType = as<BasicExpressionType>(varDecl->getType())) + { + switch (basicType->baseType) + { + case BaseType::Bool: + case BaseType::Int8: + case BaseType::Int16: + case BaseType::Int: + case BaseType::Int64: + case BaseType::UInt8: + case BaseType::UInt16: + case BaseType::UInt: + case BaseType::UInt64: + break; + default: + getSink()->diagnose(varDecl, Diagnostics::staticConstRequirementMustBeIntOrBool); + break; + } + } + if (!varDecl->findModifier<HLSLStaticModifier>() || !varDecl->findModifier<ConstModifier>()) + { + getSink()->diagnose(varDecl, Diagnostics::valueRequirementMustBeCompileTimeConst); + } + } } void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl) @@ -1588,6 +1615,47 @@ namespace Slang return true; } + bool SemanticsVisitor::doesVarMatchRequirement( + DeclRef<VarDeclBase> satisfyingMemberDeclRef, + DeclRef<VarDeclBase> requiredMemberDeclRef, + RefPtr<WitnessTable> witnessTable) + { + // The type of the satisfying member must match the type of the required member. + auto satisfyingType = getType(getASTBuilder(), satisfyingMemberDeclRef); + auto requiredType = getType(getASTBuilder(), requiredMemberDeclRef); + if (!satisfyingType->equals(requiredType)) + return false; + + for (auto modifier : requiredMemberDeclRef.getDecl()->modifiers) + { + bool found = false; + for (auto satisfyingModifier : satisfyingMemberDeclRef.getDecl()->modifiers) + { + if (satisfyingModifier->astNodeType == modifier->astNodeType) + { + found = true; + break; + } + } + if (!found) + return false; + } + + auto satisfyingVal = tryConstantFoldDeclRef(satisfyingMemberDeclRef, nullptr); + if (satisfyingVal) + { + witnessTable->add( + requiredMemberDeclRef, + RequirementWitness(satisfyingVal)); + } + else + { + witnessTable->add( + requiredMemberDeclRef.getDecl(), + RequirementWitness(satisfyingMemberDeclRef)); + } + return true; + } bool SemanticsVisitor::doesGenericSignatureMatchRequirement( DeclRef<GenericDecl> satisfyingGenericDeclRef, @@ -1975,6 +2043,14 @@ namespace Slang return doesPropertyMatchRequirement(propertyDeclRef, requiredPropertyDeclRef, witnessTable); } } + else if (auto varDeclRef = memberDeclRef.as<VarDeclBase>()) + { + if (auto requiredVarDeclRef = requiredMemberDeclRef.as<VarDeclBase>()) + { + ensureDecl(varDeclRef, DeclCheckState::SignatureChecked); + return doesVarMatchRequirement(varDeclRef, requiredVarDeclRef, witnessTable); + } + } // Default: just assume that thing aren't being satisfied. return false; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 8e14af72a..e33d26c0c 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -982,7 +982,7 @@ namespace Slang || opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") || opName == getName("?:") || opName == getName("<<") || opName == getName(">>")) { - auto result = m_astBuilder->create<SomeIntVal>(); + auto result = m_astBuilder->create<FuncCallIntVal>(); result->args.addRange(argVals, argCount); result->funcDeclRef = funcDeclRef; result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute( @@ -1131,6 +1131,22 @@ namespace Slang if(!decl->hasModifier<ConstModifier>()) return nullptr; + if (isInterfaceRequirement(decl)) + { + for (auto subst = declRef.substitutions.substitutions; subst; subst = subst->outer) + { + if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst)) + { + auto val = WitnessLookupIntVal::tryFold( + m_astBuilder, + thisTypeSubst->witness, + decl, + declRef.substitute(m_astBuilder, decl->type.type)); + return as<IntVal>(val); + } + } + } + auto initExpr = getInitExpr(m_astBuilder, declRef); if(!initExpr) return nullptr; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 9a849384c..091990cf3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -898,6 +898,10 @@ namespace Slang DeclRef<PropertyDecl> satisfyingMemberDeclRef, DeclRef<PropertyDecl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable); + bool doesVarMatchRequirement( + DeclRef<VarDeclBase> satisfyingMemberDeclRef, + DeclRef<VarDeclBase> requiredMemberDeclRef, + RefPtr<WitnessTable> witnessTable); bool doesGenericSignatureMatchRequirement( DeclRef<GenericDecl> genDecl, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 27076681d..c3062ea4f 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -338,6 +338,8 @@ DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' // 303xx: interfaces and associated types DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.") DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.") +DIAGNOSTIC(30302, Error, staticConstRequirementMustBeIntOrBool, "'static const' requirement can only have int or bool type.") +DIAGNOSTIC(30303, Error, valueRequirementMustBeCompileTimeConst, "requirement in the form of a simple value must be declared as 'static const'.") // Interop DIAGNOSTIC(30400, Error, cannotDefinePtrTypeToManagedResource, "pointer to a managed resource is invalid, use `NativeRef<T>` instead") diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b351bfe21..0921d36fb 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1199,6 +1199,9 @@ bool shouldDeclBeTreatedAsInterfaceRequirement(Decl* requirementDecl) else if (auto typeConstraint = as<TypeConstraintDecl>(requirementDecl)) { } + else if (auto varDecl = as<VarDeclBase>(requirementDecl)) + { + } else if (auto genericDecl = as<GenericDecl>(requirementDecl)) { return shouldDeclBeTreatedAsInterfaceRequirement(genericDecl->inner); @@ -1287,7 +1290,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower lowerType(context, getType(context->astBuilder, val->declRef))); } - LoweredValInfo visitSomeIntVal(SomeIntVal* val) + LoweredValInfo visitFuncCallIntVal(FuncCallIntVal* val) { TryClauseEnvironment tryEnv; List<IRInst*> args; @@ -1306,6 +1309,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower tryEnv); } + LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val) + { + auto witnessVal = lowerVal(context, val->witness); + auto key = getInterfaceRequirementKey(context, val->key); + auto type = lowerType(context, val->type); + return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( + type, witnessVal.val, key)); + } + LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val) { auto irBuilder = getBuilder(); @@ -8259,20 +8271,6 @@ bool canDeclLowerToAGeneric(Decl* decl) return false; } -static bool isInterfaceRequirement(Decl* decl) -{ - auto ancestor = decl->parentDecl; - for (; ancestor; ancestor = ancestor->parentDecl) - { - if (as<InterfaceDecl>(ancestor)) - return true; - - if (as<ExtensionDecl>(ancestor)) - return false; - } - return false; -} - /// Add flattened "leaf" elements from `val` to the `ioArgs` list static void _addFlattenedTupleArgs( List<IRInst*>& ioArgs, |
