summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp35
1 files changed, 34 insertions, 1 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index f4f61d7e9..83351d07b 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -97,6 +97,39 @@ public:
if (differentiableFunctions.Contains(func))
return true;
+ if (func->findDecoration<IRTreatAsDifferentiableDecoration>())
+ return true;
+
+ if (auto lookupInterfaceMethod = as<IRLookupWitnessMethod>(func))
+ {
+ auto wit = lookupInterfaceMethod->getWitnessTable();
+ if (!wit)
+ return false;
+ auto witType = as<IRWitnessTableTypeBase>(wit->getDataType());
+ if (!witType)
+ return false;
+ auto interfaceType = witType->getConformanceType();
+ if (!interfaceType)
+ return false;
+ if (interfaceType->findDecoration<IRTreatAsDifferentiableDecoration>())
+ return true;
+ if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType)
+ return true;
+ auto dictDecor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
+ if (!dictDecor)
+ return false;
+ for (auto child : dictDecor->getChildren())
+ {
+ if (auto entry = as<IRDifferentiableMethodRequirementDictionaryItem>(child))
+ {
+ if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey())
+ {
+ return true;
+ }
+ }
+ }
+ }
+
for (; func; func = func->parent)
{
if (as<IRGeneric>(func))
@@ -222,7 +255,7 @@ public:
case kIROp_FloatLit:
return true;
case kIROp_Call:
- return inst->findDecoration<IRTreatAsDifferentiableCallDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
+ return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
// Just assume it is producing diff.