diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-17 15:14:44 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-17 15:14:44 -0700 |
| commit | 4b55bf6d75bdeed087728505a1c9b43d3a99af8d (patch) | |
| tree | 34cdae5db38ec231243fe858bf7dbd679d820a06 | |
| parent | 29abe397427f82f6c414d99890a3f50771703003 (diff) | |
Rework differentiability dataflow check. (#2711)
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/core.meta.slang | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 82 |
2 files changed, 71 insertions, 12 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 0a3bb885e..790aa3d55 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -83,7 +83,6 @@ interface __BuiltinType {} /// A type that can be used for arithmetic operations [sealed] [builtin] -[TreatAsDifferentiable] interface __BuiltinArithmeticType : __BuiltinType { /// Initialize from a 32-bit signed integer value. diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 14178a86c..c4b09d9e8 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -220,8 +220,20 @@ public: DifferentiableTypeConformanceContext diffTypeContext(&sharedContext); diffTypeContext.setFunc(funcInst); + // We compute and track three different set of insts to complete our + // data flow analysis. + // `produceDiffSet` represents a set of insts that can provide a diff. This is conservative + // on the positive side: a float literal is considered to be able to provide a diff. + // `carryNonTrivialDiffSet` represents a set of insts that may carry a non-zero diff. This is + // conservative on the negative side: if the inst does not provide a diff, or if we can prove the diff + // is zero, we exclude the inst from the set. This makes `carryNonTrivialDiffSet` a strict subset of + // `produceDiffSet`. + // `expectDiffSet` is a set of insts that expects their operands to produce a diff. It is an error + // if they don't. HashSet<IRInst*> produceDiffSet; HashSet<IRInst*> expectDiffSet; + HashSet<IRInst*> carryNonTrivialDiffSet; + int differentiableOutputs = 0; bool isDifferentiableReturnType = false; for (auto param : funcInst->getFirstBlock()->getParams()) @@ -231,6 +243,7 @@ public: if (as<IROutTypeBase>(param->getFullType())) differentiableOutputs++; produceDiffSet.Add(param); + carryNonTrivialDiffSet.Add(param); } } if (auto funcType = as<IRFuncType>(funcInst->getDataType())) @@ -256,7 +269,8 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel); + return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) + && isDifferentiableType(diffTypeContext, inst->getFullType()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. // Just assume it is producing diff if the dest address can hold a derivative. @@ -265,6 +279,8 @@ public: default: // default case is to assume the inst produces a diff value if any // of its operands produces a diff value. + if (!isDifferentiableType(diffTypeContext, inst->getFullType())) + return false; for (UInt i = 0; i < inst->getOperandCount(); i++) { if (produceDiffSet.Contains(inst->getOperand(i))) @@ -276,6 +292,38 @@ public: } }; + auto isInstCarryingOverDiff = [&](IRInst* inst) -> bool + { + switch (inst->getOp()) + { + case kIROp_DetachDerivative: + return false; + case kIROp_Call: + if (inst->findDecoration<IRTreatAsDifferentiableDecoration>()) + return false; + return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && + isDifferentiableType(diffTypeContext, inst->getFullType()); + case kIROp_Load: + // We don't have more knowledge on whether diff is available at the destination address. + // Just assume it is producing diff if the dest address can hold a derivative. + //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter. + return canAddressHoldDerivative(diffTypeContext, as<IRLoad>(inst)->getPtr()); + default: + // default case is to assume the inst produces a diff value if any + // of its operands produces a diff value. + if (!isDifferentiableType(diffTypeContext, inst->getFullType())) + return false; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (carryNonTrivialDiffSet.Contains(inst->getOperand(i))) + { + return true; + } + } + return false; + } + }; + List<IRInst*> expectDiffInstWorkList; OrderedHashSet<IRInst*> expectDiffInstWorkListSet; auto addToExpectDiffWorkList = [&](IRInst* inst) @@ -283,7 +331,11 @@ public: if (isInstInFunc(inst, funcInst)) { if (expectDiffInstWorkListSet.Add(inst)) + { + if (inst->getFullType() && inst->getFullType()->getOp() == kIROp_IntType) + printf("break"); expectDiffInstWorkList.add(inst); + } } }; @@ -308,10 +360,9 @@ public: { auto arg = branch->getArg(paramIndex); if (produceDiffSet.Contains(arg)) - { produceDiffSet.Add(param); - break; - } + if (carryNonTrivialDiffSet.Contains(arg)) + carryNonTrivialDiffSet.Add(param); } } } @@ -322,6 +373,8 @@ public: { if (isInstProducingDiff(inst)) produceDiffSet.Add(inst); + if (isInstCarryingOverDiff(inst)) + carryNonTrivialDiffSet.Add(inst); switch (inst->getOp()) { case kIROp_Call: @@ -366,11 +419,17 @@ public: { if (auto call = as<IRCall>(inst)) { - sink->diagnose( - inst, - Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, - getResolvedInstForDecorations(call->getCallee()), - requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward"); + // If inst's type is differentiable, and it is in expectDiffInstWorkList, + // then some user is expecting the result of the call to produce a derivative. + // In this case we need to issue a diagnostic. + if (isDifferentiableType(diffTypeContext, inst->getFullType())) + { + sink->diagnose( + inst, + Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, + getResolvedInstForDecorations(call->getCallee()), + requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward"); + } } } switch (inst->getOp()) @@ -461,14 +520,15 @@ public: } // Make sure all stores of differentiable values are into addresses that can hold derivatives. + // If we are assigning a value to a non-differentiable location, we need to make sure + // that value doesn't carray a non-zero diff. for (auto block : funcInst->getBlocks()) { for (auto inst : block->getChildren()) { if (auto storeInst = as<IRStore>(inst)) { - if (produceDiffSet.Contains(storeInst->getVal()) && - instHasNonTrivialDerivative(diffTypeContext, storeInst->getVal()) && + if (carryNonTrivialDiffSet.Contains(storeInst->getVal()) && !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) { sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); |
