diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-01 14:18:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-01 14:18:57 -0800 |
| commit | bbd1e1786401bb88c34802b987d4da72e2364503 (patch) | |
| tree | 99a4be95ae517fd710fc032a1debdac917dd3ac2 /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | c5895fb0b82fd14fbe45b58d5fc7f75d67625d15 (diff) | |
Support `out` parameters in backward differentiation. (#2619)
* Support `out` parameters in backward differentiation.
* Fixes.
* Fix cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 557 |
1 files changed, 238 insertions, 319 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 0f2ceceb4..9c63a4012 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -11,6 +11,7 @@ #include "slang-ir-single-return.h" #include "slang-ir-addr-inst-elimination.h" #include "slang-ir-eliminate-multilevel-break.h" +#include "slang-ir-init-local-var.h" namespace Slang { @@ -21,32 +22,10 @@ namespace Slang for (UIndex i = 0; i < funcType->getParamCount(); i++) { - bool noDiff = false; auto origType = funcType->getParamType(i); - auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType); - - if (auto attrType = as<IRAttributedType>(primalType)) - { - if (attrType->findAttr<IRNoDiffAttr>()) - { - noDiff = true; - primalType = attrType->getBaseType(); - } - } - if (noDiff) - { - newParameterTypes.add(primalType); - } - else - { - if (auto diffPairType = tryGetDiffPairType(builder, origType)) - { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - newParameterTypes.add(inoutDiffPairType); - } - else - newParameterTypes.add(primalType); - } + auto paramType = transcribeParamTypeForPropagateFunc(builder, origType); + if (paramType) + newParameterTypes.add(paramType); } if (auto diffResultType = differentiateType(builder, funcType->getResultType())) @@ -75,7 +54,7 @@ namespace Slang for (UInt i = 0; i < funcType->getParamCount(); i++) { auto origType = funcType->getParamType(i); - auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType); + auto primalType = transcribeParamTypeForPrimalFunc(builder, origType); paramTypes.add(primalType); } paramTypes.add(outType); @@ -252,52 +231,57 @@ namespace Slang return String(""); } - InstPair BackwardDiffTranscriberBase::transposeBlock(IRBuilder* builder, IRBlock* origBlock) + static IRType* _getPrimalTypeFromNoDiffType(BackwardDiffTranscriberBase* transcriber, IRBuilder* builder, IRType* origType) { - IRBuilder subBuilder(builder->getSharedBuilder()); - subBuilder.setInsertLoc(builder->getInsertLoc()); + IRType* valueType = origType; + auto ptrType = as<IROutTypeBase>(valueType); + if (ptrType) + valueType = ptrType->getValueType(); - IRBlock* diffBlock = subBuilder.emitBlock(); - - subBuilder.setInsertInto(diffBlock); - - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - this->copyParam(&subBuilder, param); - - // The extra param for input gradient - auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType()); - - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - this->copyInst(&subBuilder, child); - - auto lastInst = diffBlock->getLastOrdinaryInst(); - List<IRInst*> grads = { gradParam }; - upperGradients.Add(lastInst, grads); - for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst()) + if (auto attrType = as<IRAttributedType>(valueType)) { - auto upperGrads = upperGradients.TryGetValue(child); - if (!upperGrads) - continue; - if (upperGrads->getCount() > 1) + if (attrType->findAttr<IRNoDiffAttr>()) { - auto sumGrad = upperGrads->getFirst(); - for (auto i = 1; i < upperGrads->getCount(); i++) - { - sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]); - } - this->transposeInstBackward(&subBuilder, child, sumGrad); + auto primalValueType = (IRType*)transcriber->findOrTranscribePrimalInst(builder, valueType); + if (ptrType) + return builder->getPtrType(ptrType->getOp(), primalValueType); + return primalValueType; } - else - this->transposeInstBackward(&subBuilder, child, upperGrads->getFirst()); } + return nullptr; + } + + IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPrimalFunc(IRBuilder* builder, IRType* paramType) + { + // If the param is marked as no_diff, return the primal type. + if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) + return primalNoDiffType; - subBuilder.emitReturn(); + return (IRType*)findOrTranscribePrimalInst(builder, paramType); + } - return InstPair(diffBlock, diffBlock); + IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType) + { + if (auto outType = as<IROutType>(paramType)) + { + auto valueType = outType->getValueType(); + auto diffValueType = differentiateType(builder, valueType); + return diffValueType; + } + + // If the param is marked as no_diff, return the primal type. + if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) + return primalNoDiffType; + + auto diffPairType = tryGetDiffPairType(builder, paramType); + if (diffPairType) + { + if (!as<IRPtrTypeBase>(diffPairType)) + return builder->getInOutType(diffPairType); + return diffPairType; + } + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType); + return primalType; } // Create an empty func to represent the transcribed func of `origFunc`. @@ -387,39 +371,65 @@ namespace Slang IRBuilder builder(inBuilder->getSharedBuilder()); builder.setInsertInto(header.differential); builder.emitBlock(); - auto funcType = as<IRFuncType>(header.differential->getDataType()); + auto origFuncType = as<IRFuncType>(origFunc->getFullType()); List<IRInst*> primalArgs, propagateArgs; List<IRType*> primalTypes, propagateTypes; - for (UInt i = 0; i < funcType->getParamCount(); i++) + for (UInt i = 0; i < origFuncType->getParamCount(); i++) { - auto paramType = (IRType*)findOrTranscribePrimalInst(&builder, funcType->getParamType(i)); - auto param = builder.emitParam(paramType); - if (i != funcType->getParamCount() - 1) + auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i)); + auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i)); + if (propagateParamType) { - primalArgs.add(param); - } - propagateArgs.add(param); - propagateTypes.add(paramType); - } + auto param = builder.emitParam(propagateParamType); + propagateTypes.add(propagateParamType); + propagateArgs.add(param); - // Fetch primal values to use as arguments in primal func call. - for (auto& arg : primalArgs) - { - IRInst* valueType = arg->getDataType(); - auto inoutType = as<IRPtrTypeBase>(arg->getDataType()); - if (inoutType) + // Fetch primal values to use as arguments in primal func call. + IRInst* primalArg = param; + if (!as<IROutType>(primalParamType)) + { + // As long as the primal parameter is not an out type, + // we need to fetch the primal value from the parameter. + if (as<IRPtrTypeBase>(propagateParamType)) + { + primalArg = builder.emitLoad(param); + } + if (auto diffPairType = as<IRDifferentialPairType>(primalArg->getDataType())) + { + primalArg = builder.emitDifferentialPairGetPrimal(primalArg); + } + } + if (auto primalParamPtrType = as<IRPtrTypeBase>(primalParamType)) + { + // If primal parameter is mutable, we need to pass in a temp var. + auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); + if (primalParamPtrType->getOp() == kIROp_InOutType) + { + // If the primal parameter is inout, we need to set the initial value. + builder.emitStore(tempVar, primalArg); + } + primalArgs.add(tempVar); + } + else + { + primalArgs.add(primalArg); + } + } + else { - valueType = inoutType->getValueType(); - arg = builder.emitLoad(arg); + auto var = builder.emitVar(primalParamType); + primalArgs.add(var); } - auto diffPairType = as<IRDifferentialPairType>(valueType); - if (!diffPairType) continue; - arg = builder.emitDifferentialPairGetPrimal(arg); + primalTypes.add(primalParamType); } - for (auto& arg : primalArgs) + // Add dOut argument to propagateArgs. + auto diffResultType = differentiateType(&builder, origFunc->getResultType()); + if (diffResultType) { - primalTypes.add(arg->getFullType()); + auto param = builder.emitParam(diffResultType); + propagateArgs.add(param); + propagateTypes.add(param->getFullType()); } auto outerGeneric = findOuterGeneric(origFunc); @@ -433,7 +443,6 @@ namespace Slang auto intermediateVar = builder.emitVar(intermediateType); - auto origFuncType = as<IRFuncType>(origFunc->getDataType()); auto primalFuncType = builder.getFuncType( primalTypes, origFuncType->getResultType()); @@ -486,6 +495,51 @@ namespace Slang builder.emitBranch(firstBlock); } + void insertTempVarForMutableParams(SharedIRBuilder* sharedBuilder, IRFunc* func) + { + IRBuilder builder(sharedBuilder); + auto firstBlock = func->getFirstBlock(); + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + + OrderedDictionary<IRParam*, IRVar*> mapParamToTempVar; + List<IRParam*> params; + for (auto param : firstBlock->getParams()) + { + if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) + { + params.add(param); + } + } + + for (auto param : params) + { + auto ptrType = as<IRPtrTypeBase>(param->getDataType()); + auto tempVar = builder.emitVar(ptrType->getValueType()); + mapParamToTempVar[param] = tempVar; + if (param->getOp() != kIROp_OutType) + { + builder.emitStore(tempVar, builder.emitLoad(param)); + } + param->replaceUsesWith(tempVar); + } + + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Return) + { + builder.setInsertBefore(inst); + for (auto& kv : mapParamToTempVar) + { + builder.emitStore(kv.Key, builder.emitLoad(kv.Value)); + } + } + } + } + } + + struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy { DifferentiableTypeConformanceContext* diffTypeContext; @@ -512,6 +566,8 @@ namespace Slang IRCFGNormalizationPass cfgPass = {this->getSink()}; normalizeCFG(autoDiffSharedContext->sharedBuilder, func); + insertTempVarForMutableParams(sharedBuilder, func); + AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &diffTypeContext; auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); @@ -592,6 +648,23 @@ namespace Slang return fwdDiffFunc; } + InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) + { + SLANG_UNUSED(primalType); + + SLANG_RELEASE_ASSERT(origParam->getParent() && origParam->getParent()->getParent() + && origParam->getParent()->getParent()->getOp() == kIROp_Generic); + + auto primalInst = maybeCloneForPrimalInst(builder, origParam); + if (auto primalParam = as<IRParam>(primalInst)) + { + SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); + primalParam->removeFromParent(); + builder->getInsertLoc().getBlock()->addParam(primalParam); + } + return InstPair(primalInst, nullptr); + } + // Transcribe a function definition. void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc) { @@ -615,6 +688,8 @@ namespace Slang if (!fwdDiffFunc) return; + bool isResultDifferentiable = as<IRDifferentialPairType>(fwdDiffFunc->getResultType()); + // Split first block into a paramter block. this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc)); @@ -642,12 +717,11 @@ namespace Slang } // Transpose the first block (parameter block) - transposeParameterBlock(builder, diffPropagateFunc); + List<IRInst*> primalFuncSpecificParams; + auto dOutParameter = transposeParameterBlock(builder, diffPropagateFunc, primalFuncSpecificParams, isResultDifferentiable); builder->setInsertInto(diffPropagateFunc); - auto dOutParameter = diffPropagateFunc->getLastParam()->getPrevParam(); - // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr}; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); @@ -658,11 +732,32 @@ namespace Slang // with the intermediate results computed from the extracted func. IRInst* intermediateType = nullptr; auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( - diffPropagateFunc, primalFunc, intermediateType); + diffPropagateFunc, primalFunc, isResultDifferentiable, intermediateType); // Clean up by deallocating the tempoarary forward derivative func. fwdDiffFunc->removeAndDeallocate(); + // Remove primalFuncSpecificParams. + for (auto specificParam : primalFuncSpecificParams) + { + while (auto use = specificParam->firstUse) + { + if (use->getUser()->getOp() == kIROp_Store && use == use->getUser()->getOperands()) + { + use->getUser()->removeAndDeallocate(); + } + else if (auto decor = as<IRDecoration>(use->getUser())) + { + decor->removeAndDeallocate(); + } + else + { + SLANG_UNEXPECTED("unexpected use of transcribed param."); + } + } + specificParam->removeAndDeallocate(); + } + // If primal function is nested in a generic, we want to create separate generics for all the associated things // we have just created. auto primalOuterGeneric = findOuterGeneric(primalFunc); @@ -689,9 +784,16 @@ namespace Slang auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric); builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc); } + + initializeLocalVariables(builder->getSharedBuilder(), primalFunc); + initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc); } - void BackwardDiffTranscriberBase::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc) + IRInst* BackwardDiffTranscriberBase::transposeParameterBlock( + IRBuilder* builder, + IRFunc* diffFunc, + List<IRInst*>& primalFuncSpecificParams, + bool isResultDifferentiable) { IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock(); @@ -699,7 +801,7 @@ namespace Slang auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator()); auto nextBlock = fwdParamBlockBranch->getTargetBlock(); - builder->setInsertInto(fwdDiffParameterBlock); + builder->setInsertBefore(fwdParamBlockBranch); List<IRParam*> fwdParams; for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam()) @@ -710,8 +812,37 @@ namespace Slang // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> for (auto fwdParam : fwdParams) { - // TODO: Handle ptr<pair> types. - if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType())) + if (auto outType = as<IROutType>(fwdParam->getDataType())) + { + IRParam* newPropParam = nullptr; + IRParam* newPrimalParam = nullptr; + auto diffPairType = as<IRDifferentialPairType>(outType->getValueType()); + if (diffPairType) + { + // Create dOut param. + auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(builder, diffPairType); + newPropParam = builder->emitParam(diffType); + newPrimalParam = builder->emitParam(builder->getOutType(diffPairType->getValueType())); + } + else + { + newPrimalParam = builder->emitParam(outType); + } + + // Create a temp var to represent the original `out` param. + auto arg = builder->emitVar(outType->getValueType()); + builder->addAutoDiffOriginalValueDecoration(arg, newPrimalParam); + if (newPropParam) + { + builder->addDecoration(arg, kIROp_OutParamReverseGradientDecoration, newPropParam); + } + + fwdParam->replaceUsesWith(arg); + fwdParam->removeAndDeallocate(); + + primalFuncSpecificParams.add(newPrimalParam); + } + else if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType())) { // Create inout version. auto inoutDiffPairType = builder->getInOutType(diffPairType); @@ -725,7 +856,7 @@ namespace Slang } else { - // Default case (parameter has nothing to do with differentiation) + // Default case (parameter is inout type or has nothing to do with differentiation) // Simply move the parameter to the end. // fwdParam->removeFromParent(); @@ -735,236 +866,24 @@ namespace Slang auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); - // 2. Add a parameter for 'derivative of the output' (d_out). + // 2. If the return type of the original function is differentiable, + // add a parameter for 'derivative of the output' (d_out). // The type is the second last parameter type of the function. // - auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2); - - SLANG_ASSERT(dOutParamType); - - builder->emitParam(dOutParamType); - - // Add a parameter for intermediate val. - builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1)); - } - - IRInst* BackwardDiffTranscriberBase::copyParam(IRBuilder* builder, IRParam* origParam) - { - auto primalDataType = origParam->getDataType(); - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - IRInst* diffParam = builder->emitParam(inoutDiffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffParam); - auto paramValue = builder->emitLoad(diffParam); - auto primal = builder->emitDifferentialPairGetPrimal(paramValue); - orginalToTranscribed.Add(origParam, primal); - primalToDiffPair.Add(primal, diffParam); - - return diffParam; - } - - return maybeCloneForPrimalInst(builder, origParam); - } - - InstPair BackwardDiffTranscriberBase::copyBinaryArith(IRBuilder* builder, IRInst* origArith) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - auto origLeft = origArith->getOperand(0); - auto origRight = origArith->getOperand(1); - - IRInst* primalLeft; - if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft)) - { - primalLeft = origLeft; - } - IRInst* primalRight; - if (!orginalToTranscribed.TryGetValue(origRight, primalRight)) - { - primalRight = origRight; - } - - auto resultType = origArith->getDataType(); - IRInst* newInst; - switch (origArith->getOp()) - { - case kIROp_Add: - newInst = builder->emitAdd(resultType, primalLeft, primalRight); - break; - case kIROp_Mul: - newInst = builder->emitMul(resultType, primalLeft, primalRight); - break; - case kIROp_Sub: - newInst = builder->emitSub(resultType, primalLeft, primalRight); - break; - case kIROp_Div: - newInst = builder->emitDiv(resultType, primalLeft, primalRight); - break; - default: - newInst = nullptr; - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - orginalToTranscribed.Add(origArith, newInst); - return InstPair(newInst, nullptr); - } - - IRInst* BackwardDiffTranscriberBase::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - auto lhs = origArith->getOperand(0); - auto rhs = origArith->getOperand(1); - - if (as<IRInOutType>(lhs->getDataType())) - { - lhs = builder->emitLoad(lhs); - lhs = builder->emitDifferentialPairGetPrimal(lhs); - } - if (as<IRInOutType>(rhs->getDataType())) - { - rhs = builder->emitLoad(rhs); - rhs = builder->emitDifferentialPairGetPrimal(rhs); - } - - IRInst* leftGrad; - IRInst* rightGrad; - - - switch (origArith->getOp()) - { - case kIROp_Add: - leftGrad = grad; - rightGrad = grad; - break; - case kIROp_Mul: - leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); - rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); - break; - case kIROp_Sub: - leftGrad = grad; - rightGrad = builder->emitNeg(grad->getDataType(), grad); - break; - case kIROp_Div: - leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); - rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad - break; - default: - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - - lhs = origArith->getOperand(0); - rhs = origArith->getOperand(1); - if (auto leftGrads = upperGradients.TryGetValue(lhs)) - { - leftGrads->add(leftGrad); - } - else + IRParam* dOutParam = nullptr; + if (isResultDifferentiable) { - upperGradients.Add(lhs, leftGrad); - } - if (auto rightGrads = upperGradients.TryGetValue(rhs)) - { - rightGrads->add(rightGrad); - } - else - { - upperGradients.Add(rhs, rightGrad); - } + auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2); - return nullptr; - } + SLANG_ASSERT(dOutParamType); - InstPair BackwardDiffTranscriberBase::copyInst(IRBuilder* builder, IRInst* origInst) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transcribeParam(builder, as<IRParam>(origInst)); - - case kIROp_Return: - return InstPair(nullptr, nullptr); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return copyBinaryArith(builder, origInst); - - default: - // Not yet implemented - SLANG_ASSERT(0); + dOutParam = builder->emitParam(dOutParamType); } - return InstPair(nullptr, nullptr); - } - - IRInst* BackwardDiffTranscriberBase::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) - { - IRInOutType* inoutParam = as<IRInOutType>(param->getDataType()); - auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType()); - auto paramValue = builder->emitLoad(param); - auto primal = builder->emitDifferentialPairGetPrimal(paramValue); - auto diff = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - paramValue - ); - auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad); - auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff); - auto store = builder->emitStore(param, updatedParam); - - return store; - } - - IRInst* BackwardDiffTranscriberBase::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transposeParamBackward(builder, as<IRParam>(origInst), grad); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return transposeBinaryArithBackward(builder, origInst, grad); - - case kIROp_DifferentialPairGetPrimal: - { - if (auto param = primalToDiffPair.TryGetValue(origInst)) - { - if (auto leftGrads = upperGradients.TryGetValue(*param)) - { - leftGrads->add(grad); - } - else - { - upperGradients.Add(*param, grad); - } - } - else - SLANG_ASSERT(0); - return nullptr; - } - - default: - // Not yet implemented - SLANG_ASSERT(0); - } + // Add a parameter for intermediate val. + builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1)); - return nullptr; + return dOutParam; } InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) |
