diff options
| author | Yong He <yonghe@outlook.com> | 2023-11-03 12:49:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-11-03 12:49:23 -0700 |
| commit | cc222702a8d7a1fccf8ad14b256570bcec1554ae (patch) | |
| tree | ddc6aeba05fe3cbc8618d7b1cb5e63a9d711667d /source/slang/slang-check-decl.cpp | |
| parent | 911a4401b08f6199e18b32349c236c186a2dd128 (diff) | |
Add more diagnostics on invalid custom derivative use. (#3309)
* Add more diagnostics on invalid custom derivative use.
* cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 80 |
1 files changed, 77 insertions, 3 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 15831ba26..64db4cdc5 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6538,6 +6538,25 @@ namespace Slang return parentGeneric; } + Decl* SemanticsVisitor::getOuterGenericOrSelf(Decl* decl) + { + auto parentDecl = decl->parentDecl; + if (!parentDecl) return decl; + auto parentGeneric = as<GenericDecl>(parentDecl); + if (!parentGeneric) return decl; + return parentGeneric; + } + + GenericDecl* SemanticsVisitor::findNextOuterGeneric(Decl* decl) + { + for (auto p = decl->parentDecl; p; p = p->parentDecl) + { + if (auto genDecl = as<GenericDecl>(p)) + return genDecl; + } + return nullptr; + } + DeclRef<ExtensionDecl> SemanticsVisitor::applyExtensionToType( ExtensionDecl* extDecl, Type* type) @@ -7301,6 +7320,40 @@ namespace Slang funcType->getParamType(ii)->toString()); } } + // The `imaginaryArguments` list does not include the `this` parameter. + // So we need to check that `this` type matches. + bool funcIsStatic = isEffectivelyStatic(funcDecl); + bool derivativeFuncIsStatic = isEffectivelyStatic(calleeDeclRef->declRef.getDecl()); + if (funcIsStatic != derivativeFuncIsStatic) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeSignatureThisParamMismatch); + return; + } + if (!funcIsStatic) + { + auto defaultFuncDeclRef = createDefaultSubstitutionsIfNeeded( + visitor->getASTBuilder(), + visitor, + makeDeclRef(funcDecl)); + auto funcThisType = visitor->calcThisType(defaultFuncDeclRef); + auto derivativeFuncThisType = visitor->calcThisType(calleeDeclRef->declRef); + if (!funcThisType->equals(derivativeFuncThisType)) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeSignatureThisParamMismatch); + return; + } + if (visitor->isTypeDifferentiable(funcThisType)) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType); + return; + } + } attr->funcExpr = calleeDeclRef; if (attr->args.getCount()) @@ -7489,9 +7542,10 @@ namespace Slang TDerivativeOfAttr* derivativeOfAttr, DeclAssociationKind assocKind) { + auto astBuilder = visitor->getASTBuilder(); DeclRef<Decl> calleeDeclRef; DeclRefExpr* calleeDeclRefExpr = nullptr; - HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); + HigherOrderInvokeExpr* higherOrderFuncExpr = astBuilder->create<TDifferentiateExpr>(); higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; if (derivativeOfAttr->args.getCount() > 0) higherOrderFuncExpr->loc = derivativeOfAttr->args[0]->loc; @@ -7501,7 +7555,7 @@ namespace Slang visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } - List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc).args; + List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(astBuilder, funcDecl, derivativeOfAttr->loc).args; auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs); SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); @@ -7532,6 +7586,21 @@ namespace Slang visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } + + // For now, if calleeFunc or funcDecl is nested inside some generic aggregate, + // they must be the same generic decl. For example, using B<T>.f() as the original function + // for C<T>.derivative() is not allowed. + // We may relax this restriction in the future by solving the "inverse" generic arguments + // from the `calleeDeclRef`, and use them to create a declRef to funcDecl from the original + // func. + auto originalNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(calleeFunc)); + auto derivativeNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(funcDecl)); + if (originalNextGeneric != derivativeNextGeneric) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + return; + } + if (isInterfaceRequirement(calleeFunc)) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); @@ -7555,10 +7624,15 @@ namespace Slang } derivativeOfAttr->funcExpr = calleeDeclRefExpr; - auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); + auto derivativeAttr = astBuilder->create<TDerivativeAttr>(); derivativeAttr->loc = derivativeOfAttr->loc; auto outterGeneric = visitor->GetOuterGeneric(funcDecl); auto declRef = makeDeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl)); + + // If both the derivative and the original function are defined in the same outer generic + // aggregate type, we want to form a full declref with default arguments. + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, declRef); + auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); declRefExpr->type.type = nullptr; derivativeAttr->args.add(declRefExpr); |
