summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-07-09 11:25:29 -0500
committerGitHub <noreply@github.com>2025-07-09 09:25:29 -0700
commita670bafc121c20168624f70a388dbe8556402c7f (patch)
tree79b48a80e7abc0744193716e400bb57a6c026bad /source/slang/slang-ir-check-differentiability.cpp
parenta7cb36901ccaf8297136c58c1451d6e04420af73 (diff)
no_diff diagnostics improvement (#7655)
close #6286. This PR is to improve the diagnostics for no_diff usage. In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute. This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR. In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning. This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses.
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp46
1 files changed, 45 insertions, 1 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index e9cb7e1f1..d83d7bb76 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -82,6 +82,49 @@ public:
callInst->findDecoration<IRDifferentiableCallDecoration>());
}
+ // If a function call takes all literals as arguments, it will implies that this function will
+ // not be expected to any gradients, in this case, this call should be treated as no_diff even
+ // there is no 'no_diff' decorated on it explicitly. In the actual check, we only need to check
+ // the argument corresponding to the differentiable parameters, because non-differentiable
+ // parameter are not expected to produce any gradients anyway.
+ bool shouldCallImpliesNoDiff(
+ DifferentiableTypeConformanceContext& diffTypeContext,
+ IRCall* callInst)
+ {
+ if (shouldTreatCallAsDifferentiable(callInst))
+ {
+ return true;
+ }
+
+ auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
+ if (!calleeFuncType)
+ return false;
+
+ SLANG_RELEASE_ASSERT(calleeFuncType->getParamCount() == callInst->getArgCount());
+
+ bool doesImplyNoDiff = true;
+ UInt paramIndex = 0;
+ for (auto paramType : calleeFuncType->getParamTypes())
+ {
+ if (isDifferentiableType(diffTypeContext, paramType))
+ {
+ auto arg = callInst->getArg(paramIndex);
+ if (!as<IRConstant>(arg))
+ {
+ doesImplyNoDiff = false;
+ }
+ }
+ paramIndex++;
+ }
+
+ if (doesImplyNoDiff)
+ {
+ IRBuilder irBuilder(callInst->getModule());
+ irBuilder.addDecoration(callInst, kIROp_TreatCallAsDifferentiableDecoration);
+ }
+ return doesImplyNoDiff;
+ }
+
bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
switch (func->getOp())
@@ -497,7 +540,8 @@ public:
// No need to fail here if the function is no_diff in
// both inputs and all outputs, this is equivalent of
// inserting no_diff on this inst.
- if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())))
+ if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())) &&
+ !shouldCallImpliesNoDiff(diffTypeContext, call))
{
sink->diagnose(
inst,