From f23e36243e9c59c02f66ec2e18b80ba4ea540f45 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 27 Feb 2023 21:21:39 -0800 Subject: Diagnose on storing differentiable value into non-differentiable location. (#2681) --- source/slang/slang-ir-check-differentiability.cpp | 62 ++++++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-ir-check-differentiability.cpp') diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 1ee94e67e..21f53fcbd 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -152,6 +152,40 @@ public: return false; } + bool canAddressHoldDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* addr) + { + if (!addr) + return false; + + while (addr) + { + switch (addr->getOp()) + { + case kIROp_Var: + case kIROp_GlobalVar: + case kIROp_Param: + case kIROp_GlobalParam: + return isDifferentiableType(diffTypeContext, addr->getDataType()); + case kIROp_FieldAddress: + if (!as(addr)->getField() || + as(addr) + ->getField() + ->findDecoration() == nullptr) + return false; + addr = as(addr)->getBase(); + break; + case kIROp_GetElementPtr: + if (!isDifferentiableType(diffTypeContext, as(addr)->getBase()->getDataType())) + return false; + addr = as(addr)->getBase(); + break; + default: + return false; + } + } + return false; + } + void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) @@ -197,9 +231,9 @@ public: return inst->findDecoration() || isDifferentiableFunc(as(inst)->getCallee(), requiredDiffLevel); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. - // Just assume it is producing diff. + // Just assume it is producing diff if the dest address can hold a derivative. //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter. - return isDifferentiableType(diffTypeContext, inst->getDataType()); + return canAddressHoldDerivative(diffTypeContext, as(inst)->getPtr()); default: // default case is to assume the inst produces a diff value if any // of its operands produces a diff value. @@ -224,6 +258,7 @@ public: expectDiffInstWorkList.add(inst); } }; + // Run data flow analysis and generate `produceDiffSet` and an intial `expectDiffSet`. Index lastProduceDiffCount = 0; do @@ -373,6 +408,29 @@ public: sink->diagnose(loop->sourceLoc, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); } } + + // Make sure all stores of differentiable values are into addresses that can hold derivatives. + for (auto block : funcInst->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto storeInst = as(inst)) + { + if (produceDiffSet.Contains(storeInst->getVal()) && + !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) + { + switch (storeInst->getVal()->getOp()) + { + case kIROp_DetachDerivative: + break; + default: + sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); + break; + } + } + } + } + } } void processModule() -- cgit v1.2.3