From 6c26aa1f7e3e28e3053dffe686baa8e0499c624d Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 1 Mar 2023 12:59:51 -0800 Subject: Improve diagnostic on differentiablitiy check. (#2687) Co-authored-by: Yong He --- source/slang/slang-ir-check-differentiability.cpp | 77 ++++++++++++++++++++--- 1 file changed, 69 insertions(+), 8 deletions(-) (limited to 'source/slang/slang-ir-check-differentiability.cpp') diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 21f53fcbd..186b0cc03 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -186,6 +186,22 @@ public: return false; } + bool instHasNonTrivialDerivative(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_DetachDerivative: + return false; + case kIROp_Call: + { + auto call = as(inst); + return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward); + } + default: + return true; + } + } + void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) @@ -199,6 +215,7 @@ public: HashSet produceDiffSet; HashSet expectDiffSet; int differentiableOutputs = 0; + bool isDifferentiableReturnType = false; for (auto param : funcInst->getFirstBlock()->getParams()) { if (isDifferentiableType(diffTypeContext, param->getFullType())) @@ -211,7 +228,10 @@ public: if (auto funcType = as(funcInst->getDataType())) { if (isDifferentiableType(diffTypeContext, funcType->getResultType())) + { differentiableOutputs++; + isDifferentiableReturnType = true; + } } if (differentiableOutputs == 0) @@ -305,7 +325,8 @@ public: case kIROp_Store: { auto storeInst = as(inst); - if (isDifferentiableType(diffTypeContext, as(inst)->getPtr()->getDataType())) + if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) && + isDifferentiableType(diffTypeContext, as(inst)->getPtr()->getDataType())) { addToExpectDiffWorkList(storeInst->getVal()); } @@ -314,7 +335,8 @@ public: case kIROp_Return: if (auto returnVal = as(inst)->getVal()) { - if (isDifferentiableType(diffTypeContext, returnVal->getDataType())) + if (isDifferentiableReturnType && + isDifferentiableType(diffTypeContext, returnVal->getDataType())) { addToExpectDiffWorkList(inst); } @@ -369,6 +391,27 @@ public: } break; } + case kIROp_Call: + { + auto callInst = as(inst); + if (callInst->findDecoration()) + continue; + if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + continue; + auto calleeFuncType = as(callInst->getCallee()->getFullType()); + if (!calleeFuncType) continue; + if (calleeFuncType->getParamCount() != callInst->getArgCount()) + continue; + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + auto paramType = calleeFuncType->getParamType(a); + if (!isDifferentiableType(diffTypeContext, paramType)) + continue; + addToExpectDiffWorkList(arg); + } + break; + } default: // Default behavior is to request all differentiable operands to provide differential. for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++) @@ -417,15 +460,33 @@ public: if (auto storeInst = as(inst)) { if (produceDiffSet.Contains(storeInst->getVal()) && + instHasNonTrivialDerivative(storeInst->getVal()) && !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) { - switch (storeInst->getVal()->getOp()) + sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); + } + } + else if (auto callInst = as(inst)) + { + if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + continue; + auto calleeFuncType = as(callInst->getCallee()->getFullType()); + if (!calleeFuncType) + continue; + if (calleeFuncType->getParamCount() != callInst->getArgCount()) + continue; + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + auto paramType = calleeFuncType->getParamType(a); + if (!isDifferentiableType(diffTypeContext, paramType)) + continue; + if (as(paramType)) { - case kIROp_DetachDerivative: - break; - default: - sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); - break; + if (!canAddressHoldDerivative(diffTypeContext, arg)) + { + sink->diagnose(arg->sourceLoc, Diagnostics::lossOfDerivativeUsingNonDifferentiableLocationAsOutArg); + } } } } -- cgit v1.2.3