summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-17 15:14:44 -0700
committerGitHub <noreply@github.com>2023-03-17 15:14:44 -0700
commit4b55bf6d75bdeed087728505a1c9b43d3a99af8d (patch)
tree34cdae5db38ec231243fe858bf7dbd679d820a06
parent29abe397427f82f6c414d99890a3f50771703003 (diff)
Rework differentiability dataflow check. (#2711)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/core.meta.slang1
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp82
2 files changed, 71 insertions, 12 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 0a3bb885e..790aa3d55 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -83,7 +83,6 @@ interface __BuiltinType {}
/// A type that can be used for arithmetic operations
[sealed]
[builtin]
-[TreatAsDifferentiable]
interface __BuiltinArithmeticType : __BuiltinType
{
/// Initialize from a 32-bit signed integer value.
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 14178a86c..c4b09d9e8 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -220,8 +220,20 @@ public:
DifferentiableTypeConformanceContext diffTypeContext(&sharedContext);
diffTypeContext.setFunc(funcInst);
+ // We compute and track three different set of insts to complete our
+ // data flow analysis.
+ // `produceDiffSet` represents a set of insts that can provide a diff. This is conservative
+ // on the positive side: a float literal is considered to be able to provide a diff.
+ // `carryNonTrivialDiffSet` represents a set of insts that may carry a non-zero diff. This is
+ // conservative on the negative side: if the inst does not provide a diff, or if we can prove the diff
+ // is zero, we exclude the inst from the set. This makes `carryNonTrivialDiffSet` a strict subset of
+ // `produceDiffSet`.
+ // `expectDiffSet` is a set of insts that expects their operands to produce a diff. It is an error
+ // if they don't.
HashSet<IRInst*> produceDiffSet;
HashSet<IRInst*> expectDiffSet;
+ HashSet<IRInst*> carryNonTrivialDiffSet;
+
int differentiableOutputs = 0;
bool isDifferentiableReturnType = false;
for (auto param : funcInst->getFirstBlock()->getParams())
@@ -231,6 +243,7 @@ public:
if (as<IROutTypeBase>(param->getFullType()))
differentiableOutputs++;
produceDiffSet.Add(param);
+ carryNonTrivialDiffSet.Add(param);
}
}
if (auto funcType = as<IRFuncType>(funcInst->getDataType()))
@@ -256,7 +269,8 @@ public:
case kIROp_FloatLit:
return true;
case kIROp_Call:
- return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel);
+ return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel)
+ && isDifferentiableType(diffTypeContext, inst->getFullType());
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
// Just assume it is producing diff if the dest address can hold a derivative.
@@ -265,6 +279,8 @@ public:
default:
// default case is to assume the inst produces a diff value if any
// of its operands produces a diff value.
+ if (!isDifferentiableType(diffTypeContext, inst->getFullType()))
+ return false;
for (UInt i = 0; i < inst->getOperandCount(); i++)
{
if (produceDiffSet.Contains(inst->getOperand(i)))
@@ -276,6 +292,38 @@ public:
}
};
+ auto isInstCarryingOverDiff = [&](IRInst* inst) -> bool
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_DetachDerivative:
+ return false;
+ case kIROp_Call:
+ if (inst->findDecoration<IRTreatAsDifferentiableDecoration>())
+ return false;
+ return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) &&
+ isDifferentiableType(diffTypeContext, inst->getFullType());
+ case kIROp_Load:
+ // We don't have more knowledge on whether diff is available at the destination address.
+ // Just assume it is producing diff if the dest address can hold a derivative.
+ //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter.
+ return canAddressHoldDerivative(diffTypeContext, as<IRLoad>(inst)->getPtr());
+ default:
+ // default case is to assume the inst produces a diff value if any
+ // of its operands produces a diff value.
+ if (!isDifferentiableType(diffTypeContext, inst->getFullType()))
+ return false;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (carryNonTrivialDiffSet.Contains(inst->getOperand(i)))
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+ };
+
List<IRInst*> expectDiffInstWorkList;
OrderedHashSet<IRInst*> expectDiffInstWorkListSet;
auto addToExpectDiffWorkList = [&](IRInst* inst)
@@ -283,7 +331,11 @@ public:
if (isInstInFunc(inst, funcInst))
{
if (expectDiffInstWorkListSet.Add(inst))
+ {
+ if (inst->getFullType() && inst->getFullType()->getOp() == kIROp_IntType)
+ printf("break");
expectDiffInstWorkList.add(inst);
+ }
}
};
@@ -308,10 +360,9 @@ public:
{
auto arg = branch->getArg(paramIndex);
if (produceDiffSet.Contains(arg))
- {
produceDiffSet.Add(param);
- break;
- }
+ if (carryNonTrivialDiffSet.Contains(arg))
+ carryNonTrivialDiffSet.Add(param);
}
}
}
@@ -322,6 +373,8 @@ public:
{
if (isInstProducingDiff(inst))
produceDiffSet.Add(inst);
+ if (isInstCarryingOverDiff(inst))
+ carryNonTrivialDiffSet.Add(inst);
switch (inst->getOp())
{
case kIROp_Call:
@@ -366,11 +419,17 @@ public:
{
if (auto call = as<IRCall>(inst))
{
- sink->diagnose(
- inst,
- Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
- getResolvedInstForDecorations(call->getCallee()),
- requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
+ // If inst's type is differentiable, and it is in expectDiffInstWorkList,
+ // then some user is expecting the result of the call to produce a derivative.
+ // In this case we need to issue a diagnostic.
+ if (isDifferentiableType(diffTypeContext, inst->getFullType()))
+ {
+ sink->diagnose(
+ inst,
+ Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
+ getResolvedInstForDecorations(call->getCallee()),
+ requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
+ }
}
}
switch (inst->getOp())
@@ -461,14 +520,15 @@ public:
}
// Make sure all stores of differentiable values are into addresses that can hold derivatives.
+ // If we are assigning a value to a non-differentiable location, we need to make sure
+ // that value doesn't carray a non-zero diff.
for (auto block : funcInst->getBlocks())
{
for (auto inst : block->getChildren())
{
if (auto storeInst = as<IRStore>(inst))
{
- if (produceDiffSet.Contains(storeInst->getVal()) &&
- instHasNonTrivialDerivative(diffTypeContext, storeInst->getVal()) &&
+ if (carryNonTrivialDiffSet.Contains(storeInst->getVal()) &&
!canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
{
sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);