From f23e36243e9c59c02f66ec2e18b80ba4ea540f45 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 27 Feb 2023 21:21:39 -0800 Subject: Diagnose on storing differentiable value into non-differentiable location. (#2681) --- source/slang/diff.meta.slang | 23 +-------- source/slang/slang-diagnostic-defs.h | 1 + source/slang/slang-ir-autodiff-fwd.cpp | 6 ++- source/slang/slang-ir-autodiff.cpp | 23 +++++++++ source/slang/slang-ir-check-differentiability.cpp | 62 ++++++++++++++++++++++- source/slang/slang-ir-inst-defs.h | 1 + source/slang/slang-ir-insts.h | 6 +++ source/slang/slang-ir.cpp | 1 + 8 files changed, 99 insertions(+), 24 deletions(-) (limited to 'source') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 859b8a488..8931cccdd 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -320,27 +320,8 @@ void mul(inout DifferentialPair> left, inout DifferentialPair -T detach(T x) -{ - return x; -} - -__generic -[ForwardDerivativeOf(detach)] -DifferentialPair __d_detach(DifferentialPair dpx) -{ - return DifferentialPair( - dpx.p, - T.dzero() - ); -} - -__generic -[BackwardDerivativeOf(detach)] -void __d_detach(inout DifferentialPair dpx, T.Differential dOut) -{ - dpx = diffPair(dpx.p, T.dzero()); -} +__intrinsic_op($(kIROp_DetachDerivative)) +T detach(T x); // Natural Exponent diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index c5f7e6cbe..214d386a2 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -588,6 +588,7 @@ DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2") DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.") DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-$1-differentiable function `$0`, use 'no_diff' to clarify intention.") +DIAGNOSTIC(41024, Error, lossOfDerivativeAssigningToNonDifferentiableLocation, "derivative is lost during assignment to non-differentiable location. Use 'detach()' to clarify intention.") DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable function must have at least one differentiable output.") DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.") DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can only be called when argument is statically resolvable to a string literal") diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 72bbe4530..27106b6a2 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1490,6 +1490,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_undefined: return transcribeUndefined(builder, origInst); + // Known non-differentiable insts. case kIROp_Not: case kIROp_BitAnd: case kIROp_BitNot: @@ -1507,15 +1508,18 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_ImageSubscript: case kIROp_ImageLoad: case kIROp_ImageStore: - case kIROp_CreateExistentialObject: case kIROp_PackAnyValue: case kIROp_UnpackAnyValue: case kIROp_GetNativePtr: case kIROp_CastIntToFloat: case kIROp_CastFloatToInt: + case kIROp_DetachDerivative: + return trascribeNonDiffInst(builder, origInst); + // A call to createDynamicObject(arbitraryData) cannot provide a diff value, // so we treat this inst as non differentiable. // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. + case kIROp_CreateExistentialObject: return trascribeNonDiffInst(builder, origInst); case kIROp_StructKey: diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index b630b798d..fcfbf3bee 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -933,6 +933,27 @@ bool processAutodiffCalls( return modified; } +struct RemoveDetachInstsPass : InstPassBase +{ + RemoveDetachInstsPass(IRModule* module) : + InstPassBase(module) + { + } + void processModule() + { + processInstsOfType(kIROp_DetachDerivative, [&](IRDetachDerivative* detach) + { + detach->replaceUsesWith(detach->getBase()); + }); + } +}; + +void removeDetachInsts(IRModule* module) +{ + RemoveDetachInstsPass pass(module); + pass.processModule(); +} + bool finalizeAutoDiffPass(IRModule* module) { bool modified = false; @@ -947,6 +968,8 @@ bool finalizeAutoDiffPass(IRModule* module) // modified |= processPairTypes(&autodiffContext); + removeDetachInsts(module); + stripNoDiffTypeAttribute(module); // Remove auto-diff related decorations. diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 1ee94e67e..21f53fcbd 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -152,6 +152,40 @@ public: return false; } + bool canAddressHoldDerivative(DifferentiableTypeConformanceContext& diffTypeContext, IRInst* addr) + { + if (!addr) + return false; + + while (addr) + { + switch (addr->getOp()) + { + case kIROp_Var: + case kIROp_GlobalVar: + case kIROp_Param: + case kIROp_GlobalParam: + return isDifferentiableType(diffTypeContext, addr->getDataType()); + case kIROp_FieldAddress: + if (!as(addr)->getField() || + as(addr) + ->getField() + ->findDecoration() == nullptr) + return false; + addr = as(addr)->getBase(); + break; + case kIROp_GetElementPtr: + if (!isDifferentiableType(diffTypeContext, as(addr)->getBase()->getDataType())) + return false; + addr = as(addr)->getBase(); + break; + default: + return false; + } + } + return false; + } + void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) @@ -197,9 +231,9 @@ public: return inst->findDecoration() || isDifferentiableFunc(as(inst)->getCallee(), requiredDiffLevel); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. - // Just assume it is producing diff. + // Just assume it is producing diff if the dest address can hold a derivative. //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter. - return isDifferentiableType(diffTypeContext, inst->getDataType()); + return canAddressHoldDerivative(diffTypeContext, as(inst)->getPtr()); default: // default case is to assume the inst produces a diff value if any // of its operands produces a diff value. @@ -224,6 +258,7 @@ public: expectDiffInstWorkList.add(inst); } }; + // Run data flow analysis and generate `produceDiffSet` and an intial `expectDiffSet`. Index lastProduceDiffCount = 0; do @@ -373,6 +408,29 @@ public: sink->diagnose(loop->sourceLoc, Diagnostics::loopInDiffFuncRequireUnrollOrMaxIters); } } + + // Make sure all stores of differentiable values are into addresses that can hold derivatives. + for (auto block : funcInst->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto storeInst = as(inst)) + { + if (produceDiffSet.Contains(storeInst->getVal()) && + !canAddressHoldDerivative(diffTypeContext, storeInst->getPtr())) + { + switch (storeInst->getVal()->getOp()) + { + case kIROp_DetachDerivative: + break; + default: + sink->diagnose(storeInst->sourceLoc, Diagnostics::lossOfDerivativeAssigningToNonDifferentiableLocation); + break; + } + } + } + } + } } void processModule() diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 4b1037240..c704359e6 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -313,6 +313,7 @@ INST(RTTIObject, rtti_object, 0, 0) INST(Alloca, alloca, 1, 0) INST(UpdateElement, updateElement, 2, 0) +INST(DetachDerivative, detachDerivative, 1, 0) INST(PackAnyValue, packAnyValue, 1, 0) INST(UnpackAnyValue, unpackAnyValue, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f2e4e05d3..5269ae02f 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2253,6 +2253,12 @@ struct IRDifferentialPairGetPrimal : IRInst IRInst* getBase() { return getOperand(0); } }; +struct IRDetachDerivative : IRInst +{ + IR_LEAF_ISA(DetachDerivative) + IRInst* getBase() { return getOperand(0); } +}; + struct IRUpdateElement : IRInst { IR_LEAF_ISA(UpdateElement) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index fd211d05c..0aa2dc607 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -7044,6 +7044,7 @@ namespace Slang case kIROp_BackwardDifferentiate: case kIROp_BackwardDifferentiatePrimal: case kIROp_BackwardDifferentiatePropagate: + case kIROp_DetachDerivative: return false; } return true; -- cgit v1.2.3