summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-01 12:59:51 -0800
committerGitHub <noreply@github.com>2023-03-01 12:59:51 -0800
commit6c26aa1f7e3e28e3053dffe686baa8e0499c624d (patch)
tree4c7268615f1b880866498f2dff0ab580932bfb75 /source/slang/slang-ir-check-differentiability.cpp
parent3c32dd951c5d69b5568929e0038e693553efca79 (diff)
Improve diagnostic on differentiablitiy check. (#2687)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp77
1 files changed, 69 insertions, 8 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 21f53fcbd..186b0cc03 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -186,6 +186,22 @@ public:
return false;
}
+ bool instHasNonTrivialDerivative(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_DetachDerivative:
+ return false;
+ case kIROp_Call:
+ {
+ auto call = as<IRCall>(inst);
+ return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward);
+ }
+ default:
+ return true;
+ }
+ }
+
void processFunc(IRGlobalValueWithCode* funcInst)
{
if (!_isFuncMarkedForAutoDiff(funcInst))
@@ -199,6 +215,7 @@ public:
HashSet<IRInst*> produceDiffSet;
HashSet<IRInst*> expectDiffSet;
int differentiableOutputs = 0;
+ bool isDifferentiableReturnType = false;
for (auto param : funcInst->getFirstBlock()->getParams())
{
if (isDifferentiableType(diffTypeContext, param->getFullType()))
@@ -211,7 +228,10 @@ public:
if (auto funcType = as<IRFuncType>(funcInst->getDataType()))
{
if (isDifferentiableType(diffTypeContext, funcType->getResultType()))
+ {
differentiableOutputs++;
+ isDifferentiableReturnType = true;
+ }
}
if (differentiableOutputs == 0)
@@ -305,7 +325,8 @@ public:
case kIROp_Store:
{
auto storeInst = as<IRStore>(inst);
- if (isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType()))
+ if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) &&
+ isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType()))
{
addToExpectDiffWorkList(storeInst->getVal());
}
@@ -314,7 +335,8 @@ public:
case kIROp_Return:
if (auto returnVal = as<IRReturn>(inst)->getVal())
{
- if (isDifferentiableType(diffTypeContext, returnVal->getDataType()))
+ if (isDifferentiableReturnType &&
+ isDifferentiableType(diffTypeContext, returnVal->getDataType()))
{
addToExpectDiffWorkList(inst);
}
@@ -369,6 +391,27 @@ public:
}
break;
}
+ case kIROp_Call:
+ {
+ auto callInst = as<IRCall>(inst);
+ if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>())
+ continue;
+ if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward))
+ continue;
+ auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
+ if (!calleeFuncType) continue;
+ if (calleeFuncType->getParamCount() != callInst->getArgCount())
+ continue;
+ for (UInt a = 0; a < callInst->getArgCount(); a++)
+ {
+ auto arg = callInst->getArg(a);
+ auto paramType = calleeFuncType->getParamType(a);
+ if (!isDifferentiableType(diffTypeContext, paramType))
+ continue;
+ addToExpectDiffWorkList(arg);
+ }
+ break;
+ }
default:
// Default behavior is to request all differentiable operands to provide differential.
for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++)
@@ -417,15 +460,33 @@ public:
if (auto storeInst = as<IRStore>(inst))
{
if (produceDiffSet.Contains(storeInst->getVal()) &&
+ instHasNonTrivialDerivative(storeInst->getVal()) &&
!canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
{
- switch (storeInst->getVal()->getOp())
+ sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
+ }
+ }
+ else if (auto callInst = as<IRCall>(inst))
+ {
+ if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward))
+ continue;
+ auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
+ if (!calleeFuncType)
+ continue;
+ if (calleeFuncType->getParamCount() != callInst->getArgCount())
+ continue;
+ for (UInt a = 0; a < callInst->getArgCount(); a++)
+ {
+ auto arg = callInst->getArg(a);
+ auto paramType = calleeFuncType->getParamType(a);
+ if (!isDifferentiableType(diffTypeContext, paramType))
+ continue;
+ if (as<IROutTypeBase>(paramType))
{
- case kIROp_DetachDerivative:
- break;
- default:
- sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
- break;
+ if (!canAddressHoldDerivative(diffTypeContext, arg))
+ {
+ sink->diagnose(arg->sourceLoc, Diagnostics::lossOfDerivativeUsingNonDifferentiableLocationAsOutArg);
+ }
}
}
}