summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp30
-rw-r--r--source/slang/slang-ir-autodiff-propagate.h5
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h281
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h244
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h24
-rw-r--r--tests/autodiff/reverse-nested-calls.slang29
-rw-r--r--tests/autodiff/reverse-nested-calls.slang.expected.txt6
-rw-r--r--tests/autodiff/reverse-struct-types.slang23
-rw-r--r--tests/autodiff/reverse-struct-types.slang.expected.txt2
11 files changed, 519 insertions, 131 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 60c2721c7..d1e9f91ec 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -615,7 +615,12 @@ InstPair ForwardDerivativeTranscriber::transcribeLoad(IRBuilder* builder, IRLoad
// Special case load from an `out` param, which will not have corresponding `diff` and
// `primal` insts yet.
+ // TODO: Could we move this load to _after_ DifferentialPairGetPrimal,
+ // and DifferentialPairGetDifferential?
+ //
auto load = builder->emitLoad(primalPtr);
+ builder->markInstAsMixedDifferential(load, diffPairType);
+
auto primalElement = builder->emitDifferentialPairGetPrimal(load);
auto diffElement = builder->emitDifferentialPairGetDifferential(
(IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
@@ -647,7 +652,7 @@ InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRSto
if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType()))
{
auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal);
- builder->markInstAsDifferential(diffStoreVal, diffPairType);
+ builder->markInstAsMixedDifferential(diffStoreVal, diffPairType);
auto store = builder->emitStore(primalStoreLocation, valToStore);
return InstPair(store, nullptr);
@@ -690,6 +695,7 @@ InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRRe
// Neither of these should be nullptr.
SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal);
IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal));
+ builder->markInstAsMixedDifferential(diffReturn, nullptr);
return InstPair(diffReturn, diffReturn);
}
@@ -704,9 +710,11 @@ InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRRe
SLANG_RELEASE_ASSERT(diffReturnVal);
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
- builder->markInstAsDifferential(diffPair, pairType);
+ builder->markInstAsMixedDifferential(diffPair, pairType);
IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
+ builder->markInstAsMixedDifferential(pairReturn, pairType);
+
return InstPair(pairReturn, pairReturn);
}
else
@@ -804,7 +812,8 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
// If the user has already provided an differentiated implementation, use that.
diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
}
- else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>())
+ else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>() ||
+ primalCallee->findDecoration<IRBackwardDifferentiableDecoration>())
{
// If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
// to generate the implementation.
@@ -851,7 +860,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
SLANG_RELEASE_ASSERT(diffArg);
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
- builder->markInstAsDifferential(diffPair, pairType);
+ builder->markInstAsMixedDifferential(diffPair, pairType);
args.add(diffPair);
continue;
@@ -875,7 +884,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
diffReturnType,
diffCallee,
args);
- builder->markInstAsDifferential(callInst, origCall->getFullType());
+ builder->markInstAsMixedDifferential(callInst, diffReturnType);
if (diffReturnType->getOp() != kIROp_VoidType)
{
@@ -1578,8 +1587,15 @@ IRInst* ForwardDerivativeTranscriber::transcribe(IRBuilder* builder, IRInst* ori
builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
}
- // Tag the differential inst using a decoration.
- builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType()));
+ // Tag the differential inst using a decoration (if it doesn't have one)
+ if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() &&
+ !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>())
+ {
+ // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential
+ // instead.
+ //
+ builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType()));
+ }
break;
}
diff --git a/source/slang/slang-ir-autodiff-propagate.h b/source/slang/slang-ir-autodiff-propagate.h
index 9518ccb34..0d5686899 100644
--- a/source/slang/slang-ir-autodiff-propagate.h
+++ b/source/slang/slang-ir-autodiff-propagate.h
@@ -15,6 +15,11 @@ bool isDifferentialInst(IRInst* inst)
return inst->findDecoration<IRDifferentialInstDecoration>();
}
+bool isMixedDifferentialInst(IRInst* inst)
+{
+ return inst->findDecoration<IRMixedDifferentialInstDecoration>();
+}
+
struct DiffPropagationPass : InstPassBase
{
AutoDiffSharedContext* autodiffContext;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 34a08ee93..c7fbc415a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -589,7 +589,7 @@ struct BackwardDiffTranscriber
{
// Create inout version.
auto inoutDiffPairType = builder->getInOutType(diffPairType);
- auto newParam = builder->emitParam(inoutDiffPairType);
+ auto newParam = builder->emitParam(inoutDiffPairType);
// Map the _load_ of the new parameter as the clone of the old one.
auto newParamLoad = builder->emitLoad(newParam);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 75491d753..a14ecad84 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -150,17 +150,25 @@ struct DiffTransposePass
// Insert after the last block.
builder.setInsertInto(revBlock);
- List<IRInst*> ptrInsts;
+ // Move pointer & reference insts to the top of the reverse-mode block.
+ List<IRInst*> nonValueInsts;
for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
{
- // If the instruction is pointer typed, move to top of new reverse-mode block
+ // If the instruction is pointer typed, it's not actually computing a value.
+ //
if (as<IRPtrTypeBase>(child->getDataType()))
- ptrInsts.add(child);
+ nonValueInsts.add(child);
+
+ // Slang doesn't support function values. So if we see a func-typed inst
+ // it's proabably a reference to a function.
+ //
+ if (as<IRFuncType>(child->getDataType()))
+ nonValueInsts.add(child);
}
- for (auto ptrInst : ptrInsts)
+ for (auto inst : nonValueInsts)
{
- ptrInst->insertAtEnd(revBlock);
+ inst->insertAtEnd(revBlock);
}
@@ -210,34 +218,6 @@ struct DiffTransposePass
if (hasRevGradients(inst))
gradients = popRevGradients(inst);
- // Are we dealing with DifferentialPairType?
- if (as<IRDifferentialPairType>(inst->getDataType()))
- {
- // This will be a 'hybrid' primal-differential inst,
- // so we add a pair (primal_value, 0) as an additional
- // gradient to represent the primal part of the computation.
- //
- // Now, if the unzip pass has done it's job, the _only_
- // case should be that inst is IRMakeDifferentialPair
- //
- SLANG_ASSERT(as<IRMakeDifferentialPair>(inst));
- auto primalType = as<IRDifferentialPairType>(inst->getDataType())->getValueType();
- auto diffType = (IRType*)pairBuilder.getDiffTypeFromPairType(builder, as<IRDifferentialPairType>(inst->getDataType()));
-
- auto primalInst = as<IRMakeDifferentialPair>(inst)->getPrimalValue();
- auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType);
-
- // Must exist.
- SLANG_ASSERT(zeroMethod);
- auto diffInst = builder->emitCallInst(diffType, zeroMethod, List<IRInst*>());
-
- gradients.add(
- RevGradient(
- inst,
- builder->emitMakeDifferentialPair(inst->getDataType(), primalInst, diffInst),
- nullptr));
- }
-
IRType* primalType = tryGetPrimalTypeFromDiffInst(inst);
if (!primalType)
@@ -249,6 +229,14 @@ struct DiffTransposePass
tryGetPrimalTypeFromDiffInst(returnInst->getVal()));
primalType = returnPairType->getValueType();
}
+ else if (auto loadInst = as<IRLoad>(inst))
+ {
+ // TODO: Unzip loads properly to avoid having to side-step this check for IRLoad
+ if (auto pairType = as<IRDifferentialPairType>(loadInst->getDataType()))
+ {
+ primalType = pairType->getValueType();
+ }
+ }
}
if (!primalType)
@@ -278,6 +266,116 @@ struct DiffTransposePass
addRevGradientForFwdInst(gradient.targetInst, gradient);
}
}
+
+ TranspositionResult transposeCall(IRBuilder* builder, IRCall* fwdCall, IRInst* revValue)
+ {
+ auto fwdDiffCallee = as<IRForwardDifferentiate>(fwdCall->getCallee());
+
+ // If the callee is not a fwd-differentiate(fn), then there's only two
+ // cases. This is a call to something that doesn't need to be transposed
+ // or this is a user-written function calling something that isn't marked
+ // with IRForwardDifferentiate, but is handling differentials.
+ // We currently do not handle the latter.
+ // However, if we see a callee with no parameters, we can just skip over.
+ // since there's nothing to backpropagate to.
+ //
+ if (!fwdDiffCallee)
+ {
+ if (fwdCall->getArgCount() == 0)
+ {
+ return TranspositionResult(List<RevGradient>());
+ }
+ else
+ {
+ SLANG_UNIMPLEMENTED_X(
+ "This case should only trigger on a user-defined fwd-mode function"
+ " calling another user-defined function not marked with __fwd_diff()");
+ }
+ }
+
+ auto baseFn = fwdDiffCallee->getBaseFn();
+
+ List<IRInst*> args;
+ List<IRType*> argTypes;
+ List<bool> isArgPtrTyped;
+
+ for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++)
+ {
+ auto arg = fwdCall->getArg(ii);
+
+ // If this isn't a ptr-type, make a var.
+ if (!as<IRPtrTypeBase>(arg->getDataType()) && as<IRDifferentialPairType>(arg->getDataType()))
+ {
+ auto pairType = as<IRDifferentialPairType>(arg->getDataType());
+
+ auto var = builder->emitVar(arg->getDataType());
+
+ SLANG_ASSERT(as<IRMakeDifferentialPair>(arg));
+
+ // Initialize this var to (arg.primal, 0).
+ builder->emitStore(
+ var,
+ builder->emitMakeDifferentialPair(
+ arg->getDataType(),
+ as<IRMakeDifferentialPair>(arg)->getPrimalValue(),
+ builder->emitCallInst(
+ (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()),
+ diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()),
+ List<IRInst*>())));
+
+ args.add(var);
+ argTypes.add(builder->getInOutType(pairType));
+ isArgPtrTyped.add(false);
+ }
+ else
+ {
+ args.add(arg);
+ argTypes.add(arg->getDataType());
+ isArgPtrTyped.add(true);
+ }
+ }
+
+ args.add(revValue);
+ argTypes.add(revValue->getDataType());
+
+ auto revFnType = builder->getFuncType(argTypes, builder->getVoidType());
+ auto revCallee = builder->emitBackwardDifferentiateInst(
+ revFnType,
+ baseFn);
+
+ builder->emitCallInst(revFnType->getResultType(), revCallee, args);
+
+ List<RevGradient> gradients;
+ for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++)
+ {
+ // Is this arg relevant to auto-diff?
+ if (as<IRDifferentialPairType>(as<IRPtrTypeBase>(args[ii]->getDataType())->getValueType()))
+ {
+ // If this is ptr typed, ignore (the gradient will be accumulated on the pointer)
+ // automatically.
+ //
+ if (!isArgPtrTyped[ii])
+ {
+ auto diffArgType = (IRType*)diffTypeContext.getDifferentialForType(
+ builder,
+ as<IRDifferentialPairType>(
+ as<IRPtrTypeBase>(argTypes[ii])->getValueType())->getValueType());
+ auto diffArgPtrType = builder->getPtrType(kIROp_PtrType, diffArgType);
+
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdCall->getArg(ii),
+ builder->emitLoad(
+ builder->emitDifferentialPairAddressDifferential(
+ diffArgPtrType,
+ args[ii])),
+ nullptr));
+ }
+ }
+ }
+
+ return TranspositionResult(gradients);
+ }
TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
{
@@ -288,6 +386,9 @@ struct DiffTransposePass
case kIROp_Mul:
case kIROp_Sub:
return transposeArithmetic(builder, fwdInst, revValue);
+
+ case kIROp_Call:
+ return transposeCall(builder, as<IRCall>(fwdInst), revValue);
case kIROp_swizzle:
return transposeSwizzle(builder, as<IRSwizzle>(fwdInst), revValue);
@@ -322,35 +423,49 @@ struct DiffTransposePass
{
auto revPtr = fwdLoad->getPtr();
+ auto primalType = tryGetPrimalTypeFromDiffInst(fwdLoad);
+ auto loadType = fwdLoad->getDataType();
+
+ List<RevGradient> gradients(RevGradient(
+ revPtr,
+ revValue,
+ nullptr));
+
if (usedPtrs.contains(revPtr))
{
// Re-emit a load to get the _current_ value of revPtr.
auto revCurrGrad = builder->emitLoad(revPtr);
// Add the current value to the aggregation list.
- List<RevGradient> gradients(
- RevGradient(
- revCurrGrad,
- revValue,
- nullptr),
- RevGradient(
- revCurrGrad,
- revCurrGrad,
- nullptr));
-
- auto primalType = tryGetPrimalTypeFromDiffInst(fwdLoad);
- // Get the _total_ value.
- auto aggregateGradient = emitAggregateValue(builder, primalType, gradients);
-
- // Store this back into the pointer.
- builder->emitStore(revPtr, aggregateGradient);
+ gradients.add(RevGradient(
+ revPtr,
+ revCurrGrad,
+ nullptr));
}
else
{
usedPtrs.add(revPtr);
+ }
+
+ // Get the _total_ value.
+ auto aggregateGradient = emitAggregateValue(
+ builder,
+ primalType,
+ gradients);
+
+ if (as<IRDifferentialPairType>(loadType))
+ {
+ auto primalPtr = builder->emitDifferentialPairAddressPrimal(revPtr);
+ auto primalVal = builder->emitLoad(primalPtr);
+
+ auto pairVal = builder->emitMakeDifferentialPair(loadType, primalVal, aggregateGradient);
- // Store into pointer
- builder->emitStore(revPtr, revValue);
+ builder->emitStore(revPtr, pairVal);
+ }
+ else
+ {
+ // Store this back into the pointer.
+ builder->emitStore(revPtr, aggregateGradient);
}
return TranspositionResult(List<RevGradient>());
@@ -359,7 +474,6 @@ struct DiffTransposePass
TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*)
{
- // (A = p.x) -> (p = float3(dA, 0, 0))
return TranspositionResult(
List<RevGradient>(
RevGradient(
@@ -384,7 +498,6 @@ struct DiffTransposePass
TranspositionResult transposeFieldExtract(IRBuilder*, IRFieldExtract* fwdExtract, IRInst* revValue)
{
- // (A = p.x) -> (p = float3(dA, 0, 0))
return TranspositionResult(
List<RevGradient>(
RevGradient(
@@ -394,17 +507,19 @@ struct DiffTransposePass
fwdExtract)));
}
- TranspositionResult transposeMakePair(IRBuilder* builder, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue)
+ TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue)
{
+ // Even though makePair returns a pair of (primal, differential)
+ // revValue will only contain the reverse-value for 'differential'
+ //
// (P = (A, dA)) -> (dA += dP)
+ //
return TranspositionResult(
List<RevGradient>(
RevGradient(
RevGradient::Flavor::Simple,
fwdMakePair->getDifferentialValue(),
- builder->emitDifferentialPairGetDifferential(
- fwdMakePair->getDifferentialValue()->getDataType(),
- revValue),
+ revValue,
fwdMakePair)));
}
@@ -414,7 +529,7 @@ struct DiffTransposePass
return TranspositionResult(
List<RevGradient>(
RevGradient(
- RevGradient::Flavor::GetDifferential,
+ RevGradient::Flavor::Simple,
fwdGetDiff->getBase(),
revValue,
fwdGetDiff)));
@@ -448,39 +563,7 @@ struct DiffTransposePass
//
void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
{
- auto revPtr = revLoad->getPtr();
-
- // Assert that ptr type is of the form IRPtrTypeBase<IRDifferentialPairType<T>>
- SLANG_ASSERT(as<IRPtrTypeBase>(revPtr->getDataType()));
- SLANG_ASSERT(as<IRPtrTypeBase>(revPtr->getDataType())->getValueType()->getOp() == kIROp_DifferentialPairType);
-
- auto paramPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(revPtr->getDataType())->getValueType());
-
- // Gather gradients.
- auto gradients = popRevGradients(revLoad);
- if (gradients.getCount() == 0)
- {
- // Ignore.
- return;
- }
- else
- {
- // Re-emit a load to get the _current_ value of revPtr.
- auto revCurrGrad = builder->emitLoad(revPtr);
-
- // Add the current value to the aggregation list.
- gradients.add(
- RevGradient(
- revLoad,
- revCurrGrad,
- nullptr));
-
- // Get the _total_ value.
- auto aggregateGradient = emitAggregateValue(builder, paramPairType, gradients);
-
- // Store this back into the pointer.
- builder->emitStore(revPtr, aggregateGradient);
- }
+ return transposeInst(builder, revLoad);
}
TranspositionResult transposeReturn(IRBuilder*, IRReturn* fwdReturn, IRInst* revValue)
@@ -488,16 +571,14 @@ struct DiffTransposePass
// TODO: This check needs to be changed to something like: isRelevantDifferentialPair()
if (as<IRDifferentialPairType>(fwdReturn->getVal()->getDataType()))
{
- // This is a subtle case, even though the returned value is returning
- // a pair, we need to pretend that the primal value is not being returned
- // since we only care about transposing differential computation.
- // So we're going to assume there is an implicit GetDifferential()
- // around the return value before returning.
+ // Simply pass on the gradient to the previous inst.
+ // (Even if the return value is pair typed, we only care about the differential part)
+ // So this will remain a 'simple' gradient.
//
return TranspositionResult(
List<RevGradient>(
RevGradient(
- RevGradient::Flavor::GetDifferential,
+ RevGradient::Flavor::Simple,
fwdReturn->getVal(),
revValue,
fwdReturn)));
@@ -856,6 +937,8 @@ struct DiffTransposePass
IRInst* emitAggregateDifferentialPair(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> pairGradients)
{
+ SLANG_UNEXPECTED("Should not run.");
+
auto aggPairType = as<IRDifferentialPairType>(aggPrimalType);
SLANG_ASSERT(aggPairType);
@@ -923,7 +1006,9 @@ struct DiffTransposePass
// a differential pair is really a 'hybrid' primal-differential type.
//
if (as<IRDifferentialPairType>(aggPrimalType))
- return emitAggregateDifferentialPair(builder, aggPrimalType, gradients);
+ {
+ SLANG_UNEXPECTED("Should not occur");
+ }
// Process non-simple gradients into simple gradients.
// TODO: This is where we can improve efficiency later.
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 79dec365c..2bfe972ec 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -17,12 +17,32 @@ struct DiffUnzipPass
IRCloneEnv cloneEnv;
+ DifferentiableTypeConformanceContext diffTypeContext;
+
+ // Maps used to keep track of primal and
+ // differential versions of split insts.
+ //
+ Dictionary<IRInst*, IRInst*> primalMap;
+ Dictionary<IRInst*, IRInst*> diffMap;
+
DiffUnzipPass(AutoDiffSharedContext* autodiffContext) :
- autodiffContext(autodiffContext)
+ autodiffContext(autodiffContext), diffTypeContext(autodiffContext)
{ }
+ IRInst* lookupPrimalInst(IRInst* inst)
+ {
+ return primalMap[inst];
+ }
+
+ IRInst* lookupDiffInst(IRInst* inst)
+ {
+ return diffMap[inst];
+ }
+
IRFunc* unzipDiffInsts(IRFunc* func)
{
+ diffTypeContext.setFunc(func);
+
IRBuilder builderStorage;
builderStorage.init(autodiffContext->sharedBuilder);
@@ -66,6 +86,185 @@ struct DiffUnzipPass
return unzippedFunc;
}
+ bool isRelevantDifferentialPair(IRType* type)
+ {
+ if (as<IRDifferentialPairType>(type))
+ {
+ return true;
+ }
+ else if (auto argPtrType = as<IRPtrTypeBase>(type))
+ {
+ if (as<IRDifferentialPairType>(argPtrType->getValueType()))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ InstPair splitCall(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRCall* mixedCall)
+ {
+ IRBuilder globalBuilder;
+ globalBuilder.init(autodiffContext->sharedBuilder);
+
+ auto fwdCallee = as<IRForwardDifferentiate>(mixedCall->getCallee());
+ auto fwdCalleeType = as<IRFuncType>(fwdCallee->getDataType());
+ auto baseFn = fwdCallee->getBaseFn();
+
+ List<IRInst*> primalArgs;
+ for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
+ {
+ auto arg = mixedCall->getArg(0);
+
+ if (isRelevantDifferentialPair(arg->getDataType()))
+ {
+ primalArgs.add(lookupPrimalInst(arg));
+ }
+ else
+ {
+ primalArgs.add(arg);
+ }
+ }
+
+ auto mixedDecoration = mixedCall->findDecoration<IRMixedDifferentialInstDecoration>();
+ SLANG_ASSERT(mixedDecoration);
+
+ auto fwdPairResultType = as<IRDifferentialPairType>(mixedDecoration->getPairType());
+ SLANG_ASSERT(fwdPairResultType);
+
+ auto primalType = fwdPairResultType->getValueType();
+ auto diffType = (IRType*) diffTypeContext.getDifferentialForType(&globalBuilder, primalType);
+
+ auto primalVal = primalBuilder->emitCallInst(primalType, baseFn, primalArgs);
+
+ List<IRInst*> diffArgs;
+ for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
+ {
+ auto arg = mixedCall->getArg(0);
+
+ if (isRelevantDifferentialPair(arg->getDataType()))
+ {
+ auto primalArg = lookupPrimalInst(arg);
+ auto diffArg = lookupDiffInst(arg);
+
+ // If arg is a mixed differential (pair), it should have already been split.
+ SLANG_ASSERT(primalArg);
+ SLANG_ASSERT(diffArg);
+
+ auto pairArg = diffBuilder->emitMakeDifferentialPair(
+ arg->getDataType(),
+ primalArg,
+ diffArg);
+
+ diffBuilder->markInstAsDifferential(pairArg, primalArg->getDataType());
+ diffArgs.add(pairArg);
+ }
+ else
+ {
+ diffArgs.add(arg);
+ }
+ }
+
+ auto newFwdCallee = diffBuilder->emitForwardDifferentiateInst(fwdCalleeType, baseFn);
+ diffBuilder->markInstAsDifferential(newFwdCallee);
+
+ auto diffPairVal = diffBuilder->emitCallInst(
+ fwdPairResultType,
+ newFwdCallee,
+ diffArgs);
+ diffBuilder->markInstAsDifferential(diffPairVal, primalType);
+
+ auto diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, diffPairVal);
+ diffBuilder->markInstAsDifferential(diffVal, primalType);
+
+ return InstPair(primalVal, diffVal);
+ }
+
+ InstPair splitMakePair(IRBuilder*, IRBuilder*, IRMakeDifferentialPair* mixedPair)
+ {
+ return InstPair(mixedPair->getPrimalValue(), mixedPair->getDifferentialValue());
+ }
+
+ InstPair splitLoad(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRLoad* mixedLoad)
+ {
+ // By the nature of how diff pairs are used, and the fact that FieldAddress/GetElementPtr,
+ // etc, cannot appear before a GetDifferential/GetPrimal, a mixed load can only be from a
+ // parameter or a variable.
+ //
+ if (as<IRParam>(mixedLoad->getPtr()))
+ {
+ // Should not occur with current impl of fwd-mode.
+ // If impl. changes, impl this case too.
+ //
+ SLANG_UNIMPLEMENTED_X("Splitting a load from a param is not currently implemented.");
+ }
+
+ // Everything else should have already been split.
+ auto primalPtr = lookupPrimalInst(mixedLoad->getPtr());
+ auto diffPtr = lookupDiffInst(mixedLoad->getPtr());
+
+ return InstPair(primalBuilder->emitLoad(primalPtr), diffBuilder->emitLoad(diffPtr));
+ }
+
+ InstPair splitVar(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRVar* mixedVar)
+ {
+ auto pairType = as<IRDifferentialPairType>(mixedVar->getDataType());
+ auto primalType = pairType->getValueType();
+ auto diffType = (IRType*) diffTypeContext.getDifferentialForType(primalBuilder, primalType);
+
+ return InstPair(primalBuilder->emitVar(primalType), diffBuilder->emitVar(diffType));
+ }
+
+ InstPair splitReturn(IRBuilder*, IRBuilder* diffBuilder, IRReturn* mixedReturn)
+ {
+ auto pairType = as<IRDifferentialPairType>(mixedReturn->getVal()->getDataType());
+ auto primalType = pairType->getValueType();
+
+ auto pairVal = diffBuilder->emitMakeDifferentialPair(
+ pairType,
+ lookupPrimalInst(mixedReturn->getVal()),
+ lookupDiffInst(mixedReturn->getVal()));
+ diffBuilder->markInstAsDifferential(pairVal, primalType);
+
+ auto returnInst = diffBuilder->emitReturn(pairVal);
+ diffBuilder->markInstAsDifferential(returnInst, primalType);
+
+ return InstPair(nullptr, returnInst);
+ }
+
+ InstPair _splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ return splitCall(primalBuilder, diffBuilder, as<IRCall>(inst));
+
+ case kIROp_Var:
+ return splitVar(primalBuilder, diffBuilder, as<IRVar>(inst));
+
+ case kIROp_MakeDifferentialPair:
+ return splitMakePair(primalBuilder, diffBuilder, as<IRMakeDifferentialPair>(inst));
+
+ case kIROp_Load:
+ return splitLoad(primalBuilder, diffBuilder, as<IRLoad>(inst));
+
+ case kIROp_Return:
+ return splitReturn(primalBuilder, diffBuilder, as<IRReturn>(inst));
+
+ default:
+ SLANG_ASSERT_FAILURE("Unhandled mixed diff inst");
+ }
+ }
+
+ void splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst)
+ {
+ auto instPair = _splitMixedInst(primalBuilder, diffBuilder, inst);
+
+ primalMap[inst] = instPair.primal;
+ diffMap[inst] = instPair.differential;
+ }
+
void splitBlock(IRBlock* mainBlock, IRBlock* primalBlock, IRBlock* diffBlock)
{
// Make two builders for primal and differential blocks.
@@ -77,14 +276,42 @@ struct DiffUnzipPass
diffBuilder.init(autodiffContext->sharedBuilder);
diffBuilder.setInsertInto(diffBlock);
+ List<IRInst*> splitInsts;
for (auto child = mainBlock->getFirstChild(); child;)
{
IRInst* nextChild = child->getNextInst();
- if (isDifferentialInst(child) || as<IRTerminatorInst>(child))
+ if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(child))
+ {
+ if (diffMap.ContainsKey(getDiffInst->getBase()))
+ {
+ getDiffInst->replaceUsesWith(lookupDiffInst(getDiffInst->getBase()));
+ getDiffInst->removeAndDeallocate();
+ child = nextChild;
+ continue;
+ }
+ }
+
+ if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(child))
+ {
+ if (primalMap.ContainsKey(getPrimalInst->getBase()))
+ {
+ getPrimalInst->replaceUsesWith(lookupPrimalInst(getPrimalInst->getBase()));
+ getPrimalInst->removeAndDeallocate();
+ child = nextChild;
+ continue;
+ }
+ }
+
+ if (isDifferentialInst(child))
{
child->insertAtEnd(diffBlock);
}
+ else if (isMixedDifferentialInst(child))
+ {
+ splitMixedInst(&primalBuilder, &diffBuilder, child);
+ splitInsts.add(child);
+ }
else
{
child->insertAtEnd(primalBlock);
@@ -93,6 +320,19 @@ struct DiffUnzipPass
child = nextChild;
}
+ // Remove insts that were split.
+ for (auto inst : splitInsts)
+ {
+ // Consistency check.
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ SLANG_RELEASE_ASSERT((use->getUser()->getParent() != primalBlock) &&
+ (use->getUser()->getParent() != diffBlock));
+ }
+
+ inst->removeAndDeallocate();
+ }
+
// Nothing should be left in the original block.
SLANG_ASSERT(mainBlock->getFirstChild() == nullptr);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 5784f60cb..c74388406 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -737,6 +737,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// a differential value.
INST(DifferentialInstDecoration, diffInstDecoration, 1, 0)
+ /// Used by the auto-diff pass to mark insts that compute
+ /// BOTH a differential and a primal value.
+ INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0)
+
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 1ef0fa4f8..67f17f5b2 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -610,6 +610,20 @@ struct IRDifferentialInstDecoration : IRDecoration
IRType* getPrimalType() { return as<IRType>(getOperand(0)); }
};
+
+struct IRMixedDifferentialInstDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_MixedDifferentialInstDecoration
+ };
+
+ IRUse pairType;
+ IR_LEAF_ISA(MixedDifferentialInstDecoration)
+
+ IRType* getPairType() { return as<IRType>(getOperand(0)); }
+};
+
struct IRBackwardDifferentiableDecoration : IRDecoration
{
enum
@@ -3377,6 +3391,16 @@ public:
addDecoration(value, kIROp_DifferentialInstDecoration, nullptr);
}
+ void markInstAsMixedDifferential(IRInst* value)
+ {
+ addDecoration(value, kIROp_MixedDifferentialInstDecoration, nullptr);
+ }
+
+ void markInstAsMixedDifferential(IRInst* value, IRType* pairType)
+ {
+ addDecoration(value, kIROp_MixedDifferentialInstDecoration, pairType);
+ }
+
void markInstAsDifferential(IRInst* value, IRType* primalType)
{
addDecoration(value, kIROp_DifferentialInstDecoration, primalType);
diff --git a/tests/autodiff/reverse-nested-calls.slang b/tests/autodiff/reverse-nested-calls.slang
new file mode 100644
index 000000000..2b55efd60
--- /dev/null
+++ b/tests/autodiff/reverse-nested-calls.slang
@@ -0,0 +1,29 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDifferentiable]
+float g(float y)
+{
+ return 4.0f * y;
+}
+
+[BackwardDifferentiable]
+float f(float x)
+{
+ return 3.0f * g(2.0f * x);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ dpfloat dpa = dpfloat(1.0, 0.0);
+
+ __bwd_diff(f)(dpa, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 24.0
+}
diff --git a/tests/autodiff/reverse-nested-calls.slang.expected.txt b/tests/autodiff/reverse-nested-calls.slang.expected.txt
new file mode 100644
index 000000000..0a39c4da6
--- /dev/null
+++ b/tests/autodiff/reverse-nested-calls.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+24.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/reverse-struct-types.slang b/tests/autodiff/reverse-struct-types.slang
index 699e50480..d2b52a008 100644
--- a/tests/autodiff/reverse-struct-types.slang
+++ b/tests/autodiff/reverse-struct-types.slang
@@ -9,27 +9,6 @@ struct A : IDifferentiable
{
float x;
float y;
-
- [__unsafeForceInlineEarly]
- static Differential dzero()
- {
- Differential b = {0.0, float.dzero()};
- return b;
- }
-
- [__unsafeForceInlineEarly]
- static Differential dadd(Differential a, Differential b)
- {
- Differential o = {a.x + b.x, 0.0};
- return o;
- }
-
- [__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
- {
- Differential o = {a.x * b.x, 0.0};
- return o;
- }
};
typedef DifferentialPair<A> dpA;
@@ -56,7 +35,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
A.Differential dout = {1.0, 1.0};
__bwd_diff(f)(dpa, dout);
- outputBuffer[0] = dpa.d.x; // Expect: 10
+ outputBuffer[0] = dpa.d.x; // Expect: 7
outputBuffer[1] = dpa.d.y; // Expect: 0
}
}
diff --git a/tests/autodiff/reverse-struct-types.slang.expected.txt b/tests/autodiff/reverse-struct-types.slang.expected.txt
index 82bc8f733..b94f4fec6 100644
--- a/tests/autodiff/reverse-struct-types.slang.expected.txt
+++ b/tests/autodiff/reverse-struct-types.slang.expected.txt
@@ -1,5 +1,5 @@
type: float
-5.000000
+7.000000
0.000000
0.000000
0.000000