diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-03 16:44:33 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-03 16:44:33 -0800 |
| commit | 228e71dab7dfa18ece979f4099ec0c7d1e37e5ff (patch) | |
| tree | ff357f4aaed2dab25ae9e3665a97a7f3e6be32ef /source/slang/slang-ir-autodiff-unzip.cpp | |
| parent | ee49a62083d28353812185fd0f0c04fb50ca6be0 (diff) | |
Overhaul `transposeParameterBlock` to support `inout` params. (#2621)
* Overhaul `transposeParameterBlock` to support `inout` params.
* Small bug fixes.
* Bug fix on differentiable intrinsic specialization.
* Fixes.
* Run autodiff tests on CPU.
* Clean up.
* More bug fixes.,
* Add test coverage on inout param.
* Fix language server hinting for transcribed mutable params.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 108 |
1 files changed, 48 insertions, 60 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 378ea1cc2..be20d8aa8 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -1,6 +1,7 @@ #include "slang-ir-autodiff-unzip.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" +#include "slang-ir-autodiff-rev.h" namespace Slang { @@ -279,7 +280,7 @@ struct ExtractPrimalFuncContext inst); } - IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& outIntermediateType) + IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet<IRInst*>& primalParams, IRInst*& outIntermediateType) { IRBuilder builder(sharedBuilder); @@ -294,6 +295,7 @@ struct ExtractPrimalFuncContext auto oldIntermediateParam = func->getLastParam(); auto outIntermediary = builder.emitParam(builder.getInOutType((IRType*)intermediateType)); + primalParams.Add(outIntermediary); oldIntermediateParam->replaceUsesWith(outIntermediary); oldIntermediateParam->removeAndDeallocate(); @@ -368,62 +370,18 @@ struct ExtractPrimalFuncContext auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType); builder.emitStore(outIntermediary, defVal); - // The primal func will not have the result derivative param (second to last param), so we remove it. - if (isResultDifferentiable) - { - auto resultDerivativeParam = func->getLastParam()->getPrevParam(); - SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses()); - resultDerivativeParam->removeAndDeallocate(); - } - - // Finally, go through parameters and translate their type back to primal type. + // Remove any parameters not in `primalParams` set. + List<IRInst*> params; for (auto param = func->getFirstParam(); param;) { - auto next = param->getNextParam(); - [this, firstBlock, &builder, param]() + auto nextParam = param->getNextParam(); + if (!primalParams.Contains(param)) { - for (auto use = param->firstUse; use; use = use->nextUse) - { - if (use->getUser()->getOp() == kIROp_AutoDiffOriginalValueDecoration) - { - use->getUser()->getParent()->replaceUsesWith(param); - return; - } - else if (use->getUser()->getOp() == kIROp_OutParamReverseGradientDecoration) - { - // This is a propagate func specific parameter, we should remove it. - SLANG_RELEASE_ASSERT(!param->hasMoreThanOneUse()); - param->removeAndDeallocate(); - return; - } - } - - IRInst* valueType = param->getDataType(); - auto inoutType = as<IRPtrTypeBase>(param->getDataType()); - if (inoutType) valueType = inoutType->getValueType(); - auto diffPairType = as<IRDifferentialPairType>(valueType); - if (!diffPairType) - return; - - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - - auto originalValueType = diffPairType->getValueType(); - - // Create a local var to act as the old param. - auto tempVar = builder.emitVar(diffPairType); - param->replaceUsesWith(tempVar); - auto pairValue = builder.emitMakeDifferentialPair( - diffPairType, - param, - backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType)); - builder.emitStore(tempVar, pairValue); - - // Change the param type to original type. - param->setFullType(originalValueType); - }(); - param = next; + param->replaceUsesWith(builder.getVoidValue()); + param->removeAndDeallocate(); + } + param = nextParam; } - return unzippedFunc; } }; @@ -446,7 +404,10 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE } IRFunc* DiffUnzipPass::extractPrimalFunc( - IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType) + IRFunc* func, + IRFunc* originalFunc, + ParameterBlockTransposeInfo& paramInfo, + IRInst*& intermediateType) { IRBuilder builder(this->autodiffContext->sharedBuilder); builder.setInsertBefore(func); @@ -456,11 +417,31 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( subEnv.parent = &cloneEnv; auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func)); + // Remove [KeepAlive] decorations in clonedFunc. + for (auto block : clonedFunc->getBlocks()) + for (auto inst : block->getChildren()) + if (auto decor = inst->findDecoration<IRKeepAliveDecoration>()) + decor->removeAndDeallocate(); + + // Remove propagate func specific primal insts from cloned func. + for (auto inst : paramInfo.propagateFuncSpecificPrimalInsts) + { + auto newInst = subEnv.mapOldValToNew[inst].GetValue(); + newInst->removeAndDeallocate(); + } + + HashSet<IRInst*> newPrimalParams; + for (auto param : func->getParams()) + { + if (paramInfo.primalFuncParams.Contains(param)) + newPrimalParams.Add(subEnv.mapOldValToNew[param].GetValue()); + } + ExtractPrimalFuncContext context; context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber); intermediateType = nullptr; - auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, isResultDifferentiable, intermediateType); + auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType); if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>()) { @@ -489,21 +470,26 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>()) { builder.setInsertBefore(inst); - auto val = builder.emitFieldExtract( - inst->getDataType(), - intermediateVar, - structKeyDecor->getStructKey()); if (inst->getOp() == kIROp_Var) { // This is a var for intermediate context. + auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType(); + auto val = builder.emitFieldExtract( + valType, + intermediateVar, + structKeyDecor->getStructKey()); auto tempVar = - builder.emitVar(cast<IRPtrTypeBase>(inst->getFullType())->getValueType()); + builder.emitVar(valType); builder.emitStore(tempVar, val); inst->replaceUsesWith(tempVar); } else { // Orindary value. + auto val = builder.emitFieldExtract( + inst->getFullType(), + intermediateVar, + structKeyDecor->getStructKey()); inst->replaceUsesWith(val); } instsToRemove.add(inst); @@ -522,6 +508,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( { inst->removeAndDeallocate(); } + + stripTempDecorations(func); // Run simplification to DCE unnecessary insts. eliminateDeadCode(func); |
