From 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 9 Nov 2022 19:19:17 -0800 Subject: Add `[ForwardDerivativeOf]` attribute. (#2501) * Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He --- source/slang/slang-ir-diff-jvp.cpp | 112 ++++++++++++++++++++++--------------- 1 file changed, 66 insertions(+), 46 deletions(-) (limited to 'source/slang/slang-ir-diff-jvp.cpp') diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 574db2036..4c7a132d0 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -847,37 +847,51 @@ struct JVPTranscriber cloneInst(&cloneEnv, builder, origParam), nullptr); } - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + + // Is this param a phi node or a function parameter? + auto func = as(origParam->getParent()->getParent()); + bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); + if (isFuncParam) { - IRInst* diffPairParam = builder->emitParam(diffPairType); + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - SLANG_ASSERT(diffPairParam); + SLANG_ASSERT(diffPairParam); - if (auto pairType = as(diffPairParam->getDataType())) + if (auto pairType = as(diffPairParam->getDataType())) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + // If this is an `in/inout DifferentialPair<>` parameter, we can't produce + // its primal and diff parts right now because they would represent a reference + // to a pair field, which doesn't make sense since pair types are considered mutable. + // We encode the result as if the param is non-differentiable, and handle it + // with special care at load/store. + return InstPair(diffPairParam, nullptr); + } + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + else + { + auto primal = cloneInst(&cloneEnv, builder, origParam); + IRInst* diff = nullptr; + if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); + diff = builder->emitParam(diffType); } - // If this is an `in/inout DifferentialPair<>` parameter, we can't produce - // its primal and diff parts right now because they would represent a reference - // to a pair field, which doesn't make sense since pair types are considered mutable. - // We encode the result as if the param is non-differentiable, and handle it - // with special care at load/store. - return InstPair(diffPairParam, nullptr); + return InstPair(primal, diff); } - - - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); } // Returns "d" to use as a name hint for variables and parameters. @@ -1313,42 +1327,49 @@ struct JVPTranscriber switch(origInst->getOp()) { case kIROp_unconditionalBranch: + case kIROp_loop: auto origBranch = as(origInst); // Grab the differentials for any phi nodes. - List pairArgs; + List newArgs; for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) { auto origArg = origBranch->getArg(ii); + auto primalArg = lookupPrimalInst(origArg); + newArgs.add(primalArg); - IRInst* pairArg = nullptr; - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType())) + if (differentiateType(builder, primalArg->getDataType())) { auto diffArg = lookupDiffInst(origArg, nullptr); - if (!diffArg) - { - diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType()); - } - - pairArg = builder->emitMakeDifferentialPair( - diffPairType, - lookupPrimalInst(origArg), - diffArg); - } - else - { - pairArg = lookupPrimalInst(origArg); + if (diffArg) + newArgs.add(diffArg); } - pairArgs.add(pairArg); } IRInst* diffBranch = nullptr; if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) { - diffBranch = builder->emitBranch( - as(diffBlock), - pairArgs.getCount(), - pairArgs.getBuffer()); + if (auto origLoop = as(origInst)) + { + auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + List operands; + operands.add(breakBlock); + operands.add(continueBlock); + operands.addRange(newArgs); + diffBranch = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + operands.getCount(), + operands.getBuffer()); + } + else + { + diffBranch = builder->emitBranch( + as(diffBlock), + newArgs.getCount(), + newArgs.getBuffer()); + } } // For now, every block in the original fn must have a corresponding @@ -2517,5 +2538,4 @@ void stripAutoDiffDecorations(IRModule* module) stripAutoDiffDecorationsFromChildren(module->getModuleInst()); } - } -- cgit v1.2.3