From bbd1e1786401bb88c34802b987d4da72e2364503 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 1 Feb 2023 14:18:57 -0800 Subject: Support `out` parameters in backward differentiation. (#2619) * Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He --- .../slang/slang-ir-autodiff-transcriber-base.cpp | 65 ++++------------------ 1 file changed, 11 insertions(+), 54 deletions(-) (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp') diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 520c6d276..8f21e8c62 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -17,16 +17,6 @@ DiagnosticSink* AutoDiffTranscriberBase::getSink() return sink; } -String AutoDiffTranscriberBase::makeDiffPairName(IRInst* origVar) -{ - if (auto namehintDecoration = origVar->findDecoration()) - { - return ("dp" + String(namehintDecoration->getName())); - } - - return String(""); -} - void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diffInst) { if (hasDifferentialInst(origInst)) @@ -523,46 +513,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); if (isFuncParam) { - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as(diffPairType)) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - else if (auto pairPtrType = as(diffPairType)) - { - auto ptrInnerPairType = as(pairPtrType->getValueType()); - - return InstPair( - builder->emitDifferentialPairAddressPrimal(diffPairParam), - builder->emitDifferentialPairAddressDifferential( - builder->getPtrType( - kIROp_PtrType, - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), - diffPairParam)); - } - } - - auto primalInst = cloneInst(&cloneEnv, builder, origParam); - if (auto primalParam = as(primalInst)) - { - SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); - primalParam->removeFromParent(); - builder->getInsertLoc().getBlock()->addParam(primalParam); - } - return InstPair(primalInst, nullptr); + return transcribeFuncParam(builder, origParam, primalDataType); } else { @@ -617,10 +568,14 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I switch (diffType->getOp()) { case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as(diffType)->getValueType())); + { + auto makeDiffPair = builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as(diffType)->getValueType())); + builder->markInstAsDifferential(makeDiffPair, as(diffType)->getValueType()); + return makeDiffPair; + } } if (auto arrayType = as(primalType)) @@ -647,6 +602,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I { auto wt = lookupInterface->getWitnessTable(); zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); + builder->markInstAsDifferential(zeroMethod); } } SLANG_RELEASE_ASSERT(zeroMethod); @@ -759,6 +715,7 @@ InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn* IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); IRInst* primalReturn = builder->emitReturn(primalReturnVal); + builder->markInstAsMixedDifferential(primalReturn, nullptr); return InstPair(primalReturn, nullptr); } -- cgit v1.2.3