diff options
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 62 |
1 files changed, 60 insertions, 2 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 1ee94e67e..21f53fcbd 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -152,6 +152,40 @@ public: return false; } + bool canAddressHoldDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* addr) + { + if (!addr) + return false; + + while (addr) + { + switch (addr->getOp()) + { + case kIROp_Var: + case kIROp_GlobalVar: + case kIROp_Param: + case kIROp_GlobalParam: + return isDifferentiableType(diffTypeContext, addr->getDataType()); + case kIROp_FieldAddress: + if (!as<IRFieldAddress>(addr)->getField() || + as<IRFieldAddress>(addr) + ->getField() + ->findDecoration<IRDerivativeMemberDecoration>() == nullptr) + return false; + addr = as<IRFieldAddress>(addr)->getBase(); + break; + case kIROp_GetElementPtr: + if (!isDifferentiableType(diffTypeContext, as<IRGetElementPtr>(addr)->getBase()->getDataType())) + return false; + addr = as<IRGetElementPtr>(addr)->getBase(); + break; + default: + return false; + } + } + return false; + } + void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) @@ -197,9 +231,9 @@ public: return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. - // Just assume it is producing diff. + // Just assume it is producing diff if the dest address can hold a derivative. //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter. - return isDifferentiableType(diffTypeContext, inst->getDataType()); + return canAddressHoldDerivative(diffTypeContext, as<IRLoad>(inst)->getPtr()); default: // default case is to assume the inst produces a diff value if any // of its operands produces a diff value. @@ -224,6 +258,7 @@ public: expectDiffInstWorkList.add(inst); } }; + // Run data flow analysis and generate `produceDiffSet` and an intial `expectDiffSet`. Index lastProduceDiffCount = 0; do @@ -373,6 +408,29 @@ public: sink->diagnose(loop->sourceLoc, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); } } + + // Make sure all stores of differentiable values are into addresses that can hold derivatives. + for (auto block : funcInst->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto storeInst = as<IRStore>(inst)) + { + if (produceDiffSet.Contains(storeInst->getVal()) && + !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) + { + switch (storeInst->getVal()->getOp()) + { + case kIROp_DetachDerivative: + break; + default: + sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); + break; + } + } + } + } + } } void processModule() |
