diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-01 14:18:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-01 14:18:57 -0800 |
| commit | bbd1e1786401bb88c34802b987d4da72e2364503 (patch) | |
| tree | 99a4be95ae517fd710fc032a1debdac917dd3ac2 /source/slang/slang-ir-autodiff-transcriber-base.cpp | |
| parent | c5895fb0b82fd14fbe45b58d5fc7f75d67625d15 (diff) | |
Support `out` parameters in backward differentiation. (#2619)
* Support `out` parameters in backward differentiation.
* Fixes.
* Fix cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 65 |
1 files changed, 11 insertions, 54 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 520c6d276..8f21e8c62 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -17,16 +17,6 @@ DiagnosticSink* AutoDiffTranscriberBase::getSink() return sink; } -String AutoDiffTranscriberBase::makeDiffPairName(IRInst* origVar) -{ - if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) - { - return ("dp" + String(namehintDecoration->getName())); - } - - return String(""); -} - void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diffInst) { if (hasDifferentialInst(origInst)) @@ -523,46 +513,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); if (isFuncParam) { - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as<IRDifferentialPairType>(diffPairType)) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) - { - auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType()); - - return InstPair( - builder->emitDifferentialPairAddressPrimal(diffPairParam), - builder->emitDifferentialPairAddressDifferential( - builder->getPtrType( - kIROp_PtrType, - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), - diffPairParam)); - } - } - - auto primalInst = cloneInst(&cloneEnv, builder, origParam); - if (auto primalParam = as<IRParam>(primalInst)) - { - SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); - primalParam->removeFromParent(); - builder->getInsertLoc().getBlock()->addParam(primalParam); - } - return InstPair(primalInst, nullptr); + return transcribeFuncParam(builder, origParam, primalDataType); } else { @@ -617,10 +568,14 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I switch (diffType->getOp()) { case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); + { + auto makeDiffPair = builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); + builder->markInstAsDifferential(makeDiffPair, as<IRDifferentialPairType>(diffType)->getValueType()); + return makeDiffPair; + } } if (auto arrayType = as<IRArrayType>(primalType)) @@ -647,6 +602,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I { auto wt = lookupInterface->getWitnessTable(); zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); + builder->markInstAsDifferential(zeroMethod); } } SLANG_RELEASE_ASSERT(zeroMethod); @@ -759,6 +715,7 @@ InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn* IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); IRInst* primalReturn = builder->emitReturn(primalReturnVal); + builder->markInstAsMixedDifferential(primalReturn, nullptr); return InstPair(primalReturn, nullptr); } |
