summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-07-09 11:25:29 -0500
committerGitHub <noreply@github.com>2025-07-09 09:25:29 -0700
commita670bafc121c20168624f70a388dbe8556402c7f (patch)
tree79b48a80e7abc0744193716e400bb57a6c026bad
parenta7cb36901ccaf8297136c58c1451d6e04420af73 (diff)
no_diff diagnostics improvement (#7655)
close #6286. This PR is to improve the diagnostics for no_diff usage. In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute. This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR. In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning. This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses.
-rw-r--r--source/slang/diff.meta.slang8
-rw-r--r--source/slang/slang-check-expr.cpp31
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-diagnostic-defs.h9
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp46
-rw-r--r--tests/autodiff/dynamic-dispatch-bwd-diff.slang2
-rw-r--r--tests/autodiff/dynamic-dispatch-generic-2.slang3
-rw-r--r--tests/autodiff/dynamic-object-bwd-diff.slang4
-rw-r--r--tests/autodiff/material2/DiffuseMaterial.slang1
-rw-r--r--tests/autodiff/material2/GlossyMaterial.slang1
-rw-r--r--tests/autodiff/material2/MxLayeredMaterial.slang1
-rw-r--r--tests/autodiff/member-func-extension-custom-derivative.slang1
-rw-r--r--tests/autodiff/nondiff-call.slang1
-rw-r--r--tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang40
-rw-r--r--tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang48
-rw-r--r--tests/diagnostics/force-no-diff-this.slang42
16 files changed, 239 insertions, 1 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 542983049..b22d91595 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -992,10 +992,12 @@ struct DiffTensorView
[BackwardDerivative(__load_backward)]
[ForwardDerivative(__load_forward)]
+ [NoDiffThis]
T load(uint i) { return primal.load(i); }
[BackwardDerivative(__load_backward)]
[ForwardDerivative(__load_forward)]
+ [NoDiffThis]
__generic<let N : int>
T load(vector<uint, N> i) { return primal.load(i); }
@@ -1026,10 +1028,12 @@ struct DiffTensorView
[BackwardDerivative(__store_backward)]
[ForwardDerivative(__store_forward)]
+ [NoDiffThis]
void store(uint x, T val) { primal.store(x, val); }
[BackwardDerivative(__store_backward)]
[ForwardDerivative(__store_forward)]
+ [NoDiffThis]
__generic<let N : int>
void store(vector<uint, N> x, T val) { primal.store(x, val); }
@@ -1135,10 +1139,12 @@ struct DiffTensorView
[BackwardDerivative(__loadOnce_backward)]
[ForwardDerivative(__loadOnce_forward)]
+ [NoDiffThis]
T loadOnce(uint i) { return primal.load(i); }
[BackwardDerivative(__loadOnce_backward)]
[ForwardDerivative(__loadOnce_forward)]
+ [NoDiffThis]
__generic<let N : int>
T loadOnce(vector<uint, N> i) { return primal.load(i); }
@@ -1168,10 +1174,12 @@ struct DiffTensorView
[BackwardDerivative(__storeOnce_backward)]
[ForwardDerivative(__storeOnce_forward)]
+ [NoDiffThis]
void storeOnce(uint x, T val) { primal.store(x, val); }
[BackwardDerivative(__storeOnce_backward)]
[ForwardDerivative(__storeOnce_forward)]
+ [NoDiffThis]
__generic<let N : int>
void storeOnce(vector<uint, N> x, T val) { primal.store(x, val); }
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b28b458da..c7e58a888 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1632,6 +1632,31 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder*
}
}
+// This checks that if a differentiable function access a non-diff type "This", in such case we
+// want to provide a non-error diagnostic to the user to notify that there could be an unexpected
+// behavior because every member access will not have derivative computed for it. User can use
+// [NoDiffThis] to clarify that this is intended.
+void SemanticsVisitor::maybeCheckMissingNoDiffThis(Expr* expr)
+{
+ if (auto memberExpr = as<MemberExpr>(expr))
+ {
+ auto thisExpr = as<ThisExpr>(memberExpr->baseExpression);
+ if (thisExpr && isTypeDifferentiable(memberExpr->type.type))
+ {
+ if (isTypeDifferentiable(calcThisType(thisExpr->type.type)) ||
+ this->m_parentFunc->findModifier<NoDiffThisAttribute>())
+ {
+ return;
+ }
+
+ getSink()->diagnose(
+ memberExpr->loc,
+ Diagnostics::noDerivativeOnNonDifferentiableThisType,
+ memberExpr->declRef.getDecl(),
+ this->m_parentFunc);
+ }
+ }
+}
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
@@ -1649,7 +1674,13 @@ Expr* SemanticsVisitor::CheckTerm(Expr* term)
if (this->m_parentFunc && this->m_parentFunc->findModifier<DifferentiableAttribute>())
{
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
+
+ if (!this->m_parentFunc->findModifier<TreatAsDifferentiableAttribute>())
+ {
+ maybeCheckMissingNoDiffThis(checkedTerm);
+ }
}
+
return checkedTerm;
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 4a6ccfe17..75d0bfc90 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2041,6 +2041,8 @@ public:
// Check and register a type if it is differentiable.
void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type);
+ void maybeCheckMissingNoDiffThis(Expr* expr);
+
// Find the default implementation of an interface requirement,
// and insert it to the witness table, if it exists.
bool findDefaultInterfaceImpl(
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index bf0e91150..249bc4f88 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -1399,6 +1399,15 @@ DIAGNOSTIC(
primalSubstituteTargetMustHaveHigherDifferentiabilityLevel,
"primal substitute function for differentiable method must also be differentiable. Use "
"[Differentiable] or [TreatAsDifferentiable] (for empty derivatives)")
+DIAGNOSTIC(
+ 31159,
+ Warning,
+ noDerivativeOnNonDifferentiableThisType,
+ "There is no derivative calculated for member '$0' because the parent struct is not "
+ "differentiable. "
+ "If this is intended, consider using [NoDiffThis] on the function '$1' to suppress this "
+ "warning. Alternatively, users can mark the parent struct as [Differentiable] to propagate "
+ "derivatives.")
DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1")
DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.")
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index e9cb7e1f1..d83d7bb76 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -82,6 +82,49 @@ public:
callInst->findDecoration<IRDifferentiableCallDecoration>());
}
+ // If a function call takes all literals as arguments, it will implies that this function will
+ // not be expected to any gradients, in this case, this call should be treated as no_diff even
+ // there is no 'no_diff' decorated on it explicitly. In the actual check, we only need to check
+ // the argument corresponding to the differentiable parameters, because non-differentiable
+ // parameter are not expected to produce any gradients anyway.
+ bool shouldCallImpliesNoDiff(
+ DifferentiableTypeConformanceContext& diffTypeContext,
+ IRCall* callInst)
+ {
+ if (shouldTreatCallAsDifferentiable(callInst))
+ {
+ return true;
+ }
+
+ auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
+ if (!calleeFuncType)
+ return false;
+
+ SLANG_RELEASE_ASSERT(calleeFuncType->getParamCount() == callInst->getArgCount());
+
+ bool doesImplyNoDiff = true;
+ UInt paramIndex = 0;
+ for (auto paramType : calleeFuncType->getParamTypes())
+ {
+ if (isDifferentiableType(diffTypeContext, paramType))
+ {
+ auto arg = callInst->getArg(paramIndex);
+ if (!as<IRConstant>(arg))
+ {
+ doesImplyNoDiff = false;
+ }
+ }
+ paramIndex++;
+ }
+
+ if (doesImplyNoDiff)
+ {
+ IRBuilder irBuilder(callInst->getModule());
+ irBuilder.addDecoration(callInst, kIROp_TreatCallAsDifferentiableDecoration);
+ }
+ return doesImplyNoDiff;
+ }
+
bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
switch (func->getOp())
@@ -497,7 +540,8 @@ public:
// No need to fail here if the function is no_diff in
// both inputs and all outputs, this is equivalent of
// inserting no_diff on this inst.
- if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())))
+ if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())) &&
+ !shouldCallImpliesNoDiff(diffTypeContext, call))
{
sink->diagnose(
inst,
diff --git a/tests/autodiff/dynamic-dispatch-bwd-diff.slang b/tests/autodiff/dynamic-dispatch-bwd-diff.slang
index 5945c22cd..9941aa7b1 100644
--- a/tests/autodiff/dynamic-dispatch-bwd-diff.slang
+++ b/tests/autodiff/dynamic-dispatch-bwd-diff.slang
@@ -17,6 +17,7 @@ struct A : IInterface
{
float a;
[BackwardDifferentiable]
+ [NoDiffThis]
float calc(float x) { return a*x*x; }
};
@@ -24,6 +25,7 @@ struct B : IInterface
{
float a;
[BackwardDifferentiable]
+ [NoDiffThis]
float calc(float x) { return a*x*x*x; }
};
diff --git a/tests/autodiff/dynamic-dispatch-generic-2.slang b/tests/autodiff/dynamic-dispatch-generic-2.slang
index bbf7c7da1..97cb5f42c 100644
--- a/tests/autodiff/dynamic-dispatch-generic-2.slang
+++ b/tests/autodiff/dynamic-dispatch-generic-2.slang
@@ -16,7 +16,9 @@ interface IInterface
struct A : IInterface
{
float z;
+
[ForwardDifferentiable]
+ [NoDiffThis]
float calc(float x) { return x * x * x; }
};
@@ -25,6 +27,7 @@ struct B : IInterface
float z;
[ForwardDifferentiable]
+ [NoDiffThis]
float calc(float x) { return x * x + z; }
};
diff --git a/tests/autodiff/dynamic-object-bwd-diff.slang b/tests/autodiff/dynamic-object-bwd-diff.slang
index a10c48f9b..a80025d52 100644
--- a/tests/autodiff/dynamic-object-bwd-diff.slang
+++ b/tests/autodiff/dynamic-object-bwd-diff.slang
@@ -26,7 +26,9 @@ struct C : IInterface2
struct A : IInterface
{
float a;
+
[BackwardDifferentiable]
+ [NoDiffThis]
float calc(IInterface2 i2, float x)
{
float b = no_diff(i2.innerCalc(x));
@@ -37,7 +39,9 @@ struct A : IInterface
struct B : IInterface
{
float a;
+
[BackwardDifferentiable]
+ [NoDiffThis]
float calc(IInterface2 i2, float x)
{
float b = no_diff(i2.innerCalc(x));
diff --git a/tests/autodiff/material2/DiffuseMaterial.slang b/tests/autodiff/material2/DiffuseMaterial.slang
index 1422ee30c..a5c0aaa30 100644
--- a/tests/autodiff/material2/DiffuseMaterial.slang
+++ b/tests/autodiff/material2/DiffuseMaterial.slang
@@ -16,6 +16,7 @@ public struct DiffuseMaterial : IMaterial
}
[BackwardDifferentiable]
+ [NoDiffThis]
public DiffuseMaterialInstance setupMaterialInstance(out MaterialInstanceData miData)
{
float3 albedo = getAlbedo(baseColor);
diff --git a/tests/autodiff/material2/GlossyMaterial.slang b/tests/autodiff/material2/GlossyMaterial.slang
index 1070c6e63..12fbc5f73 100644
--- a/tests/autodiff/material2/GlossyMaterial.slang
+++ b/tests/autodiff/material2/GlossyMaterial.slang
@@ -30,6 +30,7 @@ public struct GlossyMaterial : IMaterial
}
[BackwardDifferentiable]
+ [NoDiffThis]
public GlossyMaterialInstance setupMaterialInstance(out MaterialInstanceData miData)
{
float3 albedo = getAlbedo(baseColor);
diff --git a/tests/autodiff/material2/MxLayeredMaterial.slang b/tests/autodiff/material2/MxLayeredMaterial.slang
index fabfde80c..9e88d2ce9 100644
--- a/tests/autodiff/material2/MxLayeredMaterial.slang
+++ b/tests/autodiff/material2/MxLayeredMaterial.slang
@@ -52,6 +52,7 @@ public struct MxLayeredMaterial : IMaterial
}
[Differentiable]
+ [NoDiffThis]
public UsedMaterialInstance setupMaterialInstance(out MaterialInstanceData miData)
{
float3 albedo = getAlbedo(baseColor);
diff --git a/tests/autodiff/member-func-extension-custom-derivative.slang b/tests/autodiff/member-func-extension-custom-derivative.slang
index 8752dfff5..72fbcf9a1 100644
--- a/tests/autodiff/member-func-extension-custom-derivative.slang
+++ b/tests/autodiff/member-func-extension-custom-derivative.slang
@@ -32,6 +32,7 @@ struct Foo<T : IDifferentiable>
extension<T : IDifferentiable> Foo<T>
{
[ForwardDerivativeOf(doThing)]
+ [NoDiffThis]
DifferentialPair<T> diff_doThing()
{
return diffPair(value, T.dzero());
diff --git a/tests/autodiff/nondiff-call.slang b/tests/autodiff/nondiff-call.slang
index d62de1b78..79d9b9f15 100644
--- a/tests/autodiff/nondiff-call.slang
+++ b/tests/autodiff/nondiff-call.slang
@@ -35,6 +35,7 @@ struct A
float o;
[ForwardDifferentiable]
+ [NoDiffThis]
float doSomethingDifferentiable(float b)
{
return o + b;
diff --git a/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang
new file mode 100644
index 000000000..961ac75d5
--- /dev/null
+++ b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang
@@ -0,0 +1,40 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+
+float someNoDiffFunc(float x, no_diff float y)
+{
+ return x * x + y * y;
+}
+
+// Previously, when we call a no-diff function side a differntiable function, we will have to use no_diff to tell compiler that this is intended.
+// However, if the parameter is just a constant, there is no need to use no_diff, because constant won't carry any derivative information.
+// Therefore, this test is to check we won't report any error when the parameter is a constant in this case.
+[Differentiable]
+float eval(float x)
+{
+ // CHECK-NOT: ([[# @LINE+1]]): error 41020
+ return exp(x) - someNoDiffFunc(1.0f, x);
+}
+
+[Differentiable]
+float eval1(float x)
+{
+ // CHECK: ([[# @LINE+1]]): error 41020
+ return exp(x) - someNoDiffFunc(x, 1.0);
+}
+
+RWStructuredBuffer<float> output;
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void computeMain(uint id : SV_DispatchThreadID)
+{
+ var x = diffPair(2.0f);
+ bwd_diff(eval)(x, 1.0f);
+
+ output[0] = x.d;
+
+ var x1 = diffPair(2.0f);
+ bwd_diff(eval1)(x1, 1.0f);
+ output[1] = x1.d;
+}
+
diff --git a/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang
new file mode 100644
index 000000000..f27c6ec6b
--- /dev/null
+++ b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang
@@ -0,0 +1,48 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+
+
+// Similar to const-to-nodiff-function-diagnostic-improvement.slang, but with a CoopVec type
+// to reproduce a more realistic scenario.
+extension<T : __BuiltinFloatingPointType, let K : int> CoopVec<T, K> : IDifferentiable
+{
+ typealias Differential = CoopVec<T, K>;
+};
+
+[BackwardDerivativeOf(exp)]
+void exp_BackwardAutoDiff<T : __BuiltinFloatingPointType, let K : int>(inout DifferentialPair<CoopVec<T, K>> p0, CoopVec<T, K>.Differential dResult)
+{
+ p0 = diffPair(p0.p, dResult * exp(p0.p));
+}
+
+[Differentiable]
+CoopVec<T, K> eval<T : __BuiltinFloatingPointType, let K : int>(CoopVec<T, K> x)
+{
+ // CHECK-NOT: ([[# @LINE+1]]): error 41020
+ return exp(x) - CoopVec<T, K>(1.);
+}
+
+[Differentiable]
+CoopVec<T, K> eval1<T : __BuiltinFloatingPointType, let K : int>(CoopVec<T, K> x)
+{
+ // test.slang(25): error 41020: derivative cannot be propagated through call to non-backward-differentiable function `CoopVec.$init`, use 'no_diff' to clarify intention.
+ // CHECK: ([[# @LINE+1]]): error 41020
+ return exp(x) - CoopVec<T, K>(x[0]);
+}
+
+
+RWStructuredBuffer<float> output;
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void computeMain(uint id : SV_DispatchThreadID)
+{
+ var x = diffPair(CoopVec<float, 2>(2.0f), CoopVec<float, 2>(1.0f));
+ bwd_diff(eval)(x, CoopVec<float, 2>(1.0f));
+
+ output[0] = x.d[0];
+
+ var x1 = diffPair(CoopVec<float, 2>(2.0f), CoopVec<float, 2>(1.0f));
+ bwd_diff(eval1)(x1, CoopVec<float, 2>(1.0f));
+ output[1] = x1.d[1];
+}
+
diff --git a/tests/diagnostics/force-no-diff-this.slang b/tests/diagnostics/force-no-diff-this.slang
new file mode 100644
index 000000000..ae1464ffb
--- /dev/null
+++ b/tests/diagnostics/force-no-diff-this.slang
@@ -0,0 +1,42 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+
+struct MyStruct<T> where T: __BuiltinFloatingPointType
+{
+ float a;
+ __init(float a) { this.a = a;}
+
+ [Differentiable]
+ T eval(T x)
+ {
+ //CHECK: ([[# @LINE+1]]): warning 31159
+ return exp(x * T(a) * T(a));
+ }
+
+ [Differentiable]
+ [NoDiffThis]
+ T eval1(T x)
+ {
+ //CHECK-NOT: ([[# @LINE+1]]): warning 31159
+ return exp(x * T(a) * T(a));
+ }
+};
+
+[Differentiable]
+float evalFunc(float x)
+{
+ MyStruct<float> s = {x};
+ return s.eval(x) + s.eval1(x);
+}
+
+RWStructuredBuffer<float> output;
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void computeMain(uint id : SV_DispatchThreadID)
+{
+ var x = diffPair(2.0f);
+ bwd_diff(evalFunc)(x, 1.0f);
+
+ output[0] = x.d;
+}
+