summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-01 12:59:51 -0800
committerGitHub <noreply@github.com>2023-03-01 12:59:51 -0800
commit6c26aa1f7e3e28e3053dffe686baa8e0499c624d (patch)
tree4c7268615f1b880866498f2dff0ab580932bfb75
parent3c32dd951c5d69b5568929e0038e693553efca79 (diff)
Improve diagnostic on differentiablitiy check. (#2687)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp77
-rw-r--r--tests/diagnostics/autodiff-data-flow-3.slang13
-rw-r--r--tests/diagnostics/autodiff-data-flow-3.slang.expected5
4 files changed, 88 insertions, 10 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 214d386a2..aaf09a8be 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -588,7 +588,8 @@ DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in
DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2")
DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.")
DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-$1-differentiable function `$0`, use 'no_diff' to clarify intention.")
-DIAGNOSTIC(41024, Error, lossOfDerivativeAssigningToNonDifferentiableLocation, "derivative is lost during assignment to non-differentiable location. Use 'detach()' to clarify intention.")
+DIAGNOSTIC(41024, Error, lossOfDerivativeAssigningToNonDifferentiableLocation, "derivative is lost during assignment to non-differentiable location, use 'detach()' to clarify intention.")
+DIAGNOSTIC(41025, Error, lossOfDerivativeUsingNonDifferentiableLocationAsOutArg, "derivative is lost when passing a non-differentiable location to an `out` or `inout` parameter, consider passing a temporary variable instead.")
DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable function must have at least one differentiable output.")
DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.")
DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal")
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 21f53fcbd..186b0cc03 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -186,6 +186,22 @@ public:
return false;
}
+ bool instHasNonTrivialDerivative(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_DetachDerivative:
+ return false;
+ case kIROp_Call:
+ {
+ auto call = as<IRCall>(inst);
+ return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward);
+ }
+ default:
+ return true;
+ }
+ }
+
void processFunc(IRGlobalValueWithCode* funcInst)
{
if (!_isFuncMarkedForAutoDiff(funcInst))
@@ -199,6 +215,7 @@ public:
HashSet<IRInst*> produceDiffSet;
HashSet<IRInst*> expectDiffSet;
int differentiableOutputs = 0;
+ bool isDifferentiableReturnType = false;
for (auto param : funcInst->getFirstBlock()->getParams())
{
if (isDifferentiableType(diffTypeContext, param->getFullType()))
@@ -211,7 +228,10 @@ public:
if (auto funcType = as<IRFuncType>(funcInst->getDataType()))
{
if (isDifferentiableType(diffTypeContext, funcType->getResultType()))
+ {
differentiableOutputs++;
+ isDifferentiableReturnType = true;
+ }
}
if (differentiableOutputs == 0)
@@ -305,7 +325,8 @@ public:
case kIROp_Store:
{
auto storeInst = as<IRStore>(inst);
- if (isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType()))
+ if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) &&
+ isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType()))
{
addToExpectDiffWorkList(storeInst->getVal());
}
@@ -314,7 +335,8 @@ public:
case kIROp_Return:
if (auto returnVal = as<IRReturn>(inst)->getVal())
{
- if (isDifferentiableType(diffTypeContext, returnVal->getDataType()))
+ if (isDifferentiableReturnType &&
+ isDifferentiableType(diffTypeContext, returnVal->getDataType()))
{
addToExpectDiffWorkList(inst);
}
@@ -369,6 +391,27 @@ public:
}
break;
}
+ case kIROp_Call:
+ {
+ auto callInst = as<IRCall>(inst);
+ if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>())
+ continue;
+ if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward))
+ continue;
+ auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
+ if (!calleeFuncType) continue;
+ if (calleeFuncType->getParamCount() != callInst->getArgCount())
+ continue;
+ for (UInt a = 0; a < callInst->getArgCount(); a++)
+ {
+ auto arg = callInst->getArg(a);
+ auto paramType = calleeFuncType->getParamType(a);
+ if (!isDifferentiableType(diffTypeContext, paramType))
+ continue;
+ addToExpectDiffWorkList(arg);
+ }
+ break;
+ }
default:
// Default behavior is to request all differentiable operands to provide differential.
for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++)
@@ -417,15 +460,33 @@ public:
if (auto storeInst = as<IRStore>(inst))
{
if (produceDiffSet.Contains(storeInst->getVal()) &&
+ instHasNonTrivialDerivative(storeInst->getVal()) &&
!canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
{
- switch (storeInst->getVal()->getOp())
+ sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
+ }
+ }
+ else if (auto callInst = as<IRCall>(inst))
+ {
+ if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward))
+ continue;
+ auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
+ if (!calleeFuncType)
+ continue;
+ if (calleeFuncType->getParamCount() != callInst->getArgCount())
+ continue;
+ for (UInt a = 0; a < callInst->getArgCount(); a++)
+ {
+ auto arg = callInst->getArg(a);
+ auto paramType = calleeFuncType->getParamType(a);
+ if (!isDifferentiableType(diffTypeContext, paramType))
+ continue;
+ if (as<IROutTypeBase>(paramType))
{
- case kIROp_DetachDerivative:
- break;
- default:
- sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
- break;
+ if (!canAddressHoldDerivative(diffTypeContext, arg))
+ {
+ sink->diagnose(arg->sourceLoc, Diagnostics::lossOfDerivativeUsingNonDifferentiableLocationAsOutArg);
+ }
}
}
}
diff --git a/tests/diagnostics/autodiff-data-flow-3.slang b/tests/diagnostics/autodiff-data-flow-3.slang
index 0a8e3e58a..21dd9f76c 100644
--- a/tests/diagnostics/autodiff-data-flow-3.slang
+++ b/tests/diagnostics/autodiff-data-flow-3.slang
@@ -18,9 +18,22 @@ float g(float x)
}
[BackwardDifferentiable]
+void diffOut(inout float x)
+{
+ x = 2;
+}
+
+float noDiffFunc(float x)
+{
+ return 0.0;
+}
+
+[BackwardDifferentiable]
float h(float x)
{
NoDiffField obj;
obj.fp.f = detach(x * x); // OK.
+ obj.fp.f = noDiffFunc(x); // OK.
+ diffOut(obj.fp.f); // Error.
return obj.fp.f;
}
diff --git a/tests/diagnostics/autodiff-data-flow-3.slang.expected b/tests/diagnostics/autodiff-data-flow-3.slang.expected
index 73381cf58..817b595a6 100644
--- a/tests/diagnostics/autodiff-data-flow-3.slang.expected
+++ b/tests/diagnostics/autodiff-data-flow-3.slang.expected
@@ -1,8 +1,11 @@
result code = -1
standard error = {
-tests/diagnostics/autodiff-data-flow-3.slang(16): error 41024: derivative is lost during assignment to non-differentiable location. Use 'detach()' to clarify intention.
+tests/diagnostics/autodiff-data-flow-3.slang(16): error 41024: derivative is lost during assignment to non-differentiable location, use 'detach()' to clarify intention.
obj.fp.f = x * x; // Error, this location cannot hold derivative.
^
+tests/diagnostics/autodiff-data-flow-3.slang(37): error 41025: derivative is lost when passing a non-differentiable location to an `out` or `inout` parameter, consider passing a temporary variable instead.
+ diffOut(obj.fp.f); // Error.
+ ^
}
standard output = {
}