diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2025-07-09 11:25:29 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-09 09:25:29 -0700 |
| commit | a670bafc121c20168624f70a388dbe8556402c7f (patch) | |
| tree | 79b48a80e7abc0744193716e400bb57a6c026bad /source/slang/slang-ir-check-differentiability.cpp | |
| parent | a7cb36901ccaf8297136c58c1451d6e04420af73 (diff) | |
no_diff diagnostics improvement (#7655)
close #6286.
This PR is to improve the diagnostics for no_diff usage.
In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute.
This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR.
In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning.
This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses.
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index e9cb7e1f1..d83d7bb76 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -82,6 +82,49 @@ public: callInst->findDecoration<IRDifferentiableCallDecoration>()); } + // If a function call takes all literals as arguments, it will implies that this function will + // not be expected to any gradients, in this case, this call should be treated as no_diff even + // there is no 'no_diff' decorated on it explicitly. In the actual check, we only need to check + // the argument corresponding to the differentiable parameters, because non-differentiable + // parameter are not expected to produce any gradients anyway. + bool shouldCallImpliesNoDiff( + DifferentiableTypeConformanceContext& diffTypeContext, + IRCall* callInst) + { + if (shouldTreatCallAsDifferentiable(callInst)) + { + return true; + } + + auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); + if (!calleeFuncType) + return false; + + SLANG_RELEASE_ASSERT(calleeFuncType->getParamCount() == callInst->getArgCount()); + + bool doesImplyNoDiff = true; + UInt paramIndex = 0; + for (auto paramType : calleeFuncType->getParamTypes()) + { + if (isDifferentiableType(diffTypeContext, paramType)) + { + auto arg = callInst->getArg(paramIndex); + if (!as<IRConstant>(arg)) + { + doesImplyNoDiff = false; + } + } + paramIndex++; + } + + if (doesImplyNoDiff) + { + IRBuilder irBuilder(callInst->getModule()); + irBuilder.addDecoration(callInst, kIROp_TreatCallAsDifferentiableDecoration); + } + return doesImplyNoDiff; + } + bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level) { switch (func->getOp()) @@ -497,7 +540,8 @@ public: // No need to fail here if the function is no_diff in // both inputs and all outputs, this is equivalent of // inserting no_diff on this inst. - if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType()))) + if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())) && + !shouldCallImpliesNoDiff(diffTypeContext, call)) { sink->diagnose( inst, |
