diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-03 16:44:33 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-03 16:44:33 -0800 |
| commit | 228e71dab7dfa18ece979f4099ec0c7d1e37e5ff (patch) | |
| tree | ff357f4aaed2dab25ae9e3665a97a7f3e6be32ef /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | ee49a62083d28353812185fd0f0c04fb50ca6be0 (diff) | |
Overhaul `transposeParameterBlock` to support `inout` params. (#2621)
* Overhaul `transposeParameterBlock` to support `inout` params.
* Small bug fixes.
* Bug fix on differentiable intrinsic specialization.
* Fixes.
* Run autodiff tests on CPU.
* Clean up.
* More bug fixes.,
* Add test coverage on inout param.
* Fix language server hinting for transcribed mutable params.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 477 |
1 files changed, 406 insertions, 71 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 9c63a4012..67387e83a 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -417,7 +417,10 @@ namespace Slang } else { - auto var = builder.emitVar(primalParamType); + auto primalPtrType = as<IRPtrTypeBase>(primalParamType); + SLANG_RELEASE_ASSERT(primalPtrType); + auto primalValueType = primalPtrType->getValueType(); + auto var = builder.emitVar(primalValueType); primalArgs.add(var); } primalTypes.add(primalParamType); @@ -515,12 +518,12 @@ namespace Slang { auto ptrType = as<IRPtrTypeBase>(param->getDataType()); auto tempVar = builder.emitVar(ptrType->getValueType()); + param->replaceUsesWith(tempVar); mapParamToTempVar[param] = tempVar; - if (param->getOp() != kIROp_OutType) + if (ptrType->getOp() != kIROp_OutType) { builder.emitStore(tempVar, builder.emitLoad(param)); } - param->replaceUsesWith(tempVar); } for (auto block : func->getBlocks()) @@ -578,6 +581,43 @@ namespace Slang return result; } + void eliminateRedundantLoad(IRFunc* func) + { + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst;) + { + auto nextInst = inst->getNextInst(); + if (auto load = as<IRLoad>(inst)) + { + for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst()) + { + if (auto store = as<IRStore>(prev)) + { + if (store->getPtr() == load->getPtr()) + { + // If the load is preceeded by a store without any side-effect insts in-between, remove the load. + auto value = store->getVal(); + load->replaceUsesWith(value); + load->removeAndDeallocate(); + break; + } + } + else if (as<IRCall>(prev)) + { + break; + } + else if (prev->mightHaveSideEffects()) + { + break; + } + } + } + inst = nextInst; + } + } + } + // Create a copy of originalFunc's forward derivative in the same generic context (if any) of // `diffPropagateFunc`. IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc( @@ -621,6 +661,9 @@ namespace Slang // Remove the clone of original func. primalOuterParent->removeAndDeallocate(); + // Remove redundant loads since they interfere with transposition logic. + eliminateRedundantLoad(fwdDiffFunc); + // Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`. if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc))) { @@ -645,6 +688,7 @@ namespace Slang } fwdParentGeneric->removeAndDeallocate(); } + return fwdDiffFunc; } @@ -665,6 +709,20 @@ namespace Slang return InstPair(primalInst, nullptr); } + // Keep primal param replacement insts alive during DCE. + static void _lockPrimalParamReplacementInsts(IRBuilder* builder, ParameterBlockTransposeInfo& paramInfo) + { + for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc) + builder->addKeepAliveDecoration(kv.Value); + } + + // Remove [KeepAlive] decorations for primal param replacement insts. + static void _unlockPrimalParamReplacementInsts(ParameterBlockTransposeInfo& paramInfo) + { + for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc) + kv.Value->findDecoration<IRKeepAliveDecoration>()->removeAndDeallocate(); + } + // Transcribe a function definition. void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc) { @@ -702,9 +760,9 @@ namespace Slang // Copy primal insts to the first block of the unzipped function, copy diff insts to the // second block of the unzipped function. // - IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); + diffUnzipPass->unzipDiffInsts(fwdDiffFunc); + IRFunc* unzippedFwdDiffFunc = fwdDiffFunc; - // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell. builder->setInsertInto(diffPropagateFunc->getParent()); { @@ -717,13 +775,19 @@ namespace Slang } // Transpose the first block (parameter block) - List<IRInst*> primalFuncSpecificParams; - auto dOutParameter = transposeParameterBlock(builder, diffPropagateFunc, primalFuncSpecificParams, isResultDifferentiable); + auto paramTransposeInfo = + splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable); + + // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc + // may be used by write back logic that we are going to insert later. + // Before then we want to keep them alive. + _lockPrimalParamReplacementInsts(builder, paramTransposeInfo); builder->setInsertInto(diffPropagateFunc); - // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the - DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr}; + // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the + // derivative of the return value. + DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam, nullptr}; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); eliminateDeadCode(diffPropagateFunc); @@ -732,32 +796,36 @@ namespace Slang // with the intermediate results computed from the extracted func. IRInst* intermediateType = nullptr; auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( - diffPropagateFunc, primalFunc, isResultDifferentiable, intermediateType); + diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType); - // Clean up by deallocating the tempoarary forward derivative func. - fwdDiffFunc->removeAndDeallocate(); + // At this point the unzipped func is just an empty shell + // and we can simply remove it. + unzippedFwdDiffFunc->removeAndDeallocate(); - // Remove primalFuncSpecificParams. - for (auto specificParam : primalFuncSpecificParams) + // Write back derivatives to inout parameters. + writeBackDerivativeToInOutParams(paramTransposeInfo, diffPropagateFunc); + + // Remove primalFunc specific params. + List<IRInst*> paramsToRemove; + for (auto param : diffPropagateFunc->getParams()) + { + if (!paramTransposeInfo.propagateFuncParams.Contains(param)) + paramsToRemove.add(param); + } + for (auto param : paramsToRemove) { - while (auto use = specificParam->firstUse) + if (param->hasUses()) { - if (use->getUser()->getOp() == kIROp_Store && use == use->getUser()->getOperands()) - { - use->getUser()->removeAndDeallocate(); - } - else if (auto decor = as<IRDecoration>(use->getUser())) - { - decor->removeAndDeallocate(); - } - else - { - SLANG_UNEXPECTED("unexpected use of transcribed param."); - } + IRInst* replacement = nullptr; + paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc.TryGetValue(param, replacement); + SLANG_RELEASE_ASSERT(replacement); + param->replaceUsesWith(replacement); } - specificParam->removeAndDeallocate(); + param->removeAndDeallocate(); } + _unlockPrimalParamReplacementInsts(paramTransposeInfo); + // If primal function is nested in a generic, we want to create separate generics for all the associated things // we have just created. auto primalOuterGeneric = findOuterGeneric(primalFunc); @@ -789,87 +857,322 @@ namespace Slang initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc); } - IRInst* BackwardDiffTranscriberBase::transposeParameterBlock( + ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, - List<IRInst*>& primalFuncSpecificParams, bool isResultDifferentiable) { + // This method splits transposes the all the parameters for both the primal and propagate computation. + // At the end of this method, the parameter block will contain a combination of parameters for + // both the to-be-primal function and to-be-propagate function. + // We use ParameterBlockTransposeInfo::primalFuncParams and ParameterBlockTransposeInfo::propagateFuncParams + // to track which parameters are dedicated to the future primal or propagate func. + // A later step will then split the parameters out to each new function. + + ParameterBlockTransposeInfo result; + + // First, we initialize the IR builders and locate the import code insertion points that will + // be used for the rest of this method. + IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock(); // Find the 'next' block using the terminator inst of the parameter block. auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator()); auto nextBlock = fwdParamBlockBranch->getTargetBlock(); - builder->setInsertBefore(fwdParamBlockBranch); + auto nextBlockBuilder = *builder; + nextBlockBuilder.setInsertBefore(nextBlock->getFirstOrdinaryInst()); - List<IRParam*> fwdParams; - for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam()) + IRBlock* firstDiffBlock = nullptr; + for (auto block : diffFunc->getBlocks()) { - fwdParams.add(child); + if (isDifferentialInst(block)) + { + firstDiffBlock = block; + break; + } } - // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> + SLANG_RELEASE_ASSERT(firstDiffBlock); + + auto diffBuilder = *builder; + diffBuilder.setInsertBefore(firstDiffBlock->getFirstOrdinaryInst()); + + builder->setInsertBefore(fwdParamBlockBranch); + + // Collect all the original parameters. + List<IRParam*> fwdParams; + for (auto param : diffFunc->getParams()) + fwdParams.add(param); + + // Maintain a set for insts pending removal. + OrderedHashSet<IRInst*> instsToRemove; + + // Now we begin the actual processing. + // The first step is to transcribe all the existing parameters from the original function. + // There are many cases to handle, including different combinations of parameter directions and + // whether or not the parameter is differentiable. + // To normalize the process for all these cases, we determine the following actions for each parameter: + // 1. Should this original parameter be translated to a parameter in the primal func and the propagate func? + // if so, we emit a param inst representing the final parameter for that func. If the parameter should be + // mapped to both the primal func and the propagate func, we will emit two separate params with their + // final type. + // 2. If this parameter has a corresponding primal func parameter, we replace all uses of the original + // parameter in the primal computation code to the new primal parameter. If any initialization logic + // is needed to convert the type of the new primal parameter to what the code was expecting, we insert + // that code in the first block. + // 3. If this parameter has a correponding propagate func parameter, we replace all uses of the original parameter + // in the diff computation code to the new propagate parameter. We insert necessary initialization diff block or the first block + // depending on whether we want that logic go through the transposition pass. We may need to replace the uses + // to different values/variables depending on whether that use is a read or write. + // 4. If the parameter has both corresponding primal and propagate parameters, we also need to consider + // how the future propagate function access the primal parameter. We will insert necessary preparation code + // that constructs temp vars or values to replace the primal parameter after we remove it from the + // propagate func. + // Base on above discussion, we need to compute the following values for each parameter: + // - diffRefReplacement. What should all read(load) references to this parameter from differential code be replaced to. + // - diffRefWriteReplacement. What should all write references to this parameter from differential code be replaced to. + // - primalRefReplacement. What should all references to this parameter from primal code be replaced to. + // - mapPrimalSpecificParamToReplacementInPropFunc[param]. What should all references to this parameter + // from the primal compuation logic in the future propagate function be replaced to. for (auto fwdParam : fwdParams) { - if (auto outType = as<IROutType>(fwdParam->getDataType())) + // Define the replacement insts that we are going to fill in for each case. + IRInst* diffRefReplacement = nullptr; + IRInst* primalRefReplacement = nullptr; + IRInst* diffWriteRefReplacement = nullptr; + + // Common logic that computes all the important types we care about. + IRDifferentialPairType* diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()); + auto inoutType = as<IRInOutType>(fwdParam->getDataType()); + auto outType = as<IROutType>(fwdParam->getDataType()); + if (inoutType) + diffPairType = as<IRDifferentialPairType>(inoutType->getValueType()); + else if (outType) + diffPairType = as<IRDifferentialPairType>(outType->getValueType()); + IRType* primalType = nullptr; + IRType* diffType = nullptr; + if (diffPairType) + { + primalType = diffPairType->getValueType(); + diffType = (IRType*)differentiableTypeConformanceContext + .getDifferentialTypeFromDiffPairType(builder, diffPairType); + } + + // Now we handle each combination of parameter direction x differentiability. + if (outType) { - IRParam* newPropParam = nullptr; - IRParam* newPrimalParam = nullptr; - auto diffPairType = as<IRDifferentialPairType>(outType->getValueType()); + // Case 1: out parameters. + // Out parameters need to be handled differently whether or not it is differentiable, + // since the propagate function will not have a corresponding output. if (diffPairType) { // Create dOut param. - auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(builder, diffPairType); - newPropParam = builder->emitParam(diffType); - newPrimalParam = builder->emitParam(builder->getOutType(diffPairType->getValueType())); + auto diffParam = builder->emitParam(diffType); + result.propagateFuncParams.Add(diffParam); + primalRefReplacement = builder->emitParam(builder->getOutType(primalType)); + + // Create a local var for read access in pre-transpose code. + // This will the var from which we will fetch the final resulting derivative + // after transposition. + auto tempVar = nextBlockBuilder.emitVar(diffType); + nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType); + + // Initialize the var with input diff param at start. + // Note that we insert the store in the primal block so it won't get transposed. + auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam); + nextBlockBuilder.markInstAsDifferential(storeInst, diffPairType); + // Since this store inst is specific to propagate function, we track it in a + // set so we can remove it when we generate the primal func. + result.propagateFuncSpecificPrimalInsts.add(storeInst); + + diffWriteRefReplacement = tempVar; + diffRefReplacement = tempVar; } else { - newPrimalParam = builder->emitParam(outType); - } - - // Create a temp var to represent the original `out` param. - auto arg = builder->emitVar(outType->getValueType()); - builder->addAutoDiffOriginalValueDecoration(arg, newPrimalParam); - if (newPropParam) - { - builder->addDecoration(arg, kIROp_OutParamReverseGradientDecoration, newPropParam); + primalRefReplacement = builder->emitParam(outType); } + result.primalFuncParams.Add(primalRefReplacement); - fwdParam->replaceUsesWith(arg); - fwdParam->removeAndDeallocate(); + // Create a local var for the out param for the primal part of the prop func. + auto tempPrimalVar = nextBlockBuilder.emitVar(outType->getValueType()); + result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar; - primalFuncSpecificParams.add(newPrimalParam); + instsToRemove.Add(fwdParam); } - else if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType())) + else if (!isRelevantDifferentialPair(fwdParam->getDataType())) { + // Case 2: non differentiable, non output parameters. + // If parameter is not an out param and has nothing to do with differentiation, + // simply move the parameter to the end. + // + fwdParam->removeFromParent(); + fwdDiffParameterBlock->addParam(fwdParam); + result.primalFuncParams.Add(fwdParam); + result.propagateFuncParams.Add(fwdParam); + continue; + } + else if(!inoutType) + { + // Case 4: `in` differentiable parameters. + + SLANG_RELEASE_ASSERT(diffPairType); + // Create inout version. auto inoutDiffPairType = builder->getInOutType(diffPairType); - auto newParam = builder->emitParam(inoutDiffPairType); - - // Map the _load_ of the new parameter as the clone of the old one. - auto newParamLoad = builder->emitLoad(newParam); - newParamLoad->insertAtStart(nextBlock); // Move to first block _after_ the parameter block. - fwdParam->replaceUsesWith(newParamLoad); - fwdParam->removeAndDeallocate(); + primalRefReplacement = builder->emitParam(primalType); + result.primalFuncParams.Add(primalRefReplacement); + auto propParam = builder->emitParam(inoutDiffPairType); + result.propagateFuncParams.Add(propParam); + + // A reference to this parameter from the diff blocks should be replaced with a load + // of the differential component of the pair. + auto newParamLoad = diffBuilder.emitLoad(propParam); + diffBuilder.markInstAsDifferential(newParamLoad, primalType); + + diffRefReplacement = diffBuilder.emitDifferentialPairGetDifferential(diffType, newParamLoad); + diffBuilder.markInstAsDifferential(diffRefReplacement, primalType); + + // Load the primal component from the prop param and use it as replacement for the + // primal param in the primal part of the prop func. + // Since these are logic specific to propagate function, we will add them to the + // `propagateFuncSpecificPrimalInsts` set so we can remove them when we generate the primal func. + auto primalReplacementLoad = nextBlockBuilder.emitLoad(propParam); + result.propagateFuncSpecificPrimalInsts.add(primalReplacementLoad); + auto primalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(primalReplacementLoad); + result.propagateFuncSpecificPrimalInsts.add(primalVal); + result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = primalVal; + + instsToRemove.Add(fwdParam); } else { - // Default case (parameter is inout type or has nothing to do with differentiation) - // Simply move the parameter to the end. - // - fwdParam->removeFromParent(); - fwdDiffParameterBlock->addParam(fwdParam); + // Case 5: `inout` differentiable parameters. + SLANG_ASSERT(inoutType && diffPairType); + + // Process differentiable inout parameters. + auto primalParam = builder->emitParam(builder->getInOutType(primalType)); + result.primalFuncParams.Add(primalParam); + + auto diffParam = builder->emitParam(inoutType); + result.propagateFuncParams.Add(diffParam); + + // Primal references to this param is the new primal param. + primalRefReplacement = primalParam; + + // Diff references to this param should be replaced with one local temp var + // for read and one separate temp var for write. + + // Load the inital diff value. + auto loadedParam = nextBlockBuilder.emitLoad(diffParam); + auto initDiff = nextBlockBuilder.emitDifferentialPairGetDifferential(diffType, loadedParam); + + // Create a local var for diff read access. + auto diffVar = nextBlockBuilder.emitVar(diffType); + result.propagateFuncSpecificPrimalInsts.add(diffVar); + diffBuilder.markInstAsDifferential(diffVar, diffPairType); + diffRefReplacement = diffVar; + + // Clear the diff read var to zero at start of the function. + auto dzero = getDifferentialZeroOfType(&nextBlockBuilder, primalType); + result.propagateFuncSpecificPrimalInsts.add(dzero); + auto initDiffStore = nextBlockBuilder.emitStore(diffVar, dzero); + result.propagateFuncSpecificPrimalInsts.add(initDiffStore); + + // Create a local var for diff write access. + auto diffWriteVar = nextBlockBuilder.emitVar(diffType); + // Initialize write var to 0. + auto writeStore = nextBlockBuilder.emitStore(diffWriteVar, initDiff); + result.propagateFuncSpecificPrimalInsts.add(writeStore); + + diffWriteRefReplacement = diffWriteVar; + + // Create a local var for the primal logic in the propagate func. + auto primalVar = nextBlockBuilder.emitVar(primalType); + result.propagateFuncSpecificPrimalInsts.add(primalVar); + auto initPrimalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(loadedParam); + result.propagateFuncSpecificPrimalInsts.add(initPrimalVal); + auto storeInst = nextBlockBuilder.emitStore(primalVar, initPrimalVal); + result.propagateFuncSpecificPrimalInsts.add(storeInst); + result.mapPrimalSpecificParamToReplacementInPropFunc[primalParam] = primalVar; + result.outDiffWritebacks[diffParam] = InstPair(initPrimalVal, diffVar); + + instsToRemove.Add(fwdParam); + } + + // We have emitted all the new parameters and computed the replacements for the original + // parameter. Now we perform that replacement. + List<IRUse*> uses; + for (auto use = fwdParam->firstUse; use; use = use->nextUse) + uses.add(use); + for (auto use : uses) + { + if (auto primalRef = as<IRPrimalParamRef>(use->getUser())) + { + SLANG_RELEASE_ASSERT(primalRefReplacement); + primalRef->replaceUsesWith(primalRefReplacement); + instsToRemove.Add(primalRef); + } + else if (auto getPrimal = as<IRDifferentialPairGetPrimal>(use->getUser())) + { + SLANG_RELEASE_ASSERT(primalRefReplacement); + getPrimal->replaceUsesWith(primalRefReplacement); + instsToRemove.Add(getPrimal); + } + else if (auto propagateRef = as<IRDiffParamRef>(use->getUser())) + { + SLANG_RELEASE_ASSERT(diffRefReplacement); + auto refUse = propagateRef->firstUse; + while (refUse) + { + auto nextUse = refUse->nextUse; + switch (refUse->getUser()->getOp()) + { + case kIROp_Load: + refUse->set(diffRefReplacement); + break; + case kIROp_Store: + refUse->set(diffWriteRefReplacement); + break; + default: + SLANG_RELEASE_ASSERT(!diffWriteRefReplacement); + refUse->set(diffRefReplacement); + } + refUse = nextUse; + } + instsToRemove.Add(propagateRef); + } + else if (auto getDiff = as<IRDifferentialPairGetDifferential>(use->getUser())) + { + SLANG_RELEASE_ASSERT(diffRefReplacement); + getDiff->replaceUsesWith(diffRefReplacement); + instsToRemove.Add(getDiff); + } + else + { + // If the user is something else, it'd better be a non relevant parameter. + if (diffRefReplacement || diffWriteRefReplacement) + SLANG_UNEXPECTED("unknown use of parameter."); + use->set(primalRefReplacement); + } } } - auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); + // Actually remove all the insts that we decided to remove in the process. + for (auto inst : instsToRemove) + { + inst->removeAndDeallocate(); + } + - // 2. If the return type of the original function is differentiable, + // The next step is to insert new parameters that is not related to any existing parameters. + // + // If the return type of the original function is differentiable, // add a parameter for 'derivative of the output' (d_out). // The type is the second last parameter type of the function. // + auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); IRParam* dOutParam = nullptr; if (isResultDifferentiable) { @@ -878,12 +1181,44 @@ namespace Slang SLANG_ASSERT(dOutParamType); dOutParam = builder->emitParam(dOutParamType); + result.propagateFuncParams.Add(dOutParam); } // Add a parameter for intermediate val. - builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1)); + auto ctxParam = builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1)); + result.primalFuncParams.Add(ctxParam); + result.propagateFuncParams.Add(ctxParam); + result.dOutParam = dOutParam; + return result; + } - return dOutParam; + void BackwardDiffTranscriberBase::writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc) + { + IRInst* returnInst = nullptr; + for (auto block : diffFunc->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Return) + { + returnInst = inst; + break; + } + } + } + SLANG_RELEASE_ASSERT(returnInst); + + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(returnInst); + for (auto& wb : info.outDiffWritebacks) + { + auto dest = wb.Key; + auto srcPrimalVal = wb.Value.primal; + auto srcDiffAddr = wb.Value.differential; + auto srcDiffVal = builder.emitLoad(srcDiffAddr); + auto destVal = builder.emitMakeDifferentialPair(as<IRPtrTypeBase>(dest->getFullType())->getValueType(), srcPrimalVal, srcDiffVal); + builder.emitStore(dest, destVal); + } } InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) |
