diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-19 08:58:20 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-19 08:58:20 -0800 |
| commit | 6fae15cd1210d8b664243d640e70ca47dccf9752 (patch) | |
| tree | d3235149f587ed18147f7a0d916932e199dce888 /source/slang/slang-ir-check-differentiability.cpp | |
| parent | 0586f3298fa7d554fa2682103eefba88740d6758 (diff) | |
Add diagnostic for calling non-bwd-diff func from bwd-diff func. (#2602)
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 82 |
1 files changed, 59 insertions, 23 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 8413e7e79..cb7290036 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -12,7 +12,11 @@ public: DiagnosticSink* sink; AutoDiffSharedContext sharedContext; - HashSet<IRInst*> differentiableFunctions; + enum DifferentiableLevel + { + Forward, Backward + }; + Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions; CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink) : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst()) @@ -59,7 +63,7 @@ public: } - bool _isDifferentiableFuncImpl(IRInst* func) + bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level) { func = getLeafFunc(func); if (!func) @@ -71,32 +75,41 @@ public: { case kIROp_ForwardDerivativeDecoration: case kIROp_ForwardDifferentiableDecoration: + if (level == DifferentiableLevel::Forward) + return true; + break; case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_BackwardDerivativeDecoration: case kIROp_BackwardDifferentiableDecoration: return true; + default: + break; } } return false; } - bool isDifferentiableFunc(IRInst* func) + bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level) { - switch (func->getOp()) + if (level == DifferentiableLevel::Forward) { - case kIROp_ForwardDifferentiate: - case kIROp_BackwardDifferentiate: - return true; - default: - break; + switch (func->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + return true; + default: + break; + } } - func = getSpecializedVal(func); + func = getLeafFunc(func); if (!func) return false; - if (differentiableFunctions.Contains(func)) - return true; + + if (auto existingLevel = differentiableFunctions.TryGetValue(func)) + return *existingLevel >= level; if (func->findDecoration<IRTreatAsDifferentiableDecoration>()) return true; @@ -125,7 +138,10 @@ public: { if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey()) { - return true; + if (as<IRBackwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Backward) + return true; + if (as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child) && level == DifferentiableLevel::Forward) + return true; } } } @@ -135,7 +151,11 @@ public: { if (as<IRGeneric>(func)) { - return differentiableFunctions.Contains(func); + if (auto existingLevel = differentiableFunctions.TryGetValue(func)) + { + if (*existingLevel >= level) + return true; + } } } return false; @@ -235,6 +255,10 @@ public: if (differentiableInputs == 0) sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveInput); + DifferentiableLevel requiredDiffLevel = DifferentiableLevel::Forward; + if (isBackwardDifferentiableFunc(funcInst)) + requiredDiffLevel = DifferentiableLevel::Backward; + auto isInstProducingDiff = [&](IRInst* inst) -> bool { switch (inst->getOp()) @@ -242,7 +266,7 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee()); + return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. // Just assume it is producing diff. @@ -310,7 +334,7 @@ public: switch (inst->getOp()) { case kIROp_Call: - if (isDifferentiableFunc(as<IRCall>(inst)->getCallee())) + if (isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel)) { addToExpectDiffWorkList(inst); } @@ -349,7 +373,11 @@ public: { if (auto call = as<IRCall>(inst)) { - sink->diagnose(inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getLeafFunc(call->getCallee())); + sink->diagnose( + inst, + Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, + getLeafFunc(call->getCallee()), + requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward"); } } switch (inst->getOp()) @@ -395,22 +423,30 @@ public: void processModule() { // Collect set of differentiable functions. - HashSet<UnownedStringSlice> differentiableSymbolNames; + HashSet<UnownedStringSlice> fwdDifferentiableSymbolNames, bwdDifferentiableSymbolNames; for (auto inst : module->getGlobalInsts()) { - if (_isDifferentiableFuncImpl(inst)) + if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Backward)) + { + if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) + bwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName()); + differentiableFunctions.Add(inst, DifferentiableLevel::Backward); + } + else if (_isDifferentiableFuncImpl(inst, DifferentiableLevel::Forward)) { if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) - differentiableSymbolNames.Add(linkageDecor->getMangledName()); - differentiableFunctions.Add(inst); + fwdDifferentiableSymbolNames.Add(linkageDecor->getMangledName()); + differentiableFunctions.Add(inst, DifferentiableLevel::Forward); } } for (auto inst : module->getGlobalInsts()) { if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) { - if (differentiableSymbolNames.Contains(linkageDecor->getMangledName())) - differentiableFunctions.Add(inst); + if (bwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName())) + differentiableFunctions[inst] = DifferentiableLevel::Backward; + else if (fwdDifferentiableSymbolNames.Contains(linkageDecor->getMangledName())) + differentiableFunctions.AddIfNotExists(inst, DifferentiableLevel::Forward); } } |
