summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/diff.meta.slang36
-rw-r--r--source/slang/slang-check-expr.cpp17
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp477
-rw-r--r--source/slang/slang-ir-autodiff-rev.h40
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp72
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h99
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp108
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h101
-rw-r--r--source/slang/slang-ir-autodiff.cpp2
-rw-r--r--source/slang/slang-ir-autodiff.h17
-rw-r--r--source/slang/slang-ir-clone.cpp3
-rw-r--r--source/slang/slang-ir-clone.h1
-rw-r--r--source/slang/slang-ir-inst-defs.h18
-rw-r--r--source/slang/slang-ir-insts.h36
-rw-r--r--source/slang/slang-ir-specialize.cpp7
-rw-r--r--source/slang/slang-ir-ssa.cpp2
-rw-r--r--source/slang/slang-ir.cpp50
18 files changed, 848 insertions, 249 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index adbf8ae48..055c44135 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -313,6 +313,24 @@ DifferentialPair<vector<T, N>> __d_sin_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_UNARY(T, N, __d_sin, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(sin)]
+void __d_sin(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(cos(dpx.p), dOut));
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(sin)]
+void __d_sin_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ dpx = diffPair(
+ dpx.p,
+ vector<T, N>.dmul(cos(dpx.p), dOut));
+}
+
// Cosine
__generic<T : __BuiltinFloatingPointType>
@@ -331,6 +349,24 @@ DifferentialPair<vector<T, N>> __d_cos_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_UNARY(T, N, __d_cos, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(cos)]
+void __d_cos(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(-sin(dpx.p), dOut));
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(cos)]
+void __d_cos_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ dpx = diffPair(
+ dpx.p,
+ vector<T, N>.dmul(-sin(dpx.p), dOut));
+}
+
// Base-e logarithm
__generic<T : __BuiltinFloatingPointType>
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b89eb85c4..a52a08f15 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -478,6 +478,7 @@ namespace Slang
if (!parent)
return nullptr;
+
// If we reach here, we are expecting a synthesized decl defined in `subType`.
// Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl
// in `subType` and return a DeclRefExpr to the synthesized decl.
@@ -862,6 +863,15 @@ namespace Slang
if (auto declRefType = as<DeclRefType>(type))
{
+ if (auto builtinRequirement = declRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>())
+ {
+ if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType)
+ {
+ // We are trying to get differential type from a differential type.
+ // The result is itself.
+ return type;
+ }
+ }
if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface())))
{
auto diffTypeLookupResult = lookUpMember(
@@ -2328,6 +2338,13 @@ namespace Slang
{
for (auto param : funcDecl->getParameters())
{
+ if (param->findModifier<NoDiffModifier>())
+ {
+ if (param->findModifier<OutModifier>() &&
+ !param->findModifier<InModifier>() &&
+ !param->findModifier<InOutModifier>())
+ continue;
+ }
resultDiffExpr->newParameterNames.add(param->getName());
}
resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient"));
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index a9e716ce4..16b7a977c 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -402,6 +402,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
auto primalArg = findOrTranscribePrimalInst(&argBuilder, origArg);
SLANG_ASSERT(primalArg);
+ auto origType = origCall->getArg(ii)->getDataType();
auto primalType = primalArg->getDataType();
auto paramType = calleeType->getParamType(ii);
if (!isNoDiffType(paramType))
@@ -410,8 +411,10 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
{
while (auto attrType = as<IRAttributedType>(primalType))
primalType = attrType->getBaseType();
+ while (auto attrType = as<IRAttributedType>(origType))
+ origType = attrType->getBaseType();
}
- if (auto pairType = tryGetDiffPairType(&argBuilder, primalType))
+ if (auto pairType = tryGetDiffPairType(&argBuilder, origType))
{
auto pairPtrType = as<IRPtrTypeBase>(pairType);
auto pairValType = as<IRDifferentialPairType>(
@@ -1201,7 +1204,7 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
// Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
{
- cloneDecoration(dictDecor, diffFunc);
+ cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule());
}
return diffFunc;
}
@@ -1434,8 +1437,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam
auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
// Make a local copy of the parameter for primal and diff parts.
auto primal = builder->emitVar(ptrInnerPairType->getValueType());
+
auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType());
auto diff = builder->emitVar(diffType);
+ builder->markInstAsDifferential(diff, ptrInnerPairType->getValueType());
IRInst* primalInitVal = nullptr;
IRInst* diffInitVal = nullptr;
@@ -1447,6 +1452,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam
else
{
auto initVal = builder->emitLoad(diffPairParam);
+ builder->markInstAsMixedDifferential(initVal, ptrInnerPairType);
+
primalInitVal = builder->emitDifferentialPairGetPrimal(initVal);
diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal);
}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 9c63a4012..67387e83a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -417,7 +417,10 @@ namespace Slang
}
else
{
- auto var = builder.emitVar(primalParamType);
+ auto primalPtrType = as<IRPtrTypeBase>(primalParamType);
+ SLANG_RELEASE_ASSERT(primalPtrType);
+ auto primalValueType = primalPtrType->getValueType();
+ auto var = builder.emitVar(primalValueType);
primalArgs.add(var);
}
primalTypes.add(primalParamType);
@@ -515,12 +518,12 @@ namespace Slang
{
auto ptrType = as<IRPtrTypeBase>(param->getDataType());
auto tempVar = builder.emitVar(ptrType->getValueType());
+ param->replaceUsesWith(tempVar);
mapParamToTempVar[param] = tempVar;
- if (param->getOp() != kIROp_OutType)
+ if (ptrType->getOp() != kIROp_OutType)
{
builder.emitStore(tempVar, builder.emitLoad(param));
}
- param->replaceUsesWith(tempVar);
}
for (auto block : func->getBlocks())
@@ -578,6 +581,43 @@ namespace Slang
return result;
}
+ void eliminateRedundantLoad(IRFunc* func)
+ {
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst;)
+ {
+ auto nextInst = inst->getNextInst();
+ if (auto load = as<IRLoad>(inst))
+ {
+ for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst())
+ {
+ if (auto store = as<IRStore>(prev))
+ {
+ if (store->getPtr() == load->getPtr())
+ {
+ // If the load is preceeded by a store without any side-effect insts in-between, remove the load.
+ auto value = store->getVal();
+ load->replaceUsesWith(value);
+ load->removeAndDeallocate();
+ break;
+ }
+ }
+ else if (as<IRCall>(prev))
+ {
+ break;
+ }
+ else if (prev->mightHaveSideEffects())
+ {
+ break;
+ }
+ }
+ }
+ inst = nextInst;
+ }
+ }
+ }
+
// Create a copy of originalFunc's forward derivative in the same generic context (if any) of
// `diffPropagateFunc`.
IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc(
@@ -621,6 +661,9 @@ namespace Slang
// Remove the clone of original func.
primalOuterParent->removeAndDeallocate();
+ // Remove redundant loads since they interfere with transposition logic.
+ eliminateRedundantLoad(fwdDiffFunc);
+
// Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`.
if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc)))
{
@@ -645,6 +688,7 @@ namespace Slang
}
fwdParentGeneric->removeAndDeallocate();
}
+
return fwdDiffFunc;
}
@@ -665,6 +709,20 @@ namespace Slang
return InstPair(primalInst, nullptr);
}
+ // Keep primal param replacement insts alive during DCE.
+ static void _lockPrimalParamReplacementInsts(IRBuilder* builder, ParameterBlockTransposeInfo& paramInfo)
+ {
+ for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc)
+ builder->addKeepAliveDecoration(kv.Value);
+ }
+
+ // Remove [KeepAlive] decorations for primal param replacement insts.
+ static void _unlockPrimalParamReplacementInsts(ParameterBlockTransposeInfo& paramInfo)
+ {
+ for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc)
+ kv.Value->findDecoration<IRKeepAliveDecoration>()->removeAndDeallocate();
+ }
+
// Transcribe a function definition.
void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc)
{
@@ -702,9 +760,9 @@ namespace Slang
// Copy primal insts to the first block of the unzipped function, copy diff insts to the
// second block of the unzipped function.
//
- IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+ diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+ IRFunc* unzippedFwdDiffFunc = fwdDiffFunc;
-
// Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell.
builder->setInsertInto(diffPropagateFunc->getParent());
{
@@ -717,13 +775,19 @@ namespace Slang
}
// Transpose the first block (parameter block)
- List<IRInst*> primalFuncSpecificParams;
- auto dOutParameter = transposeParameterBlock(builder, diffPropagateFunc, primalFuncSpecificParams, isResultDifferentiable);
+ auto paramTransposeInfo =
+ splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable);
+
+ // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc
+ // may be used by write back logic that we are going to insert later.
+ // Before then we want to keep them alive.
+ _lockPrimalParamReplacementInsts(builder, paramTransposeInfo);
builder->setInsertInto(diffPropagateFunc);
- // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
- DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr};
+ // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
+ // derivative of the return value.
+ DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam, nullptr};
diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
eliminateDeadCode(diffPropagateFunc);
@@ -732,32 +796,36 @@ namespace Slang
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
- diffPropagateFunc, primalFunc, isResultDifferentiable, intermediateType);
+ diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType);
- // Clean up by deallocating the tempoarary forward derivative func.
- fwdDiffFunc->removeAndDeallocate();
+ // At this point the unzipped func is just an empty shell
+ // and we can simply remove it.
+ unzippedFwdDiffFunc->removeAndDeallocate();
- // Remove primalFuncSpecificParams.
- for (auto specificParam : primalFuncSpecificParams)
+ // Write back derivatives to inout parameters.
+ writeBackDerivativeToInOutParams(paramTransposeInfo, diffPropagateFunc);
+
+ // Remove primalFunc specific params.
+ List<IRInst*> paramsToRemove;
+ for (auto param : diffPropagateFunc->getParams())
+ {
+ if (!paramTransposeInfo.propagateFuncParams.Contains(param))
+ paramsToRemove.add(param);
+ }
+ for (auto param : paramsToRemove)
{
- while (auto use = specificParam->firstUse)
+ if (param->hasUses())
{
- if (use->getUser()->getOp() == kIROp_Store && use == use->getUser()->getOperands())
- {
- use->getUser()->removeAndDeallocate();
- }
- else if (auto decor = as<IRDecoration>(use->getUser()))
- {
- decor->removeAndDeallocate();
- }
- else
- {
- SLANG_UNEXPECTED("unexpected use of transcribed param.");
- }
+ IRInst* replacement = nullptr;
+ paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc.TryGetValue(param, replacement);
+ SLANG_RELEASE_ASSERT(replacement);
+ param->replaceUsesWith(replacement);
}
- specificParam->removeAndDeallocate();
+ param->removeAndDeallocate();
}
+ _unlockPrimalParamReplacementInsts(paramTransposeInfo);
+
// If primal function is nested in a generic, we want to create separate generics for all the associated things
// we have just created.
auto primalOuterGeneric = findOuterGeneric(primalFunc);
@@ -789,87 +857,322 @@ namespace Slang
initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc);
}
- IRInst* BackwardDiffTranscriberBase::transposeParameterBlock(
+ ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock(
IRBuilder* builder,
IRFunc* diffFunc,
- List<IRInst*>& primalFuncSpecificParams,
bool isResultDifferentiable)
{
+ // This method splits transposes the all the parameters for both the primal and propagate computation.
+ // At the end of this method, the parameter block will contain a combination of parameters for
+ // both the to-be-primal function and to-be-propagate function.
+ // We use ParameterBlockTransposeInfo::primalFuncParams and ParameterBlockTransposeInfo::propagateFuncParams
+ // to track which parameters are dedicated to the future primal or propagate func.
+ // A later step will then split the parameters out to each new function.
+
+ ParameterBlockTransposeInfo result;
+
+ // First, we initialize the IR builders and locate the import code insertion points that will
+ // be used for the rest of this method.
+
IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock();
// Find the 'next' block using the terminator inst of the parameter block.
auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator());
auto nextBlock = fwdParamBlockBranch->getTargetBlock();
- builder->setInsertBefore(fwdParamBlockBranch);
+ auto nextBlockBuilder = *builder;
+ nextBlockBuilder.setInsertBefore(nextBlock->getFirstOrdinaryInst());
- List<IRParam*> fwdParams;
- for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam())
+ IRBlock* firstDiffBlock = nullptr;
+ for (auto block : diffFunc->getBlocks())
{
- fwdParams.add(child);
+ if (isDifferentialInst(block))
+ {
+ firstDiffBlock = block;
+ break;
+ }
}
- // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
+ SLANG_RELEASE_ASSERT(firstDiffBlock);
+
+ auto diffBuilder = *builder;
+ diffBuilder.setInsertBefore(firstDiffBlock->getFirstOrdinaryInst());
+
+ builder->setInsertBefore(fwdParamBlockBranch);
+
+ // Collect all the original parameters.
+ List<IRParam*> fwdParams;
+ for (auto param : diffFunc->getParams())
+ fwdParams.add(param);
+
+ // Maintain a set for insts pending removal.
+ OrderedHashSet<IRInst*> instsToRemove;
+
+ // Now we begin the actual processing.
+ // The first step is to transcribe all the existing parameters from the original function.
+ // There are many cases to handle, including different combinations of parameter directions and
+ // whether or not the parameter is differentiable.
+ // To normalize the process for all these cases, we determine the following actions for each parameter:
+ // 1. Should this original parameter be translated to a parameter in the primal func and the propagate func?
+ // if so, we emit a param inst representing the final parameter for that func. If the parameter should be
+ // mapped to both the primal func and the propagate func, we will emit two separate params with their
+ // final type.
+ // 2. If this parameter has a corresponding primal func parameter, we replace all uses of the original
+ // parameter in the primal computation code to the new primal parameter. If any initialization logic
+ // is needed to convert the type of the new primal parameter to what the code was expecting, we insert
+ // that code in the first block.
+ // 3. If this parameter has a correponding propagate func parameter, we replace all uses of the original parameter
+ // in the diff computation code to the new propagate parameter. We insert necessary initialization diff block or the first block
+ // depending on whether we want that logic go through the transposition pass. We may need to replace the uses
+ // to different values/variables depending on whether that use is a read or write.
+ // 4. If the parameter has both corresponding primal and propagate parameters, we also need to consider
+ // how the future propagate function access the primal parameter. We will insert necessary preparation code
+ // that constructs temp vars or values to replace the primal parameter after we remove it from the
+ // propagate func.
+ // Base on above discussion, we need to compute the following values for each parameter:
+ // - diffRefReplacement. What should all read(load) references to this parameter from differential code be replaced to.
+ // - diffRefWriteReplacement. What should all write references to this parameter from differential code be replaced to.
+ // - primalRefReplacement. What should all references to this parameter from primal code be replaced to.
+ // - mapPrimalSpecificParamToReplacementInPropFunc[param]. What should all references to this parameter
+ // from the primal compuation logic in the future propagate function be replaced to.
for (auto fwdParam : fwdParams)
{
- if (auto outType = as<IROutType>(fwdParam->getDataType()))
+ // Define the replacement insts that we are going to fill in for each case.
+ IRInst* diffRefReplacement = nullptr;
+ IRInst* primalRefReplacement = nullptr;
+ IRInst* diffWriteRefReplacement = nullptr;
+
+ // Common logic that computes all the important types we care about.
+ IRDifferentialPairType* diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType());
+ auto inoutType = as<IRInOutType>(fwdParam->getDataType());
+ auto outType = as<IROutType>(fwdParam->getDataType());
+ if (inoutType)
+ diffPairType = as<IRDifferentialPairType>(inoutType->getValueType());
+ else if (outType)
+ diffPairType = as<IRDifferentialPairType>(outType->getValueType());
+ IRType* primalType = nullptr;
+ IRType* diffType = nullptr;
+ if (diffPairType)
+ {
+ primalType = diffPairType->getValueType();
+ diffType = (IRType*)differentiableTypeConformanceContext
+ .getDifferentialTypeFromDiffPairType(builder, diffPairType);
+ }
+
+ // Now we handle each combination of parameter direction x differentiability.
+ if (outType)
{
- IRParam* newPropParam = nullptr;
- IRParam* newPrimalParam = nullptr;
- auto diffPairType = as<IRDifferentialPairType>(outType->getValueType());
+ // Case 1: out parameters.
+ // Out parameters need to be handled differently whether or not it is differentiable,
+ // since the propagate function will not have a corresponding output.
if (diffPairType)
{
// Create dOut param.
- auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(builder, diffPairType);
- newPropParam = builder->emitParam(diffType);
- newPrimalParam = builder->emitParam(builder->getOutType(diffPairType->getValueType()));
+ auto diffParam = builder->emitParam(diffType);
+ result.propagateFuncParams.Add(diffParam);
+ primalRefReplacement = builder->emitParam(builder->getOutType(primalType));
+
+ // Create a local var for read access in pre-transpose code.
+ // This will the var from which we will fetch the final resulting derivative
+ // after transposition.
+ auto tempVar = nextBlockBuilder.emitVar(diffType);
+ nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType);
+
+ // Initialize the var with input diff param at start.
+ // Note that we insert the store in the primal block so it won't get transposed.
+ auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam);
+ nextBlockBuilder.markInstAsDifferential(storeInst, diffPairType);
+ // Since this store inst is specific to propagate function, we track it in a
+ // set so we can remove it when we generate the primal func.
+ result.propagateFuncSpecificPrimalInsts.add(storeInst);
+
+ diffWriteRefReplacement = tempVar;
+ diffRefReplacement = tempVar;
}
else
{
- newPrimalParam = builder->emitParam(outType);
- }
-
- // Create a temp var to represent the original `out` param.
- auto arg = builder->emitVar(outType->getValueType());
- builder->addAutoDiffOriginalValueDecoration(arg, newPrimalParam);
- if (newPropParam)
- {
- builder->addDecoration(arg, kIROp_OutParamReverseGradientDecoration, newPropParam);
+ primalRefReplacement = builder->emitParam(outType);
}
+ result.primalFuncParams.Add(primalRefReplacement);
- fwdParam->replaceUsesWith(arg);
- fwdParam->removeAndDeallocate();
+ // Create a local var for the out param for the primal part of the prop func.
+ auto tempPrimalVar = nextBlockBuilder.emitVar(outType->getValueType());
+ result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar;
- primalFuncSpecificParams.add(newPrimalParam);
+ instsToRemove.Add(fwdParam);
}
- else if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
+ else if (!isRelevantDifferentialPair(fwdParam->getDataType()))
{
+ // Case 2: non differentiable, non output parameters.
+ // If parameter is not an out param and has nothing to do with differentiation,
+ // simply move the parameter to the end.
+ //
+ fwdParam->removeFromParent();
+ fwdDiffParameterBlock->addParam(fwdParam);
+ result.primalFuncParams.Add(fwdParam);
+ result.propagateFuncParams.Add(fwdParam);
+ continue;
+ }
+ else if(!inoutType)
+ {
+ // Case 4: `in` differentiable parameters.
+
+ SLANG_RELEASE_ASSERT(diffPairType);
+
// Create inout version.
auto inoutDiffPairType = builder->getInOutType(diffPairType);
- auto newParam = builder->emitParam(inoutDiffPairType);
-
- // Map the _load_ of the new parameter as the clone of the old one.
- auto newParamLoad = builder->emitLoad(newParam);
- newParamLoad->insertAtStart(nextBlock); // Move to first block _after_ the parameter block.
- fwdParam->replaceUsesWith(newParamLoad);
- fwdParam->removeAndDeallocate();
+ primalRefReplacement = builder->emitParam(primalType);
+ result.primalFuncParams.Add(primalRefReplacement);
+ auto propParam = builder->emitParam(inoutDiffPairType);
+ result.propagateFuncParams.Add(propParam);
+
+ // A reference to this parameter from the diff blocks should be replaced with a load
+ // of the differential component of the pair.
+ auto newParamLoad = diffBuilder.emitLoad(propParam);
+ diffBuilder.markInstAsDifferential(newParamLoad, primalType);
+
+ diffRefReplacement = diffBuilder.emitDifferentialPairGetDifferential(diffType, newParamLoad);
+ diffBuilder.markInstAsDifferential(diffRefReplacement, primalType);
+
+ // Load the primal component from the prop param and use it as replacement for the
+ // primal param in the primal part of the prop func.
+ // Since these are logic specific to propagate function, we will add them to the
+ // `propagateFuncSpecificPrimalInsts` set so we can remove them when we generate the primal func.
+ auto primalReplacementLoad = nextBlockBuilder.emitLoad(propParam);
+ result.propagateFuncSpecificPrimalInsts.add(primalReplacementLoad);
+ auto primalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(primalReplacementLoad);
+ result.propagateFuncSpecificPrimalInsts.add(primalVal);
+ result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = primalVal;
+
+ instsToRemove.Add(fwdParam);
}
else
{
- // Default case (parameter is inout type or has nothing to do with differentiation)
- // Simply move the parameter to the end.
- //
- fwdParam->removeFromParent();
- fwdDiffParameterBlock->addParam(fwdParam);
+ // Case 5: `inout` differentiable parameters.
+ SLANG_ASSERT(inoutType && diffPairType);
+
+ // Process differentiable inout parameters.
+ auto primalParam = builder->emitParam(builder->getInOutType(primalType));
+ result.primalFuncParams.Add(primalParam);
+
+ auto diffParam = builder->emitParam(inoutType);
+ result.propagateFuncParams.Add(diffParam);
+
+ // Primal references to this param is the new primal param.
+ primalRefReplacement = primalParam;
+
+ // Diff references to this param should be replaced with one local temp var
+ // for read and one separate temp var for write.
+
+ // Load the inital diff value.
+ auto loadedParam = nextBlockBuilder.emitLoad(diffParam);
+ auto initDiff = nextBlockBuilder.emitDifferentialPairGetDifferential(diffType, loadedParam);
+
+ // Create a local var for diff read access.
+ auto diffVar = nextBlockBuilder.emitVar(diffType);
+ result.propagateFuncSpecificPrimalInsts.add(diffVar);
+ diffBuilder.markInstAsDifferential(diffVar, diffPairType);
+ diffRefReplacement = diffVar;
+
+ // Clear the diff read var to zero at start of the function.
+ auto dzero = getDifferentialZeroOfType(&nextBlockBuilder, primalType);
+ result.propagateFuncSpecificPrimalInsts.add(dzero);
+ auto initDiffStore = nextBlockBuilder.emitStore(diffVar, dzero);
+ result.propagateFuncSpecificPrimalInsts.add(initDiffStore);
+
+ // Create a local var for diff write access.
+ auto diffWriteVar = nextBlockBuilder.emitVar(diffType);
+ // Initialize write var to 0.
+ auto writeStore = nextBlockBuilder.emitStore(diffWriteVar, initDiff);
+ result.propagateFuncSpecificPrimalInsts.add(writeStore);
+
+ diffWriteRefReplacement = diffWriteVar;
+
+ // Create a local var for the primal logic in the propagate func.
+ auto primalVar = nextBlockBuilder.emitVar(primalType);
+ result.propagateFuncSpecificPrimalInsts.add(primalVar);
+ auto initPrimalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(loadedParam);
+ result.propagateFuncSpecificPrimalInsts.add(initPrimalVal);
+ auto storeInst = nextBlockBuilder.emitStore(primalVar, initPrimalVal);
+ result.propagateFuncSpecificPrimalInsts.add(storeInst);
+ result.mapPrimalSpecificParamToReplacementInPropFunc[primalParam] = primalVar;
+ result.outDiffWritebacks[diffParam] = InstPair(initPrimalVal, diffVar);
+
+ instsToRemove.Add(fwdParam);
+ }
+
+ // We have emitted all the new parameters and computed the replacements for the original
+ // parameter. Now we perform that replacement.
+ List<IRUse*> uses;
+ for (auto use = fwdParam->firstUse; use; use = use->nextUse)
+ uses.add(use);
+ for (auto use : uses)
+ {
+ if (auto primalRef = as<IRPrimalParamRef>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(primalRefReplacement);
+ primalRef->replaceUsesWith(primalRefReplacement);
+ instsToRemove.Add(primalRef);
+ }
+ else if (auto getPrimal = as<IRDifferentialPairGetPrimal>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(primalRefReplacement);
+ getPrimal->replaceUsesWith(primalRefReplacement);
+ instsToRemove.Add(getPrimal);
+ }
+ else if (auto propagateRef = as<IRDiffParamRef>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(diffRefReplacement);
+ auto refUse = propagateRef->firstUse;
+ while (refUse)
+ {
+ auto nextUse = refUse->nextUse;
+ switch (refUse->getUser()->getOp())
+ {
+ case kIROp_Load:
+ refUse->set(diffRefReplacement);
+ break;
+ case kIROp_Store:
+ refUse->set(diffWriteRefReplacement);
+ break;
+ default:
+ SLANG_RELEASE_ASSERT(!diffWriteRefReplacement);
+ refUse->set(diffRefReplacement);
+ }
+ refUse = nextUse;
+ }
+ instsToRemove.Add(propagateRef);
+ }
+ else if (auto getDiff = as<IRDifferentialPairGetDifferential>(use->getUser()))
+ {
+ SLANG_RELEASE_ASSERT(diffRefReplacement);
+ getDiff->replaceUsesWith(diffRefReplacement);
+ instsToRemove.Add(getDiff);
+ }
+ else
+ {
+ // If the user is something else, it'd better be a non relevant parameter.
+ if (diffRefReplacement || diffWriteRefReplacement)
+ SLANG_UNEXPECTED("unknown use of parameter.");
+ use->set(primalRefReplacement);
+ }
}
}
- auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
+ // Actually remove all the insts that we decided to remove in the process.
+ for (auto inst : instsToRemove)
+ {
+ inst->removeAndDeallocate();
+ }
+
- // 2. If the return type of the original function is differentiable,
+ // The next step is to insert new parameters that is not related to any existing parameters.
+ //
+ // If the return type of the original function is differentiable,
// add a parameter for 'derivative of the output' (d_out).
// The type is the second last parameter type of the function.
//
+ auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
IRParam* dOutParam = nullptr;
if (isResultDifferentiable)
{
@@ -878,12 +1181,44 @@ namespace Slang
SLANG_ASSERT(dOutParamType);
dOutParam = builder->emitParam(dOutParamType);
+ result.propagateFuncParams.Add(dOutParam);
}
// Add a parameter for intermediate val.
- builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
+ auto ctxParam = builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
+ result.primalFuncParams.Add(ctxParam);
+ result.propagateFuncParams.Add(ctxParam);
+ result.dOutParam = dOutParam;
+ return result;
+ }
- return dOutParam;
+ void BackwardDiffTranscriberBase::writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc)
+ {
+ IRInst* returnInst = nullptr;
+ for (auto block : diffFunc->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ returnInst = inst;
+ break;
+ }
+ }
+ }
+ SLANG_RELEASE_ASSERT(returnInst);
+
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(returnInst);
+ for (auto& wb : info.outDiffWritebacks)
+ {
+ auto dest = wb.Key;
+ auto srcPrimalVal = wb.Value.primal;
+ auto srcDiffAddr = wb.Value.differential;
+ auto srcDiffVal = builder.emitLoad(srcDiffAddr);
+ auto destVal = builder.emitMakeDifferentialPair(as<IRPtrTypeBase>(dest->getFullType())->getValueType(), srcPrimalVal, srcDiffVal);
+ builder.emitStore(dest, destVal);
+ }
}
InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 617e6b79b..c8e115f0e 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -20,6 +20,31 @@ struct IRReverseDerivativePassOptions
// Nothing for now..
};
+// The result of function parameter transposition.
+// Contains necessary info for future processing in the backward differentation pass.
+struct ParameterBlockTransposeInfo
+{
+ // Parameters that should be in the furture primal function.
+ HashSet<IRInst*> primalFuncParams;
+
+ // Parameters that should be in the furture propagate function.
+ HashSet<IRInst*> propagateFuncParams;
+
+ // The value with which a primal specific parameter should be replaced in propagate func.
+ OrderedDictionary<IRInst*, IRInst*> mapPrimalSpecificParamToReplacementInPropFunc;
+
+ // The insts added that is specific for propagate functions and should be removed
+ // from the future primal func.
+ List<IRInst*> propagateFuncSpecificPrimalInsts;
+
+ // Write backs to perform at the end of the back-prop function in order to return the
+ // computed output derivatives for an inout parameter.
+ OrderedDictionary<IRInst*, InstPair> outDiffWritebacks;
+
+ // The dOut parameter representing the result derivative to propagate backwards through.
+ IRInst* dOutParam;
+};
+
struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
{
FuncBodyTranscriptionTaskType diffTaskType;
@@ -70,9 +95,18 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
// Transcribe a function definition.
virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0;
- // Transcribes the parameter block and returns the dOut param if exists.
- IRInst* transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc, List<IRInst*>& primalFuncSpecificParams, bool isResultDifferentiable);
-
+ // Splits and transpose the parameter block.
+ // After this operation, the parameter block will contain parameters for both the future
+ // primal func and the future propagate func.
+ // Additional info is returned in `ParameterBlockTransposeInfo` for future processing such
+ // as inserting write-back logic or splitting them into different functions.
+ ParameterBlockTransposeInfo splitAndTransposeParameterBlock(
+ IRBuilder* builder,
+ IRFunc* diffFunc,
+ bool isResultDifferentiable);
+
+ void writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc);
+
InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType);
InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 8f21e8c62..31a3072c0 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -221,7 +221,19 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI
auto primalType = lookupPrimalInst(builder, originalType, nullptr);
SLANG_RELEASE_ASSERT(primalType);
- IRInst* witness = tryGetDifferentiableWitness(builder, originalType);
+ IRInst* witness = nullptr;
+ if (auto lookup = as<IRLookupWitnessMethod>(primalType))
+ {
+ if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey)
+ {
+ witness = builder->emitLookupInterfaceMethodInst(
+ lookup->getWitnessTable()->getDataType(),
+ lookup->getWitnessTable(),
+ autoDiffSharedContext->differentialAssocTypeWitnessStructKey);
+ }
+ }
+ if (!witness)
+ witness = tryGetDifferentiableWitness(builder, originalType);
SLANG_RELEASE_ASSERT(witness);
return builder->getDifferentialPairType(
@@ -239,6 +251,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType);
return (IRType*)findOrTranscribePrimalInst(builder, diffType);
}
+ else if (origType->getOp() == kIROp_LookupWitness)
+ {
+ return (IRType*)findOrTranscribePrimalInst(builder, (IRInst*)primalType);
+ }
return (IRType*)transcribe(builder, origType);
}
@@ -539,6 +555,39 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
{
return InstPair(primal, nullptr);
}
+ if (interfaceType == autoDiffSharedContext->differentiableInterfaceType)
+ {
+ if (primalKey == autoDiffSharedContext->differentialAssocTypeStructKey)
+ {
+ return InstPair(primal, primal);
+ }
+ else if (primalKey == autoDiffSharedContext->differentialAssocTypeWitnessStructKey)
+ {
+ return InstPair(primal, primal);
+ }
+ else
+ {
+ // We can't really differentiate a call to a IDifferentiable method here.
+ // They need to be specialized first.
+ return InstPair(primal, nullptr);
+ }
+ }
+ else if (auto returnWitnessType = as<IRWitnessTableTypeBase>(lookupInst->getDataType()))
+ {
+ // T.Diff_Is_IDifferential ==> T.Diff_Is_IDifferential.Diff_Is_IDifferential
+ if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiableInterfaceType)
+ {
+ auto primalDiffType = builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ primal,
+ autoDiffSharedContext->differentialAssocTypeStructKey);
+ auto diffWitness = builder->emitLookupInterfaceMethodInst(
+ (IRType*)primalDiffType,
+ primal,
+ autoDiffSharedContext->differentialAssocTypeWitnessStructKey);
+ return InstPair(primal, diffWitness);
+ }
+ }
auto decor =
lookupInst->getRequirementKey()->findDecorationImpl(
getInterfaceRequirementDerivativeDecorationOp());
@@ -563,6 +612,8 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
//
IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
{
+ primalType = (IRType*)unwrapAttributedType(primalType);
+
if (auto diffType = differentiateType(builder, primalType))
{
switch (diffType->getOp())
@@ -593,17 +644,18 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
// Since primalType has a corresponding differential type, we can lookup the
// definition for zero().
- auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
- if (!zeroMethod)
+ IRInst* zeroMethod = nullptr;
+ if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
{
// if the differential type itself comes from a witness lookup, we can just lookup the
// zero method from the same witness table.
- if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
- {
- auto wt = lookupInterface->getWitnessTable();
- zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
- builder->markInstAsDifferential(zeroMethod);
- }
+ auto wt = lookupInterface->getWitnessTable();
+ zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
+ builder->markInstAsDifferential(zeroMethod);
+ }
+ else
+ {
+ zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
}
SLANG_RELEASE_ASSERT(zeroMethod);
@@ -747,6 +799,8 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS
case kIROp_UserDefinedBackwardDerivativeDecoration:
case kIROp_ForwardDerivativeDecoration:
case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeIntermediateTypeDecoration:
+ case kIROp_BackwardDerivativePrimalContextDecoration:
case kIROp_BackwardDerivativePrimalDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
break;
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 2a341ed38..92bd0b0a8 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -786,6 +786,12 @@ struct DiffTransposePass
for (auto externInst : externInsts)
{
+ if (isNoDiffType(externInst->getDataType()))
+ {
+ popRevGradients(externInst);
+ continue;
+ }
+
auto primalType = tryGetPrimalTypeFromDiffInst(externInst);
SLANG_ASSERT(primalType);
@@ -960,11 +966,24 @@ struct DiffTransposePass
return as<IRDifferentialPairType>(type);
};
+ struct DiffValWriteBack
+ {
+ IRInst* destVar;
+ IRInst* srcTempPairVar;
+ };
+ List<DiffValWriteBack> writebacks;
+
+ auto baseFnType = as<IRFuncType>(baseFn->getDataType());
+
+ SLANG_RELEASE_ASSERT(baseFnType);
+ SLANG_RELEASE_ASSERT(fwdCall->getArgCount() == baseFnType->getParamCount());
+
for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++)
{
auto arg = fwdCall->getArg(ii);
-
- if (arg->getOp() == kIROp_LoadReverseGradient)
+ auto paramType = baseFnType->getParamType(ii);
+
+ if (as<IRLoadReverseGradient>(arg))
{
// Original parameters that are `out DifferentiableType` will turn into
// a `in Differential` parameter. The split logic will insert LoadReverseGradient insts
@@ -972,6 +991,20 @@ struct DiffTransposePass
// and use it as the final argument.
args.add(builder->emitLoad(arg->getOperand(0)));
}
+ else if (auto instPair = as<IRReverseGradientDiffPairRef>(arg))
+ {
+ // An argument to an inout parameter will come in the form of a ReverseGradientDiffPairRef(primalVar, diffVar) inst
+ // after splitting.
+ // In order to perform the call, we need a temporary var to store the DiffPair.
+ auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType();
+ auto tempVar = builder->emitVar(pairType);
+ auto primalVal = builder->emitLoad(instPair->getPrimal());
+ auto diffVal = builder->emitLoad(instPair->getDiff());
+ auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
+ builder->emitStore(tempVar, pairVal);
+ args.add(tempVar);
+ writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar});
+ }
else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
{
// Normal differentiable input parameter will become an inout DiffPair parameter
@@ -1002,9 +1035,17 @@ struct DiffTransposePass
}
else
{
- args.add(arg);
- argTypes.add(arg->getDataType());
- argRequiresLoad.add(false);
+ if (as<IROutType>(paramType))
+ {
+ args.add(nullptr);
+ argRequiresLoad.add(false);
+ }
+ else
+ {
+ args.add(arg);
+ argTypes.add(arg->getDataType());
+ argRequiresLoad.add(false);
+ }
}
}
@@ -1015,11 +1056,10 @@ struct DiffTransposePass
argRequiresLoad.add(false);
}
- args.add(primalContextDecor->getBackwardDerivativePrimalContextVar());
- argTypes.add(builder->getOutType(
- as<IRPtrTypeBase>(
+ args.add(builder->emitLoad(primalContextDecor->getBackwardDerivativePrimalContextVar()));
+ argTypes.add(as<IRPtrTypeBase>(
primalContextDecor->getBackwardDerivativePrimalContextVar()->getDataType())
- ->getValueType()));
+ ->getValueType());
argRequiresLoad.add(false);
auto revFnType = builder->getFuncType(argTypes, builder->getVoidType());
@@ -1027,11 +1067,27 @@ struct DiffTransposePass
revFnType,
baseFn);
- builder->emitCallInst(revFnType->getResultType(), revCallee, args);
+ List<IRInst*> callArgs;
+ for (auto arg : args)
+ if (arg)
+ callArgs.add(arg);
+ builder->emitCallInst(revFnType->getResultType(), revCallee, callArgs);
+
+ // Writeback result gradient to their corresponding splitted variable.
+ for (auto wb : writebacks)
+ {
+ auto loadedPair = builder->emitLoad(wb.srcTempPairVar);
+ auto diffType = as<IRPtrTypeBase>(wb.destVar->getDataType())->getValueType();
+ auto loadedDiff = builder->emitDifferentialPairGetDifferential(diffType, loadedPair);
+ builder->emitStore(wb.destVar, loadedDiff);
+ }
List<RevGradient> gradients;
- for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++)
+ for (Index ii = 0; ii < args.getCount(); ii++)
{
+ if (!args[ii])
+ continue;
+
// Is this arg relevant to auto-diff?
if (auto diffPairType = getDiffPairType(args[ii]->getDataType()))
{
@@ -1043,13 +1099,11 @@ struct DiffTransposePass
auto diffArgType = (IRType*)diffTypeContext.getDifferentialForType(
builder,
diffPairType->getValueType());
- auto diffArgPtrType = builder->getPtrType(kIROp_PtrType, diffArgType);
-
gradients.add(RevGradient(
RevGradient::Flavor::Simple,
fwdCall->getArg(ii),
builder->emitDifferentialPairGetDifferential(
- diffArgPtrType, builder->emitLoad(args[ii])),
+ diffArgType, builder->emitLoad(args[ii])),
nullptr));
}
}
@@ -1308,14 +1362,13 @@ struct DiffTransposePass
TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*)
{
- IRInst* revVal = nullptr;
- if (auto revGradDecor = fwdStore->getPtr()->findDecoration<IROutParamReverseGradientDecoration>())
+ IRInst* revVal = builder->emitLoad(fwdStore->getPtr());
+ if (auto diffPairType = as<IRDifferentialPairType>(revVal->getDataType()))
{
- revVal = revGradDecor->getValue();
- }
- else
- {
- revVal = builder->emitLoad(fwdStore->getPtr());
+ revVal = builder->emitDifferentialPairGetDifferential(
+ (IRType*)diffTypeContext.getDifferentialTypeFromDiffPairType(
+ builder, diffPairType),
+ revVal);
}
return TranspositionResult(
List<RevGradient>(
@@ -1747,7 +1800,8 @@ struct DiffTransposePass
for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++)
{
auto operand = fwdInst->getOperand(ii);
- if (operand->getDataType() != targetType)
+ auto operandType = unwrapAttributedType(operand->getDataType());
+ if (operandType != targetType)
{
// Insert new operand just after the old operand, so we have the old
// operands available.
@@ -2271,7 +2325,6 @@ struct DiffTransposePass
{
gradientsMap[fwdInst] = List<RevGradient>();
}
-
gradientsMap[fwdInst].GetValue().add(assignment);
}
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 378ea1cc2..be20d8aa8 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -1,6 +1,7 @@
#include "slang-ir-autodiff-unzip.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-util.h"
+#include "slang-ir-autodiff-rev.h"
namespace Slang
{
@@ -279,7 +280,7 @@ struct ExtractPrimalFuncContext
inst);
}
- IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet<IRInst*>& primalParams, IRInst*& outIntermediateType)
{
IRBuilder builder(sharedBuilder);
@@ -294,6 +295,7 @@ struct ExtractPrimalFuncContext
auto oldIntermediateParam = func->getLastParam();
auto outIntermediary =
builder.emitParam(builder.getInOutType((IRType*)intermediateType));
+ primalParams.Add(outIntermediary);
oldIntermediateParam->replaceUsesWith(outIntermediary);
oldIntermediateParam->removeAndDeallocate();
@@ -368,62 +370,18 @@ struct ExtractPrimalFuncContext
auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType);
builder.emitStore(outIntermediary, defVal);
- // The primal func will not have the result derivative param (second to last param), so we remove it.
- if (isResultDifferentiable)
- {
- auto resultDerivativeParam = func->getLastParam()->getPrevParam();
- SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses());
- resultDerivativeParam->removeAndDeallocate();
- }
-
- // Finally, go through parameters and translate their type back to primal type.
+ // Remove any parameters not in `primalParams` set.
+ List<IRInst*> params;
for (auto param = func->getFirstParam(); param;)
{
- auto next = param->getNextParam();
- [this, firstBlock, &builder, param]()
+ auto nextParam = param->getNextParam();
+ if (!primalParams.Contains(param))
{
- for (auto use = param->firstUse; use; use = use->nextUse)
- {
- if (use->getUser()->getOp() == kIROp_AutoDiffOriginalValueDecoration)
- {
- use->getUser()->getParent()->replaceUsesWith(param);
- return;
- }
- else if (use->getUser()->getOp() == kIROp_OutParamReverseGradientDecoration)
- {
- // This is a propagate func specific parameter, we should remove it.
- SLANG_RELEASE_ASSERT(!param->hasMoreThanOneUse());
- param->removeAndDeallocate();
- return;
- }
- }
-
- IRInst* valueType = param->getDataType();
- auto inoutType = as<IRPtrTypeBase>(param->getDataType());
- if (inoutType) valueType = inoutType->getValueType();
- auto diffPairType = as<IRDifferentialPairType>(valueType);
- if (!diffPairType)
- return;
-
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
-
- auto originalValueType = diffPairType->getValueType();
-
- // Create a local var to act as the old param.
- auto tempVar = builder.emitVar(diffPairType);
- param->replaceUsesWith(tempVar);
- auto pairValue = builder.emitMakeDifferentialPair(
- diffPairType,
- param,
- backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType));
- builder.emitStore(tempVar, pairValue);
-
- // Change the param type to original type.
- param->setFullType(originalValueType);
- }();
- param = next;
+ param->replaceUsesWith(builder.getVoidValue());
+ param->removeAndDeallocate();
+ }
+ param = nextParam;
}
-
return unzippedFunc;
}
};
@@ -446,7 +404,10 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
}
IRFunc* DiffUnzipPass::extractPrimalFunc(
- IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType)
+ IRFunc* func,
+ IRFunc* originalFunc,
+ ParameterBlockTransposeInfo& paramInfo,
+ IRInst*& intermediateType)
{
IRBuilder builder(this->autodiffContext->sharedBuilder);
builder.setInsertBefore(func);
@@ -456,11 +417,31 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
subEnv.parent = &cloneEnv;
auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func));
+ // Remove [KeepAlive] decorations in clonedFunc.
+ for (auto block : clonedFunc->getBlocks())
+ for (auto inst : block->getChildren())
+ if (auto decor = inst->findDecoration<IRKeepAliveDecoration>())
+ decor->removeAndDeallocate();
+
+ // Remove propagate func specific primal insts from cloned func.
+ for (auto inst : paramInfo.propagateFuncSpecificPrimalInsts)
+ {
+ auto newInst = subEnv.mapOldValToNew[inst].GetValue();
+ newInst->removeAndDeallocate();
+ }
+
+ HashSet<IRInst*> newPrimalParams;
+ for (auto param : func->getParams())
+ {
+ if (paramInfo.primalFuncParams.Contains(param))
+ newPrimalParams.Add(subEnv.mapOldValToNew[param].GetValue());
+ }
+
ExtractPrimalFuncContext context;
context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber);
intermediateType = nullptr;
- auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, isResultDifferentiable, intermediateType);
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType);
if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>())
{
@@ -489,21 +470,26 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>())
{
builder.setInsertBefore(inst);
- auto val = builder.emitFieldExtract(
- inst->getDataType(),
- intermediateVar,
- structKeyDecor->getStructKey());
if (inst->getOp() == kIROp_Var)
{
// This is a var for intermediate context.
+ auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType();
+ auto val = builder.emitFieldExtract(
+ valType,
+ intermediateVar,
+ structKeyDecor->getStructKey());
auto tempVar =
- builder.emitVar(cast<IRPtrTypeBase>(inst->getFullType())->getValueType());
+ builder.emitVar(valType);
builder.emitStore(tempVar, val);
inst->replaceUsesWith(tempVar);
}
else
{
// Orindary value.
+ auto val = builder.emitFieldExtract(
+ inst->getFullType(),
+ intermediateVar,
+ structKeyDecor->getStructKey());
inst->replaceUsesWith(val);
}
instsToRemove.add(inst);
@@ -522,6 +508,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
{
inst->removeAndDeallocate();
}
+
+ stripTempDecorations(func);
// Run simplification to DCE unnecessary insts.
eliminateDeadCode(func);
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 3055d057b..057ff53c4 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -14,6 +14,8 @@
namespace Slang
{
+struct ParameterBlockTransposeInfo;
+
struct DiffUnzipPass
{
AutoDiffSharedContext* autodiffContext;
@@ -120,7 +122,7 @@ struct DiffUnzipPass
return diffMap[inst];
}
- IRFunc* unzipDiffInsts(IRFunc* func)
+ void unzipDiffInsts(IRFunc* func)
{
diffTypeContext.setFunc(func);
@@ -129,21 +131,33 @@ struct DiffUnzipPass
IRBuilder* builder = &builderStorage;
- // Clone the entire function.
- // TODO: Maybe don't clone? The reverse-mode process seems to clone several times.
- // TODO: Looks like we get a copy of the decorations?
- IRCloneEnv subEnv;
- subEnv.parent = &cloneEnv;
- builder->setInsertBefore(func);
- IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&subEnv, builder, func));
+ IRFunc* unzippedFunc = func;
- builder->setInsertInto(unzippedFunc);
+ // Initialize the primal/diff map for parameters.
+ // Generate distinct references for parameters that should be split.
+ // We don't actually modify the parameter list here, instead we emit
+ // PrimalParamRef(param) and DiffParamRef(param) and use those to represent
+ // a use from the primal or diff part of the program.
+ builder->setInsertBefore(unzippedFunc->getFirstBlock()->getTerminator());
- auto originalParam = func->getFirstParam();
for (auto primalParam = unzippedFunc->getFirstParam(); primalParam; primalParam = primalParam->getNextParam())
{
- primalMap[originalParam] = primalParam;
- originalParam = originalParam->getNextParam();
+ auto type = primalParam->getFullType();
+ if (auto ptrType = as<IRPtrTypeBase>(type))
+ {
+ type = ptrType->getValueType();
+ }
+ if (auto pairType = as<IRDifferentialPairType>(type))
+ {
+ IRInst* diffType = diffTypeContext.getDifferentialTypeFromDiffPairType(builder, pairType);
+ if (as<IRPtrTypeBase>(primalParam->getFullType()))
+ diffType = builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType);
+ auto primalRef = builder->emitPrimalParamRef(primalParam);
+ auto diffRef = builder->emitDiffParamRef((IRType*)diffType, primalParam);
+ builder->markInstAsDifferential(diffRef, pairType->getValueType());
+ primalMap[primalParam] = primalRef;
+ diffMap[primalParam] = diffRef;
+ }
}
// Functions need to have at least two blocks at this point (one for parameters,
@@ -239,8 +253,6 @@ struct DiffUnzipPass
// Remove old blocks.
for (auto block : mixedBlocks)
block->removeAndDeallocate();
-
- return unzippedFunc;
}
IRBlock* getInitializerBlock(IndexedRegion* region)
@@ -476,25 +488,12 @@ struct DiffUnzipPass
}
}
- IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType);
-
- bool isRelevantDifferentialPair(IRType* type)
- {
- if (as<IRDifferentialPairType>(type))
- {
- return true;
- }
- else if (auto argPtrType = as<IRPtrTypeBase>(type))
- {
- if (as<IRDifferentialPairType>(argPtrType->getValueType()))
- {
- return true;
- }
- }
-
- return false;
- }
-
+ IRFunc* extractPrimalFunc(
+ IRFunc* func,
+ IRFunc* originalFunc,
+ ParameterBlockTransposeInfo& paramInfo,
+ IRInst*& intermediateType);
+
static IRInst* _getOriginalFunc(IRInst* call)
{
if (auto decor = call->findDecoration<IRAutoDiffOriginalValueDecoration>())
@@ -606,7 +605,13 @@ struct DiffUnzipPass
}
else if (auto inoutType = as<IRInOutType>(primalParamType))
{
- SLANG_UNIMPLEMENTED_X("nested call inout parameter");
+ // Since arg is split into separate vars, we need a new temp var that represents
+ // the remerged diff pair.
+ auto diffPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(arg->getDataType())->getValueType());
+ auto primalValueType = diffPairType->getValueType();
+ auto diffPairRef = diffBuilder->emitReverseGradientDiffPairRef(arg->getDataType(), primalArg, diffArg);
+ diffBuilder->markInstAsDifferential(diffPairRef, primalValueType);
+ diffArgs.add(diffPairRef);
}
else
{
@@ -661,20 +666,6 @@ struct DiffUnzipPass
InstPair splitLoad(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRLoad* mixedLoad)
{
- if (auto param = as<IRParam>(mixedLoad->getPtr()))
- {
- auto diffPairPtrType = as<IRPtrTypeBase>(param->getFullType());
- SLANG_RELEASE_ASSERT(diffPairPtrType);
- auto diffPairType = as<IRDifferentialPairType>(diffPairPtrType->getValueType());
- SLANG_RELEASE_ASSERT(diffPairType);
- auto diffType = (IRType*)diffTypeContext.getDifferentialTypeFromDiffPairType(diffBuilder, diffPairType);
- auto loadedParam = primalBuilder->emitLoad(param);
- return InstPair(
- primalBuilder->emitDifferentialPairGetPrimal(loadedParam),
- primalBuilder->emitDifferentialPairGetDifferential(diffType, loadedParam));
- }
-
- // Everything else should have already been split.
auto primalPtr = lookupPrimalInst(mixedLoad->getPtr());
auto diffPtr = lookupDiffInst(mixedLoad->getPtr());
auto primalVal = primalBuilder->emitLoad(primalPtr);
@@ -685,22 +676,14 @@ struct DiffUnzipPass
InstPair splitStore(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRStore* mixedStore)
{
- // We will only generate mixed store to parameters.
- if (!as<IRParam>(mixedStore->getPtr()))
- {
- SLANG_UNIMPLEMENTED_X("Splitting a store that is not writing to a param.");
- }
-
- auto primalAddr = mixedStore->getPtr();
+ auto primalAddr = lookupPrimalInst(mixedStore->getPtr());
+ auto diffAddr = lookupDiffInst(mixedStore->getPtr());
auto primalVal = lookupPrimalInst(mixedStore->getVal());
auto diffVal = lookupDiffInst(mixedStore->getVal());
- // For now the param type and value type will not type-check in these store insts,
- // but the param inst will be changed to the correct type after we synthesize primal and
- // propagate func.
auto primalStore = primalBuilder->emitStore(primalAddr, primalVal);
- auto diffStore = diffBuilder->emitStore(primalAddr, diffVal);
+ auto diffStore = diffBuilder->emitStore(diffAddr, diffVal);
diffBuilder->markInstAsDifferential(diffStore, primalVal->getFullType());
return InstPair(primalStore, diffStore);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 7a2e8c75e..fdaff4960 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -496,6 +496,8 @@ void stripTempDecorations(IRInst* inst)
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
case kIROp_AutoDiffOriginalValueDecoration:
+ case kIROp_BackwardDerivativePrimalReturnDecoration:
+ case kIROp_PrimalValueStructKeyDecoration:
decor->removeAndDeallocate();
break;
default:
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 30f053673..0662cf846 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -286,6 +286,7 @@ struct DifferentialPairTypeBuilder
};
void stripAutoDiffDecorations(IRModule* module);
+void stripTempDecorations(IRInst* inst);
bool isNoDiffType(IRType* paramType);
@@ -309,4 +310,20 @@ bool isBackwardDifferentiableFunc(IRInst* func);
bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst);
+inline bool isRelevantDifferentialPair(IRType* type)
+{
+ if (as<IRDifferentialPairType>(type))
+ {
+ return true;
+ }
+ else if (auto argPtrType = as<IRPtrTypeBase>(type))
+ {
+ if (as<IRDifferentialPairType>(argPtrType->getValueType()))
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
};
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index 5b5ace64b..dbeb1e934 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -279,6 +279,7 @@ IRInst* cloneInst(
}
void cloneDecoration(
+ IRCloneEnv* cloneEnv,
IRDecoration* oldDecoration,
IRInst* newParent,
IRModule* module)
@@ -292,6 +293,7 @@ void cloneDecoration(
builder.setInsertInto(newParent);
IRCloneEnv env;
+ env.parent = cloneEnv;
cloneInst(&env, &builder, oldDecoration);
}
@@ -300,6 +302,7 @@ void cloneDecoration(
IRInst* newParent)
{
cloneDecoration(
+ nullptr,
oldDecoration,
newParent,
newParent->getModule());
diff --git a/source/slang/slang-ir-clone.h b/source/slang/slang-ir-clone.h
index 824806d57..f4f53ff92 100644
--- a/source/slang/slang-ir-clone.h
+++ b/source/slang/slang-ir-clone.h
@@ -133,6 +133,7 @@ IRInst* cloneInst(
/// Uses `module` to allocate any new instructions.
///
void cloneDecoration(
+ IRCloneEnv* parentEnv,
IRDecoration* oldDecoration,
IRInst* newParent,
IRModule* module);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index e1143b7b9..1cb839751 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -330,6 +330,18 @@ INST(Store, store, 2, 0)
// currently accumulated derivative to pass to some dOut argument in a nested call.
INST(LoadReverseGradient, LoadReverseGradient, 1, 0)
+// Produced and removed during backward auto-diff pass as a temporary placeholder containing the
+// primal and accumulated derivative values to pass to an inout argument in a nested call.
+INST(ReverseGradientDiffPairRef, ReverseGradientDiffPairRef, 2, 0)
+
+// Produced and removed during backward auto-diff pass. This inst is generated by the splitting step
+// to represent a reference to an inout parameter for use in the primal part of the computation.
+INST(PrimalParamRef, PrimalParamRef, 1, 0)
+
+// Produced and removed during backward auto-diff pass. This inst is generated by the splitting step
+// to represent a reference to an inout parameter for use in the back-prop part of the computation.
+INST(DiffParamRef, DiffParamRef, 1, 0)
+
INST(FieldExtract, get_field, 2, 0)
INST(FieldAddress, get_field_addr, 2, 0)
@@ -771,12 +783,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// forward-differentiated updateElement inst.
INST(PrimalElementTypeDecoration, primalElementType, 1, 0)
- /// Used by the auto-diff pass. An `out T` parameter will transcribe to a `in T.Differential` parameter.
- /// We will also create a temp var of type `T.Differential` in the function body so the `load` and `stores`
- /// can operand on a valid address. We use this decoration to associate this temp var with its corresponding
- /// input parameter.
- INST(OutParamReverseGradientDecoration, outParamRevGrad, 1, 0)
-
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 132a96f16..d1374477f 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -723,18 +723,6 @@ struct IRMixedDifferentialInstDecoration : IRDecoration
IRType* getPairType() { return as<IRType>(getOperand(0)); }
};
-struct IROutParamReverseGradientDecoration : IRDecoration
-{
- enum
- {
- kOp = kIROp_OutParamReverseGradientDecoration
- };
-
- IR_LEAF_ISA(OutParamReverseGradientDecoration)
-
- IRInst* getValue() { return getOperand(0); }
-};
-
struct IRBackwardDifferentiableDecoration : IRDecoration
{
enum
@@ -1782,12 +1770,31 @@ struct IRGetElementPtr : IRInst
IRInst* getIndex() { return getOperand(1); }
};
-struct IRLoadReverseGradient :IRInst
+struct IRLoadReverseGradient : IRInst
{
IR_LEAF_ISA(LoadReverseGradient)
IRInst* getValue() { return getOperand(0); }
};
+struct IRReverseGradientDiffPairRef : IRInst
+{
+ IR_LEAF_ISA(ReverseGradientDiffPairRef)
+ IRInst* getPrimal() { return getOperand(0); }
+ IRInst* getDiff() { return getOperand(1); }
+};
+
+struct IRPrimalParamRef : IRInst
+{
+ IR_LEAF_ISA(PrimalParamRef)
+ IRInst* getReferencedParam() { return getOperand(0); }
+};
+
+struct IRDiffParamRef : IRInst
+{
+ IR_LEAF_ISA(DiffParamRef)
+ IRInst* getReferencedParam() { return getOperand(0); }
+};
+
struct IRGetNativePtr : IRInst
{
IR_LEAF_ISA(GetNativePtr);
@@ -3145,6 +3152,9 @@ public:
IRInst* ptr);
IRInst* emitLoadReverseGradient(IRType* type, IRInst* diffValue);
+ IRInst* emitReverseGradientDiffPairRef(IRType* type, IRInst* primalVar, IRInst* diffVar);
+ IRInst* emitPrimalParamRef(IRInst* param);
+ IRInst* emitDiffParamRef(IRType* type, IRInst* param);
IRInst* emitStore(
IRInst* dstPtr,
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 747d0ccdd..cf7acd46c 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -392,13 +392,14 @@ struct SpecializationContext
{
for (auto decor : genericReturnVal->getDecorations())
{
+ bool specialized = false;
if (decor->getOp() == kIROp_ForwardDerivativeDecoration ||
decor->getOp() == kIROp_UserDefinedBackwardDerivativeDecoration)
{
// If we already have a diff func on this specialize, skip.
if (auto specDiffRef = specInst->findDecorationImpl(decor->getOp()))
{
- return false;
+ continue;
}
auto specDiffFunc = as<IRSpecialize>(decor->getOperand(0));
@@ -443,8 +444,10 @@ struct SpecializationContext
builder.addDecoration(specInst, decor->getOp(), newDiffFunc);
- return true;
+ specialized = true;
}
+ if (specialized)
+ return true;
}
}
return false;
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index ee55a6546..3f250e31e 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -386,7 +386,7 @@ static void cloneRelevantDecorations(
//
if( !val->findDecorationImpl(decoration->getOp()) )
{
- cloneDecoration(decoration, val, var->getModule());
+ cloneDecoration(nullptr, decoration, val, var->getModule());
}
break;
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 4814726cf..558574bf6 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4234,6 +4234,50 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitReverseGradientDiffPairRef(IRType* type, IRInst* primalVar, IRInst* diffVar)
+ {
+ auto inst = createInst<IRReverseGradientDiffPairRef>(
+ this,
+ kIROp_ReverseGradientDiffPairRef,
+ type,
+ primalVar,
+ diffVar);
+
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitPrimalParamRef(IRInst* param)
+ {
+ auto type = param->getFullType();
+ auto ptrType = as<IRPtrTypeBase>(type);
+ auto valueType = type;
+ if (ptrType) valueType = ptrType->getValueType();
+ auto pairType = as<IRDifferentialPairType>(valueType);
+ IRType* finalType = pairType->getValueType();
+ if (ptrType) finalType = getPtrType(ptrType->getOp(), finalType);
+ auto inst = createInst<IRPrimalParamRef>(
+ this,
+ kIROp_PrimalParamRef,
+ finalType,
+ param);
+
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitDiffParamRef(IRType* type, IRInst* param)
+ {
+ auto inst = createInst<IRDiffParamRef>(
+ this,
+ kIROp_DiffParamRef,
+ type,
+ param);
+
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitLoad(
IRType* type,
IRInst* ptr)
@@ -6753,6 +6797,8 @@ namespace Slang
// common subexpression elimination, etc.
//
auto call = cast<IRCall>(this);
+ if (call->findDecoration<IRNoSideEffectDecoration>())
+ return false;
return !isPureFunctionalCall(call);
}
break;
@@ -6809,10 +6855,14 @@ namespace Slang
case kIROp_MakeOptionalNone:
case kIROp_OptionalHasValue:
case kIROp_GetOptionalValue:
+ case kIROp_DifferentialPairGetPrimal:
+ case kIROp_DifferentialPairGetDifferential:
+ case kIROp_MakeDifferentialPair:
case kIROp_MakeTuple:
case kIROp_GetTupleElement:
case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads
case kIROp_LoadReverseGradient:
+ case kIROp_ReverseGradientDiffPairRef:
case kIROp_ImageSubscript:
case kIROp_FieldExtract:
case kIROp_FieldAddress: