summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-unzip.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-03 16:44:33 -0800
committerGitHub <noreply@github.com>2023-02-03 16:44:33 -0800
commit228e71dab7dfa18ece979f4099ec0c7d1e37e5ff (patch)
treeff357f4aaed2dab25ae9e3665a97a7f3e6be32ef /source/slang/slang-ir-autodiff-unzip.cpp
parentee49a62083d28353812185fd0f0c04fb50ca6be0 (diff)
Overhaul `transposeParameterBlock` to support `inout` params. (#2621)
* Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * More bug fixes., * Add test coverage on inout param. * Fix language server hinting for transcribed mutable params. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp108
1 files changed, 48 insertions, 60 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 378ea1cc2..be20d8aa8 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -1,6 +1,7 @@
#include "slang-ir-autodiff-unzip.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-util.h"
+#include "slang-ir-autodiff-rev.h"
namespace Slang
{
@@ -279,7 +280,7 @@ struct ExtractPrimalFuncContext
inst);
}
- IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet<IRInst*>& primalParams, IRInst*& outIntermediateType)
{
IRBuilder builder(sharedBuilder);
@@ -294,6 +295,7 @@ struct ExtractPrimalFuncContext
auto oldIntermediateParam = func->getLastParam();
auto outIntermediary =
builder.emitParam(builder.getInOutType((IRType*)intermediateType));
+ primalParams.Add(outIntermediary);
oldIntermediateParam->replaceUsesWith(outIntermediary);
oldIntermediateParam->removeAndDeallocate();
@@ -368,62 +370,18 @@ struct ExtractPrimalFuncContext
auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType);
builder.emitStore(outIntermediary, defVal);
- // The primal func will not have the result derivative param (second to last param), so we remove it.
- if (isResultDifferentiable)
- {
- auto resultDerivativeParam = func->getLastParam()->getPrevParam();
- SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses());
- resultDerivativeParam->removeAndDeallocate();
- }
-
- // Finally, go through parameters and translate their type back to primal type.
+ // Remove any parameters not in `primalParams` set.
+ List<IRInst*> params;
for (auto param = func->getFirstParam(); param;)
{
- auto next = param->getNextParam();
- [this, firstBlock, &builder, param]()
+ auto nextParam = param->getNextParam();
+ if (!primalParams.Contains(param))
{
- for (auto use = param->firstUse; use; use = use->nextUse)
- {
- if (use->getUser()->getOp() == kIROp_AutoDiffOriginalValueDecoration)
- {
- use->getUser()->getParent()->replaceUsesWith(param);
- return;
- }
- else if (use->getUser()->getOp() == kIROp_OutParamReverseGradientDecoration)
- {
- // This is a propagate func specific parameter, we should remove it.
- SLANG_RELEASE_ASSERT(!param->hasMoreThanOneUse());
- param->removeAndDeallocate();
- return;
- }
- }
-
- IRInst* valueType = param->getDataType();
- auto inoutType = as<IRPtrTypeBase>(param->getDataType());
- if (inoutType) valueType = inoutType->getValueType();
- auto diffPairType = as<IRDifferentialPairType>(valueType);
- if (!diffPairType)
- return;
-
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
-
- auto originalValueType = diffPairType->getValueType();
-
- // Create a local var to act as the old param.
- auto tempVar = builder.emitVar(diffPairType);
- param->replaceUsesWith(tempVar);
- auto pairValue = builder.emitMakeDifferentialPair(
- diffPairType,
- param,
- backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType));
- builder.emitStore(tempVar, pairValue);
-
- // Change the param type to original type.
- param->setFullType(originalValueType);
- }();
- param = next;
+ param->replaceUsesWith(builder.getVoidValue());
+ param->removeAndDeallocate();
+ }
+ param = nextParam;
}
-
return unzippedFunc;
}
};
@@ -446,7 +404,10 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
}
IRFunc* DiffUnzipPass::extractPrimalFunc(
- IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType)
+ IRFunc* func,
+ IRFunc* originalFunc,
+ ParameterBlockTransposeInfo& paramInfo,
+ IRInst*& intermediateType)
{
IRBuilder builder(this->autodiffContext->sharedBuilder);
builder.setInsertBefore(func);
@@ -456,11 +417,31 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
subEnv.parent = &cloneEnv;
auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func));
+ // Remove [KeepAlive] decorations in clonedFunc.
+ for (auto block : clonedFunc->getBlocks())
+ for (auto inst : block->getChildren())
+ if (auto decor = inst->findDecoration<IRKeepAliveDecoration>())
+ decor->removeAndDeallocate();
+
+ // Remove propagate func specific primal insts from cloned func.
+ for (auto inst : paramInfo.propagateFuncSpecificPrimalInsts)
+ {
+ auto newInst = subEnv.mapOldValToNew[inst].GetValue();
+ newInst->removeAndDeallocate();
+ }
+
+ HashSet<IRInst*> newPrimalParams;
+ for (auto param : func->getParams())
+ {
+ if (paramInfo.primalFuncParams.Contains(param))
+ newPrimalParams.Add(subEnv.mapOldValToNew[param].GetValue());
+ }
+
ExtractPrimalFuncContext context;
context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber);
intermediateType = nullptr;
- auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, isResultDifferentiable, intermediateType);
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType);
if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>())
{
@@ -489,21 +470,26 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>())
{
builder.setInsertBefore(inst);
- auto val = builder.emitFieldExtract(
- inst->getDataType(),
- intermediateVar,
- structKeyDecor->getStructKey());
if (inst->getOp() == kIROp_Var)
{
// This is a var for intermediate context.
+ auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType();
+ auto val = builder.emitFieldExtract(
+ valType,
+ intermediateVar,
+ structKeyDecor->getStructKey());
auto tempVar =
- builder.emitVar(cast<IRPtrTypeBase>(inst->getFullType())->getValueType());
+ builder.emitVar(valType);
builder.emitStore(tempVar, val);
inst->replaceUsesWith(tempVar);
}
else
{
// Orindary value.
+ auto val = builder.emitFieldExtract(
+ inst->getFullType(),
+ intermediateVar,
+ structKeyDecor->getStructKey());
inst->replaceUsesWith(val);
}
instsToRemove.add(inst);
@@ -522,6 +508,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
{
inst->removeAndDeallocate();
}
+
+ stripTempDecorations(func);
// Run simplification to DCE unnecessary insts.
eliminateDeadCode(func);