summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-01 14:18:57 -0800
committerGitHub <noreply@github.com>2023-02-01 14:18:57 -0800
commitbbd1e1786401bb88c34802b987d4da72e2364503 (patch)
tree99a4be95ae517fd710fc032a1debdac917dd3ac2 /source/slang/slang-ir-autodiff-rev.cpp
parentc5895fb0b82fd14fbe45b58d5fc7f75d67625d15 (diff)
Support `out` parameters in backward differentiation. (#2619)
* Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp557
1 files changed, 238 insertions, 319 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 0f2ceceb4..9c63a4012 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -11,6 +11,7 @@
#include "slang-ir-single-return.h"
#include "slang-ir-addr-inst-elimination.h"
#include "slang-ir-eliminate-multilevel-break.h"
+#include "slang-ir-init-local-var.h"
namespace Slang
{
@@ -21,32 +22,10 @@ namespace Slang
for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
- bool noDiff = false;
auto origType = funcType->getParamType(i);
- auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType);
-
- if (auto attrType = as<IRAttributedType>(primalType))
- {
- if (attrType->findAttr<IRNoDiffAttr>())
- {
- noDiff = true;
- primalType = attrType->getBaseType();
- }
- }
- if (noDiff)
- {
- newParameterTypes.add(primalType);
- }
- else
- {
- if (auto diffPairType = tryGetDiffPairType(builder, origType))
- {
- auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
- newParameterTypes.add(inoutDiffPairType);
- }
- else
- newParameterTypes.add(primalType);
- }
+ auto paramType = transcribeParamTypeForPropagateFunc(builder, origType);
+ if (paramType)
+ newParameterTypes.add(paramType);
}
if (auto diffResultType = differentiateType(builder, funcType->getResultType()))
@@ -75,7 +54,7 @@ namespace Slang
for (UInt i = 0; i < funcType->getParamCount(); i++)
{
auto origType = funcType->getParamType(i);
- auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType);
+ auto primalType = transcribeParamTypeForPrimalFunc(builder, origType);
paramTypes.add(primalType);
}
paramTypes.add(outType);
@@ -252,52 +231,57 @@ namespace Slang
return String("");
}
- InstPair BackwardDiffTranscriberBase::transposeBlock(IRBuilder* builder, IRBlock* origBlock)
+ static IRType* _getPrimalTypeFromNoDiffType(BackwardDiffTranscriberBase* transcriber, IRBuilder* builder, IRType* origType)
{
- IRBuilder subBuilder(builder->getSharedBuilder());
- subBuilder.setInsertLoc(builder->getInsertLoc());
+ IRType* valueType = origType;
+ auto ptrType = as<IROutTypeBase>(valueType);
+ if (ptrType)
+ valueType = ptrType->getValueType();
- IRBlock* diffBlock = subBuilder.emitBlock();
-
- subBuilder.setInsertInto(diffBlock);
-
- // First transcribe every parameter in the block.
- for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
- this->copyParam(&subBuilder, param);
-
- // The extra param for input gradient
- auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType());
-
- // Then, run through every instruction and use the transcriber to generate the appropriate
- // derivative code.
- //
- for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
- this->copyInst(&subBuilder, child);
-
- auto lastInst = diffBlock->getLastOrdinaryInst();
- List<IRInst*> grads = { gradParam };
- upperGradients.Add(lastInst, grads);
- for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst())
+ if (auto attrType = as<IRAttributedType>(valueType))
{
- auto upperGrads = upperGradients.TryGetValue(child);
- if (!upperGrads)
- continue;
- if (upperGrads->getCount() > 1)
+ if (attrType->findAttr<IRNoDiffAttr>())
{
- auto sumGrad = upperGrads->getFirst();
- for (auto i = 1; i < upperGrads->getCount(); i++)
- {
- sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]);
- }
- this->transposeInstBackward(&subBuilder, child, sumGrad);
+ auto primalValueType = (IRType*)transcriber->findOrTranscribePrimalInst(builder, valueType);
+ if (ptrType)
+ return builder->getPtrType(ptrType->getOp(), primalValueType);
+ return primalValueType;
}
- else
- this->transposeInstBackward(&subBuilder, child, upperGrads->getFirst());
}
+ return nullptr;
+ }
+
+ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPrimalFunc(IRBuilder* builder, IRType* paramType)
+ {
+ // If the param is marked as no_diff, return the primal type.
+ if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
+ return primalNoDiffType;
- subBuilder.emitReturn();
+ return (IRType*)findOrTranscribePrimalInst(builder, paramType);
+ }
- return InstPair(diffBlock, diffBlock);
+ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType)
+ {
+ if (auto outType = as<IROutType>(paramType))
+ {
+ auto valueType = outType->getValueType();
+ auto diffValueType = differentiateType(builder, valueType);
+ return diffValueType;
+ }
+
+ // If the param is marked as no_diff, return the primal type.
+ if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
+ return primalNoDiffType;
+
+ auto diffPairType = tryGetDiffPairType(builder, paramType);
+ if (diffPairType)
+ {
+ if (!as<IRPtrTypeBase>(diffPairType))
+ return builder->getInOutType(diffPairType);
+ return diffPairType;
+ }
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType);
+ return primalType;
}
// Create an empty func to represent the transcribed func of `origFunc`.
@@ -387,39 +371,65 @@ namespace Slang
IRBuilder builder(inBuilder->getSharedBuilder());
builder.setInsertInto(header.differential);
builder.emitBlock();
- auto funcType = as<IRFuncType>(header.differential->getDataType());
+ auto origFuncType = as<IRFuncType>(origFunc->getFullType());
List<IRInst*> primalArgs, propagateArgs;
List<IRType*> primalTypes, propagateTypes;
- for (UInt i = 0; i < funcType->getParamCount(); i++)
+ for (UInt i = 0; i < origFuncType->getParamCount(); i++)
{
- auto paramType = (IRType*)findOrTranscribePrimalInst(&builder, funcType->getParamType(i));
- auto param = builder.emitParam(paramType);
- if (i != funcType->getParamCount() - 1)
+ auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i));
+ auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i));
+ if (propagateParamType)
{
- primalArgs.add(param);
- }
- propagateArgs.add(param);
- propagateTypes.add(paramType);
- }
+ auto param = builder.emitParam(propagateParamType);
+ propagateTypes.add(propagateParamType);
+ propagateArgs.add(param);
- // Fetch primal values to use as arguments in primal func call.
- for (auto& arg : primalArgs)
- {
- IRInst* valueType = arg->getDataType();
- auto inoutType = as<IRPtrTypeBase>(arg->getDataType());
- if (inoutType)
+ // Fetch primal values to use as arguments in primal func call.
+ IRInst* primalArg = param;
+ if (!as<IROutType>(primalParamType))
+ {
+ // As long as the primal parameter is not an out type,
+ // we need to fetch the primal value from the parameter.
+ if (as<IRPtrTypeBase>(propagateParamType))
+ {
+ primalArg = builder.emitLoad(param);
+ }
+ if (auto diffPairType = as<IRDifferentialPairType>(primalArg->getDataType()))
+ {
+ primalArg = builder.emitDifferentialPairGetPrimal(primalArg);
+ }
+ }
+ if (auto primalParamPtrType = as<IRPtrTypeBase>(primalParamType))
+ {
+ // If primal parameter is mutable, we need to pass in a temp var.
+ auto tempVar = builder.emitVar(primalParamPtrType->getValueType());
+ if (primalParamPtrType->getOp() == kIROp_InOutType)
+ {
+ // If the primal parameter is inout, we need to set the initial value.
+ builder.emitStore(tempVar, primalArg);
+ }
+ primalArgs.add(tempVar);
+ }
+ else
+ {
+ primalArgs.add(primalArg);
+ }
+ }
+ else
{
- valueType = inoutType->getValueType();
- arg = builder.emitLoad(arg);
+ auto var = builder.emitVar(primalParamType);
+ primalArgs.add(var);
}
- auto diffPairType = as<IRDifferentialPairType>(valueType);
- if (!diffPairType) continue;
- arg = builder.emitDifferentialPairGetPrimal(arg);
+ primalTypes.add(primalParamType);
}
- for (auto& arg : primalArgs)
+ // Add dOut argument to propagateArgs.
+ auto diffResultType = differentiateType(&builder, origFunc->getResultType());
+ if (diffResultType)
{
- primalTypes.add(arg->getFullType());
+ auto param = builder.emitParam(diffResultType);
+ propagateArgs.add(param);
+ propagateTypes.add(param->getFullType());
}
auto outerGeneric = findOuterGeneric(origFunc);
@@ -433,7 +443,6 @@ namespace Slang
auto intermediateVar = builder.emitVar(intermediateType);
- auto origFuncType = as<IRFuncType>(origFunc->getDataType());
auto primalFuncType = builder.getFuncType(
primalTypes,
origFuncType->getResultType());
@@ -486,6 +495,51 @@ namespace Slang
builder.emitBranch(firstBlock);
}
+ void insertTempVarForMutableParams(SharedIRBuilder* sharedBuilder, IRFunc* func)
+ {
+ IRBuilder builder(sharedBuilder);
+ auto firstBlock = func->getFirstBlock();
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+
+ OrderedDictionary<IRParam*, IRVar*> mapParamToTempVar;
+ List<IRParam*> params;
+ for (auto param : firstBlock->getParams())
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
+ {
+ params.add(param);
+ }
+ }
+
+ for (auto param : params)
+ {
+ auto ptrType = as<IRPtrTypeBase>(param->getDataType());
+ auto tempVar = builder.emitVar(ptrType->getValueType());
+ mapParamToTempVar[param] = tempVar;
+ if (param->getOp() != kIROp_OutType)
+ {
+ builder.emitStore(tempVar, builder.emitLoad(param));
+ }
+ param->replaceUsesWith(tempVar);
+ }
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ builder.setInsertBefore(inst);
+ for (auto& kv : mapParamToTempVar)
+ {
+ builder.emitStore(kv.Key, builder.emitLoad(kv.Value));
+ }
+ }
+ }
+ }
+ }
+
+
struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
{
DifferentiableTypeConformanceContext* diffTypeContext;
@@ -512,6 +566,8 @@ namespace Slang
IRCFGNormalizationPass cfgPass = {this->getSink()};
normalizeCFG(autoDiffSharedContext->sharedBuilder, func);
+ insertTempVarForMutableParams(sharedBuilder, func);
+
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &diffTypeContext;
auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
@@ -592,6 +648,23 @@ namespace Slang
return fwdDiffFunc;
}
+ InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
+ {
+ SLANG_UNUSED(primalType);
+
+ SLANG_RELEASE_ASSERT(origParam->getParent() && origParam->getParent()->getParent()
+ && origParam->getParent()->getParent()->getOp() == kIROp_Generic);
+
+ auto primalInst = maybeCloneForPrimalInst(builder, origParam);
+ if (auto primalParam = as<IRParam>(primalInst))
+ {
+ SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
+ primalParam->removeFromParent();
+ builder->getInsertLoc().getBlock()->addParam(primalParam);
+ }
+ return InstPair(primalInst, nullptr);
+ }
+
// Transcribe a function definition.
void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc)
{
@@ -615,6 +688,8 @@ namespace Slang
if (!fwdDiffFunc)
return;
+ bool isResultDifferentiable = as<IRDifferentialPairType>(fwdDiffFunc->getResultType());
+
// Split first block into a paramter block.
this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));
@@ -642,12 +717,11 @@ namespace Slang
}
// Transpose the first block (parameter block)
- transposeParameterBlock(builder, diffPropagateFunc);
+ List<IRInst*> primalFuncSpecificParams;
+ auto dOutParameter = transposeParameterBlock(builder, diffPropagateFunc, primalFuncSpecificParams, isResultDifferentiable);
builder->setInsertInto(diffPropagateFunc);
- auto dOutParameter = diffPropagateFunc->getLastParam()->getPrevParam();
-
// Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr};
diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
@@ -658,11 +732,32 @@ namespace Slang
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
- diffPropagateFunc, primalFunc, intermediateType);
+ diffPropagateFunc, primalFunc, isResultDifferentiable, intermediateType);
// Clean up by deallocating the tempoarary forward derivative func.
fwdDiffFunc->removeAndDeallocate();
+ // Remove primalFuncSpecificParams.
+ for (auto specificParam : primalFuncSpecificParams)
+ {
+ while (auto use = specificParam->firstUse)
+ {
+ if (use->getUser()->getOp() == kIROp_Store && use == use->getUser()->getOperands())
+ {
+ use->getUser()->removeAndDeallocate();
+ }
+ else if (auto decor = as<IRDecoration>(use->getUser()))
+ {
+ decor->removeAndDeallocate();
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unexpected use of transcribed param.");
+ }
+ }
+ specificParam->removeAndDeallocate();
+ }
+
// If primal function is nested in a generic, we want to create separate generics for all the associated things
// we have just created.
auto primalOuterGeneric = findOuterGeneric(primalFunc);
@@ -689,9 +784,16 @@ namespace Slang
auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric);
builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc);
}
+
+ initializeLocalVariables(builder->getSharedBuilder(), primalFunc);
+ initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc);
}
- void BackwardDiffTranscriberBase::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc)
+ IRInst* BackwardDiffTranscriberBase::transposeParameterBlock(
+ IRBuilder* builder,
+ IRFunc* diffFunc,
+ List<IRInst*>& primalFuncSpecificParams,
+ bool isResultDifferentiable)
{
IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock();
@@ -699,7 +801,7 @@ namespace Slang
auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator());
auto nextBlock = fwdParamBlockBranch->getTargetBlock();
- builder->setInsertInto(fwdDiffParameterBlock);
+ builder->setInsertBefore(fwdParamBlockBranch);
List<IRParam*> fwdParams;
for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam())
@@ -710,8 +812,37 @@ namespace Slang
// 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
for (auto fwdParam : fwdParams)
{
- // TODO: Handle ptr<pair> types.
- if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
+ if (auto outType = as<IROutType>(fwdParam->getDataType()))
+ {
+ IRParam* newPropParam = nullptr;
+ IRParam* newPrimalParam = nullptr;
+ auto diffPairType = as<IRDifferentialPairType>(outType->getValueType());
+ if (diffPairType)
+ {
+ // Create dOut param.
+ auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(builder, diffPairType);
+ newPropParam = builder->emitParam(diffType);
+ newPrimalParam = builder->emitParam(builder->getOutType(diffPairType->getValueType()));
+ }
+ else
+ {
+ newPrimalParam = builder->emitParam(outType);
+ }
+
+ // Create a temp var to represent the original `out` param.
+ auto arg = builder->emitVar(outType->getValueType());
+ builder->addAutoDiffOriginalValueDecoration(arg, newPrimalParam);
+ if (newPropParam)
+ {
+ builder->addDecoration(arg, kIROp_OutParamReverseGradientDecoration, newPropParam);
+ }
+
+ fwdParam->replaceUsesWith(arg);
+ fwdParam->removeAndDeallocate();
+
+ primalFuncSpecificParams.add(newPrimalParam);
+ }
+ else if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
{
// Create inout version.
auto inoutDiffPairType = builder->getInOutType(diffPairType);
@@ -725,7 +856,7 @@ namespace Slang
}
else
{
- // Default case (parameter has nothing to do with differentiation)
+ // Default case (parameter is inout type or has nothing to do with differentiation)
// Simply move the parameter to the end.
//
fwdParam->removeFromParent();
@@ -735,236 +866,24 @@ namespace Slang
auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
- // 2. Add a parameter for 'derivative of the output' (d_out).
+ // 2. If the return type of the original function is differentiable,
+ // add a parameter for 'derivative of the output' (d_out).
// The type is the second last parameter type of the function.
//
- auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2);
-
- SLANG_ASSERT(dOutParamType);
-
- builder->emitParam(dOutParamType);
-
- // Add a parameter for intermediate val.
- builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
- }
-
- IRInst* BackwardDiffTranscriberBase::copyParam(IRBuilder* builder, IRParam* origParam)
- {
- auto primalDataType = origParam->getDataType();
-
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
- IRInst* diffParam = builder->emitParam(inoutDiffPairType);
-
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice());
-
- SLANG_ASSERT(diffParam);
- auto paramValue = builder->emitLoad(diffParam);
- auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
- orginalToTranscribed.Add(origParam, primal);
- primalToDiffPair.Add(primal, diffParam);
-
- return diffParam;
- }
-
- return maybeCloneForPrimalInst(builder, origParam);
- }
-
- InstPair BackwardDiffTranscriberBase::copyBinaryArith(IRBuilder* builder, IRInst* origArith)
- {
- SLANG_ASSERT(origArith->getOperandCount() == 2);
-
- auto origLeft = origArith->getOperand(0);
- auto origRight = origArith->getOperand(1);
-
- IRInst* primalLeft;
- if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft))
- {
- primalLeft = origLeft;
- }
- IRInst* primalRight;
- if (!orginalToTranscribed.TryGetValue(origRight, primalRight))
- {
- primalRight = origRight;
- }
-
- auto resultType = origArith->getDataType();
- IRInst* newInst;
- switch (origArith->getOp())
- {
- case kIROp_Add:
- newInst = builder->emitAdd(resultType, primalLeft, primalRight);
- break;
- case kIROp_Mul:
- newInst = builder->emitMul(resultType, primalLeft, primalRight);
- break;
- case kIROp_Sub:
- newInst = builder->emitSub(resultType, primalLeft, primalRight);
- break;
- case kIROp_Div:
- newInst = builder->emitDiv(resultType, primalLeft, primalRight);
- break;
- default:
- newInst = nullptr;
- getSink()->diagnose(origArith->sourceLoc,
- Diagnostics::unimplemented,
- "this arithmetic instruction cannot be differentiated");
- }
- orginalToTranscribed.Add(origArith, newInst);
- return InstPair(newInst, nullptr);
- }
-
- IRInst* BackwardDiffTranscriberBase::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
- {
- SLANG_ASSERT(origArith->getOperandCount() == 2);
-
- auto lhs = origArith->getOperand(0);
- auto rhs = origArith->getOperand(1);
-
- if (as<IRInOutType>(lhs->getDataType()))
- {
- lhs = builder->emitLoad(lhs);
- lhs = builder->emitDifferentialPairGetPrimal(lhs);
- }
- if (as<IRInOutType>(rhs->getDataType()))
- {
- rhs = builder->emitLoad(rhs);
- rhs = builder->emitDifferentialPairGetPrimal(rhs);
- }
-
- IRInst* leftGrad;
- IRInst* rightGrad;
-
-
- switch (origArith->getOp())
- {
- case kIROp_Add:
- leftGrad = grad;
- rightGrad = grad;
- break;
- case kIROp_Mul:
- leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
- rightGrad = builder->emitMul(grad->getDataType(), lhs, grad);
- break;
- case kIROp_Sub:
- leftGrad = grad;
- rightGrad = builder->emitNeg(grad->getDataType(), grad);
- break;
- case kIROp_Div:
- leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
- rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad
- break;
- default:
- getSink()->diagnose(origArith->sourceLoc,
- Diagnostics::unimplemented,
- "this arithmetic instruction cannot be differentiated");
- }
-
- lhs = origArith->getOperand(0);
- rhs = origArith->getOperand(1);
- if (auto leftGrads = upperGradients.TryGetValue(lhs))
- {
- leftGrads->add(leftGrad);
- }
- else
+ IRParam* dOutParam = nullptr;
+ if (isResultDifferentiable)
{
- upperGradients.Add(lhs, leftGrad);
- }
- if (auto rightGrads = upperGradients.TryGetValue(rhs))
- {
- rightGrads->add(rightGrad);
- }
- else
- {
- upperGradients.Add(rhs, rightGrad);
- }
+ auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2);
- return nullptr;
- }
+ SLANG_ASSERT(dOutParamType);
- InstPair BackwardDiffTranscriberBase::copyInst(IRBuilder* builder, IRInst* origInst)
- {
- // Handle common SSA-style operations
- switch (origInst->getOp())
- {
- case kIROp_Param:
- return transcribeParam(builder, as<IRParam>(origInst));
-
- case kIROp_Return:
- return InstPair(nullptr, nullptr);
-
- case kIROp_Add:
- case kIROp_Mul:
- case kIROp_Sub:
- case kIROp_Div:
- return copyBinaryArith(builder, origInst);
-
- default:
- // Not yet implemented
- SLANG_ASSERT(0);
+ dOutParam = builder->emitParam(dOutParamType);
}
- return InstPair(nullptr, nullptr);
- }
-
- IRInst* BackwardDiffTranscriberBase::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
- {
- IRInOutType* inoutParam = as<IRInOutType>(param->getDataType());
- auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType());
- auto paramValue = builder->emitLoad(param);
- auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
- auto diff = builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- paramValue
- );
- auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad);
- auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff);
- auto store = builder->emitStore(param, updatedParam);
-
- return store;
- }
-
- IRInst* BackwardDiffTranscriberBase::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
- {
- // Handle common SSA-style operations
- switch (origInst->getOp())
- {
- case kIROp_Param:
- return transposeParamBackward(builder, as<IRParam>(origInst), grad);
-
- case kIROp_Add:
- case kIROp_Mul:
- case kIROp_Sub:
- case kIROp_Div:
- return transposeBinaryArithBackward(builder, origInst, grad);
-
- case kIROp_DifferentialPairGetPrimal:
- {
- if (auto param = primalToDiffPair.TryGetValue(origInst))
- {
- if (auto leftGrads = upperGradients.TryGetValue(*param))
- {
- leftGrads->add(grad);
- }
- else
- {
- upperGradients.Add(*param, grad);
- }
- }
- else
- SLANG_ASSERT(0);
- return nullptr;
- }
-
- default:
- // Not yet implemented
- SLANG_ASSERT(0);
- }
+ // Add a parameter for intermediate val.
+ builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
- return nullptr;
+ return dOutParam;
}
InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)