summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp31
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp215
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h6
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp557
-rw-r--r--source/slang/slang-ir-autodiff-rev.h18
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp65
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h75
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp82
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h176
-rw-r--r--source/slang/slang-ir-autodiff.cpp10
-rw-r--r--source/slang/slang-ir-autodiff.h2
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp30
-rw-r--r--source/slang/slang-ir-init-local-var.cpp34
-rw-r--r--source/slang/slang-ir-init-local-var.h14
-rw-r--r--source/slang/slang-ir-inst-defs.h10
-rw-r--r--source/slang/slang-ir-insts.h23
-rw-r--r--source/slang/slang-ir-util.cpp4
-rw-r--r--source/slang/slang-ir.cpp35
19 files changed, 835 insertions, 554 deletions
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index a5e0e0a4e..a451e24a5 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -54,11 +54,18 @@ struct AddressInstEliminationContext
}
endLoop:;
auto lastAddr = accessChain.getLast();
- auto lastVal = builder.emitLoad(lastAddr);
accessChain.removeLast();
accessChain.reverse();
- auto update = builder.emitUpdateElement(lastVal, accessChain, val);
- builder.emitStore(lastAddr, update);
+ if (accessChain.getCount())
+ {
+ auto lastVal = builder.emitLoad(lastAddr);
+ auto update = builder.emitUpdateElement(lastVal, accessChain, val);
+ builder.emitStore(lastAddr, update);
+ }
+ else
+ {
+ builder.emitStore(lastAddr, val);
+ }
}
void transformLoadAddr(IRUse* use)
@@ -92,7 +99,22 @@ struct AddressInstEliminationContext
IRBuilder builder(sharedBuilder);
builder.setInsertBefore(call);
auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());
- builder.emitStore(tempVar, getValue(builder, addr));
+ auto callee = getResolvedInstForDecorations(call->getCallee());
+ auto funcType = as<IRFuncType>(callee->getFullType());
+ SLANG_RELEASE_ASSERT(funcType);
+ UInt paramIndex = (UInt)(use - call->getOperands() - 1);
+ SLANG_RELEASE_ASSERT(call->getArg(paramIndex) == addr);
+ if (!as<IROutType>(funcType->getParamType(paramIndex)))
+ {
+ builder.emitStore(tempVar, getValue(builder, addr));
+ }
+ else
+ {
+ builder.emitStore(
+ tempVar,
+ builder.emitDefaultConstruct(
+ as<IRPtrTypeBase>(tempVar->getDataType())->getValueType()));
+ }
builder.setInsertAfter(call);
storeValue(builder, addr, builder.emitLoad(tempVar));
use->set(tempVar);
@@ -170,4 +192,5 @@ SlangResult eliminateAddressInsts(
AddressInstEliminationContext ctx;
return ctx.eliminateAddressInstsImpl(sharedBuilder, policy, func, sink);
}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index f60412efb..a9e716ce4 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -7,6 +7,7 @@
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-single-return.h"
namespace Slang
{
@@ -232,6 +233,8 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or
builder->markInstAsMixedDifferential(diffStoreVal, diffPairType);
auto store = builder->emitStore(primalStoreLocation, valToStore);
+ builder->markInstAsMixedDifferential(store, diffPairType);
+
return InstPair(store, nullptr);
}
}
@@ -385,12 +388,18 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
SLANG_ASSERT(calleeType);
SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount());
+ auto placeholderCall = builder->emitCallInst(nullptr, builder->emitUndefined(builder->getTypeKind()), 0, nullptr);
+ builder->setInsertBefore(placeholderCall);
+ IRBuilder argBuilder = *builder;
+ IRBuilder afterBuilder = argBuilder;
+ afterBuilder.setInsertAfter(placeholderCall);
+
List<IRInst*> args;
// Go over the parameter list and create pairs for each input (if required)
for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
{
auto origArg = origCall->getArg(ii);
- auto primalArg = findOrTranscribePrimalInst(builder, origArg);
+ auto primalArg = findOrTranscribePrimalInst(&argBuilder, origArg);
SLANG_ASSERT(primalArg);
auto primalType = primalArg->getDataType();
@@ -402,20 +411,71 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
while (auto attrType = as<IRAttributedType>(primalType))
primalType = attrType->getBaseType();
}
- if (auto pairType = tryGetDiffPairType(builder, primalType))
+ if (auto pairType = tryGetDiffPairType(&argBuilder, primalType))
{
- auto diffArg = findOrTranscribeDiffInst(builder, origArg);
- if (!diffArg)
- diffArg = getDifferentialZeroOfType(builder, primalType);
+ auto pairPtrType = as<IRPtrTypeBase>(pairType);
+ auto pairValType = as<IRDifferentialPairType>(
+ pairPtrType ? pairPtrType->getValueType() : pairType);
+ auto diffType = differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(&argBuilder, pairValType);
+ if (auto ptrParamType = as<IRPtrTypeBase>(paramType))
+ {
+ // Create temp var to pass in/out arguments.
+ auto srcVar = argBuilder.emitVar(ptrParamType->getValueType());
+ argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType());
+
+ auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg);
+ if (ptrParamType->getOp() == kIROp_InOutType)
+ {
+ // Set initial value.
+ auto primalVal = argBuilder.emitLoad(primalArg);
+ auto diffArgVal = diffArg;
+ if (!diffArg)
+ diffArgVal = getDifferentialZeroOfType(builder, (IRType*)pairValType->getValueType());
+ else
+ {
+ diffArgVal = argBuilder.emitLoad(diffArg);
+ argBuilder.markInstAsDifferential(diffArgVal, pairValType->getValueType());
+ }
+ auto initVal = argBuilder.emitMakeDifferentialPair(pairValType, primalVal, diffArgVal);
+ argBuilder.markInstAsMixedDifferential(initVal, primalType);
+ auto store = argBuilder.emitStore(srcVar, initVal);
+ argBuilder.markInstAsMixedDifferential(store, primalType);
+ }
+ if (as<IROutTypeBase>(ptrParamType))
+ {
+ // Read back new value.
+ auto newVal = afterBuilder.emitLoad(srcVar);
+ afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType());
+ auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(newVal);
+ afterBuilder.emitStore(primalArg, newPrimalVal);
+
+ if (diffArg)
+ {
+ auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal);
+ afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType());
+ auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal);
+ afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType());
+ }
+ }
+ args.add(srcVar);
+ continue;
+ }
+ else
+ {
+ auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg);
+ if (!diffArg)
+ diffArg = getDifferentialZeroOfType(&argBuilder, primalType);
- // If a pair type can be formed, this must be non-null.
- SLANG_RELEASE_ASSERT(diffArg);
-
- auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
- builder->markInstAsMixedDifferential(diffPair, pairType);
+ // If a pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffArg);
- args.add(diffPair);
- continue;
+ auto diffPair = argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffArg);
+ argBuilder.markInstAsMixedDifferential(diffPair, pairType);
+
+ args.add(diffPair);
+ continue;
+ }
+
}
}
// Argument is not differentiable.
@@ -424,26 +484,29 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
IRType* diffReturnType = nullptr;
- diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
+ diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType());
if (!diffReturnType)
{
SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType);
- diffReturnType = builder->getVoidType();
+ diffReturnType = argBuilder.getVoidType();
}
- auto callInst = builder->emitCallInst(
+ auto callInst = argBuilder.emitCallInst(
diffReturnType,
diffCallee,
args);
- builder->markInstAsMixedDifferential(callInst, diffReturnType);
- builder->addAutoDiffOriginalValueDecoration(callInst, primalCallee);
+ placeholderCall->removeAndDeallocate();
+ argBuilder.markInstAsMixedDifferential(callInst, diffReturnType);
+ argBuilder.addAutoDiffOriginalValueDecoration(callInst, primalCallee);
+
+ *builder = afterBuilder;
if (diffReturnType->getOp() != kIROp_VoidType)
{
- IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst);
- auto diffType = differentiateType(builder, origCall->getFullType());
- IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst);
+ IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst);
+ auto diffType = differentiateType(&afterBuilder, origCall->getFullType());
+ IRInst* diffResultValue = afterBuilder.emitDifferentialPairGetDifferential(diffType, callInst);
return InstPair(primalResultValue, diffResultValue);
}
else
@@ -1150,6 +1213,8 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
builder.setInsertInto(diffFunc);
differentiableTypeConformanceContext.setFunc(primalFunc);
+
+ mapInOutParamToWriteBackValue.Clear();
// Transcribe children from origFunc into diffFunc
for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
@@ -1160,6 +1225,43 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
as<IRBlock>(lookupDiffInst(block))->insertAtEnd(diffFunc);
+ for (auto block : diffFunc->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ // Insert write backs to mutable parameters before returning.
+ builder.setInsertBefore(inst);
+ for (auto& writeBack : mapInOutParamToWriteBackValue)
+ {
+ auto param = writeBack.Key;
+ auto primalVal = builder.emitLoad(writeBack.Value.primal);
+ IRInst* valToStore = nullptr;
+ if (writeBack.Value.differential)
+ {
+ auto diffVal = builder.emitLoad(writeBack.Value.differential);
+ builder.markInstAsDifferential(diffVal, primalVal->getFullType());
+ valToStore = builder.emitMakeDifferentialPair(cast<IRPtrTypeBase>(param->getFullType())->getValueType(),
+ primalVal, diffVal);
+ builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType());
+ }
+ else
+ {
+ valToStore = builder.emitLoad(writeBack.Value.primal);
+ }
+
+ auto storeInst = builder.emitStore(param, valToStore);
+
+ if (writeBack.Value.differential)
+ {
+ builder.markInstAsMixedDifferential(storeInst, valToStore->getFullType());
+ }
+ }
+ }
+ }
+ }
+
return InstPair(primalFunc, diffFunc);
}
@@ -1297,4 +1399,77 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
return InstPair(nullptr, nullptr);
}
+String ForwardDiffTranscriber::makeDiffPairName(IRInst* origVar)
+{
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
+ {
+ return ("dp" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+}
+
+InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
+{
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffPairParam);
+
+ if (auto pairType = as<IRDifferentialPairType>(diffPairType))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
+ {
+ auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
+ // Make a local copy of the parameter for primal and diff parts.
+ auto primal = builder->emitVar(ptrInnerPairType->getValueType());
+ auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType());
+ auto diff = builder->emitVar(diffType);
+
+ IRInst* primalInitVal = nullptr;
+ IRInst* diffInitVal = nullptr;
+ if (as<IROutType>(diffPairType))
+ {
+ primalInitVal = builder->emitDefaultConstruct(ptrInnerPairType->getValueType());
+ diffInitVal = builder->emitDefaultConstructRaw(diffType);
+ }
+ else
+ {
+ auto initVal = builder->emitLoad(diffPairParam);
+ primalInitVal = builder->emitDifferentialPairGetPrimal(initVal);
+ diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal);
+ }
+ builder->markInstAsDifferential(diffInitVal, ptrInnerPairType->getValueType());
+
+ builder->emitStore(primal, primalInitVal);
+
+ auto diffStore = builder->emitStore(diff, diffInitVal);
+ builder->markInstAsDifferential(diffStore, ptrInnerPairType->getValueType());
+
+ mapInOutParamToWriteBackValue[diffPairParam] = InstPair(primal, diff);
+ return InstPair(primal, diff);
+ }
+ }
+
+ auto primalInst = cloneInst(&cloneEnv, builder, origParam);
+ if (auto primalParam = as<IRParam>(primalInst))
+ {
+ SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
+ primalParam->removeFromParent();
+ builder->getInsertLoc().getBlock()->addParam(primalParam);
+ }
+ return InstPair(primalInst, nullptr);
+}
+
}
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index e595191a3..260b0a433 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -8,11 +8,15 @@ namespace Slang
struct ForwardDiffTranscriber : AutoDiffTranscriberBase
{
+ // Pending values to write back to inout params at the end of the current function.
+ OrderedDictionary<IRInst*, InstPair> mapInOutParamToWriteBackValue;
+
ForwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
: AutoDiffTranscriberBase(shared, inSharedBuilder, inSink)
{
}
+
// Returns "d<var-name>" to use as a name hint for variables and parameters.
// If no primal name is available, returns a blank string.
//
@@ -95,6 +99,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override;
+ virtual InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) override;
+
virtual IROp getInterfaceRequirementDerivativeDecorationOp() override
{
return kIROp_ForwardDerivativeDecoration;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 0f2ceceb4..9c63a4012 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -11,6 +11,7 @@
#include "slang-ir-single-return.h"
#include "slang-ir-addr-inst-elimination.h"
#include "slang-ir-eliminate-multilevel-break.h"
+#include "slang-ir-init-local-var.h"
namespace Slang
{
@@ -21,32 +22,10 @@ namespace Slang
for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
- bool noDiff = false;
auto origType = funcType->getParamType(i);
- auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType);
-
- if (auto attrType = as<IRAttributedType>(primalType))
- {
- if (attrType->findAttr<IRNoDiffAttr>())
- {
- noDiff = true;
- primalType = attrType->getBaseType();
- }
- }
- if (noDiff)
- {
- newParameterTypes.add(primalType);
- }
- else
- {
- if (auto diffPairType = tryGetDiffPairType(builder, origType))
- {
- auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
- newParameterTypes.add(inoutDiffPairType);
- }
- else
- newParameterTypes.add(primalType);
- }
+ auto paramType = transcribeParamTypeForPropagateFunc(builder, origType);
+ if (paramType)
+ newParameterTypes.add(paramType);
}
if (auto diffResultType = differentiateType(builder, funcType->getResultType()))
@@ -75,7 +54,7 @@ namespace Slang
for (UInt i = 0; i < funcType->getParamCount(); i++)
{
auto origType = funcType->getParamType(i);
- auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType);
+ auto primalType = transcribeParamTypeForPrimalFunc(builder, origType);
paramTypes.add(primalType);
}
paramTypes.add(outType);
@@ -252,52 +231,57 @@ namespace Slang
return String("");
}
- InstPair BackwardDiffTranscriberBase::transposeBlock(IRBuilder* builder, IRBlock* origBlock)
+ static IRType* _getPrimalTypeFromNoDiffType(BackwardDiffTranscriberBase* transcriber, IRBuilder* builder, IRType* origType)
{
- IRBuilder subBuilder(builder->getSharedBuilder());
- subBuilder.setInsertLoc(builder->getInsertLoc());
+ IRType* valueType = origType;
+ auto ptrType = as<IROutTypeBase>(valueType);
+ if (ptrType)
+ valueType = ptrType->getValueType();
- IRBlock* diffBlock = subBuilder.emitBlock();
-
- subBuilder.setInsertInto(diffBlock);
-
- // First transcribe every parameter in the block.
- for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
- this->copyParam(&subBuilder, param);
-
- // The extra param for input gradient
- auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType());
-
- // Then, run through every instruction and use the transcriber to generate the appropriate
- // derivative code.
- //
- for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
- this->copyInst(&subBuilder, child);
-
- auto lastInst = diffBlock->getLastOrdinaryInst();
- List<IRInst*> grads = { gradParam };
- upperGradients.Add(lastInst, grads);
- for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst())
+ if (auto attrType = as<IRAttributedType>(valueType))
{
- auto upperGrads = upperGradients.TryGetValue(child);
- if (!upperGrads)
- continue;
- if (upperGrads->getCount() > 1)
+ if (attrType->findAttr<IRNoDiffAttr>())
{
- auto sumGrad = upperGrads->getFirst();
- for (auto i = 1; i < upperGrads->getCount(); i++)
- {
- sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]);
- }
- this->transposeInstBackward(&subBuilder, child, sumGrad);
+ auto primalValueType = (IRType*)transcriber->findOrTranscribePrimalInst(builder, valueType);
+ if (ptrType)
+ return builder->getPtrType(ptrType->getOp(), primalValueType);
+ return primalValueType;
}
- else
- this->transposeInstBackward(&subBuilder, child, upperGrads->getFirst());
}
+ return nullptr;
+ }
+
+ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPrimalFunc(IRBuilder* builder, IRType* paramType)
+ {
+ // If the param is marked as no_diff, return the primal type.
+ if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
+ return primalNoDiffType;
- subBuilder.emitReturn();
+ return (IRType*)findOrTranscribePrimalInst(builder, paramType);
+ }
- return InstPair(diffBlock, diffBlock);
+ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType)
+ {
+ if (auto outType = as<IROutType>(paramType))
+ {
+ auto valueType = outType->getValueType();
+ auto diffValueType = differentiateType(builder, valueType);
+ return diffValueType;
+ }
+
+ // If the param is marked as no_diff, return the primal type.
+ if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
+ return primalNoDiffType;
+
+ auto diffPairType = tryGetDiffPairType(builder, paramType);
+ if (diffPairType)
+ {
+ if (!as<IRPtrTypeBase>(diffPairType))
+ return builder->getInOutType(diffPairType);
+ return diffPairType;
+ }
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType);
+ return primalType;
}
// Create an empty func to represent the transcribed func of `origFunc`.
@@ -387,39 +371,65 @@ namespace Slang
IRBuilder builder(inBuilder->getSharedBuilder());
builder.setInsertInto(header.differential);
builder.emitBlock();
- auto funcType = as<IRFuncType>(header.differential->getDataType());
+ auto origFuncType = as<IRFuncType>(origFunc->getFullType());
List<IRInst*> primalArgs, propagateArgs;
List<IRType*> primalTypes, propagateTypes;
- for (UInt i = 0; i < funcType->getParamCount(); i++)
+ for (UInt i = 0; i < origFuncType->getParamCount(); i++)
{
- auto paramType = (IRType*)findOrTranscribePrimalInst(&builder, funcType->getParamType(i));
- auto param = builder.emitParam(paramType);
- if (i != funcType->getParamCount() - 1)
+ auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i));
+ auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i));
+ if (propagateParamType)
{
- primalArgs.add(param);
- }
- propagateArgs.add(param);
- propagateTypes.add(paramType);
- }
+ auto param = builder.emitParam(propagateParamType);
+ propagateTypes.add(propagateParamType);
+ propagateArgs.add(param);
- // Fetch primal values to use as arguments in primal func call.
- for (auto& arg : primalArgs)
- {
- IRInst* valueType = arg->getDataType();
- auto inoutType = as<IRPtrTypeBase>(arg->getDataType());
- if (inoutType)
+ // Fetch primal values to use as arguments in primal func call.
+ IRInst* primalArg = param;
+ if (!as<IROutType>(primalParamType))
+ {
+ // As long as the primal parameter is not an out type,
+ // we need to fetch the primal value from the parameter.
+ if (as<IRPtrTypeBase>(propagateParamType))
+ {
+ primalArg = builder.emitLoad(param);
+ }
+ if (auto diffPairType = as<IRDifferentialPairType>(primalArg->getDataType()))
+ {
+ primalArg = builder.emitDifferentialPairGetPrimal(primalArg);
+ }
+ }
+ if (auto primalParamPtrType = as<IRPtrTypeBase>(primalParamType))
+ {
+ // If primal parameter is mutable, we need to pass in a temp var.
+ auto tempVar = builder.emitVar(primalParamPtrType->getValueType());
+ if (primalParamPtrType->getOp() == kIROp_InOutType)
+ {
+ // If the primal parameter is inout, we need to set the initial value.
+ builder.emitStore(tempVar, primalArg);
+ }
+ primalArgs.add(tempVar);
+ }
+ else
+ {
+ primalArgs.add(primalArg);
+ }
+ }
+ else
{
- valueType = inoutType->getValueType();
- arg = builder.emitLoad(arg);
+ auto var = builder.emitVar(primalParamType);
+ primalArgs.add(var);
}
- auto diffPairType = as<IRDifferentialPairType>(valueType);
- if (!diffPairType) continue;
- arg = builder.emitDifferentialPairGetPrimal(arg);
+ primalTypes.add(primalParamType);
}
- for (auto& arg : primalArgs)
+ // Add dOut argument to propagateArgs.
+ auto diffResultType = differentiateType(&builder, origFunc->getResultType());
+ if (diffResultType)
{
- primalTypes.add(arg->getFullType());
+ auto param = builder.emitParam(diffResultType);
+ propagateArgs.add(param);
+ propagateTypes.add(param->getFullType());
}
auto outerGeneric = findOuterGeneric(origFunc);
@@ -433,7 +443,6 @@ namespace Slang
auto intermediateVar = builder.emitVar(intermediateType);
- auto origFuncType = as<IRFuncType>(origFunc->getDataType());
auto primalFuncType = builder.getFuncType(
primalTypes,
origFuncType->getResultType());
@@ -486,6 +495,51 @@ namespace Slang
builder.emitBranch(firstBlock);
}
+ void insertTempVarForMutableParams(SharedIRBuilder* sharedBuilder, IRFunc* func)
+ {
+ IRBuilder builder(sharedBuilder);
+ auto firstBlock = func->getFirstBlock();
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+
+ OrderedDictionary<IRParam*, IRVar*> mapParamToTempVar;
+ List<IRParam*> params;
+ for (auto param : firstBlock->getParams())
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
+ {
+ params.add(param);
+ }
+ }
+
+ for (auto param : params)
+ {
+ auto ptrType = as<IRPtrTypeBase>(param->getDataType());
+ auto tempVar = builder.emitVar(ptrType->getValueType());
+ mapParamToTempVar[param] = tempVar;
+ if (param->getOp() != kIROp_OutType)
+ {
+ builder.emitStore(tempVar, builder.emitLoad(param));
+ }
+ param->replaceUsesWith(tempVar);
+ }
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Return)
+ {
+ builder.setInsertBefore(inst);
+ for (auto& kv : mapParamToTempVar)
+ {
+ builder.emitStore(kv.Key, builder.emitLoad(kv.Value));
+ }
+ }
+ }
+ }
+ }
+
+
struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
{
DifferentiableTypeConformanceContext* diffTypeContext;
@@ -512,6 +566,8 @@ namespace Slang
IRCFGNormalizationPass cfgPass = {this->getSink()};
normalizeCFG(autoDiffSharedContext->sharedBuilder, func);
+ insertTempVarForMutableParams(sharedBuilder, func);
+
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &diffTypeContext;
auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
@@ -592,6 +648,23 @@ namespace Slang
return fwdDiffFunc;
}
+ InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
+ {
+ SLANG_UNUSED(primalType);
+
+ SLANG_RELEASE_ASSERT(origParam->getParent() && origParam->getParent()->getParent()
+ && origParam->getParent()->getParent()->getOp() == kIROp_Generic);
+
+ auto primalInst = maybeCloneForPrimalInst(builder, origParam);
+ if (auto primalParam = as<IRParam>(primalInst))
+ {
+ SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
+ primalParam->removeFromParent();
+ builder->getInsertLoc().getBlock()->addParam(primalParam);
+ }
+ return InstPair(primalInst, nullptr);
+ }
+
// Transcribe a function definition.
void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc)
{
@@ -615,6 +688,8 @@ namespace Slang
if (!fwdDiffFunc)
return;
+ bool isResultDifferentiable = as<IRDifferentialPairType>(fwdDiffFunc->getResultType());
+
// Split first block into a paramter block.
this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));
@@ -642,12 +717,11 @@ namespace Slang
}
// Transpose the first block (parameter block)
- transposeParameterBlock(builder, diffPropagateFunc);
+ List<IRInst*> primalFuncSpecificParams;
+ auto dOutParameter = transposeParameterBlock(builder, diffPropagateFunc, primalFuncSpecificParams, isResultDifferentiable);
builder->setInsertInto(diffPropagateFunc);
- auto dOutParameter = diffPropagateFunc->getLastParam()->getPrevParam();
-
// Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr};
diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
@@ -658,11 +732,32 @@ namespace Slang
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
- diffPropagateFunc, primalFunc, intermediateType);
+ diffPropagateFunc, primalFunc, isResultDifferentiable, intermediateType);
// Clean up by deallocating the tempoarary forward derivative func.
fwdDiffFunc->removeAndDeallocate();
+ // Remove primalFuncSpecificParams.
+ for (auto specificParam : primalFuncSpecificParams)
+ {
+ while (auto use = specificParam->firstUse)
+ {
+ if (use->getUser()->getOp() == kIROp_Store && use == use->getUser()->getOperands())
+ {
+ use->getUser()->removeAndDeallocate();
+ }
+ else if (auto decor = as<IRDecoration>(use->getUser()))
+ {
+ decor->removeAndDeallocate();
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unexpected use of transcribed param.");
+ }
+ }
+ specificParam->removeAndDeallocate();
+ }
+
// If primal function is nested in a generic, we want to create separate generics for all the associated things
// we have just created.
auto primalOuterGeneric = findOuterGeneric(primalFunc);
@@ -689,9 +784,16 @@ namespace Slang
auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric);
builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc);
}
+
+ initializeLocalVariables(builder->getSharedBuilder(), primalFunc);
+ initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc);
}
- void BackwardDiffTranscriberBase::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc)
+ IRInst* BackwardDiffTranscriberBase::transposeParameterBlock(
+ IRBuilder* builder,
+ IRFunc* diffFunc,
+ List<IRInst*>& primalFuncSpecificParams,
+ bool isResultDifferentiable)
{
IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock();
@@ -699,7 +801,7 @@ namespace Slang
auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator());
auto nextBlock = fwdParamBlockBranch->getTargetBlock();
- builder->setInsertInto(fwdDiffParameterBlock);
+ builder->setInsertBefore(fwdParamBlockBranch);
List<IRParam*> fwdParams;
for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam())
@@ -710,8 +812,37 @@ namespace Slang
// 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
for (auto fwdParam : fwdParams)
{
- // TODO: Handle ptr<pair> types.
- if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
+ if (auto outType = as<IROutType>(fwdParam->getDataType()))
+ {
+ IRParam* newPropParam = nullptr;
+ IRParam* newPrimalParam = nullptr;
+ auto diffPairType = as<IRDifferentialPairType>(outType->getValueType());
+ if (diffPairType)
+ {
+ // Create dOut param.
+ auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(builder, diffPairType);
+ newPropParam = builder->emitParam(diffType);
+ newPrimalParam = builder->emitParam(builder->getOutType(diffPairType->getValueType()));
+ }
+ else
+ {
+ newPrimalParam = builder->emitParam(outType);
+ }
+
+ // Create a temp var to represent the original `out` param.
+ auto arg = builder->emitVar(outType->getValueType());
+ builder->addAutoDiffOriginalValueDecoration(arg, newPrimalParam);
+ if (newPropParam)
+ {
+ builder->addDecoration(arg, kIROp_OutParamReverseGradientDecoration, newPropParam);
+ }
+
+ fwdParam->replaceUsesWith(arg);
+ fwdParam->removeAndDeallocate();
+
+ primalFuncSpecificParams.add(newPrimalParam);
+ }
+ else if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
{
// Create inout version.
auto inoutDiffPairType = builder->getInOutType(diffPairType);
@@ -725,7 +856,7 @@ namespace Slang
}
else
{
- // Default case (parameter has nothing to do with differentiation)
+ // Default case (parameter is inout type or has nothing to do with differentiation)
// Simply move the parameter to the end.
//
fwdParam->removeFromParent();
@@ -735,236 +866,24 @@ namespace Slang
auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
- // 2. Add a parameter for 'derivative of the output' (d_out).
+ // 2. If the return type of the original function is differentiable,
+ // add a parameter for 'derivative of the output' (d_out).
// The type is the second last parameter type of the function.
//
- auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2);
-
- SLANG_ASSERT(dOutParamType);
-
- builder->emitParam(dOutParamType);
-
- // Add a parameter for intermediate val.
- builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
- }
-
- IRInst* BackwardDiffTranscriberBase::copyParam(IRBuilder* builder, IRParam* origParam)
- {
- auto primalDataType = origParam->getDataType();
-
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
- IRInst* diffParam = builder->emitParam(inoutDiffPairType);
-
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice());
-
- SLANG_ASSERT(diffParam);
- auto paramValue = builder->emitLoad(diffParam);
- auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
- orginalToTranscribed.Add(origParam, primal);
- primalToDiffPair.Add(primal, diffParam);
-
- return diffParam;
- }
-
- return maybeCloneForPrimalInst(builder, origParam);
- }
-
- InstPair BackwardDiffTranscriberBase::copyBinaryArith(IRBuilder* builder, IRInst* origArith)
- {
- SLANG_ASSERT(origArith->getOperandCount() == 2);
-
- auto origLeft = origArith->getOperand(0);
- auto origRight = origArith->getOperand(1);
-
- IRInst* primalLeft;
- if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft))
- {
- primalLeft = origLeft;
- }
- IRInst* primalRight;
- if (!orginalToTranscribed.TryGetValue(origRight, primalRight))
- {
- primalRight = origRight;
- }
-
- auto resultType = origArith->getDataType();
- IRInst* newInst;
- switch (origArith->getOp())
- {
- case kIROp_Add:
- newInst = builder->emitAdd(resultType, primalLeft, primalRight);
- break;
- case kIROp_Mul:
- newInst = builder->emitMul(resultType, primalLeft, primalRight);
- break;
- case kIROp_Sub:
- newInst = builder->emitSub(resultType, primalLeft, primalRight);
- break;
- case kIROp_Div:
- newInst = builder->emitDiv(resultType, primalLeft, primalRight);
- break;
- default:
- newInst = nullptr;
- getSink()->diagnose(origArith->sourceLoc,
- Diagnostics::unimplemented,
- "this arithmetic instruction cannot be differentiated");
- }
- orginalToTranscribed.Add(origArith, newInst);
- return InstPair(newInst, nullptr);
- }
-
- IRInst* BackwardDiffTranscriberBase::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
- {
- SLANG_ASSERT(origArith->getOperandCount() == 2);
-
- auto lhs = origArith->getOperand(0);
- auto rhs = origArith->getOperand(1);
-
- if (as<IRInOutType>(lhs->getDataType()))
- {
- lhs = builder->emitLoad(lhs);
- lhs = builder->emitDifferentialPairGetPrimal(lhs);
- }
- if (as<IRInOutType>(rhs->getDataType()))
- {
- rhs = builder->emitLoad(rhs);
- rhs = builder->emitDifferentialPairGetPrimal(rhs);
- }
-
- IRInst* leftGrad;
- IRInst* rightGrad;
-
-
- switch (origArith->getOp())
- {
- case kIROp_Add:
- leftGrad = grad;
- rightGrad = grad;
- break;
- case kIROp_Mul:
- leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
- rightGrad = builder->emitMul(grad->getDataType(), lhs, grad);
- break;
- case kIROp_Sub:
- leftGrad = grad;
- rightGrad = builder->emitNeg(grad->getDataType(), grad);
- break;
- case kIROp_Div:
- leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
- rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad
- break;
- default:
- getSink()->diagnose(origArith->sourceLoc,
- Diagnostics::unimplemented,
- "this arithmetic instruction cannot be differentiated");
- }
-
- lhs = origArith->getOperand(0);
- rhs = origArith->getOperand(1);
- if (auto leftGrads = upperGradients.TryGetValue(lhs))
- {
- leftGrads->add(leftGrad);
- }
- else
+ IRParam* dOutParam = nullptr;
+ if (isResultDifferentiable)
{
- upperGradients.Add(lhs, leftGrad);
- }
- if (auto rightGrads = upperGradients.TryGetValue(rhs))
- {
- rightGrads->add(rightGrad);
- }
- else
- {
- upperGradients.Add(rhs, rightGrad);
- }
+ auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2);
- return nullptr;
- }
+ SLANG_ASSERT(dOutParamType);
- InstPair BackwardDiffTranscriberBase::copyInst(IRBuilder* builder, IRInst* origInst)
- {
- // Handle common SSA-style operations
- switch (origInst->getOp())
- {
- case kIROp_Param:
- return transcribeParam(builder, as<IRParam>(origInst));
-
- case kIROp_Return:
- return InstPair(nullptr, nullptr);
-
- case kIROp_Add:
- case kIROp_Mul:
- case kIROp_Sub:
- case kIROp_Div:
- return copyBinaryArith(builder, origInst);
-
- default:
- // Not yet implemented
- SLANG_ASSERT(0);
+ dOutParam = builder->emitParam(dOutParamType);
}
- return InstPair(nullptr, nullptr);
- }
-
- IRInst* BackwardDiffTranscriberBase::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
- {
- IRInOutType* inoutParam = as<IRInOutType>(param->getDataType());
- auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType());
- auto paramValue = builder->emitLoad(param);
- auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
- auto diff = builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- paramValue
- );
- auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad);
- auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff);
- auto store = builder->emitStore(param, updatedParam);
-
- return store;
- }
-
- IRInst* BackwardDiffTranscriberBase::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
- {
- // Handle common SSA-style operations
- switch (origInst->getOp())
- {
- case kIROp_Param:
- return transposeParamBackward(builder, as<IRParam>(origInst), grad);
-
- case kIROp_Add:
- case kIROp_Mul:
- case kIROp_Sub:
- case kIROp_Div:
- return transposeBinaryArithBackward(builder, origInst, grad);
-
- case kIROp_DifferentialPairGetPrimal:
- {
- if (auto param = primalToDiffPair.TryGetValue(origInst))
- {
- if (auto leftGrads = upperGradients.TryGetValue(*param))
- {
- leftGrads->add(grad);
- }
- else
- {
- upperGradients.Add(*param, grad);
- }
- }
- else
- SLANG_ASSERT(0);
- return nullptr;
- }
-
- default:
- // Not yet implemented
- SLANG_ASSERT(0);
- }
+ // Add a parameter for intermediate val.
+ builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
- return nullptr;
+ return dOutParam;
}
InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index f789089b0..617e6b79b 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -61,7 +61,8 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
IRFuncType* differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermediateType);
- InstPair transposeBlock(IRBuilder* builder, IRBlock* origBlock);
+ IRType* transcribeParamTypeForPrimalFunc(IRBuilder* builder, IRType* paramType);
+ IRType* transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType);
// Puts parameters into their own block.
void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func);
@@ -69,19 +70,10 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
// Transcribe a function definition.
virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0;
- void transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc);
+ // Transcribes the parameter block and returns the dOut param if exists.
+ IRInst* transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc, List<IRInst*>& primalFuncSpecificParams, bool isResultDifferentiable);
- IRInst* copyParam(IRBuilder* builder, IRParam* origParam);
-
- InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith);
-
- IRInst* transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad);
-
- InstPair copyInst(IRBuilder* builder, IRInst* origInst);
-
- IRInst* transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad);
-
- IRInst* transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad);
+ InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType);
InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 520c6d276..8f21e8c62 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -17,16 +17,6 @@ DiagnosticSink* AutoDiffTranscriberBase::getSink()
return sink;
}
-String AutoDiffTranscriberBase::makeDiffPairName(IRInst* origVar)
-{
- if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
- {
- return ("dp" + String(namehintDecoration->getName()));
- }
-
- return String("");
-}
-
void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
{
if (hasDifferentialInst(origInst))
@@ -523,46 +513,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o
bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
if (isFuncParam)
{
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- IRInst* diffPairParam = builder->emitParam(diffPairType);
-
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
-
- SLANG_ASSERT(diffPairParam);
-
- if (auto pairType = as<IRDifferentialPairType>(diffPairType))
- {
- return InstPair(
- builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- diffPairParam));
- }
- else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
- {
- auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
-
- return InstPair(
- builder->emitDifferentialPairAddressPrimal(diffPairParam),
- builder->emitDifferentialPairAddressDifferential(
- builder->getPtrType(
- kIROp_PtrType,
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)),
- diffPairParam));
- }
- }
-
- auto primalInst = cloneInst(&cloneEnv, builder, origParam);
- if (auto primalParam = as<IRParam>(primalInst))
- {
- SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
- primalParam->removeFromParent();
- builder->getInsertLoc().getBlock()->addParam(primalParam);
- }
- return InstPair(primalInst, nullptr);
+ return transcribeFuncParam(builder, origParam, primalDataType);
}
else
{
@@ -617,10 +568,14 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
switch (diffType->getOp())
{
case kIROp_DifferentialPairType:
- return builder->emitMakeDifferentialPair(
- diffType,
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ {
+ auto makeDiffPair = builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ builder->markInstAsDifferential(makeDiffPair, as<IRDifferentialPairType>(diffType)->getValueType());
+ return makeDiffPair;
+ }
}
if (auto arrayType = as<IRArrayType>(primalType))
@@ -647,6 +602,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
{
auto wt = lookupInterface->getWitnessTable();
zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
+ builder->markInstAsDifferential(zeroMethod);
}
}
SLANG_RELEASE_ASSERT(zeroMethod);
@@ -759,6 +715,7 @@ InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn*
IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
IRInst* primalReturn = builder->emitReturn(primalReturnVal);
+ builder->markInstAsMixedDifferential(primalReturn, nullptr);
return InstPair(primalReturn, nullptr);
}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index e6a525dee..208bfbc28 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -123,6 +123,8 @@ struct AutoDiffTranscriberBase
InstPair transcribeParam(IRBuilder* builder, IRParam* origParam);
+ virtual InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) = 0;
+
InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst);
InstPair transcribeBlockImpl(IRBuilder* builder, IRBlock* origBlock, HashSet<IRInst*>& instsToSkip);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 5aad6e3a3..2a341ed38 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -501,6 +501,10 @@ struct DiffTransposePass
List<IRBlock*> workList;
// Build initial list of blocks to process by checking if they're differential blocks.
+ List<IRBlock*> traverseWorkList;
+ HashSet<IRBlock*> traverseSet;
+ traverseWorkList.add(revDiffFunc->getFirstBlock());
+ traverseSet.Add(revDiffFunc->getFirstBlock());
for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock())
{
if (!isDifferentialInst(block))
@@ -534,10 +538,13 @@ struct DiffTransposePass
for (auto block : workList)
{
// Set dOutParameter as the transpose gradient for the return inst, if any.
- if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ if (transposeInfo.dOutInst)
{
- this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr));
- retVal = returnInst->getVal();
+ if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ {
+ this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr));
+ retVal = returnInst->getVal();
+ }
}
IRBlock* revBlock = revBlockMap[block];
@@ -575,7 +582,7 @@ struct DiffTransposePass
auto branch = subBuilder.emitBranch(firstRevBlock);
- if (!retVal)
+ if (!retVal || retVal->getOp() == kIROp_VoidLit)
{
retVal = subBuilder.getVoidValue();
}
@@ -849,6 +856,8 @@ struct DiffTransposePass
{
auto returnPairType = as<IRDifferentialPairType>(
tryGetPrimalTypeFromDiffInst(returnInst->getVal()));
+ if (!returnPairType)
+ return;
primalType = returnPairType->getValueType();
}
else if (auto loadInst = as<IRLoad>(inst))
@@ -955,21 +964,33 @@ struct DiffTransposePass
{
auto arg = fwdCall->getArg(ii);
- // If this isn't a ptr-type, make a var.
- if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
+ if (arg->getOp() == kIROp_LoadReverseGradient)
{
- auto pairType = as<IRDifferentialPairType>(arg->getDataType());
+ // Original parameters that are `out DifferentiableType` will turn into
+ // a `in Differential` parameter. The split logic will insert LoadReverseGradient insts
+ // to inform us this case. Here we just need to generate a load of the derivative variable
+ // and use it as the final argument.
+ args.add(builder->emitLoad(arg->getOperand(0)));
+ }
+ else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
+ {
+ // Normal differentiable input parameter will become an inout DiffPair parameter
+ // in the propagate func. The split logic has already prepared the initial value
+ // to pass in. We need to define a temp variable with this initial value and pass
+ // in the temp variable as argument to the inout parameter.
- auto var = builder->emitVar(arg->getDataType());
+ auto makePairArg = as<IRMakeDifferentialPair>(arg);
+ SLANG_RELEASE_ASSERT(makePairArg);
- SLANG_ASSERT(as<IRMakeDifferentialPair>(arg));
+ auto pairType = as<IRDifferentialPairType>(arg->getDataType());
+ auto var = builder->emitVar(arg->getDataType());
// Initialize this var to (arg.primal, 0).
builder->emitStore(
- var,
+ var,
builder->emitMakeDifferentialPair(
arg->getDataType(),
- as<IRMakeDifferentialPair>(arg)->getPrimalValue(),
+ makePairArg->getPrimalValue(),
builder->emitCallInst(
(IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()),
diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()),
@@ -987,9 +1008,12 @@ struct DiffTransposePass
}
}
- args.add(revValue);
- argTypes.add(revValue->getDataType());
- argRequiresLoad.add(false);
+ if (revValue)
+ {
+ args.add(revValue);
+ argTypes.add(revValue->getDataType());
+ argRequiresLoad.add(false);
+ }
args.add(primalContextDecor->getBackwardDerivativePrimalContextVar());
argTypes.add(builder->getOutType(
@@ -1024,10 +1048,8 @@ struct DiffTransposePass
gradients.add(RevGradient(
RevGradient::Flavor::Simple,
fwdCall->getArg(ii),
- builder->emitLoad(
- builder->emitDifferentialPairAddressDifferential(
- diffArgPtrType,
- args[ii])),
+ builder->emitDifferentialPairGetDifferential(
+ diffArgPtrType, builder->emitLoad(args[ii])),
nullptr));
}
}
@@ -1213,6 +1235,8 @@ struct DiffTransposePass
case kIROp_UpdateElement:
return transposeUpdateElement(builder, fwdInst, revValue);
+ case kIROp_LoadReverseGradient:
+ case kIROp_DefaultConstruct:
case kIROp_Specialize:
case kIROp_unconditionalBranch:
case kIROp_conditionalBranch:
@@ -1266,8 +1290,8 @@ struct DiffTransposePass
if (as<IRDifferentialPairType>(loadType))
{
- auto primalPtr = builder->emitDifferentialPairAddressPrimal(revPtr);
- auto primalVal = builder->emitLoad(primalPtr);
+ auto primalPairVal = builder->emitLoad(revPtr);
+ auto primalVal = builder->emitDifferentialPairGetPrimal(primalPairVal);
auto pairVal = builder->emitMakeDifferentialPair(loadType, primalVal, aggregateGradient);
@@ -1284,12 +1308,21 @@ struct DiffTransposePass
TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*)
{
+ IRInst* revVal = nullptr;
+ if (auto revGradDecor = fwdStore->getPtr()->findDecoration<IROutParamReverseGradientDecoration>())
+ {
+ revVal = revGradDecor->getValue();
+ }
+ else
+ {
+ revVal = builder->emitLoad(fwdStore->getPtr());
+ }
return TranspositionResult(
List<RevGradient>(
RevGradient(
RevGradient::Flavor::Simple,
fwdStore->getVal(),
- builder->emitLoad(fwdStore->getPtr()),
+ revVal,
fwdStore)));
}
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index daf6e44d4..378ea1cc2 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -279,7 +279,7 @@ struct ExtractPrimalFuncContext
inst);
}
- IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& outIntermediateType)
{
IRBuilder builder(sharedBuilder);
@@ -369,33 +369,59 @@ struct ExtractPrimalFuncContext
builder.emitStore(outIntermediary, defVal);
// The primal func will not have the result derivative param (second to last param), so we remove it.
- auto resultDerivativeParam = func->getLastParam()->getPrevParam();
- SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses());
- resultDerivativeParam->removeAndDeallocate();
+ if (isResultDifferentiable)
+ {
+ auto resultDerivativeParam = func->getLastParam()->getPrevParam();
+ SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses());
+ resultDerivativeParam->removeAndDeallocate();
+ }
- // Finally, go through parameters and turn DifferentiablePair<T> back to T.
- for (auto param : func->getParams())
+ // Finally, go through parameters and translate their type back to primal type.
+ for (auto param = func->getFirstParam(); param;)
{
- IRInst* valueType = param->getDataType();
- auto inoutType = as<IRPtrTypeBase>(param->getDataType());
- if (inoutType) valueType = inoutType->getValueType();
- auto diffPairType = as<IRDifferentialPairType>(valueType);
- if (!diffPairType) continue;
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
-
- auto originalValueType = diffPairType->getValueType();
-
- // Create a local var to act as the old param.
- auto tempVar = builder.emitVar(diffPairType);
- param->replaceUsesWith(tempVar);
- auto pairValue = builder.emitMakeDifferentialPair(
- diffPairType,
- param,
- backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType));
- builder.emitStore(tempVar, pairValue);
-
- // Change the param type to original type.
- param->setFullType(originalValueType);
+ auto next = param->getNextParam();
+ [this, firstBlock, &builder, param]()
+ {
+ for (auto use = param->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->getOp() == kIROp_AutoDiffOriginalValueDecoration)
+ {
+ use->getUser()->getParent()->replaceUsesWith(param);
+ return;
+ }
+ else if (use->getUser()->getOp() == kIROp_OutParamReverseGradientDecoration)
+ {
+ // This is a propagate func specific parameter, we should remove it.
+ SLANG_RELEASE_ASSERT(!param->hasMoreThanOneUse());
+ param->removeAndDeallocate();
+ return;
+ }
+ }
+
+ IRInst* valueType = param->getDataType();
+ auto inoutType = as<IRPtrTypeBase>(param->getDataType());
+ if (inoutType) valueType = inoutType->getValueType();
+ auto diffPairType = as<IRDifferentialPairType>(valueType);
+ if (!diffPairType)
+ return;
+
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+
+ auto originalValueType = diffPairType->getValueType();
+
+ // Create a local var to act as the old param.
+ auto tempVar = builder.emitVar(diffPairType);
+ param->replaceUsesWith(tempVar);
+ auto pairValue = builder.emitMakeDifferentialPair(
+ diffPairType,
+ param,
+ backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType));
+ builder.emitStore(tempVar, pairValue);
+
+ // Change the param type to original type.
+ param->setFullType(originalValueType);
+ }();
+ param = next;
}
return unzippedFunc;
@@ -420,7 +446,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
}
IRFunc* DiffUnzipPass::extractPrimalFunc(
- IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType)
+ IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType)
{
IRBuilder builder(this->autodiffContext->sharedBuilder);
builder.setInsertBefore(func);
@@ -434,7 +460,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber);
intermediateType = nullptr;
- auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, intermediateType);
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, isResultDifferentiable, intermediateType);
if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>())
{
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index d808cbb5e..3055d057b 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -139,6 +139,13 @@ struct DiffUnzipPass
builder->setInsertInto(unzippedFunc);
+ auto originalParam = func->getFirstParam();
+ for (auto primalParam = unzippedFunc->getFirstParam(); primalParam; primalParam = primalParam->getNextParam())
+ {
+ primalMap[originalParam] = primalParam;
+ originalParam = originalParam->getNextParam();
+ }
+
// Functions need to have at least two blocks at this point (one for parameters,
// and atleast one for code)
//
@@ -469,7 +476,7 @@ struct DiffUnzipPass
}
}
- IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType);
+ IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType);
bool isRelevantDifferentialPair(IRType* type)
{
@@ -537,7 +544,6 @@ struct DiffUnzipPass
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
{
auto arg = mixedCall->getArg(ii);
-
if (isRelevantDifferentialPair(arg->getDataType()))
{
primalArgs.add(lookupPrimalInst(arg));
@@ -552,20 +558,29 @@ struct DiffUnzipPass
auto mixedDecoration = mixedCall->findDecoration<IRMixedDifferentialInstDecoration>();
SLANG_ASSERT(mixedDecoration);
- auto fwdPairResultType = as<IRDifferentialPairType>(mixedDecoration->getPairType());
- SLANG_ASSERT(fwdPairResultType);
-
- auto primalType = fwdPairResultType->getValueType();
- auto diffType = (IRType*) diffTypeContext.getDifferentialForType(&globalBuilder, primalType);
+ IRType* primalType = mixedCall->getFullType();
+ IRType* diffType = mixedCall->getFullType();
+ IRType* resultType = mixedCall->getFullType();
+ if (auto fwdPairResultType = as<IRDifferentialPairType>(mixedDecoration->getPairType()))
+ {
+ primalType = fwdPairResultType->getValueType();
+ diffType = (IRType*)diffTypeContext.getDifferentialForType(&globalBuilder, primalType);
+ resultType = fwdPairResultType;
+ }
auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs);
primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar);
+ SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount());
+
List<IRInst*> diffArgs;
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
{
auto arg = mixedCall->getArg(ii);
+ // Depending on the type and direction of each argument,
+ // we might need to prepare a different value for the transposition logic to produce the
+ // correct final argument in the propagate function call.
if (isRelevantDifferentialPair(arg->getDataType()))
{
auto primalArg = lookupPrimalInst(arg);
@@ -574,18 +589,45 @@ struct DiffUnzipPass
// If arg is a mixed differential (pair), it should have already been split.
SLANG_ASSERT(primalArg);
SLANG_ASSERT(diffArg);
-
- auto pairArg = diffBuilder->emitMakeDifferentialPair(
+ auto primalParamType = primalFuncType->getParamType(ii);
+
+ if (auto outType = as<IROutType>(primalParamType))
+ {
+ // For `out` parameters that expects an input derivative to propagate through,
+ // we insert a `LoadReverseGradient` inst here to signify the logic in `transposeStore`
+ // that this argument should actually be the currently accumulated derivative on
+ // this variable. The end purpose is that we will generate a load(diffArg) in the
+ // final transposed code and use that as the argument for the call, but we can't just
+ // emit a normal load inst here because the transposition logic will turn loads into stores.
+ auto outDiffType = cast<IRPtrTypeBase>(diffArg->getDataType())->getValueType();
+ auto gradArg = diffBuilder->emitLoadReverseGradient(outDiffType, diffArg);
+ diffBuilder->markInstAsDifferential(gradArg, primalArg->getDataType());
+ diffArgs.add(gradArg);
+ }
+ else if (auto inoutType = as<IRInOutType>(primalParamType))
+ {
+ SLANG_UNIMPLEMENTED_X("nested call inout parameter");
+ }
+ else
+ {
+ // For ordinary differentiable input parameters, we make sure to provide
+ // a differential pair. The actual logic that generates an inout variable
+ // will be handled in `transposeCall()`.
+ auto pairArg = diffBuilder->emitMakeDifferentialPair(
arg->getDataType(),
primalArg,
diffArg);
- diffBuilder->markInstAsDifferential(pairArg, primalArg->getDataType());
- diffArgs.add(pairArg);
+ diffBuilder->markInstAsDifferential(pairArg, primalArg->getDataType());
+ diffArgs.add(pairArg);
+ }
}
else
{
- diffArgs.add(arg);
+ // For non differentiable arguments, we can simply pass the argument as is
+ // if this isn't a `out` parameter, in which case it is removed from propagate call.
+ if (!as<IROutType>(arg->getDataType()))
+ diffArgs.add(arg);
}
}
@@ -593,19 +635,22 @@ struct DiffUnzipPass
diffBuilder->markInstAsDifferential(newFwdCallee);
- auto diffPairVal = diffBuilder->emitCallInst(
- fwdPairResultType,
+ auto callInst = diffBuilder->emitCallInst(
+ resultType,
newFwdCallee,
diffArgs);
- diffBuilder->markInstAsDifferential(diffPairVal, primalType);
+ diffBuilder->markInstAsDifferential(callInst, primalType);
disableIRValidationAtInsert();
- diffBuilder->addBackwardDerivativePrimalContextDecoration(diffPairVal, intermediateVar);
+ diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar);
enableIRValidationAtInsert();
- auto diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, diffPairVal);
- diffBuilder->markInstAsDifferential(diffVal, primalType);
-
+ IRInst* diffVal = nullptr;
+ if (as<IRDifferentialPairType>(callInst->getDataType()))
+ {
+ diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, callInst);
+ diffBuilder->markInstAsDifferential(diffVal, primalType);
+ }
return InstPair(primalVal, diffVal);
}
@@ -616,52 +661,92 @@ struct DiffUnzipPass
InstPair splitLoad(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRLoad* mixedLoad)
{
- // By the nature of how diff pairs are used, and the fact that FieldAddress/GetElementPtr,
- // etc, cannot appear before a GetDifferential/GetPrimal, a mixed load can only be from a
- // parameter or a variable.
- //
- if (as<IRParam>(mixedLoad->getPtr()))
+ if (auto param = as<IRParam>(mixedLoad->getPtr()))
{
- // Should not occur with current impl of fwd-mode.
- // If impl. changes, impl this case too.
- //
- SLANG_UNIMPLEMENTED_X("Splitting a load from a param is not currently implemented.");
+ auto diffPairPtrType = as<IRPtrTypeBase>(param->getFullType());
+ SLANG_RELEASE_ASSERT(diffPairPtrType);
+ auto diffPairType = as<IRDifferentialPairType>(diffPairPtrType->getValueType());
+ SLANG_RELEASE_ASSERT(diffPairType);
+ auto diffType = (IRType*)diffTypeContext.getDifferentialTypeFromDiffPairType(diffBuilder, diffPairType);
+ auto loadedParam = primalBuilder->emitLoad(param);
+ return InstPair(
+ primalBuilder->emitDifferentialPairGetPrimal(loadedParam),
+ primalBuilder->emitDifferentialPairGetDifferential(diffType, loadedParam));
}
// Everything else should have already been split.
auto primalPtr = lookupPrimalInst(mixedLoad->getPtr());
auto diffPtr = lookupDiffInst(mixedLoad->getPtr());
+ auto primalVal = primalBuilder->emitLoad(primalPtr);
+ auto diffVal = diffBuilder->emitLoad(diffPtr);
+ diffBuilder->markInstAsDifferential(diffVal, primalVal->getFullType());
+ return InstPair(primalVal, diffVal);
+ }
+
+ InstPair splitStore(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRStore* mixedStore)
+ {
+ // We will only generate mixed store to parameters.
+ if (!as<IRParam>(mixedStore->getPtr()))
+ {
+ SLANG_UNIMPLEMENTED_X("Splitting a store that is not writing to a param.");
+ }
+
+ auto primalAddr = mixedStore->getPtr();
+
+ auto primalVal = lookupPrimalInst(mixedStore->getVal());
+ auto diffVal = lookupDiffInst(mixedStore->getVal());
- return InstPair(primalBuilder->emitLoad(primalPtr), diffBuilder->emitLoad(diffPtr));
+ // For now the param type and value type will not type-check in these store insts,
+ // but the param inst will be changed to the correct type after we synthesize primal and
+ // propagate func.
+ auto primalStore = primalBuilder->emitStore(primalAddr, primalVal);
+ auto diffStore = diffBuilder->emitStore(primalAddr, diffVal);
+
+ diffBuilder->markInstAsDifferential(diffStore, primalVal->getFullType());
+ return InstPair(primalStore, diffStore);
}
InstPair splitVar(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRVar* mixedVar)
{
- auto pairType = as<IRDifferentialPairType>(mixedVar->getDataType());
+ auto pairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(mixedVar->getDataType())->getValueType());
auto primalType = pairType->getValueType();
auto diffType = (IRType*) diffTypeContext.getDifferentialForType(primalBuilder, primalType);
-
- return InstPair(primalBuilder->emitVar(primalType), diffBuilder->emitVar(diffType));
+ auto primalVar = primalBuilder->emitVar(primalType);
+ auto diffVar = diffBuilder->emitVar(diffType);
+ diffBuilder->markInstAsDifferential(diffVar, primalType);
+ return InstPair(primalVar, diffVar);
}
InstPair splitReturn(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRReturn* mixedReturn)
{
auto pairType = as<IRDifferentialPairType>(mixedReturn->getVal()->getDataType());
- auto primalType = pairType->getValueType();
+ // Are we returning a differentiable value?
+ if (pairType)
+ {
+ auto primalType = pairType->getValueType();
- // Check that we have an unambiguous 'first' differential block.
- SLANG_ASSERT(firstDiffBlock);
- auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
- auto pairVal = diffBuilder->emitMakeDifferentialPair(
- pairType,
- lookupPrimalInst(mixedReturn->getVal()),
- lookupDiffInst(mixedReturn->getVal()));
- diffBuilder->markInstAsDifferential(pairVal, primalType);
+ // Check that we have an unambiguous 'first' differential block.
+ SLANG_ASSERT(firstDiffBlock);
+ auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
+ auto pairVal = diffBuilder->emitMakeDifferentialPair(
+ pairType,
+ lookupPrimalInst(mixedReturn->getVal()),
+ lookupDiffInst(mixedReturn->getVal()));
+ diffBuilder->markInstAsDifferential(pairVal, primalType);
- auto returnInst = diffBuilder->emitReturn(pairVal);
- diffBuilder->markInstAsDifferential(returnInst, primalType);
+ auto returnInst = diffBuilder->emitReturn(pairVal);
+ diffBuilder->markInstAsDifferential(returnInst, primalType);
- return InstPair(primalBranch, returnInst);
+ return InstPair(primalBranch, returnInst);
+ }
+ else
+ {
+ // If return value is not differentiable, just turn it into a trivial branch.
+ auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
+ auto returnInst = diffBuilder->emitReturn();
+ diffBuilder->markInstAsDifferential(returnInst, nullptr);
+ return InstPair(primalBranch, returnInst);
+ }
}
bool isBlockIndexed(IRBlock* block)
@@ -973,6 +1058,9 @@ struct DiffUnzipPass
case kIROp_Load:
return splitLoad(primalBuilder, diffBuilder, as<IRLoad>(inst));
+ case kIROp_Store:
+ return splitStore(primalBuilder, diffBuilder, as<IRStore>(inst));
+
case kIROp_Return:
return splitReturn(primalBuilder, diffBuilder, as<IRReturn>(inst));
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 8952f9756..7a2e8c75e 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -24,7 +24,7 @@ bool isBackwardDifferentiableFunc(IRInst* func)
return false;
}
-static IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
{
if (auto witnessTable = as<IRWitnessTable>(witness))
{
@@ -400,6 +400,14 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b
return nullptr;
}
+IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairType(
+ IRBuilder* builder, IRDifferentialPairType* diffPairType)
+{
+ auto witness = diffPairType->getWitness();
+ SLANG_RELEASE_ASSERT(witness);
+ return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey);
+}
+
void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
for (auto globalInst : sharedContext->moduleInst->getChildren())
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 2258ff753..30f053673 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -160,6 +160,8 @@ struct DifferentiableTypeConformanceContext
IRInst* lookUpConformanceForType(IRInst* type);
IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);
+
+ IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairType* diffPairType);
// Lookup and return the 'Differential' type declared in the concrete type
// in order to conform to the IDifferentiable interface.
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index ce3e563f5..e7d5a0e5c 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -5,28 +5,6 @@
namespace Slang
{
-IRInst* getSpecializedVal(IRInst* inst)
-{
- int loopLimit = 1024;
- while (inst && inst->getOp() == kIROp_Specialize)
- {
- inst = as<IRSpecialize>(inst)->getBase();
- loopLimit--;
- if (loopLimit == 0)
- return inst;
- }
- return inst;
-}
-
-IRInst* getLeafFunc(IRInst* func)
-{
- func = getSpecializedVal(func);
- if (!func)
- return nullptr;
- if (auto genericFunc = as<IRGeneric>(func))
- return findInnerMostGenericReturnVal(genericFunc);
- return func;
-}
struct CheckDifferentiabilityPassContext : public InstPassBase
{
@@ -47,7 +25,7 @@ public:
bool _isFuncMarkedForAutoDiff(IRInst* func)
{
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
for (auto decorations : func->getDecorations())
@@ -65,7 +43,7 @@ public:
bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level)
{
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
@@ -103,7 +81,7 @@ public:
}
}
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
@@ -332,7 +310,7 @@ public:
sink->diagnose(
inst,
Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
- getLeafFunc(call->getCallee()),
+ getResolvedInstForDecorations(call->getCallee()),
requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
}
}
diff --git a/source/slang/slang-ir-init-local-var.cpp b/source/slang/slang-ir-init-local-var.cpp
new file mode 100644
index 000000000..4b28db268
--- /dev/null
+++ b/source/slang/slang-ir-init-local-var.cpp
@@ -0,0 +1,34 @@
+// slang-ir-init-local-var.cpp
+#include "slang-ir-init-local-var.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+void initializeLocalVariables(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func)
+{
+ IRBuilder builder(sharedBuilder);
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_Var)
+ {
+ auto firstUse = inst->firstUse;
+ bool initialized =
+ (firstUse && firstUse->getUser()->getOp() == kIROp_Store &&
+ firstUse->getUser()->getParent() == inst->getParent());
+ if (initialized)
+ continue;
+ builder.setInsertAfter(inst);
+ builder.emitStore(
+ inst,
+ builder.emitDefaultConstruct(
+ as<IRPtrTypeBase>(inst->getFullType())->getValueType()));
+ }
+ }
+ }
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-init-local-var.h b/source/slang/slang-ir-init-local-var.h
new file mode 100644
index 000000000..ad06684fc
--- /dev/null
+++ b/source/slang/slang-ir-init-local-var.h
@@ -0,0 +1,14 @@
+// slang-ir-init-local-var.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ struct IRGlobalValueWithCode;
+ struct SharedIRBuilder;
+
+ // Init local variables with default values if the variable isn't being initialized locally in
+ // the same basic block.
+ void initializeLocalVariables(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func);
+
+}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f2294671e..e1143b7b9 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -326,6 +326,10 @@ INST(Var, var, 0, 0)
INST(Load, load, 1, 0)
INST(Store, store, 2, 0)
+// Produced and removed during backward auto-diff pass as a temporary placeholder representing the
+// currently accumulated derivative to pass to some dOut argument in a nested call.
+INST(LoadReverseGradient, LoadReverseGradient, 1, 0)
+
INST(FieldExtract, get_field, 2, 0)
INST(FieldAddress, get_field_addr, 2, 0)
@@ -767,6 +771,12 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// forward-differentiated updateElement inst.
INST(PrimalElementTypeDecoration, primalElementType, 1, 0)
+ /// Used by the auto-diff pass. An `out T` parameter will transcribe to a `in T.Differential` parameter.
+ /// We will also create a temp var of type `T.Differential` in the function body so the `load` and `stores`
+ /// can operand on a valid address. We use this decoration to associate this temp var with its corresponding
+ /// input parameter.
+ INST(OutParamReverseGradientDecoration, outParamRevGrad, 1, 0)
+
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index aca832c0c..132a96f16 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -723,6 +723,18 @@ struct IRMixedDifferentialInstDecoration : IRDecoration
IRType* getPairType() { return as<IRType>(getOperand(0)); }
};
+struct IROutParamReverseGradientDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_OutParamReverseGradientDecoration
+ };
+
+ IR_LEAF_ISA(OutParamReverseGradientDecoration)
+
+ IRInst* getValue() { return getOperand(0); }
+};
+
struct IRBackwardDifferentiableDecoration : IRDecoration
{
enum
@@ -1770,6 +1782,12 @@ struct IRGetElementPtr : IRInst
IRInst* getIndex() { return getOperand(1); }
};
+struct IRLoadReverseGradient :IRInst
+{
+ IR_LEAF_ISA(LoadReverseGradient)
+ IRInst* getValue() { return getOperand(0); }
+};
+
struct IRGetNativePtr : IRInst
{
IR_LEAF_ISA(GetNativePtr);
@@ -2598,7 +2616,6 @@ public:
IRInst* getBoolValue(bool value);
IRInst* getIntValue(IRType* type, IRIntegerValue value);
IRInst* getFloatValue(IRType* type, IRFloatingPointValue value);
- IRInst* getDifferentialBottom();
IRStringLit* getStringValue(const UnownedStringSlice& slice);
IRPtrLit* _getPtrValue(void* ptr);
IRPtrLit* getNullPtrValue(IRType* type);
@@ -2920,8 +2937,6 @@ public:
IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue);
IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair);
IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair);
- IRInst* emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair);
- IRInst* emitDifferentialPairAddressPrimal(IRInst* diffPair);
IRInst* emitMakeVector(
IRType* type,
UInt argCount,
@@ -3129,6 +3144,8 @@ public:
IRInst* emitLoad(
IRInst* ptr);
+ IRInst* emitLoadReverseGradient(IRType* type, IRInst* diffValue);
+
IRInst* emitStore(
IRInst* dstPtr,
IRInst* srcVal);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 3ffbb75f7..5cf074484 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -223,7 +223,9 @@ String dumpIRToString(IRInst* root)
{
StringBuilder sb;
StringWriter writer(&sb, Slang::WriterFlag::AutoFlush);
- dumpIR(root, IRDumpOptions(), nullptr, &writer);
+ IRDumpOptions options = {};
+ options.flags = IRDumpOptions::Flag::DumpDebugIds;
+ dumpIR(root, options, nullptr, &writer);
return sb.ToString();
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 2a4ae59a7..4814726cf 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3745,15 +3745,7 @@ namespace Slang
IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair)
{
- return emitIntrinsicInst(
- diffType,
- kIROp_DifferentialPairGetDifferential,
- 1,
- &diffPair);
- }
-
- IRInst* IRBuilder::emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair)
- {
+ SLANG_ASSERT(as<IRDifferentialPairType>(diffPair->getDataType()));
return emitIntrinsicInst(
diffType,
kIROp_DifferentialPairGetDifferential,
@@ -3763,7 +3755,7 @@ namespace Slang
IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
{
- auto valueType = as<IRDifferentialPairType>(diffPair->getDataType())->getValueType();
+ auto valueType = cast<IRDifferentialPairType>(diffPair->getDataType())->getValueType();
return emitIntrinsicInst(
valueType,
kIROp_DifferentialPairGetPrimal,
@@ -3771,16 +3763,6 @@ namespace Slang
&diffPair);
}
- IRInst* IRBuilder::emitDifferentialPairAddressPrimal(IRInst* diffPair)
- {
- auto valueType = as<IRDifferentialPairType>(
- as<IRPtrTypeBase>(diffPair->getDataType())->getValueType())->getValueType();
- return emitIntrinsicInst(
- this->getPtrType(kIROp_PtrType, valueType),
- kIROp_DifferentialPairGetPrimal,
- 1,
- &diffPair);
- }
IRInst* IRBuilder::emitMakeMatrix(
IRType* type,
@@ -4240,6 +4222,18 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitLoadReverseGradient(IRType* type, IRInst* diffValue)
+ {
+ auto inst = createInst<IRLoadReverseGradient>(
+ this,
+ kIROp_LoadReverseGradient,
+ type,
+ diffValue);
+
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitLoad(
IRType* type,
IRInst* ptr)
@@ -6818,6 +6812,7 @@ namespace Slang
case kIROp_MakeTuple:
case kIROp_GetTupleElement:
case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads
+ case kIROp_LoadReverseGradient:
case kIROp_ImageSubscript:
case kIROp_FieldExtract:
case kIROp_FieldAddress: