diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 17:50:02 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 17:50:02 -0800 |
| commit | 1b40fe56725eeefe9c601461278376b697d4d35a (patch) | |
| tree | 2bdd321eed24e6e313839fe45aa84b23daa643fe /source/slang/slang-ir-check-differentiability.cpp | |
| parent | d4787e92253cf963f590d62522e82ce8285fc751 (diff) | |
Make differentiable data-flow pass recognize interface methods. (#2530)
* Make differentiable data-flow pass recognize interface methods.
* Make existing test to work with `[TreatAsDifferentiable]`.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 35 |
1 files changed, 34 insertions, 1 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index f4f61d7e9..83351d07b 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -97,6 +97,39 @@ public: if (differentiableFunctions.Contains(func)) return true; + if (func->findDecoration<IRTreatAsDifferentiableDecoration>()) + return true; + + if (auto lookupInterfaceMethod = as<IRLookupWitnessMethod>(func)) + { + auto wit = lookupInterfaceMethod->getWitnessTable(); + if (!wit) + return false; + auto witType = as<IRWitnessTableTypeBase>(wit->getDataType()); + if (!witType) + return false; + auto interfaceType = witType->getConformanceType(); + if (!interfaceType) + return false; + if (interfaceType->findDecoration<IRTreatAsDifferentiableDecoration>()) + return true; + if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType) + return true; + auto dictDecor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); + if (!dictDecor) + return false; + for (auto child : dictDecor->getChildren()) + { + if (auto entry = as<IRDifferentiableMethodRequirementDictionaryItem>(child)) + { + if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey()) + { + return true; + } + } + } + } + for (; func; func = func->parent) { if (as<IRGeneric>(func)) @@ -222,7 +255,7 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration<IRTreatAsDifferentiableCallDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee()); + return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. // Just assume it is producing diff. |
