summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-12-08 11:50:55 -0500
committerGitHub <noreply@github.com>2022-12-08 08:50:55 -0800
commit468bb7ecf65c000c308adae511bf65a1ca4cc412 (patch)
tree8042aaa77224d00f14a7267564ce7452ad6de67e /source
parent53e891eb28ceac5f956399c65f2ae27d37f3d724 (diff)
More type support for reverse-mode (#2551)
* Add vector arithmetic test. Make gradient accumulation work for any IRLoad * Added support for general vector types, and split transposition into transpose & materialize to allow emitting the fully accumulated gradient for complex types. * Several bug fixes + finished up support for vector & struct types + removed prop pass * minor fixes (int/uint casts) * Removed IRConstruct * Added some type casts to prevent warnings * minor fix for unused variable
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-expr.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp84
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp5
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h882
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h8
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h11
7 files changed, 849 insertions, 147 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 336682bf4..1ffc45fbd 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2136,9 +2136,9 @@ namespace Slang
type->paramTypes.add(derivType);
}
}
-
+
// Last parameter is the initial derivative of the original return type
- type->paramTypes.add(originalType->resultType);
+ type->paramTypes.add(getDifferentialType(m_astBuilder, originalType->resultType, SourceLoc()));
return type;
}
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index d2d9a0e7d..60c2721c7 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -529,25 +529,54 @@ InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder,
diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
auto resultType = primalArith->getDataType();
+ auto diffType = (IRType*) differentiableTypeConformanceContext.getDifferentialForType(builder, resultType);
+
switch(origArith->getOp())
{
case kIROp_Add:
- return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight));
+ {
+ auto diffAdd = builder->emitAdd(diffType, diffLeft, diffRight);
+ builder->markInstAsDifferential(diffAdd, resultType);
+
+ return InstPair(primalArith, diffAdd);
+ }
+
case kIROp_Mul:
- return InstPair(primalArith, builder->emitAdd(resultType,
- builder->emitMul(resultType, diffLeft, primalRight),
- builder->emitMul(resultType, primalLeft, diffRight)));
+ {
+ auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight);
+ auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight);
+ builder->markInstAsDifferential(diffLeftTimesRight, resultType);
+ builder->markInstAsDifferential(diffRightTimesLeft, resultType);
+
+ auto diffAdd = builder->emitAdd(diffType, diffLeftTimesRight, diffRightTimesLeft);
+ builder->markInstAsDifferential(diffAdd, resultType);
+
+ return InstPair(primalArith, diffAdd);
+ }
+
case kIROp_Sub:
- return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight));
+ {
+ auto diffSub = builder->emitSub(diffType, diffLeft, diffRight);
+ builder->markInstAsDifferential(diffSub, resultType);
+
+ return InstPair(primalArith, diffSub);
+ }
case kIROp_Div:
- return InstPair(primalArith, builder->emitDiv(resultType,
- builder->emitSub(
- resultType,
- builder->emitMul(resultType, diffLeft, primalRight),
- builder->emitMul(resultType, primalLeft, diffRight)),
- builder->emitMul(
- primalRight->getDataType(), primalRight, primalRight
- )));
+ {
+ auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight);
+ auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight);
+ auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft);
+ builder->markInstAsDifferential(diffLeftTimesRight, resultType);
+ builder->markInstAsDifferential(diffRightTimesLeft, resultType);
+ builder->markInstAsDifferential(diffSub, resultType);
+
+ auto diffMul = builder->emitMul(resultType, primalRight, primalRight);
+
+ auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul);
+ builder->markInstAsDifferential(diffDiv, resultType);
+
+ return InstPair(primalArith, diffDiv);
+ }
default:
getSink()->diagnose(origArith->sourceLoc,
Diagnostics::unimplemented,
@@ -558,7 +587,6 @@ InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder,
return InstPair(primalArith, nullptr);
}
-
InstPair ForwardDerivativeTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
{
SLANG_ASSERT(origLogic->getOperandCount() == 2);
@@ -619,6 +647,8 @@ InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRSto
if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType()))
{
auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal);
+ builder->markInstAsDifferential(diffStoreVal, diffPairType);
+
auto store = builder->emitStore(primalStoreLocation, valToStore);
return InstPair(store, nullptr);
}
@@ -674,6 +704,8 @@ InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRRe
SLANG_RELEASE_ASSERT(diffReturnVal);
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
+ builder->markInstAsDifferential(diffPair, pairType);
+
IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
return InstPair(pairReturn, pairReturn);
}
@@ -817,7 +849,10 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
// If a pair type can be formed, this must be non-null.
SLANG_RELEASE_ASSERT(diffArg);
+
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
+ builder->markInstAsDifferential(diffPair, pairType);
+
args.add(diffPair);
continue;
}
@@ -826,7 +861,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
// Add original/primal argument.
args.add(primalArg);
}
-
+
IRType* diffReturnType = nullptr;
diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
@@ -840,6 +875,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
diffReturnType,
diffCallee,
args);
+ builder->markInstAsDifferential(callInst, origCall->getFullType());
if (diffReturnType->getOp() != kIROp_VoidType)
{
@@ -1145,7 +1181,11 @@ IRInst* ForwardDerivativeTranscriber::getDifferentialZeroOfType(IRBuilder* build
SLANG_RELEASE_ASSERT(zeroMethod);
auto emptyArgList = List<IRInst*>();
- return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+
+ auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ builder->markInstAsDifferential(callInst, primalType);
+
+ return callInst;
}
else
{
@@ -1489,10 +1529,10 @@ InstPair ForwardDerivativeTranscriber::transcribeGeneric(IRBuilder* inBuilder, I
diffGeneric->setFullType(diffType);
- // Transcribe children from origFunc into diffFunc.
- builder.setInsertInto(diffGeneric);
- for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribe(&builder, block);
+ // Transcribe children from origFunc into diffFunc.
+ builder.setInsertInto(diffGeneric);
+ for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(&builder, block);
return InstPair(primalGeneric, diffGeneric);
}
@@ -1537,6 +1577,10 @@ IRInst* ForwardDerivativeTranscriber::transcribe(IRBuilder* builder, IRInst* ori
sb << "s_diff_" << primalNameHint->getName();
builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
}
+
+ // Tag the differential inst using a decoration.
+ builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType()));
+
break;
}
}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 8ec8f581c..34a08ee93 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -511,7 +511,10 @@ struct BackwardDiffTranscriber
this->makeParameterBlock(builder, as<IRFunc>(fwdDiffFunc));
// This steps adds a decoration to instructions that are computing the differential.
- diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc);
+ // TODO: This is disabled for now because fwd-mode already adds differential decorations
+ // wherever need. We need to run this pass only for user-writted forward derivativecode.
+ //
+ // diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc);
// Copy primal insts to the first block of the unzipped function, copy diff insts to the
// second block of the unzipped function.
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 659131820..75491d753 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -13,39 +13,61 @@ namespace Slang
struct DiffTransposePass
{
- AutoDiffSharedContext* autodiffContext;
-
- DifferentialPairTypeBuilder pairBuilder;
+
+ struct RevGradient
+ {
+ enum Flavor
+ {
+ Simple,
+ Swizzle,
+ GetElement,
+ GetDifferential,
+ FieldExtract,
- Dictionary<IRInst*, List<IRInst*>> assignmentsMap;
+ Invalid
+ };
- Dictionary<IRInst*, IRInst*>* primalsMap;
+ RevGradient() :
+ flavor(Flavor::Invalid), targetInst(nullptr), revGradInst(nullptr), fwdGradInst(nullptr)
+ { }
+
+ RevGradient(Flavor flavor, IRInst* targetInst, IRInst* revGradInst, IRInst* fwdGradInst) :
+ flavor(flavor), targetInst(targetInst), revGradInst(revGradInst), fwdGradInst(fwdGradInst)
+ { }
- DiffTransposePass(AutoDiffSharedContext* autodiffContext) :
- autodiffContext(autodiffContext), pairBuilder(autodiffContext)
- { }
+ RevGradient(IRInst* targetInst, IRInst* revGradInst, IRInst* fwdGradInst) :
+ flavor(Flavor::Simple), targetInst(targetInst), revGradInst(revGradInst), fwdGradInst(fwdGradInst)
+ { }
- struct RevAssignment
- {
- IRInst* lvalue;
- IRInst* rvalue;
+ bool operator==(const RevGradient& other) const
+ {
+ return (other.targetInst == targetInst) &&
+ (other.revGradInst == revGradInst) &&
+ (other.fwdGradInst == fwdGradInst) &&
+ (other.flavor == flavor);
+ }
+
+ IRInst* targetInst;
+ IRInst* revGradInst;
+ IRInst* fwdGradInst;
- RevAssignment(IRInst* lvalue, IRInst* rvalue) : lvalue(lvalue), rvalue(rvalue)
- { }
- RevAssignment() : lvalue(nullptr), rvalue(nullptr)
- { }
+ Flavor flavor;
};
+ DiffTransposePass(AutoDiffSharedContext* autodiffContext) :
+ autodiffContext(autodiffContext), pairBuilder(autodiffContext), diffTypeContext(autodiffContext)
+ { }
+
struct TranspositionResult
{
// Holds a set of pairs of
// (original-inst, inst-to-accumulate-for-orig-inst)
- List<RevAssignment> revPairs;
+ List<RevGradient> revPairs;
TranspositionResult()
{ }
- TranspositionResult(List<RevAssignment> revPairs) : revPairs(revPairs)
+ TranspositionResult(List<RevGradient> revPairs) : revPairs(revPairs)
{ }
};
@@ -64,9 +86,10 @@ struct DiffTransposePass
void transposeDiffBlocksInFunc(
IRFunc* revDiffFunc,
- // TODO: Maybe there's a more elegant way to pass this information.
FuncTranspositionInfo transposeInfo)
{
+ // Grab all differentiable type information.
+ diffTypeContext.setFunc(revDiffFunc);
// Traverse all instructions/blocks in reverse (starting from the terminator inst)
// look for insts/blocks marked with IRDifferentialInstDecoration,
@@ -103,7 +126,7 @@ struct DiffTransposePass
// Set dOutParameter as the transpose gradient for the return inst, if any.
if (auto returnInst = as<IRReturn>(block->getTerminator()))
{
- this->addRevAssignmentForFwdInst(returnInst, transposeInfo.dOutInst);
+ this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr));
}
IRBlock* revBlock = builder.emitBlock();
@@ -117,6 +140,8 @@ struct DiffTransposePass
}
}
+ // A[cond_inst] -> (B or C) -> D => D[cond_inst] -> (B_T -> C_T) -> A_T
+
void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock)
{
IRBuilder builder;
@@ -125,7 +150,24 @@ struct DiffTransposePass
// Insert after the last block.
builder.setInsertInto(revBlock);
+ List<IRInst*> ptrInsts;
+ for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ {
+ // If the instruction is pointer typed, move to top of new reverse-mode block
+ if (as<IRPtrTypeBase>(child->getDataType()))
+ ptrInsts.add(child);
+ }
+
+ for (auto ptrInst : ptrInsts)
+ {
+ ptrInst->insertAtEnd(revBlock);
+ }
+
+
+ // Then, go backwards through the regular instructions, and transpose them into the new
+ // rev block.
// Note the 'reverse' traversal here.
+ //
for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst())
{
if (as<IRDecoration>(child))
@@ -141,7 +183,7 @@ struct DiffTransposePass
// function scope variable, since control flow can affect what blocks contribute to
// for a specific inst.
//
- for (auto pair : assignmentsMap)
+ for (auto pair : gradientsMap)
{
if (auto param = as<IRLoad>(pair.Key))
accumulateGradientsForLoad(&builder, param);
@@ -163,20 +205,77 @@ struct DiffTransposePass
void transposeInst(IRBuilder* builder, IRInst* inst)
{
- // Look for assignment entry for this inst.
- IRInst* revValue = builder->getFloatValue(builder->getType(kIROp_FloatType), 0.0);
- if (hasRevAssignments(inst))
+ // Look for gradient entries for this inst.
+ List<RevGradient> gradients;
+ if (hasRevGradients(inst))
+ gradients = popRevGradients(inst);
+
+ // Are we dealing with DifferentialPairType?
+ if (as<IRDifferentialPairType>(inst->getDataType()))
+ {
+ // This will be a 'hybrid' primal-differential inst,
+ // so we add a pair (primal_value, 0) as an additional
+ // gradient to represent the primal part of the computation.
+ //
+ // Now, if the unzip pass has done it's job, the _only_
+ // case should be that inst is IRMakeDifferentialPair
+ //
+ SLANG_ASSERT(as<IRMakeDifferentialPair>(inst));
+ auto primalType = as<IRDifferentialPairType>(inst->getDataType())->getValueType();
+ auto diffType = (IRType*)pairBuilder.getDiffTypeFromPairType(builder, as<IRDifferentialPairType>(inst->getDataType()));
+
+ auto primalInst = as<IRMakeDifferentialPair>(inst)->getPrimalValue();
+ auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType);
+
+ // Must exist.
+ SLANG_ASSERT(zeroMethod);
+ auto diffInst = builder->emitCallInst(diffType, zeroMethod, List<IRInst*>());
+
+ gradients.add(
+ RevGradient(
+ inst,
+ builder->emitMakeDifferentialPair(inst->getDataType(), primalInst, diffInst),
+ nullptr));
+ }
+
+ IRType* primalType = tryGetPrimalTypeFromDiffInst(inst);
+
+ if (!primalType)
+ {
+ // Special-case instructions.
+ if (auto returnInst = as<IRReturn>(inst))
+ {
+ auto returnPairType = as<IRDifferentialPairType>(
+ tryGetPrimalTypeFromDiffInst(returnInst->getVal()));
+ primalType = returnPairType->getValueType();
+ }
+ }
+
+ if (!primalType)
{
- // Emit the aggregate of all the assignments here. This will form the derivative
- revValue = emitAggregateValue(builder, popRevAssignments(inst));
+ // Check for special insts for which a reverse-mode gradient doesn't apply.
+ if(!as<IRStore>(inst))
+ {
+ SLANG_UNEXPECTED("Could not resolve primal type for diff inst");
+ }
}
+ // Emit the aggregate of all the gradients here. This will form the total derivative for this inst.
+ auto revValue = emitAggregateValue(builder, primalType, gradients);
+
auto transposeResult = transposeInst(builder, inst, revValue);
- // Add the new results to the assignments map.
- for (auto pair : transposeResult.revPairs)
+ if (auto fwdNameHint = inst->findDecoration<IRNameHintDecoration>())
{
- addRevAssignmentForFwdInst(pair.lvalue, pair.rvalue);
+ StringBuilder sb;
+ sb << fwdNameHint->getName() << "_T";
+ builder->addNameHintDecoration(revValue, sb.getUnownedSlice());
+ }
+
+ // Add the new results to the gradients map.
+ for (auto gradient : transposeResult.revPairs)
+ {
+ addRevGradientForFwdInst(gradient.targetInst, gradient);
}
}
@@ -189,59 +288,176 @@ struct DiffTransposePass
case kIROp_Mul:
case kIROp_Sub:
return transposeArithmetic(builder, fwdInst, revValue);
+
+ case kIROp_swizzle:
+ return transposeSwizzle(builder, as<IRSwizzle>(fwdInst), revValue);
+
+ case kIROp_FieldExtract:
+ return transposeFieldExtract(builder, as<IRFieldExtract>(fwdInst), revValue);
case kIROp_Return:
return transposeReturn(builder, as<IRReturn>(fwdInst), revValue);
+
+ case kIROp_Store:
+ return transposeStore(builder, as<IRStore>(fwdInst), revValue);
+
+ case kIROp_Load:
+ return transposeLoad(builder, as<IRLoad>(fwdInst), revValue);
case kIROp_MakeDifferentialPair:
return transposeMakePair(builder, as<IRMakeDifferentialPair>(fwdInst), revValue);
case kIROp_DifferentialPairGetDifferential:
return transposeGetDifferential(builder, as<IRDifferentialPairGetDifferential>(fwdInst), revValue);
+
+ case kIROp_MakeVector:
+ return transposeMakeVector(builder, fwdInst, revValue);
default:
SLANG_ASSERT_FAILURE("Unhandled instruction");
}
}
- TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue)
+ TranspositionResult transposeLoad(IRBuilder* builder, IRLoad* fwdLoad, IRInst* revValue)
+ {
+ auto revPtr = fwdLoad->getPtr();
+
+ if (usedPtrs.contains(revPtr))
+ {
+ // Re-emit a load to get the _current_ value of revPtr.
+ auto revCurrGrad = builder->emitLoad(revPtr);
+
+ // Add the current value to the aggregation list.
+ List<RevGradient> gradients(
+ RevGradient(
+ revCurrGrad,
+ revValue,
+ nullptr),
+ RevGradient(
+ revCurrGrad,
+ revCurrGrad,
+ nullptr));
+
+ auto primalType = tryGetPrimalTypeFromDiffInst(fwdLoad);
+ // Get the _total_ value.
+ auto aggregateGradient = emitAggregateValue(builder, primalType, gradients);
+
+ // Store this back into the pointer.
+ builder->emitStore(revPtr, aggregateGradient);
+ }
+ else
+ {
+ usedPtrs.add(revPtr);
+
+ // Store into pointer
+ builder->emitStore(revPtr, revValue);
+ }
+
+ return TranspositionResult(List<RevGradient>());
+ }
+
+
+ TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*)
+ {
+ // (A = p.x) -> (p = float3(dA, 0, 0))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdStore->getVal(),
+ builder->emitLoad(fwdStore->getPtr()),
+ fwdStore)));
+ }
+
+ TranspositionResult transposeSwizzle(IRBuilder*, IRSwizzle* fwdSwizzle, IRInst* revValue)
+ {
+ // (A = p.x) -> (p = float3(dA, 0, 0))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Swizzle,
+ fwdSwizzle->getBase(),
+ revValue,
+ fwdSwizzle)));
+ }
+
+
+ TranspositionResult transposeFieldExtract(IRBuilder*, IRFieldExtract* fwdExtract, IRInst* revValue)
+ {
+ // (A = p.x) -> (p = float3(dA, 0, 0))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::FieldExtract,
+ fwdExtract->getBase(),
+ revValue,
+ fwdExtract)));
+ }
+
+ TranspositionResult transposeMakePair(IRBuilder* builder, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue)
{
// (P = (A, dA)) -> (dA += dP)
return TranspositionResult(
- List<RevAssignment>(
- RevAssignment(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
fwdMakePair->getDifferentialValue(),
- revValue)));
+ builder->emitDifferentialPairGetDifferential(
+ fwdMakePair->getDifferentialValue()->getDataType(),
+ revValue),
+ fwdMakePair)));
}
TranspositionResult transposeGetDifferential(IRBuilder*, IRDifferentialPairGetDifferential* fwdGetDiff, IRInst* revValue)
{
// (A = GetDiff(P)) -> (dP.d += dA)
return TranspositionResult(
- List<RevAssignment>(
- RevAssignment(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::GetDifferential,
fwdGetDiff->getBase(),
- revValue)));
+ revValue,
+ fwdGetDiff)));
}
- // Gather all reverse-mode gradients for parameters, and store to the differential
- //
- void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
+ TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
{
- auto revParam = revLoad->getPtr();
+ // For now, we support only vector types. Extend this to other built-in types if necessary.
+ SLANG_ASSERT(fwdMakeVector->getOp() == kIROp_MakeVector);
- // Don't currently handle loads from non-param insts.
- SLANG_ASSERT(as<IRParam>(revParam));
+ List<RevGradient> gradients;
+ for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++)
+ {
+ auto gradAtIndex = builder->emitElementExtract(
+ fwdMakeVector->getOperand(ii)->getDataType(),
+ revValue,
+ builder->getIntValue(builder->getIntType(), ii));
+
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeVector->getOperand(ii),
+ gradAtIndex,
+ fwdMakeVector));
+ }
+
+ // (A = float3(X, Y, Z)) -> [(dX += dA), (dY += dA), (dZ += dA)]
+ return TranspositionResult(gradients);
+ }
- // Assert that param type is of the form IRPtrTypeBase<IRDifferentialPairType<T>>
- SLANG_ASSERT(as<IRPtrTypeBase>(revParam->getDataType()));
- SLANG_ASSERT(as<IRPtrTypeBase>(revParam->getDataType())->getValueType()->getOp() == kIROp_DifferentialPairType);
+ // Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
+ //
+ void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
+ {
+ auto revPtr = revLoad->getPtr();
+
+ // Assert that ptr type is of the form IRPtrTypeBase<IRDifferentialPairType<T>>
+ SLANG_ASSERT(as<IRPtrTypeBase>(revPtr->getDataType()));
+ SLANG_ASSERT(as<IRPtrTypeBase>(revPtr->getDataType())->getValueType()->getOp() == kIROp_DifferentialPairType);
- auto paramPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(revParam->getDataType())->getValueType());
- auto diffType = (IRType*) pairBuilder.getDiffTypeFromPairType(builder, paramPairType);
+ auto paramPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(revPtr->getDataType())->getValueType());
// Gather gradients.
- auto gradients = popRevAssignments(revLoad);
+ auto gradients = popRevGradients(revLoad);
if (gradients.getCount() == 0)
{
// Ignore.
@@ -249,42 +465,42 @@ struct DiffTransposePass
}
else
{
- // Re-emit a load to get the _current_ value of revParam.
- auto revCurrLoad = builder->emitLoad(revParam);
-
- // Grab the current gradient value.
- auto revCurrGrad = builder->emitDifferentialPairGetDifferential(diffType, revCurrLoad);
+ // Re-emit a load to get the _current_ value of revPtr.
+ auto revCurrGrad = builder->emitLoad(revPtr);
// Add the current value to the aggregation list.
- gradients.add(revCurrGrad);
+ gradients.add(
+ RevGradient(
+ revLoad,
+ revCurrGrad,
+ nullptr));
// Get the _total_ value.
- auto aggregateGradient = emitAggregateValue(builder, gradients);
-
- // Grab the current primal value.
- auto revCurrPrimal = builder->emitDifferentialPairGetPrimal(revCurrLoad);
+ auto aggregateGradient = emitAggregateValue(builder, paramPairType, gradients);
- // Make the pair with the new gradient.
- auto newDiffPair = builder->emitMakeDifferentialPair(paramPairType, revCurrPrimal, aggregateGradient);
-
- // Store this back into the parameter.
- builder->emitStore(revParam, newDiffPair);
+ // Store this back into the pointer.
+ builder->emitStore(revPtr, aggregateGradient);
}
}
TranspositionResult transposeReturn(IRBuilder*, IRReturn* fwdReturn, IRInst* revValue)
{
-
+ // TODO: This check needs to be changed to something like: isRelevantDifferentialPair()
if (as<IRDifferentialPairType>(fwdReturn->getVal()->getDataType()))
{
- // If the type is a differential pair, we add the reverse-value for the *pair*
- // itself. TODO: Signal this through flags in the 'RevAssignment' struct.
- // (return (A, dA)) -> (dA += dOut)
+ // This is a subtle case, even though the returned value is returning
+ // a pair, we need to pretend that the primal value is not being returned
+ // since we only care about transposing differential computation.
+ // So we're going to assume there is an implicit GetDifferential()
+ // around the return value before returning.
+ //
return TranspositionResult(
- List<RevAssignment>(
- RevAssignment(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::GetDifferential,
fwdReturn->getVal(),
- revValue)));
+ revValue,
+ fwdReturn)));
}
else
{
@@ -293,35 +509,136 @@ struct DiffTransposePass
}
}
+ IRInst* promoteToType(IRBuilder* builder, IRType* targetType, IRInst* inst)
+ {
+ auto currentType = inst->getDataType();
+
+ switch (targetType->getOp())
+ {
+
+ case kIROp_VectorType:
+ {
+ // current type should be a scalar.
+ SLANG_RELEASE_ASSERT(!as<IRVectorType>(currentType->getDataType()));
+
+ auto targetVectorType = as<IRVectorType>(targetType);
+
+ List<IRInst*> operands;
+ for (Index ii = 0; ii < as<IRIntLit>(targetVectorType->getElementCount())->getValue(); ii++)
+ {
+ operands.add(inst);
+ }
+
+ IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer());
+
+ if (isDifferentialInst(inst))
+ builder->markInstAsDifferential(newInst);
+
+ return newInst;
+ }
+
+ default:
+ SLANG_ASSERT_FAILURE("Unhandled target type for promotion");
+ }
+ }
+
+ IRInst* promoteOperandsToTargetType(IRBuilder* builder, IRInst* fwdInst)
+ {
+ auto oldLoc = builder->getInsertLoc();
+ // If operands are not of the same type, cast them to the target type.
+ IRType* targetType = fwdInst->getDataType();
+
+ bool needNewInst = false;
+
+ List<IRInst*> newOperands;
+ for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++)
+ {
+ auto operand = fwdInst->getOperand(ii);
+ if (operand->getDataType() != targetType)
+ {
+ // Insert new operand just after the old operand, so we have the old
+ // operands available.
+ //
+ builder->setInsertAfter(operand);
+
+ IRInst* newOperand = promoteToType(builder, targetType, operand);
+ newOperands.add(newOperand);
+
+ needNewInst = true;
+ }
+ else
+ {
+ newOperands.add(operand);
+ }
+ }
+
+ if(needNewInst)
+ {
+ builder->setInsertAfter(fwdInst);
+ IRInst* newInst = builder->emitIntrinsicInst(
+ fwdInst->getDataType(),
+ fwdInst->getOp(),
+ newOperands.getCount(),
+ newOperands.getBuffer());
+
+ builder->setInsertLoc(oldLoc);
+
+ if (isDifferentialInst(fwdInst))
+ builder->markInstAsDifferential(newInst);
+
+ return newInst;
+ }
+ else
+ {
+ builder->setInsertLoc(oldLoc);
+ return fwdInst;
+ }
+ }
+
TranspositionResult transposeArithmetic(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
{
- IRType* floatType = builder->getType(kIROp_FloatType);
+
+ // Only handle arithmetic on uniform types. If the types aren't uniform, we need some
+ // promotion/demotion logic. Note that this can create a new inst in place of the old, but since we're
+ // at the transposition step for the old inst, and already have it's aggregate gradient, there's
+ // no need to worry about the 'gradientsMap' being out-of-date
+ // TODO: There are some opportunities for optimization here (otherwise we might be increasing the intermediate
+ // data size unnecessarily)
+ //
+ fwdInst = promoteOperandsToTargetType(builder, fwdInst);
+
+ auto operandType = fwdInst->getOperand(0)->getDataType();
+
switch(fwdInst->getOp())
{
case kIROp_Add:
{
// (Out = dA + dB) -> [(dA += dOut), (dB += dOut)]
return TranspositionResult(
- List<RevAssignment>(
- RevAssignment(
+ List<RevGradient>(
+ RevGradient(
fwdInst->getOperand(0),
- revValue),
- RevAssignment(
+ revValue,
+ fwdInst),
+ RevGradient(
fwdInst->getOperand(1),
- revValue)));
+ revValue,
+ fwdInst)));
}
case kIROp_Sub:
{
// (Out = dA - dB) -> [(dA += dOut), (dB -= dOut)]
return TranspositionResult(
- List<RevAssignment>(
- RevAssignment(
+ List<RevGradient>(
+ RevGradient(
fwdInst->getOperand(0),
- revValue),
- RevAssignment(
+ revValue,
+ fwdInst),
+ RevGradient(
fwdInst->getOperand(1),
builder->emitNeg(
- revValue->getDataType(), revValue))));
+ revValue->getDataType(), revValue),
+ fwdInst)));
}
case kIROp_Mul:
{
@@ -329,19 +646,21 @@ struct DiffTransposePass
{
// (Out = dA * B) -> (dA += B * dOut)
return TranspositionResult(
- List<RevAssignment>(
- RevAssignment(
+ List<RevGradient>(
+ RevGradient(
fwdInst->getOperand(0),
- builder->emitMul(floatType, fwdInst->getOperand(1), revValue))));
+ builder->emitMul(operandType, fwdInst->getOperand(1), revValue),
+ fwdInst)));
}
else if (isDifferentialInst(fwdInst->getOperand(1)))
{
// (Out = A * dB) -> (dB += A * dOut)
return TranspositionResult(
- List<RevAssignment>(
- RevAssignment(
+ List<RevGradient>(
+ RevGradient(
fwdInst->getOperand(1),
- builder->emitMul(floatType, fwdInst->getOperand(0), revValue))));
+ builder->emitMul(operandType, fwdInst->getOperand(0), revValue),
+ fwdInst)));
}
else
{
@@ -354,66 +673,397 @@ struct DiffTransposePass
}
}
- IRInst* emitAggregateValue(IRBuilder* builder, List<IRInst*> values)
+ RevGradient materializeSwizzleGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
{
- // We're handling the case where the types are all float,
- // so we can use a bunch of kIROp_Add insts to add them up.
- // If this is an arbitrary type T, we need to lookup and
- // call T.dadd()
+ List<RevGradient> simpleGradients;
- IRInst* initialValue = builder->getFloatValue(builder->getType(kIROp_FloatType), 0.0);
- if (values.getCount() == 0)
+ for (auto gradient : gradients)
{
- // If there's not values to add up, emit a 0 value.
- return initialValue;
+ // Peek at the fwd-mode swizzle inst to see what type we need to materialize.
+ IRSwizzle* fwdSwizzleInst = as<IRSwizzle>(gradient.fwdGradInst);
+ SLANG_ASSERT(fwdSwizzleInst);
+
+ auto baseType = fwdSwizzleInst->getBase()->getDataType();
+
+ // Assume for now that this is a vector type.
+ SLANG_ASSERT(as<IRVectorType>(baseType));
+
+ IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount();
+ IRType* elementType = as<IRVectorType>(baseType)->getElementType();
+
+ // Must be a concrete integer (auto-diff must always occur after specialization)
+ // For generic code, we would need to generate a for loop.
+ //
+ SLANG_ASSERT(as<IRIntLit>(elementCountInst));
+
+ auto elementCount = as<IRIntLit>(elementCountInst)->getValue();
+
+ // Make a list of 0s
+ List<IRInst*> constructArgs;
+ auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType);
+
+ // Must exist.
+ SLANG_ASSERT(zeroMethod);
+
+ auto zeroValueInst = builder->emitCallInst(elementType, zeroMethod, List<IRInst*>());
+
+ for (Index ii = 0; ii < ((Index)elementCount); ii++)
+ {
+ constructArgs.add(zeroValueInst);
+ }
+
+ // Replace swizzled elements with their gradients.
+ for (Index ii = 0; ii < ((Index)fwdSwizzleInst->getElementCount()); ii++)
+ {
+ auto sourceIndex = ii;
+ auto targetIndexInst = fwdSwizzleInst->getElementIndex(ii);
+ SLANG_ASSERT(as<IRIntLit>(targetIndexInst));
+ auto targetIndex = as<IRIntLit>(targetIndexInst)->getValue();
+
+ // Special-case for when the swizzled output is a single element.
+ if (fwdSwizzleInst->getElementCount() == 1)
+ {
+ constructArgs[(Index)targetIndex] = gradient.revGradInst;
+ }
+ else
+ {
+ auto gradAtIndex = builder->emitElementExtract(elementType, gradient.revGradInst, builder->getIntValue(builder->getIntType(), sourceIndex));
+ constructArgs[(Index)targetIndex] = gradAtIndex;
+ }
+ }
+
+ simpleGradients.add(
+ RevGradient(
+ gradient.targetInst,
+ builder->emitMakeVector(baseType, (UInt)elementCount, constructArgs.getBuffer()),
+ gradient.fwdGradInst));
}
- else if (values.getCount() == 1)
+
+ return materializeSimpleGradients(builder, aggPrimalType, simpleGradients);
+ }
+
+ RevGradient materializeGradientSet(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
+ {
+ switch (gradients[0].flavor)
+ {
+ case RevGradient::Flavor::Simple:
+ return materializeSimpleGradients(builder, aggPrimalType, gradients);
+
+ case RevGradient::Flavor::Swizzle:
+ return materializeSwizzleGradients(builder, aggPrimalType, gradients);
+
+ case RevGradient::Flavor::FieldExtract:
+ return materializeFieldExtractGradients(builder, aggPrimalType, gradients);
+
+ default:
+ SLANG_ASSERT_FAILURE("Unhandled gradient flavor for materialization");
+ }
+ }
+
+ RevGradient materializeFieldExtractGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
+ {
+ // Setup a temporary variable to aggregate gradients.
+ // TODO: We can extend this later to grab an existing ptr to allow aggregation of
+ // gradients across blocks without constructing new variables.
+ // Looking up an existing pointer could also allow chained accesses like x.a.b[1] to directly
+ // write into the specific sub-field that is affected without constructing intermediate vars.
+ //
+ auto revGradVar = builder->emitVar(
+ (IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType));
+
+ // Initialize with T.dzero()
+ auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType);
+
+ builder->emitStore(revGradVar, zeroValueInst);
+
+ Dictionary<IRStructKey*, List<RevGradient>> bucketedGradients;
+ for (auto gradient : gradients)
+ {
+ // Grab the field affected by this gradient.
+ auto fieldExtractInst = as<IRFieldExtract>(gradient.fwdGradInst);
+ SLANG_ASSERT(fieldExtractInst);
+
+ auto structKey = as<IRStructKey>(fieldExtractInst->getField());
+ SLANG_ASSERT(structKey);
+
+ if (!bucketedGradients.ContainsKey(structKey))
+ {
+ bucketedGradients[structKey] = List<RevGradient>();
+ }
+
+ bucketedGradients[structKey].GetValue().add(RevGradient(
+ RevGradient::Flavor::Simple,
+ gradient.targetInst,
+ gradient.revGradInst,
+ gradient.fwdGradInst
+ ));
+
+ }
+
+ for (auto pair : bucketedGradients)
+ {
+ auto subGrads = pair.Value;
+
+ auto primalType = tryGetPrimalTypeFromDiffInst(subGrads[0].fwdGradInst);
+
+ SLANG_ASSERT(primalType);
+
+ // Consruct address to this field in revGradVar.
+ auto revGradTargetAddress = builder->emitFieldAddress(
+ builder->getPtrType(subGrads[0].revGradInst->getDataType()),
+ revGradVar,
+ pair.Key);
+
+ builder->emitStore(revGradTargetAddress, emitAggregateValue(builder, primalType, subGrads));
+ }
+
+ // Load the entire var and return it.
+ return RevGradient(
+ RevGradient::Flavor::Simple,
+ gradients[0].targetInst,
+ builder->emitLoad(revGradVar),
+ nullptr);
+ }
+
+ RevGradient materializeSimpleGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
+ {
+ if (gradients.getCount() == 1)
{
// If there's only one value to add up, just return it in order
// to avoid a stack of 0 + 0 + 0 + ...
- return values[0];
+ return gradients[0];
+ }
+
+ // If there's more than one gradient, aggregate them by adding them up.
+ IRInst* currentValue = nullptr;
+ for (auto gradient : gradients)
+ {
+ if (!currentValue)
+ {
+ currentValue = gradient.revGradInst;
+ continue;
+ }
+
+ currentValue = emitDAddOfDiffInstType(builder, aggPrimalType, currentValue, gradient.revGradInst);
}
- // If there's more than one value, aggregate them by adding them up.
+ return RevGradient(
+ RevGradient::Flavor::Simple,
+ gradients[0].targetInst,
+ currentValue,
+ nullptr);
+ }
+
+ IRInst* emitAggregateDifferentialPair(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> pairGradients)
+ {
+ auto aggPairType = as<IRDifferentialPairType>(aggPrimalType);
+ SLANG_ASSERT(aggPairType);
- SLANG_ASSERT(values[0]->getDataType()->getOp() == kIROp_FloatType);
+ IRType* diffType = (IRType*)pairBuilder.getDiffTypeFromPairType(builder, aggPairType);
- IRInst* currentValue = initialValue;
- for (auto value : values)
+ IRInst* primalInst = nullptr;
+ IRInst* diffInst = nullptr;
+
+ List<RevGradient> gradients;
+ for (auto gradient : pairGradients)
{
- currentValue = builder->emitAdd(
- builder->getType(kIROp_FloatType), currentValue, value);
+ switch (gradient.flavor)
+ {
+ case RevGradient::Flavor::Simple:
+ {
+ // In this case, the gradient is a 'pair' already, but we need to treat the primal element
+ // as if it didn't exist (we simply copy it over)
+ // If we already saw a pair, throw an error since we don't know how to combine to primals.
+ // (i.e. something went wrong prior to this step.)
+ //
+ if (primalInst)
+ {
+ SLANG_UNEXPECTED("Encountered multiple pair types in emitAggregateDifferentialPair");
+ }
+
+ primalInst = builder->emitDifferentialPairGetPrimal(gradient.revGradInst);
+ gradients.add(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ gradient.targetInst,
+ builder->emitDifferentialPairGetDifferential(
+ diffType,
+ gradient.revGradInst),
+ gradient.fwdGradInst));
+ break;
+ }
+
+ case RevGradient::Flavor::GetDifferential:
+ {
+ // In this case, the gradient is the result of transposing a GetDifferential
+ // so we have only the gradient part. Just add it to the list of gradients to aggregate
+ gradients.add(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ gradient.targetInst,
+ gradient.revGradInst,
+ gradient.fwdGradInst));
+ break;
+ }
+ default:
+ SLANG_UNEXPECTED("Unexpected gradient flavor in emitAggregateDifferentialPair");
+ }
}
- return currentValue;
+ // Aggregate only the differentials
+ diffInst = emitAggregateValue(builder, aggPairType->getValueType(), gradients);
+
+ // Pack them back together.
+ return builder->emitMakeDifferentialPair(aggPrimalType, primalInst, diffInst);
}
- void addRevAssignmentForFwdInst(IRInst* fwdInst, IRInst* assignment)
+ IRInst* emitAggregateValue(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
{
- if (!hasRevAssignments(fwdInst))
+ // If we're dealing with the differential-pair types, we need to use a different aggregation method, since
+ // a differential pair is really a 'hybrid' primal-differential type.
+ //
+ if (as<IRDifferentialPairType>(aggPrimalType))
+ return emitAggregateDifferentialPair(builder, aggPrimalType, gradients);
+
+ // Process non-simple gradients into simple gradients.
+ // TODO: This is where we can improve efficiency later.
+ // For instance if we have one gradient each for var.x, var.y and var.z
+ // we can construct one single gradient vector out of the three vectors (i.e. float3(x_grad, y_grad, z_grad))
+ // instead of creating one vector for each gradient and accumulating them
+ // (i.e. float3(x_grad, 0, 0) + float3(0, y_grad, 0) + float3(0, 0, z_grad))
+ // The same concept can be extended for struct and array types (and for any combination of the three)
+ //
+ List<RevGradient> simpleGradients;
{
- assignmentsMap[fwdInst] = List<IRInst*>();
+ // Start by sorting gradients based on flavor.
+ gradients.sort([&](const RevGradient& a, const RevGradient& b) -> bool { return a.flavor < b.flavor; });
+
+ Index ii = 0;
+ while (ii < gradients.getCount())
+ {
+ List<RevGradient> gradientsOfFlavor;
+
+ RevGradient::Flavor currentFlavor = (gradients.getCount() > 0) ? gradients[ii].flavor : RevGradient::Flavor::Simple;
+
+ // Pull all the gradients matching the flavor of the top-most gradeint into a temporary list.
+ for (; ii < gradients.getCount(); ii++)
+ {
+ if (gradients[ii].flavor == currentFlavor)
+ {
+ gradientsOfFlavor.add(gradients[ii]);
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ // Turn the set into a simple gradient.
+ auto simpleGradient = materializeGradientSet(builder, aggPrimalType, gradientsOfFlavor);
+ SLANG_ASSERT(simpleGradient.flavor == RevGradient::Flavor::Simple);
+
+ simpleGradients.add(simpleGradient);
+ }
}
- assignmentsMap[fwdInst].GetValue().add(assignment);
+ if (simpleGradients.getCount() == 0)
+ {
+ // If there are no gradients to add up, check the type and emit a 0/null value.
+ auto aggDiffType = (aggPrimalType) ? diffTypeContext.getDifferentialForType(builder, aggPrimalType) : nullptr;
+ if (aggDiffType != nullptr)
+ {
+ // If type is non-null/non-void, call T.dzero() to produce a 0 gradient.
+ return emitDZeroOfDiffInstType(builder, aggPrimalType);
+ }
+ else
+ {
+ // Otherwise, gradients may not be applicable for this inst. return N/A
+ return nullptr;
+ }
+ }
+ else
+ {
+ return materializeSimpleGradients(builder, aggPrimalType, simpleGradients).revGradInst;
+ }
+ }
+
+ IRType* tryGetPrimalTypeFromDiffInst(IRInst* diffInst)
+ {
+ // Look for differential inst decoration.
+ if (auto diffInstDecoration = diffInst->findDecoration<IRDifferentialInstDecoration>())
+ {
+ return diffInstDecoration->getPrimalType();
+ }
+ else
+ {
+ return nullptr;
+ }
+ }
+
+ IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType)
+ {
+ auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType);
+
+ // Should exist.
+ SLANG_ASSERT(zeroMethod);
+
+ return builder->emitCallInst(
+ (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
+ zeroMethod,
+ List<IRInst*>());
+ }
+
+ IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2)
+ {
+ auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType);
+
+ // Should exist.
+ SLANG_ASSERT(addMethod);
+
+ return builder->emitCallInst(
+ (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
+ addMethod,
+ List<IRInst*>(op1, op2));
}
- List<IRInst*> getRevAssignments(IRInst* fwdInst)
+ void addRevGradientForFwdInst(IRInst* fwdInst, RevGradient assignment)
{
- return assignmentsMap[fwdInst];
+ if (!hasRevGradients(fwdInst))
+ {
+ gradientsMap[fwdInst] = List<RevGradient>();
+ }
+
+ gradientsMap[fwdInst].GetValue().add(assignment);
}
- List<IRInst*> popRevAssignments(IRInst* fwdInst)
+ List<RevGradient> getRevGradients(IRInst* fwdInst)
{
- List<IRInst*> val = assignmentsMap[fwdInst].GetValue();
- assignmentsMap.Remove(fwdInst);
+ return gradientsMap[fwdInst];
+ }
+
+ List<RevGradient> popRevGradients(IRInst* fwdInst)
+ {
+ List<RevGradient> val = gradientsMap[fwdInst].GetValue();
+ gradientsMap.Remove(fwdInst);
return val;
}
- bool hasRevAssignments(IRInst* fwdInst)
+ bool hasRevGradients(IRInst* fwdInst)
{
- return assignmentsMap.ContainsKey(fwdInst);
+ return gradientsMap.ContainsKey(fwdInst);
}
+
+ AutoDiffSharedContext* autodiffContext;
+
+ DifferentiableTypeConformanceContext diffTypeContext;
+
+ DifferentialPairTypeBuilder pairBuilder;
+
+ Dictionary<IRInst*, List<RevGradient>> gradientsMap;
+
+ Dictionary<IRInst*, IRInst*>* primalsMap;
+
+ List<IRInst*> usedPtrs;
};
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 344a930f2..79dec365c 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -83,15 +83,11 @@ struct DiffUnzipPass
if (isDifferentialInst(child) || as<IRTerminatorInst>(child))
{
- auto newInst = cloneInst(&cloneEnv, &diffBuilder, child);
- child->replaceUsesWith(newInst);
- child->removeAndDeallocate();
+ child->insertAtEnd(diffBlock);
}
else
{
- auto newInst = cloneInst(&cloneEnv, &primalBuilder, child);
- child->replaceUsesWith(newInst);
- child->removeAndDeallocate();
+ child->insertAtEnd(primalBlock);
}
child = nextChild;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 5f9ee37fa..5784f60cb 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -735,7 +735,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// Used by the auto-diff pass to mark insts that compute
/// a differential value.
- INST(DifferentialInstDecoration, diffInstDecoration, 0, 0)
+ INST(DifferentialInstDecoration, diffInstDecoration, 1, 0)
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index a8bc04701..1ef0fa4f8 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -603,7 +603,11 @@ struct IRDifferentialInstDecoration : IRDecoration
{
kOp = kIROp_DifferentialInstDecoration
};
+
+ IRUse primalType;
IR_LEAF_ISA(DifferentialInstDecoration)
+
+ IRType* getPrimalType() { return as<IRType>(getOperand(0)); }
};
struct IRBackwardDifferentiableDecoration : IRDecoration
@@ -3370,7 +3374,12 @@ public:
void markInstAsDifferential(IRInst* value)
{
- addDecoration(value, kIROp_DifferentialInstDecoration);
+ addDecoration(value, kIROp_DifferentialInstDecoration, nullptr);
+ }
+
+ void markInstAsDifferential(IRInst* value, IRType* primalType)
+ {
+ addDecoration(value, kIROp_DifferentialInstDecoration, primalType);
}
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)