diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2025-07-09 11:25:29 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-09 09:25:29 -0700 |
| commit | a670bafc121c20168624f70a388dbe8556402c7f (patch) | |
| tree | 79b48a80e7abc0744193716e400bb57a6c026bad | |
| parent | a7cb36901ccaf8297136c58c1451d6e04420af73 (diff) | |
no_diff diagnostics improvement (#7655)
close #6286.
This PR is to improve the diagnostics for no_diff usage.
In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute.
This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR.
In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning.
This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses.
| -rw-r--r-- | source/slang/diff.meta.slang | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 31 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 46 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-bwd-diff.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-generic-2.slang | 3 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-object-bwd-diff.slang | 4 | ||||
| -rw-r--r-- | tests/autodiff/material2/DiffuseMaterial.slang | 1 | ||||
| -rw-r--r-- | tests/autodiff/material2/GlossyMaterial.slang | 1 | ||||
| -rw-r--r-- | tests/autodiff/material2/MxLayeredMaterial.slang | 1 | ||||
| -rw-r--r-- | tests/autodiff/member-func-extension-custom-derivative.slang | 1 | ||||
| -rw-r--r-- | tests/autodiff/nondiff-call.slang | 1 | ||||
| -rw-r--r-- | tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang | 40 | ||||
| -rw-r--r-- | tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang | 48 | ||||
| -rw-r--r-- | tests/diagnostics/force-no-diff-this.slang | 42 |
16 files changed, 239 insertions, 1 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 542983049..b22d91595 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -992,10 +992,12 @@ struct DiffTensorView [BackwardDerivative(__load_backward)] [ForwardDerivative(__load_forward)] + [NoDiffThis] T load(uint i) { return primal.load(i); } [BackwardDerivative(__load_backward)] [ForwardDerivative(__load_forward)] + [NoDiffThis] __generic<let N : int> T load(vector<uint, N> i) { return primal.load(i); } @@ -1026,10 +1028,12 @@ struct DiffTensorView [BackwardDerivative(__store_backward)] [ForwardDerivative(__store_forward)] + [NoDiffThis] void store(uint x, T val) { primal.store(x, val); } [BackwardDerivative(__store_backward)] [ForwardDerivative(__store_forward)] + [NoDiffThis] __generic<let N : int> void store(vector<uint, N> x, T val) { primal.store(x, val); } @@ -1135,10 +1139,12 @@ struct DiffTensorView [BackwardDerivative(__loadOnce_backward)] [ForwardDerivative(__loadOnce_forward)] + [NoDiffThis] T loadOnce(uint i) { return primal.load(i); } [BackwardDerivative(__loadOnce_backward)] [ForwardDerivative(__loadOnce_forward)] + [NoDiffThis] __generic<let N : int> T loadOnce(vector<uint, N> i) { return primal.load(i); } @@ -1168,10 +1174,12 @@ struct DiffTensorView [BackwardDerivative(__storeOnce_backward)] [ForwardDerivative(__storeOnce_forward)] + [NoDiffThis] void storeOnce(uint x, T val) { primal.store(x, val); } [BackwardDerivative(__storeOnce_backward)] [ForwardDerivative(__storeOnce_forward)] + [NoDiffThis] __generic<let N : int> void storeOnce(vector<uint, N> x, T val) { primal.store(x, val); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b28b458da..c7e58a888 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1632,6 +1632,31 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* } } +// This checks that if a differentiable function access a non-diff type "This", in such case we +// want to provide a non-error diagnostic to the user to notify that there could be an unexpected +// behavior because every member access will not have derivative computed for it. User can use +// [NoDiffThis] to clarify that this is intended. +void SemanticsVisitor::maybeCheckMissingNoDiffThis(Expr* expr) +{ + if (auto memberExpr = as<MemberExpr>(expr)) + { + auto thisExpr = as<ThisExpr>(memberExpr->baseExpression); + if (thisExpr && isTypeDifferentiable(memberExpr->type.type)) + { + if (isTypeDifferentiable(calcThisType(thisExpr->type.type)) || + this->m_parentFunc->findModifier<NoDiffThisAttribute>()) + { + return; + } + + getSink()->diagnose( + memberExpr->loc, + Diagnostics::noDerivativeOnNonDifferentiableThisType, + memberExpr->declRef.getDecl(), + this->m_parentFunc); + } + } +} Expr* SemanticsVisitor::CheckTerm(Expr* term) { @@ -1649,7 +1674,13 @@ Expr* SemanticsVisitor::CheckTerm(Expr* term) if (this->m_parentFunc && this->m_parentFunc->findModifier<DifferentiableAttribute>()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + + if (!this->m_parentFunc->findModifier<TreatAsDifferentiableAttribute>()) + { + maybeCheckMissingNoDiffThis(checkedTerm); + } } + return checkedTerm; } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 4a6ccfe17..75d0bfc90 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2041,6 +2041,8 @@ public: // Check and register a type if it is differentiable. void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); + void maybeCheckMissingNoDiffThis(Expr* expr); + // Find the default implementation of an interface requirement, // and insert it to the witness table, if it exists. bool findDefaultInterfaceImpl( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index bf0e91150..249bc4f88 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1399,6 +1399,15 @@ DIAGNOSTIC( primalSubstituteTargetMustHaveHigherDifferentiabilityLevel, "primal substitute function for differentiable method must also be differentiable. Use " "[Differentiable] or [TreatAsDifferentiable] (for empty derivatives)") +DIAGNOSTIC( + 31159, + Warning, + noDerivativeOnNonDifferentiableThisType, + "There is no derivative calculated for member '$0' because the parent struct is not " + "differentiable. " + "If this is intended, consider using [NoDiffThis] on the function '$1' to suppress this " + "warning. Alternatively, users can mark the parent struct as [Differentiable] to propagate " + "derivatives.") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.") diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index e9cb7e1f1..d83d7bb76 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -82,6 +82,49 @@ public: callInst->findDecoration<IRDifferentiableCallDecoration>()); } + // If a function call takes all literals as arguments, it will implies that this function will + // not be expected to any gradients, in this case, this call should be treated as no_diff even + // there is no 'no_diff' decorated on it explicitly. In the actual check, we only need to check + // the argument corresponding to the differentiable parameters, because non-differentiable + // parameter are not expected to produce any gradients anyway. + bool shouldCallImpliesNoDiff( + DifferentiableTypeConformanceContext& diffTypeContext, + IRCall* callInst) + { + if (shouldTreatCallAsDifferentiable(callInst)) + { + return true; + } + + auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); + if (!calleeFuncType) + return false; + + SLANG_RELEASE_ASSERT(calleeFuncType->getParamCount() == callInst->getArgCount()); + + bool doesImplyNoDiff = true; + UInt paramIndex = 0; + for (auto paramType : calleeFuncType->getParamTypes()) + { + if (isDifferentiableType(diffTypeContext, paramType)) + { + auto arg = callInst->getArg(paramIndex); + if (!as<IRConstant>(arg)) + { + doesImplyNoDiff = false; + } + } + paramIndex++; + } + + if (doesImplyNoDiff) + { + IRBuilder irBuilder(callInst->getModule()); + irBuilder.addDecoration(callInst, kIROp_TreatCallAsDifferentiableDecoration); + } + return doesImplyNoDiff; + } + bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level) { switch (func->getOp()) @@ -497,7 +540,8 @@ public: // No need to fail here if the function is no_diff in // both inputs and all outputs, this is equivalent of // inserting no_diff on this inst. - if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType()))) + if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())) && + !shouldCallImpliesNoDiff(diffTypeContext, call)) { sink->diagnose( inst, diff --git a/tests/autodiff/dynamic-dispatch-bwd-diff.slang b/tests/autodiff/dynamic-dispatch-bwd-diff.slang index 5945c22cd..9941aa7b1 100644 --- a/tests/autodiff/dynamic-dispatch-bwd-diff.slang +++ b/tests/autodiff/dynamic-dispatch-bwd-diff.slang @@ -17,6 +17,7 @@ struct A : IInterface { float a; [BackwardDifferentiable] + [NoDiffThis] float calc(float x) { return a*x*x; } }; @@ -24,6 +25,7 @@ struct B : IInterface { float a; [BackwardDifferentiable] + [NoDiffThis] float calc(float x) { return a*x*x*x; } }; diff --git a/tests/autodiff/dynamic-dispatch-generic-2.slang b/tests/autodiff/dynamic-dispatch-generic-2.slang index bbf7c7da1..97cb5f42c 100644 --- a/tests/autodiff/dynamic-dispatch-generic-2.slang +++ b/tests/autodiff/dynamic-dispatch-generic-2.slang @@ -16,7 +16,9 @@ interface IInterface struct A : IInterface { float z; + [ForwardDifferentiable] + [NoDiffThis] float calc(float x) { return x * x * x; } }; @@ -25,6 +27,7 @@ struct B : IInterface float z; [ForwardDifferentiable] + [NoDiffThis] float calc(float x) { return x * x + z; } }; diff --git a/tests/autodiff/dynamic-object-bwd-diff.slang b/tests/autodiff/dynamic-object-bwd-diff.slang index a10c48f9b..a80025d52 100644 --- a/tests/autodiff/dynamic-object-bwd-diff.slang +++ b/tests/autodiff/dynamic-object-bwd-diff.slang @@ -26,7 +26,9 @@ struct C : IInterface2 struct A : IInterface { float a; + [BackwardDifferentiable] + [NoDiffThis] float calc(IInterface2 i2, float x) { float b = no_diff(i2.innerCalc(x)); @@ -37,7 +39,9 @@ struct A : IInterface struct B : IInterface { float a; + [BackwardDifferentiable] + [NoDiffThis] float calc(IInterface2 i2, float x) { float b = no_diff(i2.innerCalc(x)); diff --git a/tests/autodiff/material2/DiffuseMaterial.slang b/tests/autodiff/material2/DiffuseMaterial.slang index 1422ee30c..a5c0aaa30 100644 --- a/tests/autodiff/material2/DiffuseMaterial.slang +++ b/tests/autodiff/material2/DiffuseMaterial.slang @@ -16,6 +16,7 @@ public struct DiffuseMaterial : IMaterial } [BackwardDifferentiable] + [NoDiffThis] public DiffuseMaterialInstance setupMaterialInstance(out MaterialInstanceData miData) { float3 albedo = getAlbedo(baseColor); diff --git a/tests/autodiff/material2/GlossyMaterial.slang b/tests/autodiff/material2/GlossyMaterial.slang index 1070c6e63..12fbc5f73 100644 --- a/tests/autodiff/material2/GlossyMaterial.slang +++ b/tests/autodiff/material2/GlossyMaterial.slang @@ -30,6 +30,7 @@ public struct GlossyMaterial : IMaterial } [BackwardDifferentiable] + [NoDiffThis] public GlossyMaterialInstance setupMaterialInstance(out MaterialInstanceData miData) { float3 albedo = getAlbedo(baseColor); diff --git a/tests/autodiff/material2/MxLayeredMaterial.slang b/tests/autodiff/material2/MxLayeredMaterial.slang index fabfde80c..9e88d2ce9 100644 --- a/tests/autodiff/material2/MxLayeredMaterial.slang +++ b/tests/autodiff/material2/MxLayeredMaterial.slang @@ -52,6 +52,7 @@ public struct MxLayeredMaterial : IMaterial } [Differentiable] + [NoDiffThis] public UsedMaterialInstance setupMaterialInstance(out MaterialInstanceData miData) { float3 albedo = getAlbedo(baseColor); diff --git a/tests/autodiff/member-func-extension-custom-derivative.slang b/tests/autodiff/member-func-extension-custom-derivative.slang index 8752dfff5..72fbcf9a1 100644 --- a/tests/autodiff/member-func-extension-custom-derivative.slang +++ b/tests/autodiff/member-func-extension-custom-derivative.slang @@ -32,6 +32,7 @@ struct Foo<T : IDifferentiable> extension<T : IDifferentiable> Foo<T> { [ForwardDerivativeOf(doThing)] + [NoDiffThis] DifferentialPair<T> diff_doThing() { return diffPair(value, T.dzero()); diff --git a/tests/autodiff/nondiff-call.slang b/tests/autodiff/nondiff-call.slang index d62de1b78..79d9b9f15 100644 --- a/tests/autodiff/nondiff-call.slang +++ b/tests/autodiff/nondiff-call.slang @@ -35,6 +35,7 @@ struct A float o; [ForwardDifferentiable] + [NoDiffThis] float doSomethingDifferentiable(float b) { return o + b; diff --git a/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang new file mode 100644 index 000000000..961ac75d5 --- /dev/null +++ b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang @@ -0,0 +1,40 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + +float someNoDiffFunc(float x, no_diff float y) +{ + return x * x + y * y; +} + +// Previously, when we call a no-diff function side a differntiable function, we will have to use no_diff to tell compiler that this is intended. +// However, if the parameter is just a constant, there is no need to use no_diff, because constant won't carry any derivative information. +// Therefore, this test is to check we won't report any error when the parameter is a constant in this case. +[Differentiable] +float eval(float x) +{ + // CHECK-NOT: ([[# @LINE+1]]): error 41020 + return exp(x) - someNoDiffFunc(1.0f, x); +} + +[Differentiable] +float eval1(float x) +{ + // CHECK: ([[# @LINE+1]]): error 41020 + return exp(x) - someNoDiffFunc(x, 1.0); +} + +RWStructuredBuffer<float> output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + var x = diffPair(2.0f); + bwd_diff(eval)(x, 1.0f); + + output[0] = x.d; + + var x1 = diffPair(2.0f); + bwd_diff(eval1)(x1, 1.0f); + output[1] = x1.d; +} + diff --git a/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang new file mode 100644 index 000000000..f27c6ec6b --- /dev/null +++ b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang @@ -0,0 +1,48 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + + +// Similar to const-to-nodiff-function-diagnostic-improvement.slang, but with a CoopVec type +// to reproduce a more realistic scenario. +extension<T : __BuiltinFloatingPointType, let K : int> CoopVec<T, K> : IDifferentiable +{ + typealias Differential = CoopVec<T, K>; +}; + +[BackwardDerivativeOf(exp)] +void exp_BackwardAutoDiff<T : __BuiltinFloatingPointType, let K : int>(inout DifferentialPair<CoopVec<T, K>> p0, CoopVec<T, K>.Differential dResult) +{ + p0 = diffPair(p0.p, dResult * exp(p0.p)); +} + +[Differentiable] +CoopVec<T, K> eval<T : __BuiltinFloatingPointType, let K : int>(CoopVec<T, K> x) +{ + // CHECK-NOT: ([[# @LINE+1]]): error 41020 + return exp(x) - CoopVec<T, K>(1.); +} + +[Differentiable] +CoopVec<T, K> eval1<T : __BuiltinFloatingPointType, let K : int>(CoopVec<T, K> x) +{ + // test.slang(25): error 41020: derivative cannot be propagated through call to non-backward-differentiable function `CoopVec.$init`, use 'no_diff' to clarify intention. + // CHECK: ([[# @LINE+1]]): error 41020 + return exp(x) - CoopVec<T, K>(x[0]); +} + + +RWStructuredBuffer<float> output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + var x = diffPair(CoopVec<float, 2>(2.0f), CoopVec<float, 2>(1.0f)); + bwd_diff(eval)(x, CoopVec<float, 2>(1.0f)); + + output[0] = x.d[0]; + + var x1 = diffPair(CoopVec<float, 2>(2.0f), CoopVec<float, 2>(1.0f)); + bwd_diff(eval1)(x1, CoopVec<float, 2>(1.0f)); + output[1] = x1.d[1]; +} + diff --git a/tests/diagnostics/force-no-diff-this.slang b/tests/diagnostics/force-no-diff-this.slang new file mode 100644 index 000000000..ae1464ffb --- /dev/null +++ b/tests/diagnostics/force-no-diff-this.slang @@ -0,0 +1,42 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + +struct MyStruct<T> where T: __BuiltinFloatingPointType +{ + float a; + __init(float a) { this.a = a;} + + [Differentiable] + T eval(T x) + { + //CHECK: ([[# @LINE+1]]): warning 31159 + return exp(x * T(a) * T(a)); + } + + [Differentiable] + [NoDiffThis] + T eval1(T x) + { + //CHECK-NOT: ([[# @LINE+1]]): warning 31159 + return exp(x * T(a) * T(a)); + } +}; + +[Differentiable] +float evalFunc(float x) +{ + MyStruct<float> s = {x}; + return s.eval(x) + s.eval1(x); +} + +RWStructuredBuffer<float> output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + var x = diffPair(2.0f); + bwd_diff(evalFunc)(x, 1.0f); + + output[0] = x.d; +} + |
