summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-transcriber-base.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-01 14:18:57 -0800
committerGitHub <noreply@github.com>2023-02-01 14:18:57 -0800
commitbbd1e1786401bb88c34802b987d4da72e2364503 (patch)
tree99a4be95ae517fd710fc032a1debdac917dd3ac2 /source/slang/slang-ir-autodiff-transcriber-base.cpp
parentc5895fb0b82fd14fbe45b58d5fc7f75d67625d15 (diff)
Support `out` parameters in backward differentiation. (#2619)
* Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp65
1 files changed, 11 insertions, 54 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 520c6d276..8f21e8c62 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -17,16 +17,6 @@ DiagnosticSink* AutoDiffTranscriberBase::getSink()
return sink;
}
-String AutoDiffTranscriberBase::makeDiffPairName(IRInst* origVar)
-{
- if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
- {
- return ("dp" + String(namehintDecoration->getName()));
- }
-
- return String("");
-}
-
void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
{
if (hasDifferentialInst(origInst))
@@ -523,46 +513,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o
bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
if (isFuncParam)
{
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- IRInst* diffPairParam = builder->emitParam(diffPairType);
-
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
-
- SLANG_ASSERT(diffPairParam);
-
- if (auto pairType = as<IRDifferentialPairType>(diffPairType))
- {
- return InstPair(
- builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- diffPairParam));
- }
- else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
- {
- auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
-
- return InstPair(
- builder->emitDifferentialPairAddressPrimal(diffPairParam),
- builder->emitDifferentialPairAddressDifferential(
- builder->getPtrType(
- kIROp_PtrType,
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)),
- diffPairParam));
- }
- }
-
- auto primalInst = cloneInst(&cloneEnv, builder, origParam);
- if (auto primalParam = as<IRParam>(primalInst))
- {
- SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
- primalParam->removeFromParent();
- builder->getInsertLoc().getBlock()->addParam(primalParam);
- }
- return InstPair(primalInst, nullptr);
+ return transcribeFuncParam(builder, origParam, primalDataType);
}
else
{
@@ -617,10 +568,14 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
switch (diffType->getOp())
{
case kIROp_DifferentialPairType:
- return builder->emitMakeDifferentialPair(
- diffType,
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ {
+ auto makeDiffPair = builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ builder->markInstAsDifferential(makeDiffPair, as<IRDifferentialPairType>(diffType)->getValueType());
+ return makeDiffPair;
+ }
}
if (auto arrayType = as<IRArrayType>(primalType))
@@ -647,6 +602,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
{
auto wt = lookupInterface->getWitnessTable();
zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
+ builder->markInstAsDifferential(zeroMethod);
}
}
SLANG_RELEASE_ASSERT(zeroMethod);
@@ -759,6 +715,7 @@ InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn*
IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
IRInst* primalReturn = builder->emitReturn(primalReturnVal);
+ builder->markInstAsMixedDifferential(primalReturn, nullptr);
return InstPair(primalReturn, nullptr);
}