summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-09 19:19:17 -0800
committerGitHub <noreply@github.com>2022-11-09 19:19:17 -0800
commit004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch)
treecbc942746bab043da0eb5298993d95f9665dfddf /source/slang/slang-ir-diff-jvp.cpp
parentcedd93690c63188cf98e452c9d104cf51aad6c4e (diff)
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp112
1 files changed, 66 insertions, 46 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 574db2036..4c7a132d0 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -847,37 +847,51 @@ struct JVPTranscriber
cloneInst(&cloneEnv, builder, origParam),
nullptr);
}
-
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+
+ // Is this param a phi node or a function parameter?
+ auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent());
+ bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
+ if (isFuncParam)
{
- IRInst* diffPairParam = builder->emitParam(diffPairType);
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
- SLANG_ASSERT(diffPairParam);
+ SLANG_ASSERT(diffPairParam);
- if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
+ // its primal and diff parts right now because they would represent a reference
+ // to a pair field, which doesn't make sense since pair types are considered mutable.
+ // We encode the result as if the param is non-differentiable, and handle it
+ // with special care at load/store.
+ return InstPair(diffPairParam, nullptr);
+ }
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+ else
+ {
+ auto primal = cloneInst(&cloneEnv, builder, origParam);
+ IRInst* diff = nullptr;
+ if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType))
{
- return InstPair(
- builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- diffPairParam));
+ diff = builder->emitParam(diffType);
}
- // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
- // its primal and diff parts right now because they would represent a reference
- // to a pair field, which doesn't make sense since pair types are considered mutable.
- // We encode the result as if the param is non-differentiable, and handle it
- // with special care at load/store.
- return InstPair(diffPairParam, nullptr);
+ return InstPair(primal, diff);
}
-
-
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
}
// Returns "d<var-name>" to use as a name hint for variables and parameters.
@@ -1313,42 +1327,49 @@ struct JVPTranscriber
switch(origInst->getOp())
{
case kIROp_unconditionalBranch:
+ case kIROp_loop:
auto origBranch = as<IRUnconditionalBranch>(origInst);
// Grab the differentials for any phi nodes.
- List<IRInst*> pairArgs;
+ List<IRInst*> newArgs;
for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
{
auto origArg = origBranch->getArg(ii);
+ auto primalArg = lookupPrimalInst(origArg);
+ newArgs.add(primalArg);
- IRInst* pairArg = nullptr;
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType()))
+ if (differentiateType(builder, primalArg->getDataType()))
{
auto diffArg = lookupDiffInst(origArg, nullptr);
- if (!diffArg)
- {
- diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType());
- }
-
- pairArg = builder->emitMakeDifferentialPair(
- diffPairType,
- lookupPrimalInst(origArg),
- diffArg);
- }
- else
- {
- pairArg = lookupPrimalInst(origArg);
+ if (diffArg)
+ newArgs.add(diffArg);
}
- pairArgs.add(pairArg);
}
IRInst* diffBranch = nullptr;
if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock()))
{
- diffBranch = builder->emitBranch(
- as<IRBlock>(diffBlock),
- pairArgs.getCount(),
- pairArgs.getBuffer());
+ if (auto origLoop = as<IRLoop>(origInst))
+ {
+ auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
+ auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
+ List<IRInst*> operands;
+ operands.add(breakBlock);
+ operands.add(continueBlock);
+ operands.addRange(newArgs);
+ diffBranch = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_loop,
+ operands.getCount(),
+ operands.getBuffer());
+ }
+ else
+ {
+ diffBranch = builder->emitBranch(
+ as<IRBlock>(diffBlock),
+ newArgs.getCount(),
+ newArgs.getBuffer());
+ }
}
// For now, every block in the original fn must have a corresponding
@@ -2517,5 +2538,4 @@ void stripAutoDiffDecorations(IRModule* module)
stripAutoDiffDecorationsFromChildren(module->getModuleInst());
}
-
}