diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 302 |
1 files changed, 229 insertions, 73 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index fa05dde11..f28f46deb 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -81,8 +81,6 @@ namespace Slang void checkCallableDeclCommon(CallableDecl* decl); - void maybeCheckDifferentiableAccessorSignature(FuncDecl* funcDecl); - void visitFuncDecl(FuncDecl* funcDecl); void visitParamDecl(ParamDecl* paramDecl); @@ -889,6 +887,7 @@ namespace Slang auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder()); diffTypeDict->parentDecl = containerDecl; containerDecl->members.add(diffTypeDict); + containerDecl->invalidateMemberDictionary(); } } } @@ -1072,8 +1071,11 @@ namespace Slang auto initExpr = varDecl->initExpr; if(!initExpr) { - getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); - varDecl->type.type = m_astBuilder->getErrorType(); + if (!varDecl->type.type) + { + getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); + varDecl->type.type = m_astBuilder->getErrorType(); + } } else { @@ -1390,7 +1392,7 @@ namespace Slang context->parentDecl->members.add((aggTypeDecl)); aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; - context->parentDecl->getMemberDictionary().Add(aggTypeDecl->getName(), aggTypeDecl); + context->parentDecl->invalidateMemberDictionary(); } // TODO: if we want to make the synthesized type itself to be differentiable, @@ -1409,6 +1411,7 @@ namespace Slang diffField->checkState = DeclCheckState::SignatureChecked; diffField->parentDecl = aggTypeDecl; aggTypeDecl->members.add(diffField); + aggTypeDecl->invalidateMemberDictionary(); // Inject a `DerivativeMember` modifier on the original decl. auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>(); @@ -2412,54 +2415,12 @@ namespace Slang return false; } - bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( + FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef<FuncDecl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable) + DeclRef<FuncDecl> requiredMemberDeclRef, + List<Expr*>& synArgs, + ThisExpr*& synThis) { - // The situation here is that the context of an inheritance - // declaration didn't provide an exact match for a required - // method. E.g.: - // - // interface ICounter { [mutating] int increment(); } - // struct MyCounter : ICounter - // { - // [murtating] int increment(int val = 1) { ... } - // } - // - // It is clear in this case that the `MyCounter` type *can* - // satisfy the signature required by `ICounter`, but it has - // no explicit method declaration that is a perfect match. - // - // The approach in this function will be to construct a - // synthesized method along the lines of: - // - // struct MyCounter ... - // { - // ... - // [murtating] int synthesized() - // { - // return this.increment(); - // } - // } - // - // That is, we construct a method with the exact signature - // of the requirement (same parameter and result types), - // and then provide it with a body that simple `return`s - // the result of applying the desired requirement name - // (`increment` in this case) to those parameters. - // - // If the synthesized method type-checks, then we can say - // that the type must satisfy the requirement structurally, - // even if there isn't an exact signature match. More - // importantly, the method we just synthesized can be - // used as a witness to the fact that the requirement is - // satisfied. - - // With the big picture spelled out, we can settle into - // the work of constructing our synthesized method. - // auto synFuncDecl = m_astBuilder->create<FuncDecl>(); // For now our synthesized method will use the name and source @@ -2497,8 +2458,7 @@ namespace Slang // that reference those parametesr as arguments for the call expresison // that makes up the body. // - List<Expr*> synArgs; - for( auto paramDeclRef : getParameters(requiredMemberDeclRef) ) + for (auto paramDeclRef : getParameters(requiredMemberDeclRef)) { auto paramType = getType(m_astBuilder, paramDeclRef); @@ -2524,13 +2484,13 @@ namespace Slang synArgs.add(synArg); } + // Required interface methods can be `static` or non-`static`, // and non-`static` methods can be `[mutating]` or non-`[mutating]`. // All of these details affect how we introduce our `this` parameter, // if any. // - ThisExpr* synThis = nullptr; - if( requiredMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>() ) + if (requiredMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>()) { auto synStaticModifier = m_astBuilder->create<HLSLStaticModifier>(); synFuncDecl->modifiers.first = synStaticModifier; @@ -2546,7 +2506,7 @@ namespace Slang // synThis->type.type = context->conformingType; - if( requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>() ) + if (requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>()) { // If the interface requirement is `[mutating]` then our // synthesized method should be too, and also the `this` @@ -2559,6 +2519,64 @@ namespace Slang } } + return synFuncDecl; + } + + bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef<FuncDecl> requiredMemberDeclRef, + RefPtr<WitnessTable> witnessTable) + { + // The situation here is that the context of an inheritance + // declaration didn't provide an exact match for a required + // method. E.g.: + // + // interface ICounter { [mutating] int increment(); } + // struct MyCounter : ICounter + // { + // [murtating] int increment(int val = 1) { ... } + // } + // + // It is clear in this case that the `MyCounter` type *can* + // satisfy the signature required by `ICounter`, but it has + // no explicit method declaration that is a perfect match. + // + // The approach in this function will be to construct a + // synthesized method along the lines of: + // + // struct MyCounter ... + // { + // ... + // [murtating] int synthesized() + // { + // return this.increment(); + // } + // } + // + // That is, we construct a method with the exact signature + // of the requirement (same parameter and result types), + // and then provide it with a body that simple `return`s + // the result of applying the desired requirement name + // (`increment` in this case) to those parameters. + // + // If the synthesized method type-checks, then we can say + // that the type must satisfy the requirement structurally, + // even if there isn't an exact signature match. More + // importantly, the method we just synthesized can be + // used as a witness to the fact that the requirement is + // satisfied. + + // With the big picture spelled out, we can settle into + // the work of constructing our synthesized method. + // + ThisExpr* synThis = nullptr; + List<Expr*> synArgs; + auto synFuncDecl = synthesizeMethodSignatureForRequirementWitness( + context, requiredMemberDeclRef, synArgs, synThis); + + auto resultType = synFuncDecl->returnType.type; + // The body of our synthesized method is going to try to // make a call using the name of the method requirement (e.g., // the name `increment` in our example at the top of this function). @@ -3036,11 +3054,27 @@ namespace Slang if (auto requiredFuncDeclRef = requiredMemberDeclRef.as<FuncDecl>()) { // Check signature match. - return trySynthesizeMethodRequirementWitness( + if (trySynthesizeMethodRequirementWitness( context, lookupResult, requiredFuncDeclRef, - witnessTable); + witnessTable)) + return true; + + if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>()) + { + switch (builtinAttr->kind) + { + case BuiltinRequirementKind::DAddFunc: + case BuiltinRequirementKind::DMulFunc: + case BuiltinRequirementKind::DZeroFunc: + return trySynthesizeDifferentialMethodRequirementWitness( + context, + requiredFuncDeclRef, + witnessTable); + } + } + return false; } if( auto requiredPropertyDeclRef = requiredMemberDeclRef.as<PropertyDecl>() ) @@ -3054,11 +3088,11 @@ namespace Slang if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>()) { - if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>()) + if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>()) { switch (builtinAttr->kind) { - case BuiltinAssociatedTypeRequirementKind::Differential: + case BuiltinRequirementKind::DifferentialType: return trySynthesizeDifferentialAssociatedTypeRequirementWitness( context, requiredAssocTypeDeclRef, @@ -3088,6 +3122,130 @@ namespace Slang return false; } + bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<Decl> requirementDeclRef, + RefPtr<WitnessTable> witnessTable) + { + // This method implements a general code synthesis pattern. + // For requirement of the form: + // ``` + // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) + // ``` + // Where TResult, TParam1, TParam2 is either `This` or `Differential`, + // We synthesize a memberwise dispatch to compute each field of `TResult`, + // resulting an implementation of the form: + // ``` + // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) + // { + // TResult result; + // result.member0 = decltype(result.member0).requiredMethod(p0.member0, p1.member0); + // result.member1 = decltype(result.member1).requiredMethod(p0.member1, p1.member1); + // ... + // return result; + // } + // ``` + List<Expr*> synArgs; + ThisExpr* synThis = nullptr; + auto synFunc = synthesizeMethodSignatureForRequirementWitness( + context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis); + + auto blockStmt = m_astBuilder->create<BlockStmt>(); + synFunc->body = blockStmt; + auto seqStmt = m_astBuilder->create<SeqStmt>(); + blockStmt->body = seqStmt; + + // Create a variable for return value. + auto scopeDecl = m_astBuilder->create<ScopeDecl>(); + synFunc->members.add(scopeDecl); + scopeDecl->parentDecl = synFunc; + auto varStmt = m_astBuilder->create<DeclStmt>(); + seqStmt->stmts.add(varStmt); + + auto returnVar = m_astBuilder->create<VarDecl>(); + returnVar->parentDecl = scopeDecl; + scopeDecl->members.add(returnVar); + + returnVar->type.type = synFunc->returnType.type; + returnVar->nameAndLoc.name = getName("result"); + varStmt->decl = returnVar; + auto resultVarExpr = m_astBuilder->create<VarExpr>(); + resultVarExpr->declRef = makeDeclRef(returnVar); + resultVarExpr->type.type = synFunc->returnType.type; + resultVarExpr->type.isLeftValue = true; + + for (auto member : context->parentDecl->members) + { + auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); + if (!derivativeAttr) + continue; + auto varMember = as<VarDeclBase>(member); + if (!varMember) + continue; + ensureDecl(varMember, DeclCheckState::ReadyForReference); + auto memberType = varMember->getType(); + auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); + if (!diffMemberType) + continue; + + // Construct reference exprs to the member's corresponding fields in each parameter. + List<Expr*> paramFields; + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + paramIndex++; + } + + // Invoke the method for the field. + auto callee = m_astBuilder->create<StaticMemberExpr>(); + auto baseSharedType = m_astBuilder->create<SharedTypeExpr>(); + auto baseSharedTypeType = m_astBuilder->create<TypeType>(); + baseSharedTypeType->type = memberType; + baseSharedType->type = baseSharedTypeType; + baseSharedType->base.type = memberType; + callee->baseExpression = baseSharedType; + callee->name = requirementDeclRef.getName(); + callee->loc = synFunc->loc; + auto invokeExpr = m_astBuilder->create<InvokeExpr>(); + invokeExpr->functionExpr = callee; + invokeExpr->arguments = _Move(paramFields); + + // Assign the value to resultVar. + auto leftVal = m_astBuilder->create<MemberExpr>(); + leftVal->baseExpression = resultVarExpr; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` + // is Differential type. + leftVal->name = varMember->getName(); + + auto assignExpr = m_astBuilder->create<AssignExpr>(); + assignExpr->left = leftVal; + assignExpr->right = invokeExpr; + auto assignStmt = m_astBuilder->create<ExpressionStmt>(); + assignStmt->expression = assignExpr; + seqStmt->stmts.add(assignStmt); + } + + // TODO: synthesize assignments for inherited members here. + + auto synReturn = m_astBuilder->create<ReturnStmt>(); + synReturn->expression = resultVarExpr; + seqStmt->stmts.add(synReturn); + + synFunc->parentDecl = context->parentDecl; + context->parentDecl->members.add(synFunc); + context->parentDecl->invalidateMemberDictionary(); + addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>()); + + witnessTable->add(requirementDeclRef, RequirementWitness(makeDeclRef(synFunc))); + return true; + } + bool SemanticsVisitor::findWitnessForInterfaceRequirement( ConformanceCheckingContext* context, Type* subType, @@ -3210,18 +3368,16 @@ namespace Slang if(!lookupResult.isValid()) { - // If we failed to even look up a member with the name of the - // requirement, then we can be certain that the type doesn't - // satisfy the requirement. - // - // TODO: If we ever allowed certain kinds of requirements to - // be inferred (e.g., inferring associated types from the - // signatures of methods, as is done for Swift), we'd - // need to revisit this step. - // - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); - getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); - return false; + // If we failed to look up a member with the name of the + // requirement, it may be possible that we can still synthesis the + // implementation if this is one of the known builtin requirements. + // Otherwise, report diagnostic now. + if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementAttribute>()) + { + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); + getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); + return false; + } } // Iterate over the members and look for one that matches @@ -5103,7 +5259,7 @@ namespace Slang { resultType = CheckProperType(resultType); } - else + else if (!funcDecl->returnType.type) { resultType = TypeExp(m_astBuilder->getVoidType()); } |
