diff options
| author | Yong He <yonghe@outlook.com> | 2023-11-03 12:49:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-11-03 12:49:23 -0700 |
| commit | cc222702a8d7a1fccf8ad14b256570bcec1554ae (patch) | |
| tree | ddc6aeba05fe3cbc8618d7b1cb5e63a9d711667d | |
| parent | 911a4401b08f6199e18b32349c236c186a2dd128 (diff) | |
Add more diagnostics on invalid custom derivative use. (#3309)
* Add more diagnostics on invalid custom derivative use.
* cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 80 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | tests/autodiff/custom-derivative-generic.slang | 32 |
4 files changed, 122 insertions, 4 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 15831ba26..64db4cdc5 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6538,6 +6538,25 @@ namespace Slang return parentGeneric; } + Decl* SemanticsVisitor::getOuterGenericOrSelf(Decl* decl) + { + auto parentDecl = decl->parentDecl; + if (!parentDecl) return decl; + auto parentGeneric = as<GenericDecl>(parentDecl); + if (!parentGeneric) return decl; + return parentGeneric; + } + + GenericDecl* SemanticsVisitor::findNextOuterGeneric(Decl* decl) + { + for (auto p = decl->parentDecl; p; p = p->parentDecl) + { + if (auto genDecl = as<GenericDecl>(p)) + return genDecl; + } + return nullptr; + } + DeclRef<ExtensionDecl> SemanticsVisitor::applyExtensionToType( ExtensionDecl* extDecl, Type* type) @@ -7301,6 +7320,40 @@ namespace Slang funcType->getParamType(ii)->toString()); } } + // The `imaginaryArguments` list does not include the `this` parameter. + // So we need to check that `this` type matches. + bool funcIsStatic = isEffectivelyStatic(funcDecl); + bool derivativeFuncIsStatic = isEffectivelyStatic(calleeDeclRef->declRef.getDecl()); + if (funcIsStatic != derivativeFuncIsStatic) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeSignatureThisParamMismatch); + return; + } + if (!funcIsStatic) + { + auto defaultFuncDeclRef = createDefaultSubstitutionsIfNeeded( + visitor->getASTBuilder(), + visitor, + makeDeclRef(funcDecl)); + auto funcThisType = visitor->calcThisType(defaultFuncDeclRef); + auto derivativeFuncThisType = visitor->calcThisType(calleeDeclRef->declRef); + if (!funcThisType->equals(derivativeFuncThisType)) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeSignatureThisParamMismatch); + return; + } + if (visitor->isTypeDifferentiable(funcThisType)) + { + visitor->getSink()->diagnose( + attr, + Diagnostics::customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType); + return; + } + } attr->funcExpr = calleeDeclRef; if (attr->args.getCount()) @@ -7489,9 +7542,10 @@ namespace Slang TDerivativeOfAttr* derivativeOfAttr, DeclAssociationKind assocKind) { + auto astBuilder = visitor->getASTBuilder(); DeclRef<Decl> calleeDeclRef; DeclRefExpr* calleeDeclRefExpr = nullptr; - HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); + HigherOrderInvokeExpr* higherOrderFuncExpr = astBuilder->create<TDifferentiateExpr>(); higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; if (derivativeOfAttr->args.getCount() > 0) higherOrderFuncExpr->loc = derivativeOfAttr->args[0]->loc; @@ -7501,7 +7555,7 @@ namespace Slang visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } - List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc).args; + List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(astBuilder, funcDecl, derivativeOfAttr->loc).args; auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs); SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); @@ -7532,6 +7586,21 @@ namespace Slang visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } + + // For now, if calleeFunc or funcDecl is nested inside some generic aggregate, + // they must be the same generic decl. For example, using B<T>.f() as the original function + // for C<T>.derivative() is not allowed. + // We may relax this restriction in the future by solving the "inverse" generic arguments + // from the `calleeDeclRef`, and use them to create a declRef to funcDecl from the original + // func. + auto originalNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(calleeFunc)); + auto derivativeNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(funcDecl)); + if (originalNextGeneric != derivativeNextGeneric) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + return; + } + if (isInterfaceRequirement(calleeFunc)) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotAssociateInterfaceRequirementWithDerivative); @@ -7555,10 +7624,15 @@ namespace Slang } derivativeOfAttr->funcExpr = calleeDeclRefExpr; - auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); + auto derivativeAttr = astBuilder->create<TDerivativeAttr>(); derivativeAttr->loc = derivativeOfAttr->loc; auto outterGeneric = visitor->GetOuterGeneric(funcDecl); auto declRef = makeDeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl)); + + // If both the derivative and the original function are defined in the same outer generic + // aggregate type, we want to form a full declref with default arguments. + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, declRef); + auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); declRefExpr->type.type = nullptr; derivativeAttr->args.add(declRefExpr); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 7dd46037b..5b67dc413 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2151,9 +2151,19 @@ namespace Slang // If the given declaration has generic parameters, then // return the corresponding `GenericDecl` that holds the - // parameters, etc. + // parameters, etc. This returns the immediate generic parent + // of `decl`, e.g. the generic for f<T>, and *not* any indirect + // generic parents, such as P<T>.f(). GenericDecl* GetOuterGeneric(Decl* decl); + // If `decl` is inside a generic, return that outer generic, + // otherwise returns `decl`. + Decl* getOuterGenericOrSelf(Decl* decl); + + // Find the next outer generic parent of `decl`, including + // indirect parents. + GenericDecl* findNextOuterGeneric(Decl* decl); + // Try to find a unification for two values bool TryUnifyVals( ConstraintSystem& constraints, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index da91449f4..2dc3c7388 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -400,6 +400,8 @@ DIAGNOSTIC(31151, Error, cannotResolveGenericArgumentForDerivativeFunction, "[BackwardDerivativeOf], and [PrimalSubstituteOf] attributes are not supported when the generic arguments to the derivatives cannot be automatically deduced.") DIAGNOSTIC(31152, Error, cannotAssociateInterfaceRequirementWithDerivative, "cannot associate an interface requirement with a derivative.") DIAGNOSTIC(31153, Error, cannotUseInterfaceRequirementAsDerivative, "cannot use an interface requirement as a derivative.") +DIAGNOSTIC(31154, Error, customDerivativeSignatureThisParamMismatch, "custom derivative does not match expected signature on `this`. Either both the original and the derivative function are static, or they must have the same `this` type.") +DIAGNOSTIC(31155, Error, customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType, "custom derivative is not allowed for non-static member functions of a differentiable type.") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") // Enums diff --git a/tests/autodiff/custom-derivative-generic.slang b/tests/autodiff/custom-derivative-generic.slang new file mode 100644 index 000000000..4ecde65cf --- /dev/null +++ b/tests/autodiff/custom-derivative-generic.slang @@ -0,0 +1,32 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +struct Buggy<let N : int> +{ + float m(float x) { return N * x; } + + [BackwardDerivativeOf(m)] + void mDiff(inout DifferentialPair<float> x, float dResult) + { + updateDiff(x, N * dResult); + } +} + +[Differentiable] +float test(float x) +{ + Buggy<2> b; + return b.m(x); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var a = diffPair(3.0); + __bwd_diff(test)(a, 1.0); + outputBuffer[dispatchThreadID.x] = a.d; + // CHECK: 2.0 +} |
