diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 |
3 files changed, 28 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 2ce4f81f1..76074f551 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -12114,6 +12114,27 @@ static void checkDerivativeAttribute( imaginaryArguments.directions, imaginaryArguments.thisArg, imaginaryArguments.thisArgDirection); + + // For primal-substitute we'd also want to make sure that the differentiability + // level of the target is as high as the funcDecl itself + // + if (auto declRefExpr = as<DeclRefExpr>(attr->funcExpr)) + { + if (auto declRef = declRefExpr->declRef) + { + auto targetDiffLevel = visitor->getShared()->getFuncDifferentiableLevel( + declRef.as<FunctionDeclBase>().getDecl()); + auto currDiffLevel = visitor->getShared()->getFuncDifferentiableLevel(funcDecl); + if (targetDiffLevel < currDiffLevel) + { + visitor->getSink()->diagnose( + attr->loc, + Diagnostics::primalSubstituteTargetMustHaveHigherDifferentiabilityLevel, + declRefExpr->declRef.getDecl(), + funcDecl); + } + } + } } static void checkCudaKernelAttribute( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 59a1bbdb6..acb9beb94 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1208,6 +1208,12 @@ DIAGNOSTIC( Error, overloadedFuncUsedWithDerivativeOfAttributes, "cannot resolve overloaded functions for derivative-of attributes.") +DIAGNOSTIC( + 31158, + Error, + primalSubstituteTargetMustHaveHigherDifferentiabilityLevel, + "primal substitute function for differentiable method must also be differentiable. Use " + "[Differentiable] or [TreatAsDifferentiable] (for empty derivatives)") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.") diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index d65a22e77..9075002e0 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -2471,6 +2471,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_PrimalSubstituteDecoration: case kIROp_AutoDiffOriginalValueDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_IntermediateContextFieldDifferentialTypeDecoration: |
