diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-01 12:59:51 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-01 12:59:51 -0800 |
| commit | 6c26aa1f7e3e28e3053dffe686baa8e0499c624d (patch) | |
| tree | 4c7268615f1b880866498f2dff0ab580932bfb75 /source/slang/slang-ir-check-differentiability.cpp | |
| parent | 3c32dd951c5d69b5568929e0038e693553efca79 (diff) | |
Improve diagnostic on differentiablitiy check. (#2687)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 77 |
1 files changed, 69 insertions, 8 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 21f53fcbd..186b0cc03 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -186,6 +186,22 @@ public: return false; } + bool instHasNonTrivialDerivative(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_DetachDerivative: + return false; + case kIROp_Call: + { + auto call = as<IRCall>(inst); + return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward); + } + default: + return true; + } + } + void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) @@ -199,6 +215,7 @@ public: HashSet<IRInst*> produceDiffSet; HashSet<IRInst*> expectDiffSet; int differentiableOutputs = 0; + bool isDifferentiableReturnType = false; for (auto param : funcInst->getFirstBlock()->getParams()) { if (isDifferentiableType(diffTypeContext, param->getFullType())) @@ -211,7 +228,10 @@ public: if (auto funcType = as<IRFuncType>(funcInst->getDataType())) { if (isDifferentiableType(diffTypeContext, funcType->getResultType())) + { differentiableOutputs++; + isDifferentiableReturnType = true; + } } if (differentiableOutputs == 0) @@ -305,7 +325,8 @@ public: case kIROp_Store: { auto storeInst = as<IRStore>(inst); - if (isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType())) + if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) && + isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType())) { addToExpectDiffWorkList(storeInst->getVal()); } @@ -314,7 +335,8 @@ public: case kIROp_Return: if (auto returnVal = as<IRReturn>(inst)->getVal()) { - if (isDifferentiableType(diffTypeContext, returnVal->getDataType())) + if (isDifferentiableReturnType && + isDifferentiableType(diffTypeContext, returnVal->getDataType())) { addToExpectDiffWorkList(inst); } @@ -369,6 +391,27 @@ public: } break; } + case kIROp_Call: + { + auto callInst = as<IRCall>(inst); + if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>()) + continue; + if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + continue; + auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); + if (!calleeFuncType) continue; + if (calleeFuncType->getParamCount() != callInst->getArgCount()) + continue; + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + auto paramType = calleeFuncType->getParamType(a); + if (!isDifferentiableType(diffTypeContext, paramType)) + continue; + addToExpectDiffWorkList(arg); + } + break; + } default: // Default behavior is to request all differentiable operands to provide differential. for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++) @@ -417,15 +460,33 @@ public: if (auto storeInst = as<IRStore>(inst)) { if (produceDiffSet.Contains(storeInst->getVal()) && + instHasNonTrivialDerivative(storeInst->getVal()) && !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) { - switch (storeInst->getVal()->getOp()) + sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); + } + } + else if (auto callInst = as<IRCall>(inst)) + { + if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + continue; + auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); + if (!calleeFuncType) + continue; + if (calleeFuncType->getParamCount() != callInst->getArgCount()) + continue; + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + auto paramType = calleeFuncType->getParamType(a); + if (!isDifferentiableType(diffTypeContext, paramType)) + continue; + if (as<IROutTypeBase>(paramType)) { - case kIROp_DetachDerivative: - break; - default: - sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); - break; + if (!canAddressHoldDerivative(diffTypeContext, arg)) + { + sink->diagnose(arg->sourceLoc, Diagnostics::lossOfDerivativeUsingNonDifferentiableLocationAsOutArg); + } } } } |
