summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff-unzip.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-01 14:18:57 -0800
committerGitHub <noreply@github.com>2023-02-01 14:18:57 -0800
commitbbd1e1786401bb88c34802b987d4da72e2364503 (patch)
tree99a4be95ae517fd710fc032a1debdac917dd3ac2 /source/slang/slang-ir-autodiff-unzip.cpp
parentc5895fb0b82fd14fbe45b58d5fc7f75d67625d15 (diff)
Support `out` parameters in backward differentiation. (#2619)
* Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp82
1 files changed, 54 insertions, 28 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index daf6e44d4..378ea1cc2 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -279,7 +279,7 @@ struct ExtractPrimalFuncContext
inst);
}
- IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& outIntermediateType)
{
IRBuilder builder(sharedBuilder);
@@ -369,33 +369,59 @@ struct ExtractPrimalFuncContext
builder.emitStore(outIntermediary, defVal);
// The primal func will not have the result derivative param (second to last param), so we remove it.
- auto resultDerivativeParam = func->getLastParam()->getPrevParam();
- SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses());
- resultDerivativeParam->removeAndDeallocate();
+ if (isResultDifferentiable)
+ {
+ auto resultDerivativeParam = func->getLastParam()->getPrevParam();
+ SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses());
+ resultDerivativeParam->removeAndDeallocate();
+ }
- // Finally, go through parameters and turn DifferentiablePair<T> back to T.
- for (auto param : func->getParams())
+ // Finally, go through parameters and translate their type back to primal type.
+ for (auto param = func->getFirstParam(); param;)
{
- IRInst* valueType = param->getDataType();
- auto inoutType = as<IRPtrTypeBase>(param->getDataType());
- if (inoutType) valueType = inoutType->getValueType();
- auto diffPairType = as<IRDifferentialPairType>(valueType);
- if (!diffPairType) continue;
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
-
- auto originalValueType = diffPairType->getValueType();
-
- // Create a local var to act as the old param.
- auto tempVar = builder.emitVar(diffPairType);
- param->replaceUsesWith(tempVar);
- auto pairValue = builder.emitMakeDifferentialPair(
- diffPairType,
- param,
- backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType));
- builder.emitStore(tempVar, pairValue);
-
- // Change the param type to original type.
- param->setFullType(originalValueType);
+ auto next = param->getNextParam();
+ [this, firstBlock, &builder, param]()
+ {
+ for (auto use = param->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->getOp() == kIROp_AutoDiffOriginalValueDecoration)
+ {
+ use->getUser()->getParent()->replaceUsesWith(param);
+ return;
+ }
+ else if (use->getUser()->getOp() == kIROp_OutParamReverseGradientDecoration)
+ {
+ // This is a propagate func specific parameter, we should remove it.
+ SLANG_RELEASE_ASSERT(!param->hasMoreThanOneUse());
+ param->removeAndDeallocate();
+ return;
+ }
+ }
+
+ IRInst* valueType = param->getDataType();
+ auto inoutType = as<IRPtrTypeBase>(param->getDataType());
+ if (inoutType) valueType = inoutType->getValueType();
+ auto diffPairType = as<IRDifferentialPairType>(valueType);
+ if (!diffPairType)
+ return;
+
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+
+ auto originalValueType = diffPairType->getValueType();
+
+ // Create a local var to act as the old param.
+ auto tempVar = builder.emitVar(diffPairType);
+ param->replaceUsesWith(tempVar);
+ auto pairValue = builder.emitMakeDifferentialPair(
+ diffPairType,
+ param,
+ backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType));
+ builder.emitStore(tempVar, pairValue);
+
+ // Change the param type to original type.
+ param->setFullType(originalValueType);
+ }();
+ param = next;
}
return unzippedFunc;
@@ -420,7 +446,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
}
IRFunc* DiffUnzipPass::extractPrimalFunc(
- IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType)
+ IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType)
{
IRBuilder builder(this->autodiffContext->sharedBuilder);
builder.setInsertBefore(func);
@@ -434,7 +460,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber);
intermediateType = nullptr;
- auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, intermediateType);
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, isResultDifferentiable, intermediateType);
if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>())
{