summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp62
1 files changed, 60 insertions, 2 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 1ee94e67e..21f53fcbd 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -152,6 +152,40 @@ public:
return false;
}
+ bool canAddressHoldDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* addr)
+ {
+ if (!addr)
+ return false;
+
+ while (addr)
+ {
+ switch (addr->getOp())
+ {
+ case kIROp_Var:
+ case kIROp_GlobalVar:
+ case kIROp_Param:
+ case kIROp_GlobalParam:
+ return isDifferentiableType(diffTypeContext, addr->getDataType());
+ case kIROp_FieldAddress:
+ if (!as<IRFieldAddress>(addr)->getField() ||
+ as<IRFieldAddress>(addr)
+ ->getField()
+ ->findDecoration<IRDerivativeMemberDecoration>() == nullptr)
+ return false;
+ addr = as<IRFieldAddress>(addr)->getBase();
+ break;
+ case kIROp_GetElementPtr:
+ if (!isDifferentiableType(diffTypeContext, as<IRGetElementPtr>(addr)->getBase()->getDataType()))
+ return false;
+ addr = as<IRGetElementPtr>(addr)->getBase();
+ break;
+ default:
+ return false;
+ }
+ }
+ return false;
+ }
+
void processFunc(IRGlobalValueWithCode* funcInst)
{
if (!_isFuncMarkedForAutoDiff(funcInst))
@@ -197,9 +231,9 @@ public:
return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel);
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
- // Just assume it is producing diff.
+ // Just assume it is producing diff if the dest address can hold a derivative.
//TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter.
- return isDifferentiableType(diffTypeContext, inst->getDataType());
+ return canAddressHoldDerivative(diffTypeContext, as<IRLoad>(inst)->getPtr());
default:
// default case is to assume the inst produces a diff value if any
// of its operands produces a diff value.
@@ -224,6 +258,7 @@ public:
expectDiffInstWorkList.add(inst);
}
};
+
// Run data flow analysis and generate `produceDiffSet` and an intial `expectDiffSet`.
Index lastProduceDiffCount = 0;
do
@@ -373,6 +408,29 @@ public:
sink->diagnose(loop->sourceLoc, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters);
}
}
+
+ // Make sure all stores of differentiable values are into addresses that can hold derivatives.
+ for (auto block : funcInst->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (auto storeInst = as<IRStore>(inst))
+ {
+ if (produceDiffSet.Contains(storeInst->getVal()) &&
+ !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr()))
+ {
+ switch (storeInst->getVal()->getOp())
+ {
+ case kIROp_DetachDerivative:
+ break;
+ default:
+ sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation);
+ break;
+ }
+ }
+ }
+ }
+ }
}
void processModule()