summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-03 16:44:33 -0800
committerGitHub <noreply@github.com>2023-02-03 16:44:33 -0800
commit228e71dab7dfa18ece979f4099ec0c7d1e37e5ff (patch)
treeff357f4aaed2dab25ae9e3665a97a7f3e6be32ef /source/slang/slang-ir-autodiff-rev.cpp
parentee49a62083d28353812185fd0f0c04fb50ca6be0 (diff)
Overhaul `transposeParameterBlock` to support `inout` params. (#2621)
* Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * More bug fixes., * Add test coverage on inout param. * Fix language server hinting for transcribed mutable params. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp477
1 files changed, 406 insertions, 71 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 9c63a4012..67387e83a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -417,7 +417,10 @@ namespace Slang
}
else
{
- auto var = builder.emitVar(primalParamType);
+ auto primalPtrType = as<IRPtrTypeBase>(primalParamType);
+ SLANG_RELEASE_ASSERT(primalPtrType);
+ auto primalValueType = primalPtrType->getValueType();
+ auto var = builder.emitVar(primalValueType);
primalArgs.add(var);
}
primalTypes.add(primalParamType);
@@ -515,12 +518,12 @@ namespace Slang
{
auto ptrType = as<IRPtrTypeBase>(param->getDataType());
auto tempVar = builder.emitVar(ptrType->getValueType());
+ param->replaceUsesWith(tempVar);
mapParamToTempVar[param] = tempVar;
- if (param->getOp() != kIROp_OutType)
+ if (ptrType->getOp() != kIROp_OutType)
{
builder.emitStore(tempVar, builder.emitLoad(param));
}
- param->replaceUsesWith(tempVar);
}
for (auto block : func->getBlocks())
@@ -578,6 +581,43 @@ namespace Slang
return result;
}
+ void eliminateRedundantLoad(IRFunc* func)
+ {
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst;)
+ {
+ auto nextInst = inst->getNextInst();
+ if (auto load = as<IRLoad>(inst))
+ {
+ for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst())
+ {
+ if (auto store = as<IRStore>(prev))
+ {
+ if (store->getPtr() == load->getPtr())
+ {
+ // If the load is preceeded by a store without any side-effect insts in-between, remove the load.
+ auto value = store->getVal();
+ load->replaceUsesWith(value);
+ load->removeAndDeallocate();
+ break;
+ }
+ }
+ else if (as<IRCall>(prev))
+ {
+ break;
+ }
+ else if (prev->mightHaveSideEffects())
+ {
+ break;
+ }
+ }
+ }
+ inst = nextInst;
+ }
+ }
+ }
+
// Create a copy of originalFunc's forward derivative in the same generic context (if any) of
// `diffPropagateFunc`.
IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc(
@@ -621,6 +661,9 @@ namespace Slang
// Remove the clone of original func.
primalOuterParent->removeAndDeallocate();
+ // Remove redundant loads since they interfere with transposition logic.
+ eliminateRedundantLoad(fwdDiffFunc);
+
// Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`.
if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc)))
{
@@ -645,6 +688,7 @@ namespace Slang
}
fwdParentGeneric->removeAndDeallocate();
}
+
return fwdDiffFunc;
}
@@ -665,6 +709,20 @@ namespace Slang
return InstPair(primalInst, nullptr);
}
+ // Keep primal param replacement insts alive during DCE.
+ static void _lockPrimalParamReplacementInsts(IRBuilder* builder, ParameterBlockTransposeInfo& paramInfo)
+ {
+ for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc)
+ builder->addKeepAliveDecoration(kv.Value);
+ }
+
+ // Remove [KeepAlive] decorations for primal param replacement insts.
+ static void _unlockPrimalParamReplacementInsts(ParameterBlockTransposeInfo& paramInfo)
+ {
+ for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc)
+ kv.Value->findDecoration<IRKeepAliveDecoration>()->removeAndDeallocate();
+ }
+
// Transcribe a function definition.
void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc)
{
@@ -702,9 +760,9 @@ namespace Slang
// Copy primal insts to the first block of the unzipped function, copy diff insts to the
// second block of the unzipped function.
//
- IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+ diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+ IRFunc* unzippedFwdDiffFunc = fwdDiffFunc;
-
// Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell.
builder->setInsertInto(diffPropagateFunc->getParent());
{
@@ -717,13 +775,19 @@ namespace Slang
}
// Transpose the first block (parameter block)
- List<IRInst*> primalFuncSpecificParams;
- auto dOutParameter = transposeParameterBlock(builder, diffPropagateFunc, primalFuncSpecificParams, isResultDifferentiable);
+ auto paramTransposeInfo =
+ splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable);
+
+ // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc
+ // may be used by write back logic that we are going to insert later.
+ // Before then we want to keep them alive.
+ _lockPrimalParamReplacementInsts(builder, paramTransposeInfo);
builder->setInsertInto(diffPropagateFunc);
- // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
- DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr};
+ // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
+ // derivative of the return value.
+ DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam, nullptr};
diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
eliminateDeadCode(diffPropagateFunc);
@@ -732,32 +796,36 @@ namespace Slang
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
- diffPropagateFunc, primalFunc, isResultDifferentiable, intermediateType);
+ diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType);
- // Clean up by deallocating the tempoarary forward derivative func.
- fwdDiffFunc->removeAndDeallocate();
+ // At this point the unzipped func is just an empty shell
+ // and we can simply remove it.
+ unzippedFwdDiffFunc->removeAndDeallocate();
- // Remove primalFuncSpecificParams.
- for (auto specificParam : primalFuncSpecificParams)
+ // Write back derivatives to inout parameters.
+ writeBackDerivativeToInOutParams(paramTransposeInfo, diffPropagateFunc);
+
+ // Remove primalFunc specific params.
+ List<IRInst*> paramsToRemove;
+ for (auto param : diffPropagateFunc->getParams())
+ {
+ if (!paramTransposeInfo.propagateFuncParams.Contains(param))
+ paramsToRemove.add(param);
+ }
+ for (auto param : paramsToRemove)
{
- while (auto use = specificParam->firstUse)
+ if (param->hasUses())
{
- if (use->getUser()->getOp() == kIROp_Store && use == use->getUser()->getOperands())
- {
- use->getUser()->removeAndDeallocate();
- }
- else if (auto decor = as<IRDecoration>(use->getUser()))
- {
- decor->removeAndDeallocate();
- }
- else
- {
- SLANG_UNEXPECTED("unexpected use of transcribed param.");
- }
+ IRInst* replacement = nullptr;
+ paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc.TryGetValue(param, replacement);
+ SLANG_RELEASE_ASSERT(replacement);
+ param->replaceUsesWith(replacement);
}
- specificParam->removeAndDeallocate();
+ param->removeAndDeallocate();
}
+ _unlockPrimalParamReplacementInsts(paramTransposeInfo);
+
// If primal function is nested in a generic, we want to create separate generics for all the associated things
// we have just created.
auto primalOuterGeneric = findOuterGeneric(primalFunc);
@@ -789,87 +857,322 @@ namespace Slang
initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc);
}
- IRInst* BackwardDiffTranscriberBase::transposeParameterBlock(
+ ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock(
IRBuilder* builder,
IRFunc* diffFunc,
- List<IRInst*>& primalFuncSpecificParams,
bool isResultDifferentiable)
{
+ // This method splits transposes the all the parameters for both the primal and propagate computation.
+ // At the end of this method, the parameter block will contain a combination of parameters for
+ // both the to-be-primal function and to-be-propagate function.
+ // We use ParameterBlockTransposeInfo::primalFuncParams and ParameterBlockTransposeInfo::propagateFuncParams
+ // to track which parameters are dedicated to the future primal or propagate func.
+ // A later step will then split the parameters out to each new function.
+
+ ParameterBlockTransposeInfo result;
+
+ // First, we initialize the IR builders and locate the import code insertion points that will
+ // be used for the rest of this method.
+
IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock();
// Find the 'next' block using the terminator inst of the parameter block.
auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator());
auto nextBlock = fwdParamBlockBranch->getTargetBlock();
- builder->setInsertBefore(fwdParamBlockBranch);
+ auto nextBlockBuilder = *builder;
+ nextBlockBuilder.setInsertBefore(nextBlock->getFirstOrdinaryInst());
- List<IRParam*> fwdParams;
- for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam())
+ IRBlock* firstDiffBlock = nullptr;
+ for (auto block : diffFunc->getBlocks())
{
- fwdParams.add(child);
+ if (isDifferentialInst(block))
+ {
+ firstDiffBlock = block;
+ break;
+ }
}
- // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
+ SLANG_RELEASE_ASSERT(firstDiffBlock);
+
+ auto diffBuilder = *builder;
+ diffBuilder.setInsertBefore(firstDiffBlock->getFirstOrdinaryInst());
+
+ builder->setInsertBefore(fwdParamBlockBranch);
+
+ // Collect all the original parameters.
+ List<IRParam*> fwdParams;
+ for (auto param : diffFunc->getParams())
+ fwdParams.add(param);
+
+ // Maintain a set for insts pending removal.
+ OrderedHashSet<IRInst*> instsToRemove;
+
+ // Now we begin the actual processing.
+ // The first step is to transcribe all the existing parameters from the original function.
+ // There are many cases to handle, including different combinations of parameter directions and
+ // whether or not the parameter is differentiable.
+ // To normalize the process for all these cases, we determine the following actions for each parameter:
+ // 1. Should this original parameter be translated to a parameter in the primal func and the propagate func?
+ // if so, we emit a param inst representing the final parameter for that func. If the parameter should be
+ // mapped to both the primal func and the propagate func, we will emit two separate params with their
+ // final type.
+ // 2. If this parameter has a corresponding primal func parameter, we replace all uses of the original
+ // parameter in the primal computation code to the new primal parameter. If any initialization logic
+ // is needed to convert the type of the new primal parameter to what the code was expecting, we insert
+ // that code in the first block.
+ // 3. If this parameter has a correponding propagate func parameter, we replace all uses of the original parameter
+ // in the diff computation code to the new propagate parameter. We insert necessary initialization diff block or the first block
+ // depending on whether we want that logic go through the transposition pass. We may need to replace the uses
+ // to different values/variables depending on whether that use is a read or write.
+ // 4. If the parameter has both corresponding primal and propagate parameters, we also need to consider
+ // how the future propagate function access the primal parameter. We will insert necessary preparation code
+ // that constructs temp vars or values to replace the primal parameter after we remove it from the
+ // propagate func.
+ // Base on above discussion, we need to compute the following values for each parameter:
+ // - diffRefReplacement. What should all read(load) references to this parameter from differential code be replaced to.
+ // - diffRefWriteReplacement. What should all write references to this parameter from differential code be replaced to.
+ // - primalRefReplacement. What should all references to this parameter from primal code be replaced to.
+ // - mapPrimalSpecificParamToReplacementInPropFunc[param]. What should all references to this parameter
+ // from the primal compuation logic in the future propagate function be replaced to.
for (auto fwdParam : fwdParams)
{
- if (auto outType = as<IROutType>(fwdParam->getDataType()))
+ // Define the replacement insts that we are going to fill in for each case.
+ IRInst* diffRefReplacement = nullptr;
+ IRInst* primalRefReplacement = nullptr;
+ IRInst* diffWriteRefReplacement = nullptr;
+
+ // Common logic that computes all the important types we care about.
+ IRDifferentialPairType* diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType());
+ auto inoutType = as<IRInOutType>(fwdParam->getDataType());
+ auto outType = as<IROutType>(fwdParam->getDataType());
+ if (inoutType)
+ diffPairType = as<IRDifferentialPairType>(inoutType->getValueType());
+ else if (outType)
+ diffPairType = as<IRDifferentialPairType>(outType->getValueType());
+ IRType* primalType = nullptr;
+ IRType* diffType = nullptr;
+ if (diffPairType)
+ {
+ primalType = diffPairType->getValueType();
+ diffType = (IRType*)differentiableTypeConformanceContext
+ .getDifferentialTypeFromDiffPairType(builder, diffPairType);
+ }
+
+ // Now we handle each combination of parameter direction x differentiability.
+ if (outType)
{
- IRParam* newPropParam = nullptr;
- IRParam* newPrimalParam = nullptr;
- auto diffPairType = as<IRDifferentialPairType>(outType->getValueType());
+ // Case 1: out parameters.
+ // Out parameters need to be handled differently whether or not it is differentiable,
+ // since the propagate function will not have a corresponding output.
if (diffPairType)
{
// Create dOut param.
- auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(builder, diffPairType);
- newPropParam = builder->emitParam(diffType);
- newPrimalParam = builder->emitParam(builder->getOutType(diffPairType->getValueType()));
+ auto diffParam = builder->emitParam(diffType);
+ result.propagateFuncParams.Add(diffParam);
+ primalRefReplacement = builder->emitParam(builder->getOutType(primalType));
+
+ // Create a local var for read access in pre-transpose code.
+ // This will the var from which we will fetch the final resulting derivative
+ // after transposition.
+ auto tempVar = nextBlockBuilder.emitVar(diffType);
+ nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType);
+
+ // Initialize the var with input diff param at start.
+ // Note that we insert the store in the primal block so it won't get transposed.
+ auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam);
+ nextBlockBuilder.markInstAsDifferential(storeInst, diffPairType);
+ // Since this store inst is specific to propagate function, we track it in a
+ // set so we can remove it when we generate the primal func.
+ result.propagateFuncSpecificPrimalInsts.add(storeInst);
+
+ diffWriteRefReplacement = tempVar;
+ diffRefReplacement = tempVar;
}
else
{
- newPrimalParam = builder->emitParam(outType);
- }
-
- // Create a temp var to represent the original `out` param.
- auto arg = builder->emitVar(outType->getValueType());
- builder->addAutoDiffOriginalValueDecoration(arg, newPrimalParam);
- if (newPropParam)
- {
- builder->addDecoration(arg, kIROp_OutParamReverseGradientDecoration, newPropParam);
+ primalRefReplacement = builder->emitParam(outType);
}
+ result.primalFuncParams.Add(primalRefReplacement);
- fwdParam->replaceUsesWith(arg);
- fwdParam->removeAndDeallocate();
+ // Create a local var for the out param for the primal part of the prop func.
+ auto tempPrimalVar = nextBlockBuilder.emitVar(outType->getValueType());
+ result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar;
- primalFuncSpecificParams.add(newPrimalParam);
+ instsToRemove.Add(fwdParam);
}
- else if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
+ else if (!isRelevantDifferentialPair(fwdParam->getDataType()))
{
+ // Case 2: non differentiable, non output parameters.
+ // If parameter is not an out param and has nothing to do with differentiation,
+ // simply move the parameter to the end.
+ //
+ fwdParam->removeFromParent();
+ fwdDiffParameterBlock->addParam(fwdParam);
+ result.primalFuncParams.Add(fwdParam);
+ result.propagateFuncParams.Add(fwdParam);
+ continue;
+ }
+ else if(!inoutType)
+ {
+ // Case 4: `in` differentiable parameters.
+
+ SLANG_RELEASE_ASSERT(diffPairType);
+
// Create inout version.
auto inoutDiffPairType = builder->getInOutType(diffPairType);
- auto newParam = builder->emitParam(inoutDiffPairType);
-
- // Map the _load_ of the new parameter as the clone of the old one.
- auto newParamLoad = builder->emitLoad(newParam);
- newParamLoad->insertAtStart(nextBlock); // Move to first block _after_ the parameter block.
- fwdParam->replaceUsesWith(newParamLoad);
- fwdParam->removeAndDeallocate();
+ primalRefReplacement = builder->emitParam(primalType);
+ result.primalFuncParams.Add(primalRefReplacement);
+ auto propParam = builder->emitParam(inoutDiffPairType);
+ result.propagateFuncParams.Add(propParam);
+
+ // A reference to this parameter from the diff blocks should be replaced with a load
+ // of the differential component of the pair.
+ auto newParamLoad = diffBuilder.emitLoad(propParam);
+ diffBuilder.markInstAsDifferential(newParamLoad, primalType);
+
+ diffRefReplacement = diffBuilder.emitDifferentialPairGetDifferential(diffType, newParamLoad);
+ diffBuilder.markInstAsDifferential(diffRefReplacement, primalType);
+
+ // Load the primal component from the prop param and use it as replacement for the
+ // primal param in the primal part of the prop func.
+ // Since these are logic specific to propagate function, we will add them to the
+ // `propagateFuncSpecificPrimalInsts` set so we can remove them when we generate the primal func.
+ auto primalReplacementLoad = nextBlockBuilder.emitLoad(propParam);
+ result.propagateFuncSpecificPrimalInsts.add(primalReplacementLoad);
+ auto primalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(primalReplacementLoad);
+ result.propagateFuncSpecificPrimalInsts.add(primalVal);
+ result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = primalVal;
+
+ instsToRemove.Add(fwdParam);
}
else
{
- // Default case (parameter is inout type or has nothing to do with differentiation)
- // Simply move the parameter to the end.
- //
- fwdParam->removeFromParent();
- fwdDiffParameterBlock->addParam(fwdParam);
+ // Case 5: `inout` differentiable parameters.
+ SLANG_ASSERT(inoutType && diffPairType);
+
+ // Process differentiable inout parameters.
+ auto primalParam = builder->emitParam(builder->getInOutType(primalType));
+ result.primalFuncParams.Add(primalParam);
+
+ auto diffParam = builder->emitParam(inoutType);
+ result.propagateFuncParams.Add(diffParam);
+
+ // Primal references to this param is the new primal param.
+ primalRefReplacement = primalParam;
+
+ // Diff references to this param should be replaced with one local temp var
+ // for read and one separate temp var for write.
+
+ // Load the inital diff value.
+ auto loadedParam = nextBlockBuilder.emitLoad(diffParam);
+ auto initDiff = nextBlockBuilder.emitDifferentialPairGetDifferential(diffType, loadedParam);
+
+ // Create a local var for diff read access.
+ auto diffVar = nextBlockBuilder.emitVar(diffType);
+ result.propagateFuncSpecificPrimalInsts.add(diffVar);
+ diffBuilder.markInstAsDifferential(diffVar, diffPairType);
+ diffRefReplacement = diffVar;
+
+ // Clear the diff read var to zero at start of the function.
+ auto dzero = getDifferentialZeroOfType(&nextBlockBuilder, primalType);
+ result.propagateFuncSpecificPrimalInsts.add(dzero);
+ auto initDiffStore = nextBlockBuilder.emitStore(diffVar, dzero);
+ result.propagateFuncSpecificPrimalInsts.add(initDiffStore);
+
+ // Create a local var for diff write access.
+ auto diffWriteVar = nextBlockBuilder.emitVar(diffType);
+ // Initialize write var to 0.
+ auto writeStore = nextBlockBuilder.emitStore(diffWriteVar, initDiff);
+ result.propagateFuncSpecificPrimalInsts.add(writeStore);
+
+ diffWriteRefReplacement = diffWriteVar;
+
+ // Create a local var for the primal logic in the propagate func.
+ auto primalVar = nextBlockBuilder.emitVar(primalType);
+ result.propagateFuncSpecificPrimalInsts.add(primalVar);
+ auto initPrimalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(loadedParam);
+ result.propagateFuncSpecificPrimalInsts.add(initPrimalVal);
+ auto storeInst = nextBlockBuilder.emitStore(primalVar, initPrimalVal);
+ result.propagateFuncSpecificPrimalInsts.add(storeInst);
+ result.mapPrimalSpecificParamToReplacementInPropFunc[primalParam] = primalVar;
+ result.outDiffWritebacks[diffParam] = InstPair(initPrimalVal, diffVar);
+
+ instsToRemove.Add(fwdParam);
+ }
+
+ // We have emitted all the new parameters and computed the replacements for the original
+ // parameter. Now we perform that replacement.
+ List<IRUse*> uses;
+ for (auto use = fwdParam->firstUse; use; use = use->nextUse)
+ uses.add(use);
+ for (auto use : uses)
+ {
+ if (auto primalRef = as<IRPrimalParamRef>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(primalRefReplacement);
+ primalRef->replaceUsesWith(primalRefReplacement);
+ instsToRemove.Add(primalRef);
+ }
+ else if (auto getPrimal = as<IRDifferentialPairGetPrimal>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(primalRefReplacement);
+ getPrimal->replaceUsesWith(primalRefReplacement);
+ instsToRemove.Add(getPrimal);
+ }
+ else if (auto propagateRef = as<IRDiffParamRef>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(diffRefReplacement);
+ auto refUse = propagateRef->firstUse;
+ while (refUse)
+ {
+ auto nextUse = refUse->nextUse;
+ switch (refUse->getUser()->getOp())
+ {
+ case kIROp_Load:
+ refUse->set(diffRefReplacement);
+ break;
+ case kIROp_Store:
+ refUse->set(diffWriteRefReplacement);
+ break;
+ default:
+ SLANG_RELEASE_ASSERT(!diffWriteRefReplacement);
+ refUse->set(diffRefReplacement);
+ }
+ refUse = nextUse;
+ }
+ instsToRemove.Add(propagateRef);
+ }
+ else if (auto getDiff = as<IRDifferentialPairGetDifferential>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(diffRefReplacement);
+ getDiff->replaceUsesWith(diffRefReplacement);
+ instsToRemove.Add(getDiff);
+ }
+ else
+ {
+ // If the user is something else, it'd better be a non relevant parameter.
+ if (diffRefReplacement || diffWriteRefReplacement)
+ SLANG_UNEXPECTED("unknown use of parameter.");
+ use->set(primalRefReplacement);
+ }
}
}
- auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
+ // Actually remove all the insts that we decided to remove in the process.
+ for (auto inst : instsToRemove)
+ {
+ inst->removeAndDeallocate();
+ }
+
- // 2. If the return type of the original function is differentiable,
+ // The next step is to insert new parameters that is not related to any existing parameters.
+ //
+ // If the return type of the original function is differentiable,
// add a parameter for 'derivative of the output' (d_out).
// The type is the second last parameter type of the function.
//
+ auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
IRParam* dOutParam = nullptr;
if (isResultDifferentiable)
{
@@ -878,12 +1181,44 @@ namespace Slang
SLANG_ASSERT(dOutParamType);
dOutParam = builder->emitParam(dOutParamType);
+ result.propagateFuncParams.Add(dOutParam);
}
// Add a parameter for intermediate val.
- builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
+ auto ctxParam = builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
+ result.primalFuncParams.Add(ctxParam);
+ result.propagateFuncParams.Add(ctxParam);
+ result.dOutParam = dOutParam;
+ return result;
+ }
- return dOutParam;
+ void BackwardDiffTranscriberBase::writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc)
+ {
+ IRInst* returnInst = nullptr;
+ for (auto block : diffFunc->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ returnInst = inst;
+ break;
+ }
+ }
+ }
+ SLANG_RELEASE_ASSERT(returnInst);
+
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(returnInst);
+ for (auto& wb : info.outDiffWritebacks)
+ {
+ auto dest = wb.Key;
+ auto srcPrimalVal = wb.Value.primal;
+ auto srcDiffAddr = wb.Value.differential;
+ auto srcDiffVal = builder.emitLoad(srcDiffAddr);
+ auto destVal = builder.emitMakeDifferentialPair(as<IRPtrTypeBase>(dest->getFullType())->getValueType(), srcPrimalVal, srcDiffVal);
+ builder.emitStore(dest, destVal);
+ }
}
InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)