summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-19 08:58:20 -0800
committerGitHub <noreply@github.com>2023-01-19 08:58:20 -0800
commit6fae15cd1210d8b664243d640e70ca47dccf9752 (patch)
treed3235149f587ed18147f7a0d916932e199dce888 /source/slang/slang-ir-check-differentiability.cpp
parent0586f3298fa7d554fa2682103eefba88740d6758 (diff)
Add diagnostic for calling non-bwd-diff func from bwd-diff func. (#2602)
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp82
1 files changed, 59 insertions, 23 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 8413e7e79..cb7290036 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -12,7 +12,11 @@ public:
DiagnosticSink* sink;
AutoDiffSharedContext sharedContext;
- HashSet<IRInst*> differentiableFunctions;
+ enum DifferentiableLevel
+ {
+ Forward, Backward
+ };
+ Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions;
CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
: InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
@@ -59,7 +63,7 @@ public:
}
- bool _isDifferentiableFuncImpl(IRInst* func)
+ bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level)
{
func = getLeafFunc(func);
if (!func)
@@ -71,32 +75,41 @@ public:
{
case kIROp_ForwardDerivativeDecoration:
case kIROp_ForwardDifferentiableDecoration:
+ if (level == DifferentiableLevel::Forward)
+ return true;
+ break;
case kIROp_UserDefinedBackwardDerivativeDecoration:
case kIROp_BackwardDerivativeDecoration:
case kIROp_BackwardDifferentiableDecoration:
return true;
+ default:
+ break;
}
}
return false;
}
- bool isDifferentiableFunc(IRInst* func)
+ bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
- switch (func->getOp())
+ if (level == DifferentiableLevel::Forward)
{
- case kIROp_ForwardDifferentiate:
- case kIROp_BackwardDifferentiate:
- return true;
- default:
- break;
+ switch (func->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ return true;
+ default:
+ break;
+ }
}
- func = getSpecializedVal(func);
+ func = getLeafFunc(func);
if (!func)
return false;
- if (differentiableFunctions.Contains(func))
- return true;
+
+ if (auto existingLevel = differentiableFunctions.TryGetValue(func))
+ return *existingLevel >= level;
if (func->findDecoration<IRTreatAsDifferentiableDecoration>())
return true;
@@ -125,7 +138,10 @@ public:
{
if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey())
{
- return true;
+ if (as<IRBackwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Backward)
+ return true;
+ if (as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Forward)
+ return true;
}
}
}
@@ -135,7 +151,11 @@ public:
{
if (as<IRGeneric>(func))
{
- return differentiableFunctions.Contains(func);
+ if (auto existingLevel = differentiableFunctions.TryGetValue(func))
+ {
+ if (*existingLevel >= level)
+ return true;
+ }
}
}
return false;
@@ -235,6 +255,10 @@ public:
if (differentiableInputs == 0)
sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveInput);
+ DifferentiableLevel requiredDiffLevel = DifferentiableLevel::Forward;
+ if (isBackwardDifferentiableFunc(funcInst))
+ requiredDiffLevel = DifferentiableLevel::Backward;
+
auto isInstProducingDiff = [&](IRInst* inst) -> bool
{
switch (inst->getOp())
@@ -242,7 +266,7 @@ public:
case kIROp_FloatLit:
return true;
case kIROp_Call:
- return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
+ return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel);
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
// Just assume it is producing diff.
@@ -310,7 +334,7 @@ public:
switch (inst->getOp())
{
case kIROp_Call:
- if (isDifferentiableFunc(as<IRCall>(inst)->getCallee()))
+ if (isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel))
{
addToExpectDiffWorkList(inst);
}
@@ -349,7 +373,11 @@ public:
{
if (auto call = as<IRCall>(inst))
{
- sink->diagnose(inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getLeafFunc(call->getCallee()));
+ sink->diagnose(
+ inst,
+ Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
+ getLeafFunc(call->getCallee()),
+ requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
}
}
switch (inst->getOp())
@@ -395,22 +423,30 @@ public:
void processModule()
{
// Collect set of differentiable functions.
- HashSet<UnownedStringSlice> differentiableSymbolNames;
+ HashSet<UnownedStringSlice> fwdDifferentiableSymbolNames, bwdDifferentiableSymbolNames;
for (auto inst : module->getGlobalInsts())
{
- if (_isDifferentiableFuncImpl(inst))
+ if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Backward))
+ {
+ if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
+ bwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName());
+ differentiableFunctions.Add(inst, DifferentiableLevel::Backward);
+ }
+ else if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Forward))
{
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
- differentiableSymbolNames.Add(linkageDecor->getMangledName());
- differentiableFunctions.Add(inst);
+ fwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName());
+ differentiableFunctions.Add(inst, DifferentiableLevel::Forward);
}
}
for (auto inst : module->getGlobalInsts())
{
if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
{
- if (differentiableSymbolNames.Contains(linkageDecor->getMangledName()))
- differentiableFunctions.Add(inst);
+ if (bwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName()))
+ differentiableFunctions[inst] = DifferentiableLevel::Backward;
+ else if (fwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName()))
+ differentiableFunctions.AddIfNotExists(inst, DifferentiableLevel::Forward);
}
}