diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-01 14:18:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-01 14:18:57 -0800 |
| commit | bbd1e1786401bb88c34802b987d4da72e2364503 (patch) | |
| tree | 99a4be95ae517fd710fc032a1debdac917dd3ac2 /source/slang/slang-ir-autodiff-unzip.cpp | |
| parent | c5895fb0b82fd14fbe45b58d5fc7f75d67625d15 (diff) | |
Support `out` parameters in backward differentiation. (#2619)
* Support `out` parameters in backward differentiation.
* Fixes.
* Fix cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 82 |
1 files changed, 54 insertions, 28 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index daf6e44d4..378ea1cc2 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -279,7 +279,7 @@ struct ExtractPrimalFuncContext inst); } - IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, IRInst*& outIntermediateType) + IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& outIntermediateType) { IRBuilder builder(sharedBuilder); @@ -369,33 +369,59 @@ struct ExtractPrimalFuncContext builder.emitStore(outIntermediary, defVal); // The primal func will not have the result derivative param (second to last param), so we remove it. - auto resultDerivativeParam = func->getLastParam()->getPrevParam(); - SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses()); - resultDerivativeParam->removeAndDeallocate(); + if (isResultDifferentiable) + { + auto resultDerivativeParam = func->getLastParam()->getPrevParam(); + SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses()); + resultDerivativeParam->removeAndDeallocate(); + } - // Finally, go through parameters and turn DifferentiablePair<T> back to T. - for (auto param : func->getParams()) + // Finally, go through parameters and translate their type back to primal type. + for (auto param = func->getFirstParam(); param;) { - IRInst* valueType = param->getDataType(); - auto inoutType = as<IRPtrTypeBase>(param->getDataType()); - if (inoutType) valueType = inoutType->getValueType(); - auto diffPairType = as<IRDifferentialPairType>(valueType); - if (!diffPairType) continue; - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - - auto originalValueType = diffPairType->getValueType(); - - // Create a local var to act as the old param. - auto tempVar = builder.emitVar(diffPairType); - param->replaceUsesWith(tempVar); - auto pairValue = builder.emitMakeDifferentialPair( - diffPairType, - param, - backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType)); - builder.emitStore(tempVar, pairValue); - - // Change the param type to original type. - param->setFullType(originalValueType); + auto next = param->getNextParam(); + [this, firstBlock, &builder, param]() + { + for (auto use = param->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getOp() == kIROp_AutoDiffOriginalValueDecoration) + { + use->getUser()->getParent()->replaceUsesWith(param); + return; + } + else if (use->getUser()->getOp() == kIROp_OutParamReverseGradientDecoration) + { + // This is a propagate func specific parameter, we should remove it. + SLANG_RELEASE_ASSERT(!param->hasMoreThanOneUse()); + param->removeAndDeallocate(); + return; + } + } + + IRInst* valueType = param->getDataType(); + auto inoutType = as<IRPtrTypeBase>(param->getDataType()); + if (inoutType) valueType = inoutType->getValueType(); + auto diffPairType = as<IRDifferentialPairType>(valueType); + if (!diffPairType) + return; + + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + + auto originalValueType = diffPairType->getValueType(); + + // Create a local var to act as the old param. + auto tempVar = builder.emitVar(diffPairType); + param->replaceUsesWith(tempVar); + auto pairValue = builder.emitMakeDifferentialPair( + diffPairType, + param, + backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType)); + builder.emitStore(tempVar, pairValue); + + // Change the param type to original type. + param->setFullType(originalValueType); + }(); + param = next; } return unzippedFunc; @@ -420,7 +446,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE } IRFunc* DiffUnzipPass::extractPrimalFunc( - IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType) + IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType) { IRBuilder builder(this->autodiffContext->sharedBuilder); builder.setInsertBefore(func); @@ -434,7 +460,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber); intermediateType = nullptr; - auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, intermediateType); + auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, isResultDifferentiable, intermediateType); if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>()) { |
