diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-09 19:19:17 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-09 19:19:17 -0800 |
| commit | 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch) | |
| tree | cbc942746bab043da0eb5298993d95f9665dfddf /source/slang/slang-ir-diff-jvp.cpp | |
| parent | cedd93690c63188cf98e452c9d104cf51aad6c4e (diff) | |
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute.
* Fix handling around phi nodes.
* Fixes.
* Remove IR opcode for ForwardDerivativeOfDecoration.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 112 |
1 files changed, 66 insertions, 46 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 574db2036..4c7a132d0 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -847,37 +847,51 @@ struct JVPTranscriber cloneInst(&cloneEnv, builder, origParam), nullptr); } - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + + // Is this param a phi node or a function parameter? + auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent()); + bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); + if (isFuncParam) { - IRInst* diffPairParam = builder->emitParam(diffPairType); + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - SLANG_ASSERT(diffPairParam); + SLANG_ASSERT(diffPairParam); - if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + // If this is an `in/inout DifferentialPair<>` parameter, we can't produce + // its primal and diff parts right now because they would represent a reference + // to a pair field, which doesn't make sense since pair types are considered mutable. + // We encode the result as if the param is non-differentiable, and handle it + // with special care at load/store. + return InstPair(diffPairParam, nullptr); + } + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + else + { + auto primal = cloneInst(&cloneEnv, builder, origParam); + IRInst* diff = nullptr; + if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); + diff = builder->emitParam(diffType); } - // If this is an `in/inout DifferentialPair<>` parameter, we can't produce - // its primal and diff parts right now because they would represent a reference - // to a pair field, which doesn't make sense since pair types are considered mutable. - // We encode the result as if the param is non-differentiable, and handle it - // with special care at load/store. - return InstPair(diffPairParam, nullptr); + return InstPair(primal, diff); } - - - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); } // Returns "d<var-name>" to use as a name hint for variables and parameters. @@ -1313,42 +1327,49 @@ struct JVPTranscriber switch(origInst->getOp()) { case kIROp_unconditionalBranch: + case kIROp_loop: auto origBranch = as<IRUnconditionalBranch>(origInst); // Grab the differentials for any phi nodes. - List<IRInst*> pairArgs; + List<IRInst*> newArgs; for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) { auto origArg = origBranch->getArg(ii); + auto primalArg = lookupPrimalInst(origArg); + newArgs.add(primalArg); - IRInst* pairArg = nullptr; - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType())) + if (differentiateType(builder, primalArg->getDataType())) { auto diffArg = lookupDiffInst(origArg, nullptr); - if (!diffArg) - { - diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType()); - } - - pairArg = builder->emitMakeDifferentialPair( - diffPairType, - lookupPrimalInst(origArg), - diffArg); - } - else - { - pairArg = lookupPrimalInst(origArg); + if (diffArg) + newArgs.add(diffArg); } - pairArgs.add(pairArg); } IRInst* diffBranch = nullptr; if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) { - diffBranch = builder->emitBranch( - as<IRBlock>(diffBlock), - pairArgs.getCount(), - pairArgs.getBuffer()); + if (auto origLoop = as<IRLoop>(origInst)) + { + auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + List<IRInst*> operands; + operands.add(breakBlock); + operands.add(continueBlock); + operands.addRange(newArgs); + diffBranch = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + operands.getCount(), + operands.getBuffer()); + } + else + { + diffBranch = builder->emitBranch( + as<IRBlock>(diffBlock), + newArgs.getCount(), + newArgs.getBuffer()); + } } // For now, every block in the original fn must have a corresponding @@ -2517,5 +2538,4 @@ void stripAutoDiffDecorations(IRModule* module) stripAutoDiffDecorationsFromChildren(module->getModuleInst()); } - } |
