diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-01 08:46:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-01 08:46:57 -0700 |
| commit | cbc1eff56057f199183bb7c17d8a360326512367 (patch) | |
| tree | 487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/slang-check-decl.cpp | |
| parent | b707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff) | |
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 438 |
1 files changed, 311 insertions, 127 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 457ae229b..f60fbcc2c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -11,9 +11,8 @@ // and when things get checked. #include "slang-lookup.h" - #include "slang-syntax.h" - +#include "slang-ast-synthesis.h" #include <limits> namespace Slang @@ -166,6 +165,65 @@ namespace Slang void visitExtensionDecl(ExtensionDecl* decl); }; + struct SemanticsDeclTypeResolutionVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor<SemanticsDeclTypeResolutionVisitor> + { + SemanticsDeclTypeResolutionVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + Val* resolveVal(Val* val); + Type* resolveType(Type* type) + { + return (Type*)resolveVal(type); + } + + void visitTypeExp(TypeExp& exp) + { + exp.type = resolveType(exp.type); + } + + void visitVarDeclBase(VarDeclBase* varDecl) + { + visitTypeExp(varDecl->type); + } + + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) + { + visitTypeExp(decl->sup); + } + + void visitTypeDefDecl(TypeDefDecl* decl) + { + visitTypeExp(decl->type); + } + + void visitGenericTypeParamDecl(GenericTypeParamDecl* paramDecl) + { + visitTypeExp(paramDecl->initType); + } + + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + visitTypeExp(inheritanceDecl->base); + } + + void visitCallableDecl(CallableDecl* decl) + { + visitTypeExp(decl->returnType); + visitTypeExp(decl->errorType); + } + + void visitPropertyDecl(PropertyDecl* decl) + { + visitTypeExp(decl->type); + } + }; + struct SemanticsDeclBodyVisitor : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclBodyVisitor> @@ -1363,27 +1421,30 @@ namespace Slang bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness( ConformanceCheckingContext* context, - DeclRef<Decl> requirementDeclRef, + DeclRef<AssocTypeDecl> requirementDeclRef, RefPtr<WitnessTable> witnessTable) { - // We currently can't handle generic types. - if (GetOuterGeneric(context->parentDecl) != nullptr) - { - return false; - } - + ASTSynthesizer synth(m_astBuilder, getNamePool()); Decl* existingDecl = nullptr; AggTypeDecl* aggTypeDecl = nullptr; if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl)) { - aggTypeDecl = as<AggTypeDecl>(existingDecl); - SLANG_RELEASE_ASSERT(aggTypeDecl); - // Remove the `ToBeSynthesizedModifier`. - if (as<ToBeSynthesizedModifier>(aggTypeDecl->modifiers.first)) + if (as<ToBeSynthesizedModifier>(existingDecl->modifiers.first)) { - aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next; + existingDecl->modifiers.first = existingDecl->modifiers.first->next; } + else + { + // The user has defined an associatedtype explicitly but that we reach here because + // that type failed to satisfy the `IDifferential` requirement. + // We stop the synthesis and let the follow-up logic to report a diagnostic. + return false; + } + + aggTypeDecl = as<AggTypeDecl>(existingDecl); + SLANG_RELEASE_ASSERT(aggTypeDecl); + synth.pushContainerScope(aggTypeDecl); } else { @@ -1393,15 +1454,12 @@ namespace Slang aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; context->parentDecl->invalidateMemberDictionary(); + synth.pushScopeForContainer(aggTypeDecl); } - // TODO: if we want to make the synthesized type itself to be differentiable, - // add an inheritance decl here. Need to be careful to avoid infinite recursion - // trying to synthesize the higher order differential types. - // Helper function to add a `diffType` field into the synthesized type for the original // `member`. - auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc); + auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl)); auto addDiffMember = [&](Decl* member, Type* diffMemberType) { // If the field is differentiable, add a corresponding field in the associated Differential type. @@ -1452,12 +1510,35 @@ namespace Slang addDiffMember(member, diffType); } - // In the future when the Differential type itself needs to conform to some interface, - // this is the place to synthesize requirements for them. addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>()); - auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr); - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); - return true; + + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution + // from requirementDeclRef to get the generic substitution for outer generic parameters, and + // apply it to the newly synthesized decl. + SubstitutionSet substSet; + if (auto thisTypeSusbt = findThisTypeSubstitution( + requirementDeclRef.substitutions, + as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + { + substSet = declRefType->declRef.substitutions; + } + } + + auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet); + + if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) + { + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + return true; + } + + // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed. + // If not, there is something wrong with the code synthesis logic. For now we just return false + // instead of crashing so the user can work around the issues. + return false; } void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*) @@ -2242,22 +2323,8 @@ namespace Slang witnessTable); } - bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement( - Type* satisfyingType, - DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, - RefPtr<WitnessTable> witnessTable) + bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, RefPtr<WitnessTable> witnessTable) { - if (auto declRefType = as<DeclRefType>(satisfyingType)) - { - // If we are seeing a placeholder that awaits synthesis, return false now to trigger - // auto synthesis. - if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>()) - return false; - } - // We need to confirm that the chosen type `satisfyingType`, - // meets all the constraints placed on the associated type - // requirement `requiredAssociatedTypeDeclRef`. - // // We will enumerate the type constraints placed on the // associated type and see if they can be satisfied. // @@ -2269,7 +2336,7 @@ namespace Slang // Perform a search for a witness to the subtype relationship. auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); - if(witness) + if (witness) { // If a subtype witness was found, then the conformance // appears to hold, and we can satisfy that requirement. @@ -2282,6 +2349,30 @@ namespace Slang conformance = false; } } + return conformance; + } + + bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement( + Type* satisfyingType, + DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, + RefPtr<WitnessTable> witnessTable) + { + if (auto declRefType = as<DeclRefType>(satisfyingType)) + { + // If we are seeing a placeholder that awaits synthesis, return false now to trigger + // auto synthesis. + if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>()) + return false; + } + // We need to confirm that the chosen type `satisfyingType`, + // meets all the constraints placed on the associated type + // requirement `requiredAssociatedTypeDeclRef`. + // + // We will enumerate the type constraints placed on the + // associated type and see if they can be satisfied. + // + bool conformance = doesTypeSatisfyAssociatedTypeConstraintRequirement( + satisfyingType, requiredAssociatedTypeDeclRef, witnessTable); // TODO: if any conformance check failed, we should probably include // that in an error message produced about not satisfying the requirement. @@ -3122,12 +3213,43 @@ namespace Slang return false; } + Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List<Expr*>&& args, int nestingLevel = 0) + { + if (nestingLevel > 16) + return nullptr; + + // If field type is an array, assign each element individually. + if (auto arrayType = as<ArrayExpressionType>(leftType)) + { + VarDecl* indexVar = nullptr; + auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); + auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); + for (auto& arg : args) + { + arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); + } + auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->baseType, innerLeft, _Move(args), nestingLevel + 1); + synth.popScope(); + if (!assignStmt) + return nullptr; + forStmt->statement = assignStmt; + return forStmt; + } + + auto callee = synth.emitMemberExpr(leftType, funcName); + return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); + } + bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef<Decl> requirementDeclRef, RefPtr<WitnessTable> witnessTable) { - // This method implements a general code synthesis pattern. + // We support two cases of synthesis here. + // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`. + // In this case we just trivially return `DifferentialBottom` in all synthesized methods. + // Case 2 is that the `Differential` type contains members corresponding to each primal member. + // We will apply a general code synthesis pattern to reflect that structure. // For requirement of the form: // ``` // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) @@ -3145,104 +3267,123 @@ namespace Slang // return result; // } // ``` + + // First we need to make sure the associated `Differential` type requirement is satisfied. + bool hasDifferentialAssocType = false; + for (auto existingEntry : witnessTable->requirementList) + { + if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementAttribute>()) + { + if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType && + existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none) + { + hasDifferentialAssocType = true; + } + } + } + if (!hasDifferentialAssocType) + return false; + + ASTSynthesizer synth(m_astBuilder, getNamePool()); List<Expr*> synArgs; ThisExpr* synThis = nullptr; auto synFunc = synthesizeMethodSignatureForRequirementWitness( context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis); - + synFunc->parentDecl = context->parentDecl; + synth.pushContainerScope(synFunc); auto blockStmt = m_astBuilder->create<BlockStmt>(); synFunc->body = blockStmt; - auto seqStmt = m_astBuilder->create<SeqStmt>(); + auto seqStmt = synth.pushSeqStmtScope(); blockStmt->body = seqStmt; - // Create a variable for return value. - auto scopeDecl = m_astBuilder->create<ScopeDecl>(); - synFunc->members.add(scopeDecl); - scopeDecl->parentDecl = synFunc; - auto varStmt = m_astBuilder->create<DeclStmt>(); - seqStmt->stmts.add(varStmt); - - auto returnVar = m_astBuilder->create<VarDecl>(); - returnVar->parentDecl = scopeDecl; - scopeDecl->members.add(returnVar); - - returnVar->type.type = synFunc->returnType.type; - returnVar->nameAndLoc.name = getName("result"); - varStmt->decl = returnVar; - auto resultVarExpr = m_astBuilder->create<VarExpr>(); - resultVarExpr->declRef = makeDeclRef(returnVar); - resultVarExpr->type.type = synFunc->returnType.type; - resultVarExpr->type.isLeftValue = true; - - for (auto member : context->parentDecl->members) - { - auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); - if (!derivativeAttr) - continue; - auto varMember = as<VarDeclBase>(member); - if (!varMember) - continue; - ensureDecl(varMember, DeclCheckState::ReadyForReference); - auto memberType = varMember->getType(); - auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); - if (!diffMemberType) - continue; + if (synFunc->returnType.type->equals(m_astBuilder->getDifferentialBottomType())) + { + // Trivial case, the `Differential` type is `DifferentialBottom`. + // We will just return `DifferentialBottom.dzero()`. + auto resultExpr = m_astBuilder->create<InvokeExpr>(); + auto dzeroMember = m_astBuilder->create<StaticMemberExpr>(); + auto base = m_astBuilder->create<SharedTypeExpr>(); + auto typetype = m_astBuilder->create<TypeType>(); + typetype->type = m_astBuilder->getDifferentialBottomType(); + base->type.type = typetype; + dzeroMember->baseExpression = base; + dzeroMember->name = getName("dzero"); + resultExpr->functionExpr = dzeroMember; + auto synReturn = m_astBuilder->create<ReturnStmt>(); + synReturn->expression = resultExpr; + seqStmt->stmts.add(synReturn); + } + else + { + // The general case. + // Create a variable for return value. + synth.pushVarScope(); + auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result")); + auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type); - // Construct reference exprs to the member's corresponding fields in each parameter. - List<Expr*> paramFields; - int paramIndex = 0; - for (auto arg : synArgs) - { - auto memberExpr = m_astBuilder->create<MemberExpr>(); - memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); - paramFields.add(memberExpr); - paramIndex++; - } - - // Invoke the method for the field. - auto callee = m_astBuilder->create<StaticMemberExpr>(); - auto baseSharedType = m_astBuilder->create<SharedTypeExpr>(); - auto baseSharedTypeType = m_astBuilder->create<TypeType>(); - baseSharedTypeType->type = memberType; - baseSharedType->type = baseSharedTypeType; - baseSharedType->base.type = memberType; - callee->baseExpression = baseSharedType; - callee->name = requirementDeclRef.getName(); - callee->loc = synFunc->loc; - auto invokeExpr = m_astBuilder->create<InvokeExpr>(); - invokeExpr->functionExpr = callee; - invokeExpr->arguments = _Move(paramFields); - - // Assign the value to resultVar. - auto leftVal = m_astBuilder->create<MemberExpr>(); - leftVal->baseExpression = resultVarExpr; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` - // is Differential type. - leftVal->name = varMember->getName(); - - auto assignExpr = m_astBuilder->create<AssignExpr>(); - assignExpr->left = leftVal; - assignExpr->right = invokeExpr; - auto assignStmt = m_astBuilder->create<ExpressionStmt>(); - assignStmt->expression = assignExpr; - seqStmt->stmts.add(assignStmt); - } - - // TODO: synthesize assignments for inherited members here. - - auto synReturn = m_astBuilder->create<ReturnStmt>(); - synReturn->expression = resultVarExpr; - seqStmt->stmts.add(synReturn); + for (auto member : context->parentDecl->members) + { + auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); + if (!derivativeAttr) + continue; + auto varMember = as<VarDeclBase>(member); + if (!varMember) + continue; + ensureDecl(varMember, DeclCheckState::ReadyForReference); + auto memberType = varMember->getType(); + auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); + if (!diffMemberType) + continue; - synFunc->parentDecl = context->parentDecl; + // Construct reference exprs to the member's corresponding fields in each parameter. + List<Expr*> paramFields; + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + paramIndex++; + } + + // Invoke the method for the field and assign the value to resultVar. + // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` + // is Differential type. + auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); + if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields))) + return false; + } + + // TODO: synthesize assignments for inherited members here. + + auto synReturn = m_astBuilder->create<ReturnStmt>(); + synReturn->expression = resultVarExpr; + seqStmt->stmts.add(synReturn); + } + context->parentDecl->members.add(synFunc); context->parentDecl->invalidateMemberDictionary(); addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>()); - witnessTable->add(requirementDeclRef, RequirementWitness(makeDeclRef(synFunc))); + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized method here in order to fill into the witness table. + // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the + // generic substitution for outer generic parameters, and apply it here. + SubstitutionSet substSet; + if (auto thisTypeSusbt = findThisTypeSubstitution( + requirementDeclRef.substitutions, + as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + { + substSet = declRefType->declRef.substitutions; + } + } + + witnessTable->add(requirementDeclRef, RequirementWitness(DeclRef<Decl>(synFunc, substSet))); return true; } @@ -3801,7 +3942,10 @@ namespace Slang // be required to implement all interface requirements, // just with `abstract` methods that replicate things? // (That's what C# does). - for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>()) + + // Make a copy of inhertanceDecls firstsince `checkConformance` may modify decl->members. + auto inheritanceDecls = decl->getMembersOfType<InheritanceDecl>().toList(); + for (auto inheritanceDecl : inheritanceDecls) { checkConformance(type, inheritanceDecl, decl); } @@ -5230,7 +5374,7 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { - if (decl->findModifier<ForwardDifferentiableAttribute>()) + if (decl->findModifier<DifferentiableAttribute>()) { this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); } @@ -6274,6 +6418,10 @@ namespace Slang SemanticsDeclConformancesVisitor(shared).dispatch(decl); break; + case DeclCheckState::TypesFullyResolved: + SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl); + break; + case DeclCheckState::Checked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; @@ -6325,4 +6473,40 @@ namespace Slang return result; } + Val* SemanticsDeclTypeResolutionVisitor::resolveVal(Val* val) + { + if (auto declRefType = as<DeclRefType>(val)) + { + if (auto concreteType = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(m_astBuilder, declRefType->declRef)) + return as<Type>(concreteType); + for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst=subst->outer) + { + if (auto genericSubst = as<GenericSubstitution>(subst)) + { + ShortList<Val*> newArgs; + for (auto& arg : genericSubst->getArgs()) + { + arg = resolveVal(arg); + SLANG_RELEASE_ASSERT(arg); + } + } + } + } + else if (auto subtypeWitness = as<SubtypeWitness>(val)) + { + auto sub = as<Type>(resolveVal(subtypeWitness->sub)); + auto sup = as<Type>(resolveVal(subtypeWitness->sup)); + if (sub && sup) + { + if (sub != subtypeWitness->sub || sup != subtypeWitness->sup) + { + auto newVal = tryGetSubtypeWitness(as<Type>(sub), as<Type>(sup)); + if (newVal) + val = newVal; + } + } + } + return val; + } + } |
