summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-17 22:19:10 -0800
committerGitHub <noreply@github.com>2023-01-17 22:19:10 -0800
commit86ddb9c452c4f33d09b4f7d4f90a9abad4984071 (patch)
tree833f8bb0fd5df8f0328e20a4568a19081593cedb /source
parenta0994a8da142e54362e9ec1fdb5e5abc708ec3d2 (diff)
First custom backward-derivative test case working. (#2598)
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang18
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp10
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp30
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp176
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h16
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h43
-rw-r--r--source/slang/slang-ir-autodiff.cpp4
-rw-r--r--source/slang/slang-ir-dce.cpp4
-rw-r--r--source/slang/slang-ir-deduplicate-generic-children.cpp43
-rw-r--r--source/slang/slang-ir-deduplicate-generic-children.h12
-rw-r--r--source/slang/slang-ir-link.cpp1
-rw-r--r--source/slang/slang-ir-specialize.cpp96
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp2
-rw-r--r--source/slang/slang-ir-util.cpp116
-rw-r--r--source/slang/slang-ir-util.h23
17 files changed, 451 insertions, 156 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index e19923c80..c732d1a5e 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -223,6 +223,24 @@ DifferentialPair<vector<T, N>> __d_exp_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_UNARY(T, N, __d_exp, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(exp)]
+void __d_exp(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(exp(dpx.p), dOut));
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(exp)]
+void __d_exp_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ dpx = diffPair(
+ dpx.p,
+ vector<T, N>.dmul(exp(dpx.p), dOut));
+}
+
// Absolute value
__generic<T : __BuiltinFloatingPointType>
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index e8fe2beac..ee159b80b 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -880,10 +880,12 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build
auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
SLANG_ASSERT(diffDiffVal);
+ auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType());
+ auto diffPairType = findOrTranscribeDiffInst(builder, origInst->getFullType());
auto primalPair = builder->emitMakeDifferentialPair(
- tryGetDiffPairType(builder, primalVal->getDataType()), primalVal, diffPrimalVal);
+ (IRType*)primalPairType, primalVal, diffPrimalVal);
auto diffPair = builder->emitMakeDifferentialPair(
- tryGetDiffPairType(builder, differentiateType(builder, origInst->getPrimalValue()->getDataType())),
+ (IRType*)diffPairType,
primalDiffVal,
diffDiffVal);
return InstPair(primalPair, diffPair);
@@ -901,7 +903,9 @@ InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder*
auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0));
SLANG_ASSERT(diffVal);
- auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal);
+ auto primalType = findOrTranscribePrimalInst(builder, origInst->getFullType());
+
+ auto primalResult = builder->emitIntrinsicInst((IRType*)primalType, origInst->getOp(), 1, &primalVal);
auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType());
IRInst* diffResultType = nullptr;
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
index dc72ed44a..b3665f27c 100644
--- a/source/slang/slang-ir-autodiff-pairs.cpp
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -124,15 +124,22 @@ struct DiffPairLoweringPass : InstPassBase
}
});
+ OrderedDictionary<IRInst*, IRInst*> pendingReplacements;
processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
{
if (auto loweredType = lowerPairType(builder, inst))
{
- inst->replaceUsesWith(loweredType);
- inst->removeAndDeallocate();
+ pendingReplacements.Add(inst, loweredType);
modified = true;
}
});
+ for (auto replacement : pendingReplacements)
+ {
+ replacement.Key->replaceUsesWith(replacement.Value);
+ replacement.Key->removeAndDeallocate();
+ }
+ autodiffContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+
return modified;
}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index de4fbe182..d3a6137c1 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -36,7 +36,7 @@ namespace Slang
}
else
{
- if (auto diffPairType = tryGetDiffPairType(builder, primalType))
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
{
auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
newParameterTypes.add(inoutDiffPairType);
@@ -311,6 +311,7 @@ namespace Slang
IRFunc* primalFunc = origFunc;
+ maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
differentiableTypeConformanceContext.setFunc(origFunc);
auto diffFunc = builder.createFunc();
@@ -337,7 +338,8 @@ namespace Slang
// Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
{
- cloneDecoration(dictDecor, diffFunc);
+ builder.setInsertBefore(diffFunc->getFirstDecorationOrChild());
+ cloneInst(&cloneEnv, &builder, dictDecor);
}
return InstPair(primalFunc, diffFunc);
@@ -526,7 +528,7 @@ namespace Slang
auto diffOuterGeneric = as<IRGeneric>(findOuterGeneric(diffPropagateFunc));
SLANG_RELEASE_ASSERT(diffOuterGeneric);
- migrationContext.init(fwdParentGeneric, diffOuterGeneric);
+ migrationContext.init(fwdParentGeneric, diffOuterGeneric, diffPropagateFunc);
auto inst = fwdParentGeneric->getFirstBlock()->getFirstOrdinaryInst();
builder->setInsertBefore(diffPropagateFunc);
while (inst)
@@ -580,24 +582,12 @@ namespace Slang
//
IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
- // Clone the primal blocks from unzippedFwdDiffFunc
- // to the reverse-mode function.
- //
- // Special care needs to be taken for the first block since it holds the parameters
-
- // Clone all blocks into a temporary diff func.
- // We're using a temporary sice we don't want to clone decorations,
- // only blocks, and right now there's no provision in slang-ir-clone.h
- // for that.
- //
+
+ // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell.
builder->setInsertInto(diffPropagateFunc->getParent());
- IRCloneEnv subCloneEnv;
- auto tempDiffFunc = as<IRFunc>(cloneInst(&subCloneEnv, builder, unzippedFwdDiffFunc));
-
- // Move blocks to the diffFunc shell.
{
List<IRBlock*> workList;
- for (auto block = tempDiffFunc->getFirstBlock(); block; block = block->getNextBlock())
+ for (auto block = unzippedFwdDiffFunc->getFirstBlock(); block; block = block->getNextBlock())
workList.add(block);
for (auto block : workList)
@@ -623,9 +613,7 @@ namespace Slang
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
diffPropagateFunc, primalFunc, intermediateType);
- // Clean up by deallocating intermediate versions.
- tempDiffFunc->removeAndDeallocate();
- unzippedFwdDiffFunc->removeAndDeallocate();
+ // Clean up by deallocating the tempoarary forward derivative func.
fwdDiffFunc->removeAndDeallocate();
// If primal function is nested in a generic, we want to create separate generics for all the associated things
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 275b40b46..cfbc9638a 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -176,61 +176,56 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI
}
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
-IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRInst* inDiffPairType)
+IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType)
{
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(inDiffPairType->parent);
- auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
- SLANG_ASSERT(diffPairType);
-
- auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
-
// Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffDiffPairType = differentiateType(&builder, diffPairType);
+ auto diffDiffPairType = differentiateType(builder, (IRType*)inOriginalDiffPairType);
+
+ auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalDiffPairType);
// And place it in the synthesized witness table.
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+
// Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
// Record this in the context for future lookups
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalDiffPairType] = table;
return table;
}
-IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
{
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- return builder.getDifferentialPairType(
+ return builder->getDifferentialPairType(
(IRType*)primalType,
witness);
}
-IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType)
+IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType)
{
- IRBuilder builder(sharedBuilder);
- if (!primalType->next)
- builder.setInsertInto(primalType->parent);
- else
- builder.setInsertBefore(primalType->next);
-
- IRInst* witness = as<IRWitnessTable>(
- differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+ auto primalType = lookupPrimalInst(builder, originalType, nullptr);
+ SLANG_RELEASE_ASSERT(primalType);
+ IRInst* witness =
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType);
+ if (witness)
+ {
+ witness = lookupPrimalInst(builder, witness, nullptr);
+ SLANG_RELEASE_ASSERT(witness);
+ }
if (!witness)
{
if (auto primalPairType = as<IRDifferentialPairType>(primalType))
{
- witness = getDifferentialPairWitness(primalPairType);
+ witness = getDifferentialPairWitness(builder, originalType, primalPairType);
}
- else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
+ else if (auto extractExistential = as<IRExtractExistentialType>(originalType))
{
- differentiateExtractExistentialType(&builder, extractExistential, witness);
+ differentiateExtractExistentialType(builder, extractExistential, witness);
}
}
- return builder.getDifferentialPairType(
+ return builder->getDifferentialPairType(
(IRType*)primalType,
witness);
}
@@ -242,7 +237,8 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
primalType->getParent() && primalType->getParent()->getParent() &&
primalType->getParent()->getParent()->getOp() == kIROp_Generic)
{
- return (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType);
+ auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType);
+ return (IRType*)findOrTranscribePrimalInst(builder, diffType);
}
return (IRType*)transcribe(builder, origType);
}
@@ -254,10 +250,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
origType->getOp(),
differentiateType(builder, ptrType->getValueType()));
- // If there is an explicit primal version of this type in the local scope, load that
- // otherwise use the original type.
- //
- IRInst* primalType = lookupPrimalInst(builder, origType, origType);
+ auto primalType = maybeCloneForPrimalInst(builder, origType);
// Special case certain compound types (PtrType, FuncType, etc..)
// otherwise try to lookup a differential definition for the given type.
@@ -288,6 +281,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
{
auto primalPairType = as<IRDifferentialPairType>(primalType);
return getOrCreateDiffPairType(
+ builder,
pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
}
@@ -409,6 +403,44 @@ InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBui
return InstPair(primalResult, nullptr);
}
+void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc)
+{
+ auto decor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
+ if (decor)
+ return;
+ // A differentiable func must have `IRDifferentiableTypeDictionaryDecoration`, except it has a
+ // `IRUserDefinedBackwardDerivativeDecoration`.
+ auto udfDecor = origFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>();
+ SLANG_RELEASE_ASSERT(udfDecor);
+ // We need to migrate the dictionary from the backward derivative func so we can properly
+ // differentiate the function header.
+ IRBuilder subBuilder = *builder;
+ subBuilder.setInsertBefore(origFunc);
+
+ auto derivative = udfDecor->getBackwardDerivativeFunc();
+ if (auto specialize = as<IRSpecialize>(derivative))
+ {
+ auto derivativeGeneric = cast<IRGeneric>(specialize->getBase());
+ GenericChildrenMigrationContext migrationContext;
+ migrationContext.init(derivativeGeneric, cast<IRGeneric>(findOuterGeneric(origFunc)), origFunc);
+ auto derivativeFunc = findGenericReturnVal(derivativeGeneric);
+ auto derivativeBlock = cast<IRBlock>(derivativeFunc->getParent());
+ for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc;
+ dInst = dInst->getNextInst())
+ {
+ migrationContext.cloneInst(&subBuilder, dInst);
+ }
+ auto udfDictDecor = derivativeFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
+ SLANG_RELEASE_ASSERT(udfDictDecor);
+ subBuilder.setInsertBefore(origFunc->getFirstDecorationOrChild());
+ migrationContext.cloneInst(&subBuilder, udfDictDecor);
+ eliminateDeadCode(origFunc->getParent());
+ }
+ else
+ {
+ cloneDecoration(udfDecor, origFunc);
+ }
+}
IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& outWitnessTable)
{
@@ -435,21 +467,21 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
return nullptr;
}
-IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
+IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* originalType)
{
// If this is a PtrType (out, inout, etc..), then create diff pair from
// value type and re-apply the appropropriate PtrType wrapper.
//
- if (auto origPtrType = as<IRPtrTypeBase>(primalType))
+ if (auto origPtrType = as<IRPtrTypeBase>(originalType))
{
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
- return builder->getPtrType(primalType->getOp(), diffPairValueType);
+ return builder->getPtrType(originalType->getOp(), diffPairValueType);
else
return nullptr;
}
- auto diffType = differentiateType(builder, primalType);
+ auto diffType = differentiateType(builder, originalType);
if (diffType)
- return (IRType*)getOrCreateDiffPairType(primalType);
+ return (IRType*)getOrCreateDiffPairType(builder, originalType);
return nullptr;
}
@@ -628,7 +660,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
}
}
-InstPair AutoDiffTranscriberBase::transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBlock* origBlock, HashSet<IRInst*>& instsToSkip)
{
IRBuilder subBuilder(builder->getSharedBuilder());
subBuilder.setInsertLoc(builder->getInsertLoc());
@@ -653,7 +685,14 @@ InstPair AutoDiffTranscriberBase::transcribeBlock(IRBuilder* builder, IRBlock* o
// derivative code.
//
for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ {
+ if (instsToSkip.Contains(child))
+ {
+ continue;
+ }
+
this->transcribe(&subBuilder, child);
+ }
return InstPair(diffBlock, diffBlock);
}
@@ -713,12 +752,63 @@ InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn*
}
}
+static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashSet<IRInst*>& outInstsToSkip)
+{
+ for (;;)
+ {
+ bool changed = false;
+ for (auto inst = origGeneric->getFirstBlock()->getFirstOrdinaryInst(); inst;
+ inst = inst->getNextInst())
+ {
+ // If an inst is only referenced by a UserDefinedDerivativeDecoration, we need to skip
+ // its transcription.
+ switch (inst->getOp())
+ {
+ case kIROp_Return:
+ continue;
+ default:
+ break;
+ }
+
+ bool hasRelaventUse = false;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ switch (use->getUser()->getOp())
+ {
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDerivativePrimalDecoration:
+ case kIROp_BackwardDerivativePropagateDecoration:
+ break;
+ default:
+ if (!outInstsToSkip.Contains(use->getUser()))
+ {
+ hasRelaventUse = true;
+ }
+ break;
+ }
+ }
+ if (!hasRelaventUse)
+ {
+ if (outInstsToSkip.Add(inst))
+ {
+ changed = true;
+ }
+ }
+ }
+ if (!changed)
+ break;
+ }
+}
+
// Transcribe a generic definition
InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
{
auto innerVal = findInnerMostGenericReturnVal(origGeneric);
if (auto innerFunc = as<IRFunc>(innerVal))
{
+ maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc);
differentiableTypeConformanceContext.setFunc(innerFunc);
}
else if (auto funcType = as<IRFuncType>(innerVal))
@@ -752,10 +842,14 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
diffGeneric->setFullType(diffType);
+ HashSet<IRInst*> instsToSkip;
+ _markGenericChildrenWithoutRelaventUse(origGeneric, instsToSkip);
+
// Transcribe children from origFunc into diffFunc.
builder.setInsertInto(diffGeneric);
- for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribe(&builder, block);
+ auto transcribedBlock = transcribeBlockImpl(&builder, origGeneric->getFirstBlock(), instsToSkip);
+ mapPrimalInst(origGeneric->getFirstBlock(), transcribedBlock.primal);
+ mapDifferentialInst(origGeneric->getFirstBlock(), transcribedBlock.differential);
return InstPair(primalGeneric, diffGeneric);
}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index 2d980145e..b0397069b 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -96,12 +96,14 @@ struct AutoDiffTranscriberBase
InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst);
+ void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc);
+
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
- IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType);
+ IRWitnessTable* getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType);
- IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness);
+ IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness);
- IRType* getOrCreateDiffPairType(IRInst* primalType);
+ IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType);
IRType* differentiateType(IRBuilder* builder, IRType* origType);
@@ -121,7 +123,13 @@ struct AutoDiffTranscriberBase
InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst);
- InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock);
+ InstPair transcribeBlockImpl(IRBuilder* builder, IRBlock* origBlock, HashSet<IRInst*>& instsToSkip);
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ {
+ HashSet<IRInst*> emptySet;
+ return transcribeBlockImpl(builder, origBlock, emptySet);
+ }
// Transcribe a generic definition
InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index a95fd7b9b..640041ecf 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -76,7 +76,7 @@ struct ExtractPrimalFuncContext
outIntermediateType = createIntermediateType(destFunc);
GenericChildrenMigrationContext migrationContext;
- migrationContext.init(as<IRGeneric>(findOuterGeneric(originalFunc)), as<IRGeneric>(findOuterGeneric(destFunc)));
+ migrationContext.init(as<IRGeneric>(findOuterGeneric(originalFunc)), as<IRGeneric>(findOuterGeneric(destFunc)), destFunc);
originalFuncType = as<IRFuncType>(originalFunc->getDataType());
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index e616578c1..c525191a3 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -14,39 +14,6 @@
namespace Slang
{
-struct GenericChildrenMigrationContext
-{
- IRCloneEnv cloneEnv;
- IRGeneric* srcGeneric;
- void init(IRGeneric* genericSrc, IRGeneric* genericDst)
- {
- srcGeneric = genericSrc;
- if (!genericSrc)
- return;
- auto srcParam = genericSrc->getFirstBlock()->getFirstParam();
- auto dstParam = genericDst->getFirstBlock()->getFirstParam();
- while (srcParam && dstParam)
- {
- cloneEnv.mapOldValToNew[srcParam] = dstParam;
- srcParam = srcParam->getNextParam();
- dstParam = dstParam->getNextParam();
- }
- cloneEnv.mapOldValToNew[genericSrc] = genericDst;
- cloneEnv.mapOldValToNew[genericSrc->getFirstBlock()] = genericDst->getFirstBlock();
- }
-
- IRInst* cloneInst(IRBuilder* builder, IRInst* src)
- {
- if (!srcGeneric)
- return src;
- if (findOuterGeneric(src) == srcGeneric)
- {
- return Slang::cloneInst(&cloneEnv, builder, src);
- }
- return src;
- }
-};
-
struct DiffUnzipPass
{
AutoDiffSharedContext* autodiffContext;
@@ -210,10 +177,14 @@ struct DiffUnzipPass
auto func = findSpecializeReturnVal(specialize);
auto outerGen = findOuterGeneric(func);
intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen);
- intermediateType = specializeWithGeneric(
- *primalBuilder,
+ List<IRInst*> args;
+ for (UInt i = 0; i < specialize->getArgCount(); i++)
+ args.add(specialize->getArg(i));
+ intermediateType = primalBuilder->emitSpecializeInst(
+ primalBuilder->getTypeKind(),
intermediateType,
- as<IRGeneric>(findOuterGeneric(primalBuilder->getInsertLoc().getParent())));
+ args.getCount(),
+ args.getBuffer());
}
else
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 7182375de..363006f58 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -31,11 +31,11 @@ static IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requi
return entry->getSatisfyingVal();
}
}
- else if (auto witnessTableParam = as<IRParam>(witness))
+ else
{
return builder->emitLookupInterfaceMethodInst(
builder->getTypeKind(),
- witnessTableParam,
+ witness,
requirementKey);
}
return nullptr;
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index ae04acbd1..33e5b3cb4 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -314,11 +314,15 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o
{
if (inst->findDecoration<IRForwardDerivativeDecoration>())
return true;
+ if (inst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ return true;
if (auto genInst = as<IRGeneric>(inst))
{
auto inner = findInnerMostGenericReturnVal(genInst);
if (inner->findDecoration<IRForwardDerivativeDecoration>())
return true;
+ if (inner->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ return true;
}
}
}
diff --git a/source/slang/slang-ir-deduplicate-generic-children.cpp b/source/slang/slang-ir-deduplicate-generic-children.cpp
new file mode 100644
index 000000000..f933f77cc
--- /dev/null
+++ b/source/slang/slang-ir-deduplicate-generic-children.cpp
@@ -0,0 +1,43 @@
+#include "slang-ir-deduplicate-generic-children.h"
+#include "slang-ir.h"
+#include "slang-ir-clone.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+
+bool deduplicateGenericChildren(IRGeneric* genericInst)
+{
+ bool changed = false;
+ GenericChildrenMigrationContext ctx;
+ ctx.init(genericInst, genericInst, nullptr);
+ List<IRInst*> instsToRemove;
+ for (auto inst = genericInst->getFirstBlock()->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ auto deduped = ctx.deduplicate(inst);
+ if (deduped != inst)
+ {
+ inst->replaceUsesWith(deduped);
+ instsToRemove.add(inst);
+ changed = true;
+ }
+ }
+ for (auto inst : instsToRemove)
+ inst->removeAndDeallocate();
+ return changed;
+}
+
+bool deduplicateGenericChildren(IRModule* module)
+{
+ bool changed = false;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto gen = as<IRGeneric>(inst))
+ {
+ changed |= deduplicateGenericChildren(gen);
+ }
+ }
+ return changed;
+}
+
+}
diff --git a/source/slang/slang-ir-deduplicate-generic-children.h b/source/slang/slang-ir-deduplicate-generic-children.h
new file mode 100644
index 000000000..71cfafd4e
--- /dev/null
+++ b/source/slang/slang-ir-deduplicate-generic-children.h
@@ -0,0 +1,12 @@
+// slang-ir-deduplicate-generic-children.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ struct IRGeneric;
+
+ // Deduplicate insts inside a generic.
+ bool deduplicateGenericChildren(IRModule* module);
+ bool deduplicateGenericChildren(IRGeneric* genericInst);
+}
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index ea4bfb7b9..80f974536 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -432,6 +432,7 @@ static void cloneExtraDecorationsFromInst(
case kIROp_PublicDecoration:
case kIROp_SequentialIDDecoration:
case kIROp_ForwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
case kIROp_IntrinsicOpDecoration:
if (!clonedInst->findDecorationImpl(decoration->getOp()))
{
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index cbb5ccf09..747d0ccdd 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -390,57 +390,61 @@ struct SpecializationContext
auto genericReturnVal = findInnerMostGenericReturnVal(genericVal);
if (genericReturnVal->findDecoration<IRTargetIntrinsicDecoration>())
{
- if (auto customDiffRef = genericReturnVal->findDecoration<IRForwardDerivativeDecoration>())
+ for (auto decor : genericReturnVal->getDecorations())
{
- // If we already have a diff func on this specialize, skip.
- if (auto specDiffRef = specInst->findDecoration<IRForwardDerivativeDecoration>())
+ if (decor->getOp() == kIROp_ForwardDerivativeDecoration ||
+ decor->getOp() == kIROp_UserDefinedBackwardDerivativeDecoration)
{
- return false;
- }
-
- auto specDiffFunc = as<IRSpecialize>(customDiffRef->getForwardDerivativeFunc());
+ // If we already have a diff func on this specialize, skip.
+ if (auto specDiffRef = specInst->findDecorationImpl(decor->getOp()))
+ {
+ return false;
+ }
- // If the base is specialized, the JVP version must be also be a specialized
- // generic.
- //
- SLANG_RELEASE_ASSERT(specDiffFunc);
+ auto specDiffFunc = as<IRSpecialize>(decor->getOperand(0));
- // Build specialization arguments from specInst.
- // Note that if we've reached this point, we can safely assume
- // that our args are fully specialized/concrete.
- //
- UCount argCount = specInst->getArgCount();
- List<IRInst*> args;
- for (UIndex ii = 0; ii < argCount; ii++)
- args.add(specInst->getArg(ii));
-
- IRBuilder builder(&sharedBuilderStorage);
-
- // Specialize the custom JVP function type with the original arguments.
- builder.setInsertInto(module);
- auto newDiffFuncType = builder.emitSpecializeInst(
- builder.getTypeKind(),
- specDiffFunc->getBase()->getDataType(),
- argCount,
- args.getBuffer());
-
- // Specialize the custom JVP function with the original arguments.
- builder.setInsertBefore(specInst);
- auto newDiffFunc = builder.emitSpecializeInst(
- (IRType*) newDiffFuncType,
- specDiffFunc->getBase(),
- argCount,
- args.getBuffer());
-
- // Add the new spec insts to the list so they get specialized with
- // the usual logic.
- //
- addToWorkList(newDiffFuncType);
- addToWorkList(newDiffFunc);
-
- builder.addForwardDerivativeDecoration(specInst, newDiffFunc);
+ // If the base is specialized, the JVP version must be also be a specialized
+ // generic.
+ //
+ SLANG_RELEASE_ASSERT(specDiffFunc);
- return true;
+ // Build specialization arguments from specInst.
+ // Note that if we've reached this point, we can safely assume
+ // that our args are fully specialized/concrete.
+ //
+ UCount argCount = specInst->getArgCount();
+ List<IRInst*> args;
+ for (UIndex ii = 0; ii < argCount; ii++)
+ args.add(specInst->getArg(ii));
+
+ IRBuilder builder(&sharedBuilderStorage);
+
+ // Specialize the custom derivative function type with the original arguments.
+ builder.setInsertInto(module);
+ auto newDiffFuncType = builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ specDiffFunc->getBase()->getDataType(),
+ argCount,
+ args.getBuffer());
+
+ // Specialize the custom derivative function with the original arguments.
+ builder.setInsertBefore(specInst);
+ auto newDiffFunc = builder.emitSpecializeInst(
+ (IRType*)newDiffFuncType,
+ specDiffFunc->getBase(),
+ argCount,
+ args.getBuffer());
+
+ // Add the new spec insts to the list so they get specialized with
+ // the usual logic.
+ //
+ addToWorkList(newDiffFuncType);
+ addToWorkList(newDiffFunc);
+
+ builder.addDecoration(specInst, decor->getOp(), newDiffFunc);
+
+ return true;
+ }
}
}
return false;
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index 938094551..fd5f41f49 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -7,6 +7,7 @@
#include "slang-ir-simplify-cfg.h"
#include "slang-ir-peephole.h"
#include "slang-ir-hoist-constants.h"
+#include "slang-ir-deduplicate-generic-children.h"
#include "slang-ir-remove-unused-generic-param.h"
namespace Slang
@@ -22,6 +23,7 @@ namespace Slang
{
changed = false;
changed |= hoistConstants(module);
+ changed |= deduplicateGenericChildren(module);
changed |= applySparseConditionalConstantPropagation(module);
changed |= peepholeOptimize(module);
changed |= simplifyCFG(module);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index fb465f638..881f041c0 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -219,4 +219,120 @@ void moveInstChildren(IRInst* dest, IRInst* src)
}
}
+struct GenericChildrenMigrationContextImpl
+{
+ IRCloneEnv cloneEnv;
+ IRGeneric* srcGeneric;
+ IRGeneric* dstGeneric;
+ Dictionary<IRInstKey, IRInst*> deduplicateMap;
+
+ void init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore)
+ {
+ srcGeneric = genericSrc;
+ dstGeneric = genericDst;
+
+ if (!genericSrc)
+ return;
+ auto srcParam = genericSrc->getFirstBlock()->getFirstParam();
+ auto dstParam = genericDst->getFirstBlock()->getFirstParam();
+ while (srcParam && dstParam)
+ {
+ cloneEnv.mapOldValToNew[srcParam] = dstParam;
+ srcParam = srcParam->getNextParam();
+ dstParam = dstParam->getNextParam();
+ }
+ cloneEnv.mapOldValToNew[genericSrc] = genericDst;
+ cloneEnv.mapOldValToNew[genericSrc->getFirstBlock()] = genericDst->getFirstBlock();
+
+ if (insertBefore)
+ {
+ for (auto inst = genericDst->getFirstBlock()->getFirstOrdinaryInst();
+ inst && inst != insertBefore;
+ inst = inst->getNextInst())
+ {
+ IRInstKey key = { inst };
+ deduplicateMap.AddIfNotExists(key, inst);
+ }
+ }
+ }
+
+ IRInst* deduplicate(IRInst* value)
+ {
+ if (!value) return nullptr;
+ if (value->getParent() != dstGeneric->getFirstBlock())
+ return value;
+ switch (value->getOp())
+ {
+ case kIROp_Param:
+ case kIROp_StructType:
+ case kIROp_StructKey:
+ case kIROp_InterfaceType:
+ case kIROp_ClassType:
+ case kIROp_Func:
+ case kIROp_Generic:
+ return value;
+ default:
+ break;
+ }
+ if (as<IRConstant>(value))
+ return value;
+
+ for (UInt i = 0; i < value->getOperandCount(); i++)
+ {
+ value->setOperand(i, deduplicate(value->getOperand(i)));
+ }
+ value->setFullType((IRType*)deduplicate(value->getFullType()));
+ IRInstKey key = { value };
+ if (auto newValue = deduplicateMap.TryGetValue(key))
+ return *newValue;
+ deduplicateMap[key] = value;
+ return value;
+ }
+
+ IRInst* cloneInst(IRBuilder* builder, IRInst* src)
+ {
+ if (!srcGeneric)
+ return src;
+ if (findOuterGeneric(src) == srcGeneric)
+ {
+ auto cloned = Slang::cloneInst(&cloneEnv, builder, src);
+ auto deduplicated = deduplicate(cloned);
+ if (deduplicated != cloned)
+ cloneEnv.mapOldValToNew[src] = deduplicated;
+ return deduplicated;
+ }
+ return src;
+ }
+};
+
+GenericChildrenMigrationContext::GenericChildrenMigrationContext()
+{
+ impl = new GenericChildrenMigrationContextImpl();
+}
+
+GenericChildrenMigrationContext::~GenericChildrenMigrationContext()
+{
+ delete impl;
+}
+
+IRCloneEnv* GenericChildrenMigrationContext::getCloneEnv()
+{
+ return &impl->cloneEnv;
+}
+
+void GenericChildrenMigrationContext::init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore)
+{
+ impl->init(genericSrc, genericDst, insertBefore);
+}
+
+IRInst* GenericChildrenMigrationContext::deduplicate(IRInst* value)
+{
+ return impl->deduplicate(value);
+}
+
+IRInst* GenericChildrenMigrationContext::cloneInst(IRBuilder* builder, IRInst* src)
+{
+ return impl->cloneInst(builder, src);
+}
+
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 4885dcd96..92446138f 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -8,6 +8,29 @@
namespace Slang
{
+struct GenericChildrenMigrationContextImpl;
+struct IRCloneEnv;
+
+// A helper class to clone children insts to a different generic parent that has equivalent set of
+// generic parameters. The clone will take care of substitution of equivalent generic parameters and
+// intermediate values between the two generic parents.
+struct GenericChildrenMigrationContext : public RefObject
+{
+private:
+ GenericChildrenMigrationContextImpl* impl;
+
+public:
+ IRCloneEnv* getCloneEnv();
+
+ GenericChildrenMigrationContext();
+ ~GenericChildrenMigrationContext();
+
+ void init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore);
+
+ IRInst* deduplicate(IRInst* value);
+
+ IRInst* cloneInst(IRBuilder* builder, IRInst* src);
+};
bool isPtrToClassType(IRInst* type);