diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-01 12:59:51 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-01 12:59:51 -0800 |
| commit | 6c26aa1f7e3e28e3053dffe686baa8e0499c624d (patch) | |
| tree | 4c7268615f1b880866498f2dff0ab580932bfb75 | |
| parent | 3c32dd951c5d69b5568929e0038e693553efca79 (diff) | |
Improve diagnostic on differentiablitiy check. (#2687)
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 77 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-data-flow-3.slang | 13 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-data-flow-3.slang.expected | 5 |
4 files changed, 88 insertions, 10 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 214d386a2..aaf09a8be 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -588,7 +588,8 @@ DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2") DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.") DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-$1-differentiable function `$0`, use 'no_diff' to clarify intention.") -DIAGNOSTIC(41024, Error, lossOfDerivativeAssigningToNonDifferentiableLocation, "derivative is lost during assignment to non-differentiable location. Use 'detach()' to clarify intention.") +DIAGNOSTIC(41024, Error, lossOfDerivativeAssigningToNonDifferentiableLocation, "derivative is lost during assignment to non-differentiable location, use 'detach()' to clarify intention.") +DIAGNOSTIC(41025, Error, lossOfDerivativeUsingNonDifferentiableLocationAsOutArg, "derivative is lost when passing a non-differentiable location to an `out` or `inout` parameter, consider passing a temporary variable instead.") DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable function must have at least one differentiable output.") DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.") DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal") diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 21f53fcbd..186b0cc03 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -186,6 +186,22 @@ public: return false; } + bool instHasNonTrivialDerivative(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_DetachDerivative: + return false; + case kIROp_Call: + { + auto call = as<IRCall>(inst); + return isDifferentiableFunc(call->getCallee(), CheckDifferentiabilityPassContext::DifferentiableLevel::Forward); + } + default: + return true; + } + } + void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) @@ -199,6 +215,7 @@ public: HashSet<IRInst*> produceDiffSet; HashSet<IRInst*> expectDiffSet; int differentiableOutputs = 0; + bool isDifferentiableReturnType = false; for (auto param : funcInst->getFirstBlock()->getParams()) { if (isDifferentiableType(diffTypeContext, param->getFullType())) @@ -211,7 +228,10 @@ public: if (auto funcType = as<IRFuncType>(funcInst->getDataType())) { if (isDifferentiableType(diffTypeContext, funcType->getResultType())) + { differentiableOutputs++; + isDifferentiableReturnType = true; + } } if (differentiableOutputs == 0) @@ -305,7 +325,8 @@ public: case kIROp_Store: { auto storeInst = as<IRStore>(inst); - if (isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType())) + if (canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()) && + isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType())) { addToExpectDiffWorkList(storeInst->getVal()); } @@ -314,7 +335,8 @@ public: case kIROp_Return: if (auto returnVal = as<IRReturn>(inst)->getVal()) { - if (isDifferentiableType(diffTypeContext, returnVal->getDataType())) + if (isDifferentiableReturnType && + isDifferentiableType(diffTypeContext, returnVal->getDataType())) { addToExpectDiffWorkList(inst); } @@ -369,6 +391,27 @@ public: } break; } + case kIROp_Call: + { + auto callInst = as<IRCall>(inst); + if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>()) + continue; + if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + continue; + auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); + if (!calleeFuncType) continue; + if (calleeFuncType->getParamCount() != callInst->getArgCount()) + continue; + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + auto paramType = calleeFuncType->getParamType(a); + if (!isDifferentiableType(diffTypeContext, paramType)) + continue; + addToExpectDiffWorkList(arg); + } + break; + } default: // Default behavior is to request all differentiable operands to provide differential. for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++) @@ -417,15 +460,33 @@ public: if (auto storeInst = as<IRStore>(inst)) { if (produceDiffSet.Contains(storeInst->getVal()) && + instHasNonTrivialDerivative(storeInst->getVal()) && !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) { - switch (storeInst->getVal()->getOp()) + sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); + } + } + else if (auto callInst = as<IRCall>(inst)) + { + if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + continue; + auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); + if (!calleeFuncType) + continue; + if (calleeFuncType->getParamCount() != callInst->getArgCount()) + continue; + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + auto paramType = calleeFuncType->getParamType(a); + if (!isDifferentiableType(diffTypeContext, paramType)) + continue; + if (as<IROutTypeBase>(paramType)) { - case kIROp_DetachDerivative: - break; - default: - sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); - break; + if (!canAddressHoldDerivative(diffTypeContext, arg)) + { + sink->diagnose(arg->sourceLoc, Diagnostics::lossOfDerivativeUsingNonDifferentiableLocationAsOutArg); + } } } } diff --git a/tests/diagnostics/autodiff-data-flow-3.slang b/tests/diagnostics/autodiff-data-flow-3.slang index 0a8e3e58a..21dd9f76c 100644 --- a/tests/diagnostics/autodiff-data-flow-3.slang +++ b/tests/diagnostics/autodiff-data-flow-3.slang @@ -18,9 +18,22 @@ float g(float x) } [BackwardDifferentiable] +void diffOut(inout float x) +{ + x = 2; +} + +float noDiffFunc(float x) +{ + return 0.0; +} + +[BackwardDifferentiable] float h(float x) { NoDiffField obj; obj.fp.f = detach(x * x); // OK. + obj.fp.f = noDiffFunc(x); // OK. + diffOut(obj.fp.f); // Error. return obj.fp.f; } diff --git a/tests/diagnostics/autodiff-data-flow-3.slang.expected b/tests/diagnostics/autodiff-data-flow-3.slang.expected index 73381cf58..817b595a6 100644 --- a/tests/diagnostics/autodiff-data-flow-3.slang.expected +++ b/tests/diagnostics/autodiff-data-flow-3.slang.expected @@ -1,8 +1,11 @@ result code = -1 standard error = { -tests/diagnostics/autodiff-data-flow-3.slang(16): error 41024: derivative is lost during assignment to non-differentiable location. Use 'detach()' to clarify intention. +tests/diagnostics/autodiff-data-flow-3.slang(16): error 41024: derivative is lost during assignment to non-differentiable location, use 'detach()' to clarify intention. obj.fp.f = x * x; // Error, this location cannot hold derivative. ^ +tests/diagnostics/autodiff-data-flow-3.slang(37): error 41025: derivative is lost when passing a non-differentiable location to an `out` or `inout` parameter, consider passing a temporary variable instead. + diffOut(obj.fp.f); // Error. + ^ } standard output = { } |
