summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp26
-rw-r--r--source/slang/slang-emit.cpp28
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp143
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp95
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp141
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp49
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp112
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h210
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp16
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h91
-rw-r--r--source/slang/slang-ir-autodiff.cpp1177
-rw-r--r--source/slang/slang-ir-autodiff.h73
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h12
-rw-r--r--source/slang/slang-ir-link.cpp3
-rw-r--r--source/slang/slang-ir-lower-generics.cpp2
-rw-r--r--source/slang/slang-ir-specialize.cpp18
-rw-r--r--source/slang/slang-ir-specialize.h14
-rw-r--r--source/slang/slang-lower-to-ir.cpp75
20 files changed, 1764 insertions, 526 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 3667a36ba..04d5b7a75 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -9247,12 +9247,16 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl*
if (!decl->hasModifier<NoDiffThisAttribute>())
{
// Build decl-ref-type from interface.
- auto interfaceType =
- DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
+ auto thisType = DeclRefType::create(
+ m_astBuilder,
+ createDefaultSubstitutionsIfNeeded(
+ m_astBuilder,
+ this,
+ makeDeclRef(interfaceDecl->getThisTypeDecl())));
// If the interface is differentiable, make the this type a pair.
- if (tryGetDifferentialType(getASTBuilder(), interfaceType))
- reqDecl->diffThisType = getDifferentialPairType(interfaceType);
+ if (tryGetDifferentialType(getASTBuilder(), thisType))
+ reqDecl->diffThisType = getDifferentialPairType(thisType);
}
auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
@@ -9277,13 +9281,17 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl*
reqDecl->parentDecl = interfaceDecl;
if (!decl->hasModifier<NoDiffThisAttribute>())
{
- // Build decl-ref-type from interface.
- auto interfaceType =
- DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
+ // Build decl-ref-type for this-type.
+ auto thisType = DeclRefType::create(
+ m_astBuilder,
+ createDefaultSubstitutionsIfNeeded(
+ m_astBuilder,
+ this,
+ makeDeclRef(interfaceDecl->getThisTypeDecl())));
// If the interface is differentiable, make the this type a pair.
- if (tryGetDifferentialType(getASTBuilder(), interfaceType))
- reqDecl->diffThisType = getDifferentialPairType(interfaceType);
+ if (tryGetDifferentialType(getASTBuilder(), thisType))
+ reqDecl->diffThisType = getDifferentialPairType(thisType);
}
auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index b9217de41..cd1b177b2 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -815,7 +815,18 @@ Result linkAndOptimizeIR(
bool changed = false;
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
if (!codeGenContext->isSpecializationDisabled())
- changed |= specializeModule(targetProgram, irModule, codeGenContext->getSink());
+ {
+ // Pre-autodiff, we will attempt to specialize as much as possible.
+ //
+ // Note: Lowered dynamic-dispatch code cannot be differentiated correctly due to
+ // missing information, so we defer that to after the auto-dff step.
+ //
+ SpecializationOptions specOptions;
+ specOptions.lowerWitnessLookups = false;
+ changed |=
+ specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions);
+ }
+
if (codeGenContext->getSink()->getErrorCount() != 0)
return SLANG_FAIL;
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
@@ -867,9 +878,20 @@ Result linkAndOptimizeIR(
reportCheckpointIntermediates(codeGenContext, sink, irModule);
// Finalization is always run so AD-related instructions can be removed,
- // even the AD pass itself is not run.
+ // even if the AD pass itself is not run.
//
finalizeAutoDiffPass(targetProgram, irModule);
+ eliminateDeadCode(irModule, deadCodeEliminationOptions);
+
+ // After auto-diff, we can perform more aggressive specialization with dynamic-dispatch
+ // lowering.
+ //
+ if (!codeGenContext->isSpecializationDisabled())
+ {
+ SpecializationOptions specOptions;
+ specOptions.lowerWitnessLookups = true;
+ specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions);
+ }
finalizeSpecialization(irModule);
@@ -930,6 +952,8 @@ Result linkAndOptimizeIR(
validateIRModuleIfEnabled(codeGenContext, irModule);
+ inferAnyValueSizeWhereNecessary(targetProgram, irModule);
+
// If we have any witness tables that are marked as `KeepAlive`,
// but are not used for dynamic dispatch, unpin them so we don't
// do unnecessary work to lower them.
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 9f26f9d55..30c14f706 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -160,6 +160,40 @@ InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRIns
return InstPair(primalVal, diffVal);
}
+InstPair ForwardDiffTranscriber::transcribeDifferentiableTypeAnnotation(
+ IRBuilder* builder,
+ IRInst* origInst)
+{
+ auto primalAnnotation =
+ as<IRDifferentiableTypeAnnotation>(maybeCloneForPrimalInst(builder, origInst));
+
+ IRDifferentiableTypeAnnotation* annotation = as<IRDifferentiableTypeAnnotation>(origInst);
+
+ differentiableTypeConformanceContext.addTypeToDictionary(
+ (IRType*)primalAnnotation->getBaseType(),
+ primalAnnotation->getWitness());
+
+ auto diffType = differentiateType(builder, (IRType*)annotation->getBaseType());
+ if (!diffType)
+ return InstPair(primalAnnotation, nullptr);
+
+ auto diffTypeDiffWitness =
+ tryGetDifferentiableWitness(builder, diffType, DiffConformanceKind::Any);
+
+ IRInst* args[] = {diffType, diffTypeDiffWitness};
+
+ auto diffAnnotation = builder->emitIntrinsicInst(
+ builder->getVoidType(),
+ kIROp_DifferentiableTypeAnnotation,
+ 2,
+ args);
+
+ builder->markInstAsPrimal(diffAnnotation);
+ builder->markInstAsPrimal(primalAnnotation);
+
+ return InstPair(primalAnnotation, diffAnnotation);
+}
+
InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
{
if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
@@ -752,9 +786,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
auto pairValType = as<IRDifferentialPairTypeBase>(
pairPtrType ? pairPtrType->getValueType() : pairType);
- auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(
- &argBuilder,
- pairValType);
+ auto diffType = differentiateType(&argBuilder, primalType);
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
@@ -795,7 +827,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (diffArg)
{
auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential(
- (IRType*)diffType,
+ (IRType*)as<IRPtrTypeBase>(diffType)->getValueType(),
newVal);
markDiffTypeInst(
&afterBuilder,
@@ -827,17 +859,72 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
}
}
+
+ {
+ // --WORKAROUND--
+ // This is a temporary workaround for a very specific case..
+ //
+ // If all the following are true:
+ // 1. the parameter type expects a differential pair,
+ // 2. the argument is derived from a no_diff type, and
+ // 3. the argument type is a run-time type (i.e. extract_existential_type),
+ // then we need to generate a differential 0, but the IR has no
+ // information on the diff witness.
+ //
+ // We will bypass the conformance system & brute-force the lookup for the interface
+ // keys, but the proper fix is to lower this key mapping during `no_diff` lowering.
+ //
+
+ // Condition 1
+ if (differentiableTypeConformanceContext.isDifferentiableType((originalParamType)))
+ {
+ // Condition 3
+ if (auto extractExistentialType = as<IRExtractExistentialType>(primalType))
+ {
+ // Condition 2
+ if (isNoDiffType(extractExistentialType->getOperand(0)->getDataType()))
+ {
+ // Force-differentiate the type (this will perform a search for the witness
+ // without going through the diff-type annotation list)
+ //
+ IRInst* witnessTable = nullptr;
+ auto diffType = differentiateExtractExistentialType(
+ &argBuilder,
+ extractExistentialType,
+ witnessTable);
+
+ auto pairType =
+ getOrCreateDiffPairType(&argBuilder, primalType, witnessTable);
+ auto zeroMethod = argBuilder.emitLookupInterfaceMethodInst(
+ differentiableTypeConformanceContext.sharedContext->zeroMethodType,
+ witnessTable,
+ differentiableTypeConformanceContext.sharedContext
+ ->zeroMethodStructKey);
+ auto diffZero = argBuilder.emitCallInst(diffType, zeroMethod, 0, nullptr);
+ auto diffPair =
+ argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffZero);
+
+ args.add(diffPair);
+ continue;
+ }
+ }
+ }
+ }
+
// Argument is not differentiable.
// Add original/primal argument.
args.add(primalArg);
}
IRType* diffReturnType = nullptr;
- diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType());
+ auto primalReturnType =
+ (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());
+
+ diffReturnType = tryGetDiffPairType(&argBuilder, primalReturnType);
if (!diffReturnType)
{
- diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());
+ diffReturnType = primalReturnType;
}
auto callInst = argBuilder.emitCallInst(diffReturnType, diffCallee, args);
@@ -1035,6 +1122,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(
IRInst* diffBase = nullptr;
if (instMapD.tryGetValue(origSpecialize->getBase(), diffBase))
{
+ auto diffType = differentiateType(builder, origSpecialize->getFullType());
if (diffBase)
{
List<IRInst*> args;
@@ -1042,11 +1130,8 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(
{
args.add(primalSpecialize->getArg(i));
}
- auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(),
- diffBase,
- args.getCount(),
- args.getBuffer());
+ auto diffSpecialize =
+ builder->emitSpecializeInst(diffType, diffBase, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
else
@@ -1572,7 +1657,24 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc());
}
- auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ IRFunc* diffFunc = nullptr;
+
+ // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the
+ // insert location unchanged). If we're transcribing it as a declaration, we should
+ // insert into the module.
+ //
+ auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc));
+ if (!origOuterGen || findInnerMostGenericReturnVal(origOuterGen) != origFunc)
+ {
+ // Dealing with a declaration.. insert into module scope.
+ IRBuilder subBuilder = *inBuilder;
+ subBuilder.setInsertInto(inBuilder->getModule());
+ diffFunc = transcribeFuncHeaderImpl(&subBuilder, origFunc);
+ }
+ else
+ {
+ diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ }
if (auto outerGen = findOuterGeneric(diffFunc))
{
@@ -1605,7 +1707,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
IRBuilder builder = *inBuilder;
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
-
differentiableTypeConformanceContext.setFunc(origFunc);
auto diffFunc = builder.createFunc();
@@ -1632,12 +1733,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
// Transfer checkpoint hint decorations
copyCheckpointHints(&builder, origFunc, diffFunc);
-
- // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
- if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule());
- }
return diffFunc;
}
@@ -2012,6 +2107,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_Reinterpret:
return transcribeReinterpret(builder, origInst);
+ case kIROp_DifferentiableTypeAnnotation:
+ return transcribeDifferentiableTypeAnnotation(builder, origInst);
+
// Differentiable insts that should have been lowered in a previous pass.
case kIROp_SwizzledStore:
{
@@ -2138,13 +2236,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(
if (as<IRDifferentialPairType>(diffPairType) || as<IRDifferentialPtrPairType>(diffPairType))
{
+ auto diffType = differentiateType(builder, (IRType*)origParam->getFullType());
return InstPair(
builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(
- builder,
- as<IRDifferentialPairTypeBase>(diffPairType)),
- diffPairParam));
+ builder->emitDifferentialPairGetDifferential(diffType, diffPairParam));
}
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
{
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 39e195464..09b3f14b8 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -94,6 +94,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeReinterpret(IRBuilder* builder, IRInst* origInst);
+ InstPair transcribeDifferentiableTypeAnnotation(IRBuilder* builder, IRInst* origInst);
+
virtual IRFuncType* differentiateFunctionType(
IRBuilder* builder,
IRInst* func,
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
index a49a2f762..c732263f0 100644
--- a/source/slang/slang-ir-autodiff-pairs.cpp
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -13,7 +13,6 @@ struct DiffPairLoweringPass : InstPassBase
IRInst* lowerPairType(IRBuilder* builder, IRType* pairType)
{
- builder->setInsertBefore(pairType);
auto loweredPairType = pairBuilder->lowerDiffPairType(builder, pairType);
return loweredPairType;
}
@@ -22,26 +21,81 @@ struct DiffPairLoweringPass : InstPassBase
{
if (auto makePairInst = as<IRMakeDifferentialPairBase>(inst))
{
- bool isTrivial = false;
auto pairType = as<IRDifferentialPairTypeBase>(makePairInst->getDataType());
- if (auto loweredPairType = lowerPairType(builder, pairType))
+ builder->setInsertBefore(makePairInst);
+ if (auto loweredPairType = (IRType*)lowerPairType(builder, pairType))
{
- builder->setInsertBefore(makePairInst);
- IRInst* result = nullptr;
- if (isTrivial)
+ if (isRuntimeType(pairType->getValueType()))
{
- result = makePairInst->getPrimalValue();
+ auto result = pairBuilder->emitExistentialMakePair(
+ builder,
+ loweredPairType,
+ makePairInst->getPrimalValue(),
+ makePairInst->getDifferentialValue());
+
+ makePairInst->replaceUsesWith(result);
+ makePairInst->removeAndDeallocate();
+ return result;
+ }
+ else if (auto typePack = as<IRTypePack>(pairType->getValueType()))
+ {
+ // TODO: Do we need to flatten the packs here?
+
+ // If the type is a type pack, then the value must be in
+ // MakePair(MakeValuePack(p_0, p_1, ...), MakeValuePack(d_0, d_1, ...)) form
+ // Convert it to MakeValuePack(MakePair(p_0, d_0), MakePair(p_1, d_1), ...)
+ // and lower each MakePair.
+ //
+
+ // Primal pack
+ auto primalValue = as<IRMakeValuePack>(makePairInst->getPrimalValue());
+ SLANG_ASSERT(primalValue);
+
+ // Differential pack
+ auto diffValue = as<IRMakeValuePack>(makePairInst->getDifferentialValue());
+ SLANG_ASSERT(diffValue);
+
+ // Expect the lowered pair type to be a type pack of pair types.
+ SLANG_ASSERT(as<IRTypePack>(loweredPairType));
+
+ List<IRInst*> newValues;
+ for (UInt i = 0; i < typePack->getOperandCount(); i++)
+ {
+ auto primalElement = primalValue->getOperand(i);
+ auto diffElement = diffValue->getOperand(i);
+
+ auto loweredElementPairType = (IRType*)loweredPairType->getOperand(i);
+
+ IRInst* operands[] = {primalElement, diffElement};
+
+ auto loweredMakePair =
+ builder->emitMakeStruct((IRType*)loweredElementPairType, 2, operands);
+
+ newValues.add(loweredMakePair);
+ }
+
+ auto newPack = builder->emitMakeValuePack(
+ loweredPairType,
+ newValues.getCount(),
+ newValues.getBuffer());
+
+ makePairInst->replaceUsesWith(newPack);
+ makePairInst->removeAndDeallocate();
+ return newPack;
}
else
{
+ IRInst* result = nullptr;
+
IRInst* operands[2] = {
makePairInst->getPrimalValue(),
makePairInst->getDifferentialValue()};
result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands);
+
+ makePairInst->replaceUsesWith(result);
+ makePairInst->removeAndDeallocate();
+ return result;
}
- makePairInst->replaceUsesWith(result);
- makePairInst->removeAndDeallocate();
- return result;
}
}
@@ -58,12 +112,14 @@ struct DiffPairLoweringPass : InstPassBase
pairType = pairPtrType->getValueType();
}
- if (lowerPairType(builder, pairType))
+ builder->setInsertBefore(getDiffInst);
+ if (auto loweredType = lowerPairType(builder, pairType))
{
- builder->setInsertBefore(getDiffInst);
IRInst* diffFieldExtract = nullptr;
- diffFieldExtract =
- pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ diffFieldExtract = pairBuilder->emitDiffFieldAccess(
+ builder,
+ (IRType*)loweredType,
+ getDiffInst->getBase());
getDiffInst->replaceUsesWith(diffFieldExtract);
getDiffInst->removeAndDeallocate();
return diffFieldExtract;
@@ -77,13 +133,14 @@ struct DiffPairLoweringPass : InstPassBase
pairType = pairPtrType->getValueType();
}
- if (lowerPairType(builder, pairType))
+ builder->setInsertBefore(getPrimalInst);
+ if (auto loweredType = lowerPairType(builder, pairType))
{
- builder->setInsertBefore(getPrimalInst);
-
IRInst* primalFieldExtract = nullptr;
- primalFieldExtract =
- pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ primalFieldExtract = pairBuilder->emitPrimalFieldAccess(
+ builder,
+ (IRType*)loweredType,
+ getPrimalInst->getBase());
getPrimalInst->replaceUsesWith(primalFieldExtract);
getPrimalInst->removeAndDeallocate();
return primalFieldExtract;
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index a3f6079ac..ef5161104 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -344,8 +344,18 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
continue;
}
+ // General case: we'll add all primal operands to the work list.
addPrimalOperandsToWorkList(child);
+ // Also add type annotations to the list, since these have to be made available to the
+ // function context.
+ //
+ if (as<IRDifferentiableTypeAnnotation>(child))
+ {
+ checkpointInfo->recomputeSet.add(child);
+ addPrimalOperandsToWorkList(child);
+ }
+
// We'll be conservative with the decorations we consider as differential uses
// of a primal inst, in order to avoid weird behaviour with some decorations
//
@@ -1333,7 +1343,7 @@ struct UseChain
return result;
}
- void replace(IRBuilder* builder, IRInst* inst)
+ void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst)
{
SLANG_ASSERT(chain.getCount() > 0);
@@ -1345,30 +1355,27 @@ struct UseChain
return;
}
- IRCloneEnv env;
-
// Pop the last use, which is the base use that needs to be replaced.
auto baseUse = chain.getLast();
chain.removeLast();
// Ensure that replacement inst is set as mapping for the baseUse.
- env.mapOldValToNew[baseUse->get()] = inst;
-
- auto lastInstInChain = inst;
+ ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst;
IRBuilder chainBuilder(builder->getModule());
setInsertAfterOrdinaryInst(&chainBuilder, inst);
chain.reverse();
+ chain.removeLast();
// Clone the rest of the chain.
for (auto& use : chain)
{
- lastInstInChain = cloneInst(&env, &chainBuilder, use->get());
+ ctx->cloneInstOutOfOrder(&chainBuilder, use->get());
}
- // Replace the base use.
- builder->replaceOperand(chain.getLast(), lastInstInChain);
+ // We won't actually replace the final use, because if there are multiple chains
+ // it can cause problems. The parent UseGraph will handle that.
chain.clear();
}
@@ -1380,13 +1387,93 @@ struct UseChain
}
};
+struct UseGraph
+{
+ // Set of linear paths to the base use.
+ // Note that some nodes may be common to multiple paths.
+ //
+ OrderedDictionary<IRUse*, List<UseChain>> chainSets;
+
+ static UseGraph from(
+ IRInst* baseInst,
+ Func<bool, IRUse*> isRelevantUse,
+ Func<bool, IRInst*> passthroughInst)
+ {
+ UseGraph result;
+ for (auto use = baseInst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ auto chains = UseChain::from(use, isRelevantUse, passthroughInst);
+ for (auto& chain : chains)
+ {
+ auto finalUse = chain.chain.getFirst();
+
+ if (!result.chainSets.containsKey(finalUse))
+ {
+ result.chainSets[finalUse] = List<UseChain>();
+ }
+
+ result.chainSets[finalUse].getValue().add(chain);
+ }
+
+ use = nextUse;
+ }
+ return result;
+ }
+
+ void replace(IRBuilder* builder, IRUse* use, IRInst* inst)
+ {
+ // Since we may have common nodes, we will use an out-of-order cloning context
+ // that can retroactively correct the uses as needed.
+ //
+ IROutOfOrderCloneContext ctx;
+ List<UseChain> chains = chainSets[use];
+ for (auto chain : chains)
+ {
+ chain.replace(&ctx, builder, inst);
+ }
+
+ if (!isTrivial())
+ {
+ builder->setInsertBefore(use->getUser());
+ auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get());
+
+ // Replace the base use.
+ builder->replaceOperand(use, lastInstInChain);
+ }
+ }
+
+ bool isTrivial()
+ {
+ // We're trivial if there's only one chain, and it has only one use.
+ if (chainSets.getCount() != 1)
+ return false;
+
+ auto& chain = chainSets.getFirst().value;
+ return chain.getCount() == 1;
+ }
+
+ List<IRUse*> getUniqueUses() const
+ {
+ List<IRUse*> result;
+
+ for (auto& pair : chainSets)
+ {
+ result.add(pair.key);
+ }
+
+ return result;
+ }
+};
+
// Trim defBlockIndices based on the indices of out of scope uses.
//
static List<IndexTrackingInfo> maybeTrimIndices(
const List<IndexTrackingInfo>& defBlockIndices,
const Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
- const List<UseChain>& outOfScopeUses)
+ const List<IRUse*>& outOfScopeUses)
{
// Go through uses, lookup the defBlockIndices, and remove any indices if they
// are not present in any of the uses. (This is sort of slow...)
@@ -1397,7 +1484,7 @@ static List<IndexTrackingInfo> maybeTrimIndices(
bool found = false;
for (const auto& use : outOfScopeUses)
{
- auto useInst = use.getUser();
+ auto useInst = use->getUser();
auto useBlock = useInst->getParent();
auto useBlockIndices = indexedBlockInfo.getValue(as<IRBlock>(useBlock));
if (useBlockIndices.contains(index))
@@ -1419,7 +1506,8 @@ bool canInstBeStored(IRInst* inst)
// stored into variables or context structs as normal values.
//
if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) ||
- as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()))
+ as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) ||
+ !inst->getDataType())
return false;
return true;
@@ -1577,6 +1665,9 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
//
auto isPassthroughInst = [&](IRInst* inst)
{
+ if (as<IRTerminatorInst>(inst))
+ return false;
+
if (!canInstBeStored(inst))
return true;
@@ -1590,16 +1681,9 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
return false;
};
- List<UseChain> outOfScopeUses;
- for (auto use = instToStore->firstUse; use;)
- {
- auto nextUse = use->nextUse;
+ UseGraph useGraph = UseGraph::from(instToStore, isRelevantUse, isPassthroughInst);
- List<UseChain> useChains = UseChain::from(use, isRelevantUse, isPassthroughInst);
- outOfScopeUses.addRange(useChains);
-
- use = nextUse;
- }
+ List<IRUse*> outOfScopeUses = useGraph.getUniqueUses();
if (outOfScopeUses.getCount() == 0)
{
@@ -1659,10 +1743,10 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
for (auto use : outOfScopeUses)
{
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser()));
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
List<IndexTrackingInfo>& useBlockIndices =
- indexedBlockInfo[getBlock(use.getUser())];
+ indexedBlockInfo[getBlock(use->getUser())];
IRInst* loadAddr = emitIndexedLoadAddressForVar(
&builder,
@@ -1670,7 +1754,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
defBlock,
defBlockIndices,
useBlockIndices);
- use.replace(&builder, loadAddr);
+
+ useGraph.replace(&builder, use, loadAddr);
}
if (!isRecomputeInst)
@@ -1729,11 +1814,13 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
for (auto use : outOfScopeUses)
{
+ // TODO: Prevent terminator insts from being treated as passthrough..
List<IndexTrackingInfo> useBlockIndices =
- indexedBlockInfo[getBlock(use.getUser())];
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser()));
- use.replace(
+ indexedBlockInfo[getBlock(use->getUser())];
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+ useGraph.replace(
&builder,
+ use,
loadIndexedValue(
&builder,
localVar,
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 36093518a..5ac4016d7 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -203,13 +203,23 @@ IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(
IRInst* func,
IRFuncType* funcType)
{
- IRType* intermediateType =
- builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
+ IRType* intermediateType = nullptr;
if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent()))
{
intermediateType =
+ builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
+ intermediateType =
(IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric));
}
+ else if (as<IRLookupWitnessMethod>(func))
+ {
+ intermediateType = nullptr;
+ }
+ else
+ {
+ intermediateType =
+ builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
+ }
return differentiateFunctionTypeImpl(builder, funcType, intermediateType);
}
@@ -382,14 +392,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(
IRFunc* primalFunc = origFunc;
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
-
- // The original func may not have a type dictionary if it is not originally marked as
- // differentiable, in this case we would have already pulled the necessary types from
- // the user-provided derivative function, so we are still fine.
- if (origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- differentiableTypeConformanceContext.setFunc(origFunc);
- }
+ differentiableTypeConformanceContext.setFunc(origFunc);
auto diffFunc = builder.createFunc();
@@ -414,12 +417,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(
// Mark the generated derivative function itself as differentiable.
builder.addBackwardDifferentiableDecoration(diffFunc);
- // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
- if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- builder.setInsertBefore(diffFunc->getFirstDecorationOrChild());
- cloneInst(&cloneEnv, &builder, dictDecor);
- }
+
copyOriginalDecorations(origFunc, diffFunc);
builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
return InstPair(primalFunc, diffFunc);
@@ -446,7 +444,24 @@ void BackwardDiffTranscriberBase::addTranscribedFuncDecoration(
InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
- auto result = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ InstPair result;
+
+ // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the
+ // insert location unchanges). If we're transcribing it as a declaration, we should
+ // insert into the module.
+ //
+ auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc));
+ if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc))
+ {
+ // Dealing with a declaration.. insert into module scope.
+ IRBuilder subBuilder = *inBuilder;
+ subBuilder.setInsertInto(inBuilder->getModule());
+ result = transcribeFuncHeaderImpl(&subBuilder, origFunc);
+ }
+ else
+ {
+ result = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ }
FuncBodyTranscriptionTask task;
task.originalFunc = as<IRFunc>(result.primal);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 1b3825a7d..38a7a18bb 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -256,7 +256,7 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
return nullptr;
// Special-case for differentiable existential types.
- if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType))
+ if (as<IRInterfaceType>(origType))
{
if (differentiableTypeConformanceContext.lookUpConformanceForType(
origType,
@@ -269,6 +269,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
else
return nullptr;
}
+ else if (as<IRAssociatedType>(origType))
+ {
+ SLANG_UNEXPECTED("unexpected associated type during auto-diff");
+ }
auto primalType = lookupPrimalInst(builder, origType, origType);
if (primalType->getOp() == kIROp_Param && primalType->getParent() &&
@@ -324,9 +328,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
auto primalPairType = as<IRDifferentialPairTypeBase>(primalType);
return getOrCreateDiffPairType(
builder,
- differentiableTypeConformanceContext.getDiffTypeFromPairType(
- builder,
- primalPairType),
+ differentiateType(builder, primalPairType->getValueType()),
differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(
builder,
primalPairType));
@@ -336,9 +338,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
{
auto primalPairType = as<IRDifferentialPairUserCodeType>(primalType);
return builder->getDifferentialPairUserCodeType(
- (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(
- builder,
- primalPairType),
+ differentiateType(builder, primalPairType->getValueType()),
differentiableTypeConformanceContext.getDiffTypeWitnessFromPairType(
builder,
primalPairType));
@@ -406,6 +406,7 @@ bool AutoDiffTranscriberBase::isExistentialType(IRType* type)
case kIROp_ExtractExistentialType:
case kIROp_InterfaceType:
case kIROp_AssociatedType:
+ case kIROp_LookupWitness:
return true;
default:
return false;
@@ -460,47 +461,34 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative
IRBuilder* builder,
IRInst* origFunc)
{
- auto decor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- if (decor)
- return;
- // A differentiable func must have `IRDifferentiableTypeDictionaryDecoration`, except it has a
- // `IRUserDefinedBackwardDerivativeDecoration`.
- auto udfDecor = origFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>();
- SLANG_RELEASE_ASSERT(udfDecor);
- // We need to migrate the dictionary from the backward derivative func so we can properly
- // differentiate the function header.
- IRBuilder subBuilder = *builder;
- subBuilder.setInsertBefore(origFunc);
-
- auto derivative = udfDecor->getBackwardDerivativeFunc();
- if (auto specialize = as<IRSpecialize>(derivative))
- {
- auto derivativeGeneric = cast<IRGeneric>(specialize->getBase());
- GenericChildrenMigrationContext migrationContext;
- migrationContext.init(
- derivativeGeneric,
- cast<IRGeneric>(findOuterGeneric(origFunc)),
- origFunc);
- auto derivativeFunc = findGenericReturnVal(derivativeGeneric);
- auto derivativeBlock = cast<IRBlock>(derivativeFunc->getParent());
- for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc;
- dInst = dInst->getNextInst())
- {
- migrationContext.cloneInst(&subBuilder, dInst);
- }
- auto udfDictDecor =
- derivativeFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- SLANG_RELEASE_ASSERT(udfDictDecor);
- subBuilder.setInsertBefore(origFunc->getFirstDecorationOrChild());
- migrationContext.cloneInst(&subBuilder, udfDictDecor);
- eliminateDeadCode(origFunc->getParent());
- }
- else
+ // There's one corner case where our function may not have the differentiable type annotations.
+ // If the function is not declared differentiable, but has a custom derivative, we need to copy
+ // over any IRDifferentiableTypeAnnotation insts
+ if (auto udfDecor = origFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
{
- auto udfDictDecor = derivative->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- if (udfDictDecor)
+ // We need to migrate the dictionary from the backward derivative func so we can properly
+ // differentiate the function header.
+ IRBuilder subBuilder = *builder;
+ subBuilder.setInsertBefore(origFunc);
+
+ auto derivative = udfDecor->getBackwardDerivativeFunc();
+ if (auto specialize = as<IRSpecialize>(derivative))
{
- cloneDecoration(udfDictDecor, origFunc);
+ auto derivativeGeneric = cast<IRGeneric>(specialize->getBase());
+
+ GenericChildrenMigrationContext migrationContext;
+ migrationContext.init(
+ derivativeGeneric,
+ cast<IRGeneric>(findOuterGeneric(origFunc)),
+ origFunc);
+ auto derivativeFunc = findGenericReturnVal(derivativeGeneric);
+ auto derivativeBlock = cast<IRBlock>(derivativeFunc->getParent());
+ for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc;
+ dInst = dInst->getNextInst())
+ {
+ migrationContext.cloneInst(&subBuilder, dInst);
+ }
+ eliminateDeadCode(origFunc->getParent());
}
}
}
@@ -575,8 +563,8 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
else
return nullptr;
}
- auto diffType = differentiateType(builder, originalType);
- if (diffType)
+
+ if (tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Any))
return (IRType*)getOrCreateDiffPairType(builder, originalType);
return nullptr;
}
@@ -690,6 +678,15 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(
return InstPair(primal, diffWitness);
}
}
+ else if (as<IRTypeKind>(lookupInst->getDataType()))
+ {
+ if (auto diffType = differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType))
+ {
+ return InstPair(primal, diffType);
+ }
+ }
auto decor = lookupInst->getRequirementKey()->findDecorationImpl(
getInterfaceRequirementDerivativeDecorationOp());
@@ -997,8 +994,15 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
if (auto innerFunc = as<IRFunc>(innerVal))
{
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc);
- if (!innerFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ // Is our function differentiable?
+ if (!(innerFunc->findDecoration<IRForwardDifferentiableDecoration>() ||
+ innerFunc->findDecoration<IRBackwardDifferentiableDecoration>() ||
+ innerFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>() ||
+ innerFunc->findDecoration<IRForwardDerivativeDecoration>()))
+ {
return InstPair(origGeneric, nullptr);
+ }
+
differentiableTypeConformanceContext.setFunc(innerFunc);
}
else if (const auto funcType = as<IRFuncType>(innerVal))
@@ -1027,7 +1031,14 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
IRType* diffType = nullptr;
if (primalType)
{
- diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType);
+ if (as<IRGenericKind>(primalType))
+ {
+ diffType = builder.getGenericKind();
+ }
+ else
+ {
+ diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType);
+ }
}
diffGeneric->setFullType(diffType);
@@ -1110,7 +1121,6 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
mapPrimalInst(origInst, pair.primal);
mapDifferentialInst(origInst, pair.differential);
-
if (pair.primal != pair.differential &&
!pair.primal->findDecoration<IRAutodiffInstDecoration>() &&
!as<IRConstant>(pair.primal))
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index c84fd778c..a5ed5814c 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -623,8 +623,9 @@ struct DiffTransposePass
{
varInst->insertAtEnd(firstRevDiffBlock);
- auto dzero =
- emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType());
+ auto dzero = diffTypeContext.emitDZeroOfDiffInstType(
+ &builder,
+ ptrPrimalType->getValueType());
builder.emitStore(varInst, dzero);
}
else
@@ -726,7 +727,9 @@ struct DiffTransposePass
auto gradValue = builder->emitLoad(accVar);
builder->emitStore(
accVar,
- emitDZeroOfDiffInstType(builder, tryGetPrimalTypeFromDiffInst(fwdInst)));
+ diffTypeContext.emitDZeroOfDiffInstType(
+ builder,
+ tryGetPrimalTypeFromDiffInst(fwdInst)));
return gradValue;
}
@@ -760,7 +763,7 @@ struct DiffTransposePass
auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst);
auto diffType = fwdInst->getDataType();
- auto zero = emitDZeroOfDiffInstType(&tempVarBuilder, primalType);
+ auto zero = diffTypeContext.emitDZeroOfDiffInstType(&tempVarBuilder, primalType);
// Emit a var in the top-level differential block to hold the gradient,
// and initialize it.
@@ -925,8 +928,9 @@ struct DiffTransposePass
}
else
{
- phiParamRevGradInsts.add(
- emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
+ phiParamRevGradInsts.add(diffTypeContext.emitDZeroOfDiffInstType(
+ &builder,
+ tryGetPrimalTypeFromDiffInst(param)));
}
}
else
@@ -1177,7 +1181,8 @@ struct DiffTransposePass
auto pairType = as<IRDifferentialPairType>(arg->getDataType());
auto var = builder->emitVar(arg->getDataType());
- auto diffZero = emitDZeroOfDiffInstType(builder, pairType->getValueType());
+ auto diffZero =
+ diffTypeContext.emitDZeroOfDiffInstType(builder, pairType->getValueType());
// Initialize this var to (arg.primal, 0).
builder->emitStore(
@@ -1236,7 +1241,13 @@ struct DiffTransposePass
argRequiresLoad.add(false);
}
- auto revFnType = builder->getFuncType(argTypes, builder->getVoidType());
+
+ auto revFnType =
+ this->autodiffContext->transcriberSet.propagateTranscriber->differentiateFunctionType(
+ builder,
+ getResolvedInstForDecorations(baseFn),
+ baseFnType);
+
IRInst* revCallee = nullptr;
if (getResolvedInstForDecorations(baseFn)->getOp() == kIROp_LookupWitness)
{
@@ -1615,7 +1626,7 @@ struct DiffTransposePass
SLANG_ASSERT(primalType);
// Clear the value at the differential address, by setting to 0.
- IRInst* emptyVal = emitDZeroOfDiffInstType(builder, primalType);
+ IRInst* emptyVal = diffTypeContext.emitDZeroOfDiffInstType(builder, primalType);
builder->emitStore(fwdStore->getPtr(), emptyVal);
if (auto diffPairType = as<IRDifferentialPairType>(revVal->getDataType()))
@@ -2071,7 +2082,7 @@ struct DiffTransposePass
auto primalElementTypeDecor = updateInst->findDecoration<IRPrimalElementTypeDecoration>();
SLANG_RELEASE_ASSERT(primalElementTypeDecor);
- auto diffZero = emitDZeroOfDiffInstType(
+ auto diffZero = diffTypeContext.emitDZeroOfDiffInstType(
builder,
(IRType*)primalElementTypeDecor->getPrimalElementType());
SLANG_ASSERT(diffZero);
@@ -2350,16 +2361,18 @@ struct DiffTransposePass
{
auto primalCondition = fwdInst->getOperand(0);
- auto leftZero =
- emitDZeroOfDiffInstType(builder, tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(1)));
+ auto leftZero = diffTypeContext.emitDZeroOfDiffInstType(
+ builder,
+ tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(1)));
auto leftGradientInst = builder->emitIntrinsicInst(
fwdInst->getOperand(1)->getDataType(),
kIROp_Select,
3,
List<IRInst*>(primalCondition, revValue, leftZero).getBuffer());
- auto rightZero =
- emitDZeroOfDiffInstType(builder, tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(2)));
+ auto rightZero = diffTypeContext.emitDZeroOfDiffInstType(
+ builder,
+ tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(2)));
auto rightGradientInst = builder->emitIntrinsicInst(
fwdInst->getOperand(2)->getDataType(),
kIROp_Select,
@@ -2527,7 +2540,8 @@ struct DiffTransposePass
List<IRInst*> zeroElements;
for (Index i = 0; i < elementCount; ++i)
{
- auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementTypes[i]);
+ auto zeroElement =
+ diffTypeContext.emitDZeroOfDiffInstType(builder, primalElementTypes[i]);
elementGrads.add(zeroElement);
zeroElements.add(zeroElement);
}
@@ -2537,8 +2551,11 @@ struct DiffTransposePass
if (elementGrads[i] == zeroElements[i])
elementGrads[i] = grad;
else
- elementGrads[i] =
- emitDAddOfDiffInstType(builder, primalElementTypes[i], elementGrads[i], grad);
+ elementGrads[i] = diffTypeContext.emitDAddOfDiffInstType(
+ builder,
+ primalElementTypes[i],
+ elementGrads[i],
+ grad);
};
for (auto gradient : gradients)
@@ -2624,7 +2641,7 @@ struct DiffTransposePass
gradient.targetInst,
builder->emitMakeDifferentialPairUserCode(
baseType,
- emitDZeroOfDiffInstType(builder, baseType->getValueType()),
+ diffTypeContext.emitDZeroOfDiffInstType(builder, baseType->getValueType()),
gradient.revGradInst),
gradient.fwdGradInst));
}
@@ -2640,7 +2657,9 @@ struct DiffTransposePass
builder->emitMakeDifferentialPairUserCode(
baseType,
gradient.revGradInst,
- emitDZeroOfDiffInstType(builder, fwdGetPrimal->getFullType())),
+ diffTypeContext.emitDZeroOfDiffInstType(
+ builder,
+ fwdGetPrimal->getFullType())),
gradient.fwdGradInst));
}
}
@@ -2694,7 +2713,7 @@ struct DiffTransposePass
(IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType));
// Initialize with T.dzero()
- auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType);
+ auto zeroValueInst = diffTypeContext.emitDZeroOfDiffInstType(builder, aggPrimalType);
builder->emitStore(revGradVar, zeroValueInst);
@@ -2764,7 +2783,7 @@ struct DiffTransposePass
(IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType));
// Initialize with T.dzero()
- auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType);
+ auto zeroValueInst = diffTypeContext.emitDZeroOfDiffInstType(builder, aggPrimalType);
builder->emitStore(revGradVar, zeroValueInst);
@@ -2839,8 +2858,11 @@ struct DiffTransposePass
continue;
}
- currentValue =
- emitDAddOfDiffInstType(builder, aggPrimalType, currentValue, gradient.revGradInst);
+ currentValue = diffTypeContext.emitDAddOfDiffInstType(
+ builder,
+ aggPrimalType,
+ currentValue,
+ gradient.revGradInst);
}
return RevGradient(
@@ -2919,7 +2941,7 @@ struct DiffTransposePass
if (aggDiffType != nullptr)
{
// If type is non-null/non-void, call T.dzero() to produce a 0 gradient.
- return emitDZeroOfDiffInstType(builder, aggPrimalType);
+ return diffTypeContext.emitDZeroOfDiffInstType(builder, aggPrimalType);
}
else
{
@@ -2951,146 +2973,6 @@ struct DiffTransposePass
return nullptr;
}
- IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType)
- {
- if (auto arrayType = as<IRArrayType>(primalType))
- {
- auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(
- builder,
- arrayType->getElementType());
- SLANG_RELEASE_ASSERT(diffElementType);
- auto diffArrayType =
- builder->getArrayType(diffElementType, arrayType->getElementCount());
- auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType());
- return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero);
- }
- else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
- {
- auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType());
- auto diffZero = primalZero;
- auto diffType = primalZero->getFullType();
- auto diffWitness =
- diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType);
- auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
- return builder->emitMakeDifferentialPairUserCode(
- diffDiffPairType,
- primalZero,
- diffZero);
- }
- else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType))
- {
- // Pack a null value into an existential type.
- auto existentialZero = builder->emitMakeExistential(
- autodiffContext->differentiableInterfaceType,
- diffTypeContext.emitNullDifferential(builder),
- autodiffContext->nullDifferentialWitness);
-
- return existentialZero;
- }
-
- auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType);
-
- // Should exist.
- SLANG_ASSERT(zeroMethod);
-
- return builder->emitCallInst(
- (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
- zeroMethod,
- List<IRInst*>());
- }
-
- IRInst* emitDAddForExistentialType(
- IRBuilder* builder,
- IRType* primalType,
- IRInst* op1,
- IRInst* op2)
- {
- auto existentialDAddFunc = diffTypeContext.getOrCreateExistentialDAddMethod();
-
- // Should exist.
- SLANG_ASSERT(existentialDAddFunc);
-
- return builder->emitCallInst(
- (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
- existentialDAddFunc,
- List<IRInst*>({op1, op2}));
- }
-
- IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2)
- {
- if (auto arrayType = as<IRArrayType>(primalType))
- {
- auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(
- builder,
- arrayType->getElementType());
- SLANG_RELEASE_ASSERT(diffElementType);
- auto arraySize = arrayType->getElementCount();
-
- if (auto constArraySize = as<IRIntLit>(arraySize))
- {
- List<IRInst*> args;
- for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++)
- {
- auto index = builder->getIntValue(builder->getIntType(), i);
- auto op1Val = builder->emitElementExtract(diffElementType, op1, index);
- auto op2Val = builder->emitElementExtract(diffElementType, op2, index);
- args.add(emitDAddOfDiffInstType(
- builder,
- arrayType->getElementType(),
- op1Val,
- op2Val));
- }
- auto diffArrayType =
- builder->getArrayType(diffElementType, arrayType->getElementCount());
- return builder->emitMakeArray(
- diffArrayType,
- (UInt)args.getCount(),
- args.getBuffer());
- }
- else
- {
- // TODO: insert a runtime loop here.
- SLANG_UNIMPLEMENTED_X("dadd of dynamic array.");
- }
- }
- else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
- {
- auto diffType =
- (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, diffPairUserType);
- auto diffWitness =
- diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType);
-
- auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1);
- auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2);
- auto primal =
- emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2);
-
- auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1);
- auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2);
- auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2);
-
- auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
- return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff);
- }
- else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType))
- {
- // If our type is existential, we need to handle the case where
- // one or both of our operands are null-type.
- //
- return emitDAddForExistentialType(builder, primalType, op1, op2);
- }
-
- auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType);
-
- // Should exist.
- SLANG_ASSERT(addMethod);
-
- return builder->emitCallInst(
- (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
- addMethod,
- List<IRInst*>(op1, op2));
- }
-
void addRevGradientForFwdInst(IRInst* fwdInst, RevGradient assignment)
{
if (!hasRevGradients(fwdInst))
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 9ee2cb4d2..49c1d9ff7 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -93,6 +93,22 @@ struct ExtractPrimalFuncContext
as<IRGeneric>(findOuterGeneric(destFunc)),
destFunc);
+ if (auto origGeneric = as<IRGeneric>(findOuterGeneric(originalFunc)))
+ {
+ // Clone in everything else except the return value.
+ IRBuilder subBuilder(destFunc);
+ builder.setInsertAfter(findOuterGeneric(destFunc)->getFirstBlock()->getLastParam());
+
+ // Clone in any hoistable insts.
+ for (auto child = origGeneric->getFirstBlock()->getFirstOrdinaryInst(); child;
+ child = child->getNextInst())
+ {
+ if ((child != originalFunc) && !as<IRReturn>(child) &&
+ !as<IRGlobalValueWithCode>(child))
+ migrationContext.cloneInst(&subBuilder, child);
+ }
+ }
+
originalFuncType = as<IRFuncType>(originalFunc->getDataType());
SLANG_RELEASE_ASSERT(originalFuncType);
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 31c4dbf91..556fb58a8 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -201,49 +201,61 @@ struct DiffUnzipPass
return nullptr;
}
- InstPair splitCall(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRCall* mixedCall)
+ IRInst* getIntermediateType(IRBuilder* builder, IRInst* baseFn)
{
- IRBuilder globalBuilder(autodiffContext->moduleInst->getModule());
-
- auto fwdCalleeType = mixedCall->getCallee()->getDataType();
- auto baseFn = _getOriginalFunc(mixedCall);
- SLANG_RELEASE_ASSERT(baseFn);
-
- auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->transcribe(
- primalBuilder,
- baseFn->getDataType());
-
- IRInst* intermediateType = nullptr;
-
- if (auto specialize = as<IRSpecialize>(baseFn))
+ if (as<IRLookupWitnessMethod>(baseFn))
+ {
+ return builder->getVoidType();
+ }
+ else if (auto specialize = as<IRSpecialize>(baseFn))
{
+ if (as<IRLookupWitnessMethod>(specialize->getBase()))
+ return builder->getVoidType();
+
auto func = findSpecializeReturnVal(specialize);
- auto outerGen = findOuterGeneric(func);
- if (func->getOp() == kIROp_LookupWitness)
+ if (as<IRLookupWitnessMethod>(func))
{
// An interface method won't have intermediate type.
- intermediateType = primalBuilder->getVoidType();
+ return builder->getVoidType();
}
else
{
- intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen);
+ auto outerGen = findOuterGeneric(func);
+ auto innerIntermediateType =
+ builder->getBackwardDiffIntermediateContextType(outerGen);
+
List<IRInst*> args;
for (UInt i = 0; i < specialize->getArgCount(); i++)
args.add(specialize->getArg(i));
- intermediateType = primalBuilder->emitSpecializeInst(
- primalBuilder->getTypeKind(),
- intermediateType,
+
+ return builder->emitSpecializeInst(
+ builder->getTypeKind(),
+ innerIntermediateType,
args.getCount(),
args.getBuffer());
}
}
else
{
- if (baseFn->getOp() == kIROp_LookupWitness)
- intermediateType = primalBuilder->getVoidType();
- else
- intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn);
+ return builder->getBackwardDiffIntermediateContextType(baseFn);
}
+ }
+
+ InstPair splitCall(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRCall* mixedCall)
+ {
+ IRBuilder globalBuilder(autodiffContext->moduleInst->getModule());
+
+ auto fwdCalleeType = mixedCall->getCallee()->getDataType();
+ auto baseFn = _getOriginalFunc(mixedCall);
+ SLANG_RELEASE_ASSERT(baseFn);
+
+ auto primalFuncType =
+ autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType(
+ primalBuilder,
+ baseFn,
+ as<IRFuncType>(baseFn->getDataType()));
+
+ IRInst* intermediateType = getIntermediateType(primalBuilder, baseFn);
IRVar* intermediateVar = nullptr;
if (!as<IRVoidType>(intermediateType))
@@ -314,8 +326,8 @@ struct DiffUnzipPass
auto arg = mixedCall->getArg(ii);
// Depending on the type and direction of each argument,
- // we might need to prepare a different value for the transposition logic to produce the
- // correct final argument in the propagate function call.
+ // we might need to prepare a different value for the transposition logic to produce
+ // the correct final argument in the propagate function call.
if (isRelevantDifferentialPair(arg->getDataType()))
{
auto primalArg = lookupPrimalInst(arg);
@@ -328,13 +340,13 @@ struct DiffUnzipPass
if (const auto outType = as<IROutType>(primalParamType))
{
- // For `out` parameters that expects an input derivative to propagate through,
- // we insert a `LoadReverseGradient` inst here to signify the logic in
- // `transposeStore` that this argument should actually be the currently
+ // For `out` parameters that expects an input derivative to propagate
+ // through, we insert a `LoadReverseGradient` inst here to signify the logic
+ // in `transposeStore` that this argument should actually be the currently
// accumulated derivative on this variable. The end purpose is that we will
// generate a load(diffArg) in the final transposed code and use that as the
- // argument for the call, but we can't just emit a normal load inst here because
- // the transposition logic will turn loads into stores.
+ // argument for the call, but we can't just emit a normal load inst here
+ // because the transposition logic will turn loads into stores.
auto outDiffType = cast<IRPtrTypeBase>(diffArg->getDataType())->getValueType();
auto gradArg = diffBuilder->emitLoadReverseGradient(outDiffType, diffArg);
diffBuilder->markInstAsDifferential(gradArg, primalArg->getDataType());
@@ -342,23 +354,24 @@ struct DiffUnzipPass
}
else if (const auto inoutType = as<IRInOutType>(primalParamType))
{
- // Since arg is split into separate vars, we need a new temp var that represents
- // the remerged diff pair.
+ // Since arg is split into separate vars, we need a new temp var that
+ // represents the remerged diff pair.
auto diffPairType = as<IRDifferentialPairType>(
as<IRPtrTypeBase>(arg->getDataType())->getValueType());
auto primalValueType = diffPairType->getValueType();
// We can't simply reuse primalArg for an inout parameter since this will
// represent the value after the primal call which can potentially alter
- // primalArg. Therefore, we will find the first store into primalArg, and create
- // a temp var holding that value (i.e. value prior to primal call)
+ // primalArg. Therefore, we will find the first store into primalArg, and
+ // create a temp var holding that value (i.e. value prior to primal call)
//
auto storeUse = findUniqueStoredVal(cast<IRVar>(primalArg));
auto storeInst = cast<IRStore>(storeUse->getUser());
auto storedVal = storeInst->getVal();
- // Emit the temp var into the primal blocks since it's holding a primal value.
+ // Emit the temp var into the primal blocks since it's holding a primal
+ // value.
auto tempPrimalVar = primalBuilder->emitVar(primalValueType);
primalBuilder->emitStore(tempPrimalVar, storedVal);
@@ -407,8 +420,8 @@ struct DiffUnzipPass
// For pure 'in' type. Simply re-use the original argument inst.
//
// For 'out' type parameters, it doesn't really matter what we pass in here,
- // since the tranposition logic will discard the argument anyway (we'll pass in
- // the old arg, just to keep the number of arguments consistent)
+ // since the tranposition logic will discard the argument anyway (we'll pass
+ // in the old arg, just to keep the number of arguments consistent)
//
diffArgs.add(arg);
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 4edd8eabe..7507e2fac 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -230,9 +230,6 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
}
else if (auto specializedType = as<IRSpecialize>(pairType))
{
- // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
- // type, emit the specialization type.
- //
auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase()));
if (auto genericBasePairStructType = as<IRStructType>(genericType))
{
@@ -263,14 +260,142 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
return nullptr;
}
-IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+bool isExistentialOrRuntimeInst(IRInst* inst)
{
- return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
+ if (auto lookup = as<IRLookupWitnessMethod>(inst))
+ {
+ return isExistentialOrRuntimeInst(lookup->getWitnessTable());
+ }
+
+ return as<IRExtractExistentialType>(inst) || as<IRExtractExistentialWitnessTable>(inst) ||
+ as<IRMakeExistential>(inst) || as<IRInterfaceType>(inst->getDataType());
}
-IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+bool isRuntimeType(IRType* type)
{
- return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
+ if (as<IRExtractExistentialType>(type))
+ return true;
+
+ if (auto lookup = as<IRLookupWitnessMethod>(type))
+ {
+ return isExistentialOrRuntimeInst(lookup->getWitnessTable());
+ }
+
+ return false;
+}
+
+IRInst* getExistentialBaseWitnessTable(IRBuilder* builder, IRType* type)
+{
+ if (auto lookupWitnessMethod = as<IRLookupWitnessMethod>(type))
+ {
+ return lookupWitnessMethod->getWitnessTable();
+ }
+ else if (auto extractExistentialType = as<IRExtractExistentialType>(type))
+ {
+ return builder->emitExtractExistentialWitnessTable(extractExistentialType->getOperand(0));
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unexpected existential type");
+ }
+}
+
+IRInst* getCacheKey(IRBuilder* builder, IRInst* primalType)
+{
+ if (auto lookupWitness = as<IRLookupWitnessMethod>(primalType))
+ return lookupWitness->getRequirementKey();
+ else if (auto extractExistentialType = as<IRExtractExistentialType>(primalType))
+ {
+ auto interfaceType = extractExistentialType->getOperand(0)->getDataType();
+
+ // We will cache on the interface's this-type, since the interface type itself can be
+ // deallocated during the lowering process.
+ //
+ return builder->getThisType(interfaceType);
+ }
+
+ return primalType;
+}
+
+IRInst* DifferentialPairTypeBuilder::emitExistentialMakePair(
+ IRBuilder* builder,
+ IRInst* pairType,
+ IRInst* primalInst,
+ IRInst* diffInst)
+{
+ auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)pairType);
+
+ auto pairTypeKey = cast<IRLookupWitnessMethod>(pairType)->getRequirementKey();
+ auto makePairKey = makePairKeyMap[pairTypeKey];
+
+ auto makePairMethod = builder->emitLookupInterfaceMethodInst(
+ makePairFuncTypeMap[makePairKey],
+ baseWitness,
+ makePairKey);
+
+ List<IRInst*> args;
+ args.add(primalInst);
+ args.add(diffInst);
+
+ auto makePairVal = builder->emitCallInst((IRType*)pairType, makePairMethod, args);
+
+ return makePairVal;
+}
+
+IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(
+ IRBuilder* builder,
+ IRType* loweredPairType,
+ IRInst* baseInst)
+{
+ if (isRuntimeType(loweredPairType))
+ {
+ auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)loweredPairType);
+
+ auto pairTypeKey = cast<IRLookupWitnessMethod>(loweredPairType)->getRequirementKey();
+ auto getPrimalKey = getPrimalKeyMap[pairTypeKey];
+
+ auto primalFieldMethod = builder->emitLookupInterfaceMethodInst(
+ getPrimalFuncTypeMap[getPrimalKey],
+ baseWitness,
+ getPrimalKey);
+
+ auto primalFieldVal =
+ builder->emitCallInst(primalTypeMap[loweredPairType], primalFieldMethod, baseInst);
+
+ return primalFieldVal;
+ }
+ else
+ {
+ return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
+ }
+}
+
+IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(
+ IRBuilder* builder,
+ IRType* loweredPairType,
+ IRInst* baseInst)
+{
+ if (isRuntimeType(loweredPairType))
+ {
+ auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)loweredPairType);
+
+ auto pairTypeKey = cast<IRLookupWitnessMethod>(loweredPairType)->getRequirementKey();
+ auto getDiffKey = getDiffKeyMap[pairTypeKey];
+
+ auto diffFieldMethod = builder->emitLookupInterfaceMethodInst(
+ getDiffFuncTypeMap[getDiffKey],
+ baseWitness,
+ getDiffKey);
+
+ auto diffFieldVal =
+ builder->emitCallInst(diffTypeMap[loweredPairType], diffFieldMethod, baseInst);
+
+ return diffFieldVal;
+ }
+ else
+ {
+ return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
+ }
}
IRStructKey* DifferentialPairTypeBuilder::_getOrCreateDiffStructKey()
@@ -307,6 +432,380 @@ IRStructKey* DifferentialPairTypeBuilder::_getOrCreatePrimalStructKey()
return this->globalPrimalKey;
}
+IRInst* DifferentialPairTypeBuilder::getOrCreateCommonDiffPairInterface(IRBuilder* builder)
+{
+ if (!this->commonDiffPairInterface)
+ {
+ this->commonDiffPairInterface = builder->createInterfaceType(0, nullptr);
+ builder->addNameHintDecoration(
+ this->commonDiffPairInterface,
+ UnownedStringSlice("IDiffPair"));
+ }
+
+ return this->commonDiffPairInterface;
+}
+
+IRInst* DifferentialPairTypeBuilder::_createDiffPairInterfaceRequirement(
+ IRType* origBaseType,
+ IRType*)
+{
+ // We will create an interface requirement for the type's pair & then create implementations in
+ // all the implementing witness tables.
+ //
+
+ IRBuilder builder(sharedContext->moduleInst);
+
+ // Find the right interface to put the requirement in.
+ IRInterfaceType* interfaceType = nullptr;
+
+ // Find the effective type to put in the requirement entry
+ // for the base type
+ //
+ IRType* requirementBaseType = nullptr;
+
+ // Requirement key (only used for associated types)
+ //
+ IRInst* requirementKey = nullptr;
+
+ // Add a name hint to the key.
+ StringBuilder nameBuilderReqKey;
+ nameBuilderReqKey << "DiffPair_Req_";
+
+ if (auto lookup = as<IRLookupWitnessMethod>(origBaseType))
+ {
+ interfaceType =
+ cast<IRInterfaceType>(cast<IRWitnessTableType>(lookup->getWitnessTable()->getDataType())
+ ->getConformanceType());
+
+ requirementBaseType =
+ cast<IRType>(findInterfaceRequirement(interfaceType, lookup->getRequirementKey()));
+
+ requirementKey = lookup->getRequirementKey();
+
+ if (auto nameHint = lookup->getRequirementKey()->findDecoration<IRNameHintDecoration>())
+ {
+ nameBuilderReqKey << nameHint->getName();
+ }
+ else
+ {
+ nameBuilderReqKey << "unknown_assoc_type";
+ }
+ }
+ else if (auto extractType = as<IRExtractExistentialType>(origBaseType))
+ {
+ auto existentialType = extractType->getOperand(0);
+ interfaceType = cast<IRInterfaceType>(existentialType->getDataType());
+ requirementBaseType = builder.getThisType(interfaceType);
+
+ requirementKey = nullptr;
+
+ if (auto nameHint = interfaceType->findDecoration<IRNameHintDecoration>())
+ {
+ nameBuilderReqKey << nameHint->getName();
+ }
+ else
+ {
+ nameBuilderReqKey << "unknown_interface_type";
+ }
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unexpected type for differential pair interface requirement");
+ }
+
+ auto diffPairInterfaceType =
+ cast<IRInterfaceType>(getOrCreateCommonDiffPairInterface(&builder));
+
+ // Add 4 requirements to the interface:
+ // the associated pair type, getPrimal, getDiff & makePair
+ //
+ builder.setInsertInto(interfaceType);
+ IRStructKey* diffPairRequirementKey = builder.createStructKey();
+ IRStructKey* getPrimalRequirementKey = builder.createStructKey();
+ IRStructKey* getDiffRequirementKey = builder.createStructKey();
+ IRStructKey* makePairRequirementKey = builder.createStructKey();
+
+ makePairKeyMap[diffPairRequirementKey] = makePairRequirementKey;
+ getPrimalKeyMap[diffPairRequirementKey] = getPrimalRequirementKey;
+ getDiffKeyMap[diffPairRequirementKey] = getDiffRequirementKey;
+
+ List<IRInst*> entries;
+
+ // Add all the old requirements to the new interface.
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
+ entries.add(interfaceType->getOperand(i));
+
+ //
+ // Create the new requirement entries.
+ //
+
+ {
+ // Create & insert the requirement key.
+ List<IRInterfaceType*> constraintTypes;
+ constraintTypes.add(diffPairInterfaceType);
+ auto entry = builder.createInterfaceRequirementEntry(
+ diffPairRequirementKey,
+ builder.getAssociatedType(constraintTypes.getArrayView()));
+
+ builder.addNameHintDecoration(diffPairRequirementKey, nameBuilderReqKey.getUnownedSlice());
+ entries.add(entry);
+ }
+
+ {
+ // Create & insert the getPrimal requirement.
+
+ List<IRType*> paramTypes;
+ List<IRInterfaceType*> paramConstraintTypes;
+ paramConstraintTypes.add(diffPairInterfaceType);
+ paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView()));
+
+ auto entryFuncType = builder.getFuncType(paramTypes, requirementBaseType);
+ auto entry =
+ builder.createInterfaceRequirementEntry(getPrimalRequirementKey, entryFuncType);
+
+ getPrimalFuncTypeMap[getPrimalRequirementKey] = entryFuncType;
+
+ StringBuilder entryNameBuilder;
+ entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_getPrimal";
+ builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice());
+
+ entries.add(entry);
+ }
+
+ {
+ // Create & insert the getDiff requirement.
+
+ List<IRType*> paramTypes;
+ List<IRInterfaceType*> paramConstraintTypes;
+ paramConstraintTypes.add(diffPairInterfaceType);
+ paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView()));
+
+ List<IRInterfaceType*> resultConstraintTypes;
+ resultConstraintTypes.add(sharedContext->differentiableInterfaceType);
+ auto resultType = builder.getAssociatedType(resultConstraintTypes.getArrayView());
+
+ auto entryFuncType = builder.getFuncType(paramTypes, resultType);
+ auto entry = builder.createInterfaceRequirementEntry(getDiffRequirementKey, entryFuncType);
+
+ getDiffFuncTypeMap[getDiffRequirementKey] = entryFuncType;
+
+ StringBuilder entryNameBuilder;
+ entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_getDiff";
+ builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice());
+
+ entries.add(entry);
+ }
+
+ {
+ // Create & insert the makePair requirement.
+
+ List<IRType*> paramTypes;
+ paramTypes.add(requirementBaseType);
+
+ List<IRInterfaceType*> paramConstraintTypes;
+ paramConstraintTypes.add(sharedContext->differentiableInterfaceType);
+ paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView()));
+
+ List<IRInterfaceType*> resultConstraintTypes;
+ resultConstraintTypes.add(diffPairInterfaceType);
+ auto entryFuncType = builder.getFuncType(
+ paramTypes,
+ builder.getAssociatedType(resultConstraintTypes.getArrayView()));
+ auto entry = builder.createInterfaceRequirementEntry(makePairRequirementKey, entryFuncType);
+
+ makePairFuncTypeMap[makePairRequirementKey] = entryFuncType;
+
+ StringBuilder entryNameBuilder;
+ entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_makePair";
+ builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice());
+
+ entries.add(entry);
+ }
+
+ {
+ // Create the new interface type.
+
+ auto newInterfaceType =
+ builder.createInterfaceType(entries.getCount(), entries.getBuffer());
+
+ // Transfer decorations from the old interface to the new one.
+ interfaceType->transferDecorationsTo(newInterfaceType);
+ interfaceType->replaceUsesWith(newInterfaceType);
+
+ // Replace the interface maps in the caches.
+ if (this->pairTypeCache.containsKey(interfaceType))
+ this->pairTypeCache[newInterfaceType] = this->pairTypeCache[interfaceType];
+
+ if (this->existentialPairTypeCache.containsKey(interfaceType))
+ this->existentialPairTypeCache[newInterfaceType] =
+ this->existentialPairTypeCache[interfaceType];
+
+ interfaceType->removeAndDeallocate();
+ interfaceType = newInterfaceType;
+ }
+
+ //
+ // Implement the requirements in all the witness tables.
+ //
+
+ // Collect all witness tables of the given interfaceType.
+ List<IRWitnessTable*> concreteWitnessTables;
+ auto witnessTableType = builder.getWitnessTableType(interfaceType);
+ for (auto use = witnessTableType->firstUse; use; use = use->nextUse)
+ {
+ if (auto witnessTable = as<IRWitnessTable>(use->getUser()))
+ {
+ if (use->getUser()->getFullType() == witnessTableType)
+ concreteWitnessTables.add(witnessTable);
+ }
+ }
+
+ DifferentiableTypeConformanceContext ctx(sharedContext);
+ ctx.buildGlobalWitnessDictionary();
+
+ for (auto concreteWitnessTable : concreteWitnessTables)
+ {
+ IRType* concretePrimalType = nullptr;
+
+ // What requirement are we trying to satisfy?
+ if (as<IRThisType>(requirementBaseType))
+ {
+ // For this types, we should lower the concrete type of the witness table itself.
+ concretePrimalType = concreteWitnessTable->getConcreteType();
+ }
+ else if (as<IRAssociatedType>(requirementBaseType))
+ {
+ // For associated types, look it up in the witness table.
+ concretePrimalType =
+ (IRType*)findWitnessTableEntry(concreteWitnessTable, requirementKey);
+ }
+ else
+ {
+ // We shouldn't see any other case here.
+ SLANG_UNEXPECTED("Unexpected requirement base type");
+ }
+
+ // Create the pair type.
+ auto witness = ctx.tryGetDifferentiableWitness(
+ &builder,
+ concretePrimalType,
+ DiffConformanceKind::Value);
+
+ // Really should not see a case where the original interface is differentiable, but
+ // we can't find the witness table.
+ //
+ SLANG_ASSERT(witness);
+
+ auto concretePairType = builder.getDifferentialPairType(
+ concretePrimalType,
+ witness); // TODO: Need to handle the other conformance kinds
+ auto concreteDiffType =
+ (IRType*)_getDiffTypeFromPairType(sharedContext, &builder, concretePairType);
+
+ auto loweredStructType = (IRType*)lowerDiffPairType(&builder, concretePairType);
+
+ // Create an (empty) witness table for loweredStuctType : IDiffPair_...
+ // This is just so that there is a bound on the any-value-size for each group of pair types.
+ //
+ auto witnessTable = builder.createWitnessTable(diffPairInterfaceType, loweredStructType);
+ builder.addKeepAliveDecoration(witnessTable);
+
+ builder.setInsertInto(concreteWitnessTable);
+
+ // Create the associated type entry.
+ {
+ builder.createWitnessTableEntry(
+ concreteWitnessTable,
+ diffPairRequirementKey,
+ loweredStructType);
+ }
+
+ // Create the getPrimal method.
+ {
+ auto primalMethod = builder.createFunc();
+
+ StringBuilder nameBuilder;
+ getTypeNameHint(nameBuilder, loweredStructType);
+ nameBuilder << "_getPrimal";
+ builder.addNameHintDecoration(primalMethod, nameBuilder.getUnownedSlice());
+
+ primalMethod->setFullType(builder.getFuncType(
+ List<IRType*>({(IRType*)loweredStructType}),
+ concretePrimalType));
+
+ builder.setInsertInto(primalMethod);
+ auto block = builder.emitBlock();
+ builder.setInsertInto(block);
+ auto param = builder.emitParam((IRType*)loweredStructType);
+ builder.emitReturn(
+ builder.emitFieldExtract(concretePrimalType, param, _getOrCreatePrimalStructKey()));
+
+ builder.setInsertInto(concreteWitnessTable);
+ builder.createWitnessTableEntry(
+ concreteWitnessTable,
+ getPrimalRequirementKey,
+ primalMethod);
+ }
+
+ // Create the getDiff method.
+ {
+ auto diffMethod = builder.createFunc();
+
+ StringBuilder nameBuilder;
+ getTypeNameHint(nameBuilder, loweredStructType);
+ nameBuilder << "_getDiff";
+ builder.addNameHintDecoration(diffMethod, nameBuilder.getUnownedSlice());
+
+ diffMethod->setFullType(
+ builder.getFuncType(List<IRType*>({(IRType*)loweredStructType}), concreteDiffType));
+
+ builder.setInsertInto(diffMethod);
+ auto block = builder.emitBlock();
+ builder.setInsertInto(block);
+ auto param = builder.emitParam((IRType*)loweredStructType);
+ builder.emitReturn(
+ builder.emitFieldExtract(concreteDiffType, param, _getOrCreateDiffStructKey()));
+
+ builder.setInsertInto(concreteWitnessTable);
+ builder.createWitnessTableEntry(
+ concreteWitnessTable,
+ getDiffRequirementKey,
+ diffMethod);
+ }
+
+ // Create the makePair method.
+ {
+ auto makePairMethod = builder.createFunc();
+
+ StringBuilder nameBuilder;
+ getTypeNameHint(nameBuilder, loweredStructType);
+ nameBuilder << "_makePair";
+ builder.addNameHintDecoration(makePairMethod, nameBuilder.getUnownedSlice());
+
+ makePairMethod->setFullType(builder.getFuncType(
+ List<IRType*>({concretePrimalType, concreteDiffType}),
+ (IRType*)loweredStructType));
+
+ builder.setInsertInto(makePairMethod);
+ auto block = builder.emitBlock();
+ builder.setInsertInto(block);
+ auto param1 = builder.emitParam(concretePrimalType);
+ auto param2 = builder.emitParam(concreteDiffType);
+ List<IRInst*> args = {param1, param2};
+ auto pair = builder.emitMakeStruct((IRType*)loweredStructType, args);
+ builder.emitReturn(pair);
+
+ builder.setInsertInto(concreteWitnessTable);
+ builder.createWitnessTableEntry(
+ concreteWitnessTable,
+ makePairRequirementKey,
+ makePairMethod);
+ }
+ }
+
+ return diffPairRequirementKey;
+}
+
IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, IRType* diffType)
{
switch (origBaseType->getOp())
@@ -333,6 +832,7 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I
return pairStructType;
}
+
IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(IRBuilder* builder, IRType* originalPairType)
{
IRInst* result = nullptr;
@@ -352,26 +852,119 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(IRBuilder* builder, IRTyp
// purposes.
auto primalType = pairType->getValueType();
- if (pairTypeCache.tryGetValue(primalType, result))
- return result;
- if (!pairType)
+
+ if (isRuntimeType(primalType))
{
- result = originalPairType;
+ // Existential case.
+ auto cacheKey = getCacheKey(builder, primalType);
+ auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType);
+
+ IRInst* pairReqKey = nullptr;
+ if (!existentialPairTypeCache.tryGetValue(cacheKey, pairReqKey))
+ {
+ pairReqKey = _createDiffPairInterfaceRequirement(primalType, (IRType*)diffType);
+ existentialPairTypeCache.add(cacheKey, pairReqKey);
+ }
+
+ auto baseWitnessTable = getExistentialBaseWitnessTable(builder, primalType);
+ result = builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ baseWitnessTable,
+ pairReqKey);
+
+ primalTypeMap[result] = primalType;
+ diffTypeMap[result] = (IRType*)diffType;
+
return result;
}
- if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalType))
+ else if (auto typePack = as<IRTypePack>(primalType))
{
- result = nullptr;
- return result;
+ // Lower DiffPair(TypePack(a_0, a_1, ...), MakeWitnessPack(w_0, w_1, ...)) as
+ // TypePack(DiffPair(a_0, w_0), DiffPair(a_1, w_1), ...)
+ //
+ auto cacheKey = primalType;
+ if (pairTypeCache.tryGetValue(cacheKey, result))
+ return result;
+
+ auto packWitness = pairType->getWitness();
+
+ // Right now we only support concrete witness tables for type packs.
+ auto concretePackWitness = as<IRWitnessTable>(packWitness);
+ SLANG_ASSERT(concretePackWitness);
+
+ // Get diff type pack.
+ IRTypePack* diffTypePack = nullptr;
+
+ if (concretePackWitness->getConformanceType() ==
+ this->sharedContext->differentiableInterfaceType)
+ diffTypePack = as<IRTypePack>(findWitnessTableEntry(
+ concretePackWitness,
+ this->sharedContext->differentialAssocTypeStructKey));
+ else if (
+ concretePackWitness->getConformanceType() ==
+ this->sharedContext->differentiablePtrInterfaceType)
+ diffTypePack = as<IRTypePack>(findWitnessTableEntry(
+ concretePackWitness,
+ this->sharedContext->differentialAssocRefTypeStructKey));
+ else
+ SLANG_UNEXPECTED("Unexpected witness table");
+
+ SLANG_ASSERT(diffTypePack);
+
+ List<IRType*> args;
+ for (UInt i = 0; i < typePack->getOperandCount(); i++)
+ {
+ auto type = (IRType*)typePack->getOperand(i);
+ auto diffType = (IRType*)typePack->getOperand(i);
+
+ if (pairTypeCache.tryGetValue(type, result))
+ {
+ args.add((IRType*)result);
+ continue;
+ }
+
+ // Lower the diff pair type.
+ auto loweredPairType = (IRType*)_createDiffPairType(type, diffType);
+
+ pairTypeCache.add(type, loweredPairType);
+ args.add(loweredPairType);
+ }
+
+ auto loweredTypePack = builder->getTypePack(args.getCount(), args.getBuffer());
+ // TODO: Unify the cache between the three cases.
+ pairTypeCache.add(cacheKey, loweredTypePack);
+
+ return loweredTypePack;
}
+ else
+ {
+ auto cacheKey = primalType;
+ if (pairTypeCache.tryGetValue(primalType, result))
+ return result;
- auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType);
- if (!diffType)
- return result;
- result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
- pairTypeCache.add(primalType, result);
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalType))
+ {
+ result = nullptr;
+ return result;
+ }
+
+ if (as<IRThisType>(primalType) || as<IRAssociatedType>(primalType))
+ {
+ List<IRInterfaceType*> constraintTypes;
+ constraintTypes.add(this->commonDiffPairInterface);
+ return builder->getAssociatedType(constraintTypes.getArrayView());
+ }
+
+ auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType);
+ if (!diffType)
+ return result;
+
+ // Concrete case.
+ result = _createDiffPairType(primalType, (IRType*)diffType);
+ pairTypeCache.add(cacheKey, result);
- return result;
+ return result;
+ }
}
IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst)
@@ -550,6 +1143,13 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit
auto innerWitnessTableType = cast<IRWitnessTableType>(operand);
return cast<IRInterfaceType>(innerWitnessTableType->getConformanceType());
}
+ else if (auto genericWitness = as<IRGeneric>(witness))
+ {
+ // This is a generic witness table.
+ auto innerWitness = getGenericReturnVal(genericWitness);
+ SLANG_ASSERT(as<IRWitnessTableType>(innerWitness->getDataType()));
+ return getConformanceTypeFromWitness(innerWitness);
+ }
else
{
SLANG_UNEXPECTED("Unexpected witness type");
@@ -558,81 +1158,134 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit
return diffInterfaceType;
}
+List<IRDifferentiableTypeAnnotation*> DifferentiableTypeConformanceContext::getAnnotations(
+ IRGlobalValueWithCode* code)
+{
+ // Scan function for all IRDifferentiableTypeAnnotation insts.
+ List<IRDifferentiableTypeAnnotation*> annotations;
+ for (auto block : code->getBlocks())
+ {
+ for (auto child : block->getChildren())
+ {
+ if (auto annotation = as<IRDifferentiableTypeAnnotation>(child))
+ {
+ annotations.add(annotation);
+ }
+ }
+ }
+
+ return annotations;
+}
+
+List<IRDifferentiableTypeAnnotation*> DifferentiableTypeConformanceContext::getAnnotations(
+ IRModuleInst* module)
+{
+ // Scan module for all IRDifferentiableTypeAnnotation insts.
+ List<IRDifferentiableTypeAnnotation*> annotations;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto annotation = as<IRDifferentiableTypeAnnotation>(globalInst))
+ {
+ annotations.add(annotation);
+ }
+ }
+
+ return annotations;
+}
+
void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
parentFunc = func;
+ List<IRDifferentiableTypeAnnotation*> annotations = getAnnotations(func);
- auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- SLANG_RELEASE_ASSERT(decor);
-
- // Build lookup dictionary for type witnesses.
- for (auto child = decor->getFirstChild(); child; child = child->next)
+ // Go up the parents of func & add the annotations of any IRGeneric or IRModule parent:
+ IRInst* parent = func;
+ while (parent)
{
- if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
+ if (auto upperFunc = as<IRGlobalValueWithCode>(parent))
{
- IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness());
+ // TODO: Cache this.
+ auto parentAnnotations = getAnnotations(upperFunc);
+ annotations.addRange(parentAnnotations);
+ }
+ else if (auto module = as<IRModuleInst>(parent))
+ {
+ // TODO: Cache this.
+ auto parentAnnotations = getAnnotations(module);
+ annotations.addRange(parentAnnotations);
+ }
+ parent = parent->getParent();
+ }
- SLANG_ASSERT(
- diffInterfaceType == sharedContext->differentiableInterfaceType ||
- diffInterfaceType == sharedContext->differentiablePtrInterfaceType);
+ for (auto item : annotations)
+ {
+ IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness());
- auto existingItem =
- differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType());
- if (existingItem)
- {
- *existingItem = item->getWitness();
- }
- else
- {
- auto witness = item->getWitness();
+ SLANG_ASSERT(
+ diffInterfaceType == sharedContext->differentiableInterfaceType ||
+ diffInterfaceType == sharedContext->differentiablePtrInterfaceType);
- // Also register the type's differential type with the same witness.
- auto concreteType = item->getConcreteType();
- IRBuilder subBuilder(item->getConcreteType());
- if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType))
+ auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getBaseType());
+ if (existingItem)
+ {
+ *existingItem = item->getWitness();
+ }
+ else
+ {
+ auto witness = item->getWitness();
+
+ // Also register the type's differential type with the same witness.
+ auto concreteType = item->getBaseType();
+ IRBuilder subBuilder(item->getBaseType());
+ if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType))
+ {
+ // For tuple types with concrete element types,
+ // register the differential type for each element, but don't register for the
+ // tuple/typepack itself.
+ if (auto witnessPack = as<IRMakeWitnessPack>(witness))
{
- // For tuple types with concrete element types,
- // register the differential type for each element, but don't register for the
- // tuple/typepack itself.
- if (auto witnessPack = as<IRMakeWitnessPack>(witness))
+
+ for (UInt i = 0; i < concreteType->getOperandCount(); i++)
{
+ auto element = concreteType->getOperand(i);
+ auto elementWitness = witnessPack->getOperand(i);
- for (UInt i = 0; i < concreteType->getOperandCount(); i++)
- {
- auto element = concreteType->getOperand(i);
- auto elementWitness = witnessPack->getOperand(i);
-
- if (diffInterfaceType == sharedContext->differentiableInterfaceType)
- addTypeToDictionary((IRType*)element, elementWitness);
- else if (
- diffInterfaceType == sharedContext->differentiablePtrInterfaceType)
- addTypeToDictionary((IRType*)element, elementWitness);
- }
- return;
+ if (diffInterfaceType == sharedContext->differentiableInterfaceType)
+ addTypeToDictionary((IRType*)element, elementWitness);
+ else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType)
+ addTypeToDictionary((IRType*)element, elementWitness);
}
+ return;
}
+ }
- addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness());
+ addTypeToDictionary((IRType*)item->getBaseType(), item->getWitness());
- if (!as<IRInterfaceType>(item->getConcreteType()))
- {
- addTypeToDictionary(
- (IRType*)_lookupWitness(
- &subBuilder,
- item->getWitness(),
- sharedContext->differentialAssocTypeStructKey,
- subBuilder.getTypeKind()),
- item->getWitness());
- }
+ // TODO: Is this really needed?
+ if (!as<IRInterfaceType>(item->getBaseType()) &&
+ !as<IRAssociatedType>(item->getBaseType()))
+ {
+ addTypeToDictionary(
+ (IRType*)_lookupWitness(
+ &subBuilder,
+ item->getWitness(),
+ sharedContext->differentialAssocTypeStructKey,
+ subBuilder.getTypeKind()),
+ item->getWitness());
+ }
- if (auto diffPairType = as<IRDifferentialPairTypeBase>(item->getConcreteType()))
- {
- // For differential pair types, register the differential type as well.
- IRBuilder builder(diffPairType);
- builder.setInsertAfter(diffPairType->getWitness());
+ // TODO: Is this really needed?
+ if (auto diffPairType = as<IRDifferentialPairTypeBase>(item->getBaseType()))
+ {
+ // For differential pair types, register the differential type as well.
+ IRBuilder builder(diffPairType);
+ builder.setInsertAfter(diffPairType->getWitness());
- // TODO(sai): lot of this logic is duplicated. need to refactor.
+ // TODO(sai): lot of this logic is duplicated. need to refactor.
+ if (!as<IRInterfaceType>(diffPairType->getValueType()) &&
+ !as<IRAssociatedType>(diffPairType->getValueType()))
+ {
auto diffType =
(diffInterfaceType == sharedContext->differentiableInterfaceType)
? _lookupWitness(
@@ -665,12 +1318,28 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
}
}
+IRWitnessTable* findGlobalWitness(IRInterfaceType* interface, IRInst* type)
+{
+ for (auto use = type->firstUse; use; use = use->nextUse)
+ {
+ if (auto witnessTable = as<IRWitnessTable>(use->getUser()))
+ {
+ if (witnessTable->getConcreteType() == type &&
+ witnessTable->getConformanceType() == interface)
+ return witnessTable;
+ }
+ }
+
+ return nullptr;
+}
+
IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(
IRInst* type,
DiffConformanceKind kind)
{
IRInst* foundResult = nullptr;
differentiableTypeWitnessDictionary.tryGetValue(type, foundResult);
+
if (!foundResult)
return nullptr;
@@ -791,8 +1460,8 @@ IRInst* DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface
return nullptr;
}
-// Given an interface type, return the lookup path from a witness table of `type` to a witness table
-// of `supType`.
+// Given an interface type, return the lookup path from a witness table of `type` to a witness
+// table of `supType`.
static bool _findInterfaceLookupPathImpl(
HashSet<IRInst*>& processedTypes,
IRInterfaceType* supType,
@@ -967,6 +1636,11 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
addTypeToDictionary(pairType->getValueType(), pairType->getWitness());
}
+
+ if (auto annotation = as<IRDifferentiableTypeAnnotation>(globalInst))
+ {
+ addTypeToDictionary((IRType*)annotation->getBaseType(), annotation->getWitness());
+ }
}
}
@@ -1071,6 +1745,20 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(
}
}
+IRType* getAssociatedTypeForKey(IRInst* key)
+{
+ for (auto use = key->firstUse; use; use = use->nextUse)
+ {
+ if (auto interfaceReq = as<IRInterfaceRequirementEntry>(key))
+ {
+ if (auto assocType = as<IRAssociatedType>(interfaceReq->getRequirementVal()))
+ return assocType;
+ }
+ }
+
+ return nullptr;
+}
+
IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(
IRBuilder* builder,
IRInst* primalType,
@@ -1118,8 +1806,9 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(
}
else if (auto lookup = as<IRLookupWitnessMethod>(primalType))
{
- // For types that are lookups from a table, we can simply lookup the witness from the same
- // table
+ // Trivial cases: For types that are lookups from a table, we can simply lookup the
+ // witness from the same table
+ //
if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey)
{
witness = builder->emitLookupInterfaceMethodInst(
@@ -1203,8 +1892,8 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
auto p0 = b.emitParam(diffDiffPairType);
auto p1 = b.emitParam(diffDiffPairType);
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value
- // type == diff type.
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that
+ // value type == diff type.
auto innerAdd = _lookupWitness(
&b,
innerWitness,
@@ -1325,8 +2014,8 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
auto p0 = b.emitParam(diffArrayType);
auto p1 = b.emitParam(diffArrayType);
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value
- // type == diff type.
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that
+ // value type == diff type.
auto innerAdd = _lookupWitness(
&b,
innerWitness,
@@ -1566,6 +2255,143 @@ IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness
return nullptr;
}
+IRInst* DifferentiableTypeConformanceContext::emitDAddOfDiffInstType(
+ IRBuilder* builder,
+ IRType* primalType,
+ IRInst* op1,
+ IRInst* op2)
+{
+ if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ // TODO: This case should really not be necessary anymore
+ auto diffElementType =
+ (IRType*)this->getDifferentialForType(builder, arrayType->getElementType());
+ SLANG_RELEASE_ASSERT(diffElementType);
+ auto arraySize = arrayType->getElementCount();
+
+ if (auto constArraySize = as<IRIntLit>(arraySize))
+ {
+ List<IRInst*> args;
+ for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++)
+ {
+ auto index = builder->getIntValue(builder->getIntType(), i);
+ auto op1Val = builder->emitElementExtract(diffElementType, op1, index);
+ auto op2Val = builder->emitElementExtract(diffElementType, op2, index);
+ args.add(
+ emitDAddOfDiffInstType(builder, arrayType->getElementType(), op1Val, op2Val));
+ }
+ auto diffArrayType =
+ builder->getArrayType(diffElementType, arrayType->getElementCount());
+ return builder->emitMakeArray(diffArrayType, (UInt)args.getCount(), args.getBuffer());
+ }
+ else
+ {
+ // TODO: insert a runtime loop here.
+ SLANG_UNIMPLEMENTED_X("dadd of dynamic array.");
+ }
+ }
+ else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
+ {
+ // TODO: This case should really not be necessary anymore
+ auto diffType = (IRType*)this->getDiffTypeFromPairType(builder, diffPairUserType);
+ auto diffWitness = this->getDiffTypeWitnessFromPairType(builder, diffPairUserType);
+
+ auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1);
+ auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2);
+ auto primal =
+ emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2);
+
+ auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1);
+ auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2);
+ auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2);
+
+ auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
+ return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff);
+ }
+ else if (as<IRInterfaceType>(primalType))
+ {
+ // If our type is existential, we need to handle the case where
+ // one or both of our operands are null-type.
+ //
+ return emitDAddForExistentialType(builder, primalType, op1, op2);
+ }
+ else if (as<IRAssociatedType>(primalType))
+ {
+ // Should not happen. associated type does not have any additional info, we can't
+ // lookup the necessary methods.
+ //
+ SLANG_UNEXPECTED("unexpected associated type during transposition");
+ }
+
+ auto addMethod = this->getAddMethodForType(builder, primalType);
+
+ // Should exist.
+ SLANG_ASSERT(addMethod);
+
+ return builder->emitCallInst(
+ (IRType*)this->getDifferentialForType(builder, primalType),
+ addMethod,
+ List<IRInst*>(op1, op2));
+}
+
+IRInst* DifferentiableTypeConformanceContext::emitDAddForExistentialType(
+ IRBuilder* builder,
+ IRType* primalType,
+ IRInst* op1,
+ IRInst* op2)
+{
+ return builder->emitCallInst(
+ (IRType*)this->getDifferentialForType(builder, primalType),
+ this->getOrCreateExistentialDAddMethod(),
+ List<IRInst*>({op1, op2}));
+}
+
+IRInst* DifferentiableTypeConformanceContext::emitDZeroOfDiffInstType(
+ IRBuilder* builder,
+ IRType* primalType)
+{
+ if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ // TODO: This case should really not be necessary anymore
+ auto diffElementType =
+ (IRType*)this->getDifferentialForType(builder, arrayType->getElementType());
+ SLANG_RELEASE_ASSERT(diffElementType);
+ auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount());
+ auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType());
+ return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero);
+ }
+ else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
+ {
+ // TODO: This case should really not be necessary anymore.
+ auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType());
+ auto diffZero = primalZero;
+ auto diffType = primalZero->getFullType();
+ auto diffWitness = this->getDiffTypeWitnessFromPairType(builder, diffPairUserType);
+ auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
+ return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero);
+ }
+ else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType))
+ {
+ // Pack a null value into an existential type.
+ auto existentialZero = builder->emitMakeExistential(
+ this->sharedContext->differentiableInterfaceType,
+ this->emitNullDifferential(builder),
+ this->sharedContext->nullDifferentialWitness);
+
+ return existentialZero;
+ }
+
+ auto zeroMethod = this->getZeroMethodForType(builder, primalType);
+
+ // Should exist.
+ SLANG_ASSERT(zeroMethod);
+
+ return builder->emitCallInst(
+ (IRType*)this->getDifferentialForType(builder, primalType),
+ zeroMethod,
+ List<IRInst*>());
+}
+
void copyCheckpointHints(
IRBuilder* builder,
IRGlobalValueWithCode* oldInst,
@@ -1883,6 +2709,7 @@ struct AutoDiffPass : public InstPassBase
{
bool result = false;
OrderedHashSet<IRInst*> loweredIntermediateTypes;
+ Dictionary<IRInst*, IRGlobalValueWithCode*> typeToBwdFuncMap;
// Replace all `BackwardDiffIntermediateContextType` insts with the struct type
// that we generated during backward diff pass.
@@ -1906,6 +2733,38 @@ struct AutoDiffPass : public InstPassBase
if (type)
{
loweredIntermediateTypes.add(type);
+
+ auto func = differentiateInst->getFunc();
+
+ if (auto spec = as<IRSpecialize>(func))
+ func = spec->getBase();
+
+ if (auto generic = as<IRGeneric>(func))
+ {
+ func =
+ cast<IRGlobalValueWithCode>(findGenericReturnVal(generic));
+
+ auto bwdFuncDecor = func->findDecoration<
+ IRBackwardDerivativePropagateDecoration>();
+
+ typeToBwdFuncMap.add(
+ type,
+ cast<IRGlobalValueWithCode>(
+ as<IRSpecialize>(
+ bwdFuncDecor->getBackwardDerivativePropagateFunc())
+ ->getBase()));
+ }
+ else
+ {
+ auto bwdFuncDecor = func->findDecoration<
+ IRBackwardDerivativePropagateDecoration>();
+
+ typeToBwdFuncMap.add(
+ type,
+ cast<IRGlobalValueWithCode>(
+ bwdFuncDecor->getBackwardDerivativePropagateFunc()));
+ }
+
inst->replaceUsesWith(type);
inst->removeAndDeallocate();
changed = true;
@@ -1922,7 +2781,9 @@ struct AutoDiffPass : public InstPassBase
}
// Now we generate the differential type for the intermediate context type
// to allow higher order differentiation.
- generateDifferentialImplementationForContextType(loweredIntermediateTypes);
+ generateDifferentialImplementationForContextType(
+ loweredIntermediateTypes,
+ typeToBwdFuncMap);
return result;
}
@@ -1977,22 +2838,13 @@ struct AutoDiffPass : public InstPassBase
IRInst* addMethod = nullptr;
};
- // Register the differential type for an intermediate context type to the derivative functions
- // that uses the type.
+ // Register the differential type for an intermediate context type to the derivative
+ // functions that uses the type.
void registerDiffContextType(
IRBuilder& builder,
- IRDifferentiableTypeDictionaryDecoration* diffDecor,
OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
IRInst* origType)
{
- HashSet<IRInst*> registeredType;
- for (auto entry : diffDecor->getChildren())
- {
- if (auto e = as<IRDifferentiableTypeDictionaryItem>(entry))
- {
- registeredType.add(e->getOperand(0));
- }
- }
// Use a work list to recursively walk through all sub fields of the struct type.
List<IRInst*> wlist;
wlist.add(origType);
@@ -2002,10 +2854,13 @@ struct AutoDiffPass : public InstPassBase
IntermediateContextTypeDifferentialInfo diffInfo;
if (!diffTypes.tryGetValue(t, diffInfo))
continue;
- if (registeredType.add(t))
- builder.addDifferentiableTypeEntry(diffDecor, t, diffInfo.diffWitness);
- else
- continue;
+
+ IRInst* args[] = {t, diffInfo.diffWitness};
+ builder.emitIntrinsicInst(
+ builder.getVoidType(),
+ kIROp_DifferentiableTypeAnnotation,
+ 2,
+ args);
if (auto structType = as<IRStructType>(getResolvedInstForDecorations(t)))
{
@@ -2017,7 +2872,9 @@ struct AutoDiffPass : public InstPassBase
}
}
- void generateDifferentialImplementationForContextType(OrderedHashSet<IRInst*>& contextTypes)
+ void generateDifferentialImplementationForContextType(
+ OrderedHashSet<IRInst*>& contextTypes,
+ Dictionary<IRInst*, IRGlobalValueWithCode*> typeToBwdFuncMap)
{
// First we are going to topology sort all intermediate context types.
OrderedHashSet<IRInst*> sortedContextTypes;
@@ -2043,6 +2900,10 @@ struct AutoDiffPass : public InstPassBase
IRBuilder builder(module);
for (auto t : sortedContextTypes)
{
+ auto func = typeToBwdFuncMap[t];
+ DifferentiableTypeConformanceContext ctx(this->autodiffContext);
+ ctx.setFunc(func);
+
if (t->getOp() == kIROp_Generic || t->getOp() == kIROp_StructType)
{
// For generics/struct types, we will generate a new generic/struct type
@@ -2050,7 +2911,7 @@ struct AutoDiffPass : public InstPassBase
SLANG_RELEASE_ASSERT(t->getParent() && t->getParent()->getOp() == kIROp_Module);
builder.setInsertBefore(t);
- auto diffInfo = fillDifferentialTypeImplementation(diffTypes, t);
+ auto diffInfo = fillDifferentialTypeImplementation(&ctx, diffTypes, t);
diffTypes[t] = diffInfo;
}
else if (auto specialize = as<IRSpecialize>(t))
@@ -2085,30 +2946,29 @@ struct AutoDiffPass : public InstPassBase
// function without a intermediate-type via an interface.
SLANG_RELEASE_ASSERT(diffTypes.containsKey(t));
}
- }
- // Register the differential types into the conformance dictionaries of the functions that
- // uses them.
- for (auto t : diffTypes)
- {
+ if (!diffTypes.containsKey(t))
+ continue;
+
+ // If we created a new differential type, we need to place into the contexts of all
+ // functions that use it.
+ //
HashSet<IRFunc*> registeredFuncs;
- for (auto use = t.key->firstUse; use; use = use->nextUse)
+ for (auto use = t->firstUse; use; use = use->nextUse)
{
auto parentFunc = getParentFunc(use->getUser());
if (!parentFunc)
continue;
if (!registeredFuncs.add(parentFunc))
continue;
- if (auto dictDecor =
- parentFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- registerDiffContextType(builder, dictDecor, diffTypes, t.key);
- }
+
+ registerDiffContextType(builder, diffTypes, t);
}
}
}
IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementationForStruct(
+ DifferentiableTypeConformanceContext* ctx,
OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
IRStructType* originalType,
IRStructType* diffType)
@@ -2122,6 +2982,7 @@ struct AutoDiffPass : public InstPassBase
// Generate the fields for all differentiable members of the original struct type.
struct FieldInfo
{
+ IRType* primalType;
IRStructField* field;
IRInst* witness;
};
@@ -2130,30 +2991,30 @@ struct AutoDiffPass : public InstPassBase
for (auto field : originalType->getFields())
{
IRInst* diffFieldWitness = nullptr;
- if (auto diffDecor =
- field->findDecoration<IRIntermediateContextFieldDifferentialTypeDecoration>())
- {
- diffFieldWitness = diffDecor->getDifferentialWitness();
- }
- else
+
+ diffFieldWitness = ctx->tryGetDifferentiableWitness(
+ &builder,
+ field->getFieldType(),
+ DiffConformanceKind::Value);
+
+ if (!diffFieldWitness)
{
IntermediateContextTypeDifferentialInfo diffFieldTypeInfo;
diffTypes.tryGetValue(field->getFieldType(), diffFieldTypeInfo);
diffFieldWitness = diffFieldTypeInfo.diffWitness;
}
+
if (diffFieldWitness)
{
FieldInfo info;
IRBuilder keyBuilder = builder;
keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType));
auto diffKey = keyBuilder.createStructKey();
- auto diffFieldType = _lookupWitness(
- &keyBuilder,
- diffFieldWitness,
- autodiffContext->differentialAssocTypeStructKey,
- builder.getTypeKind());
+ auto diffFieldType = ctx->getDifferentialForType(&builder, field->getFieldType());
+
info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType);
info.witness = diffFieldWitness;
+ info.primalType = field->getFieldType();
builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey);
builder.addDecoration(diffKey, kIROp_DerivativeMemberDecoration, diffKey);
diffFields.add(info);
@@ -2172,16 +3033,10 @@ struct AutoDiffPass : public InstPassBase
builder.setInsertInto(zeroMethod);
builder.emitBlock();
List<IRInst*> fieldVals;
+
for (auto info : diffFields)
{
- auto innerZeroMethod = _lookupWitness(
- &builder,
- info.witness,
- autodiffContext->zeroMethodStructKey,
- autodiffContext->zeroMethodType);
- IRInst* val =
- builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr);
- fieldVals.add(val);
+ fieldVals.add(ctx->emitDZeroOfDiffInstType(&builder, info.primalType));
}
builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals));
}
@@ -2203,20 +3058,15 @@ struct AutoDiffPass : public InstPassBase
List<IRInst*> fieldVals;
for (auto info : diffFields)
{
- auto innerAddMethod = _lookupWitness(
- &builder,
- info.witness,
- autodiffContext->addMethodStructKey,
- autodiffContext->addMethodType);
IRInst* args[2] = {
builder
.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()),
builder
.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()),
};
- IRInst* val =
- builder.emitCallInst(info.field->getFieldType(), innerAddMethod, 2, args);
- fieldVals.add(val);
+
+ fieldVals.add(
+ ctx->emitDAddOfDiffInstType(&builder, info.primalType, args[0], args[1]));
}
builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals));
}
@@ -2265,6 +3115,7 @@ struct AutoDiffPass : public InstPassBase
}
IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementation(
+ DifferentiableTypeConformanceContext* ctx,
OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
IRInst* originalType)
{
@@ -2274,6 +3125,7 @@ struct AutoDiffPass : public InstPassBase
builder.setInsertBefore(originalType);
auto diffType = builder.createStructType();
return fillDifferentialTypeImplementationForStruct(
+ ctx,
diffTypes,
as<IRStructType>(originalType),
as<IRStructType>(diffType));
@@ -2286,7 +3138,7 @@ struct AutoDiffPass : public InstPassBase
auto structType = as<IRStructType>(findGenericReturnVal(genType));
SLANG_RELEASE_ASSERT(structType);
- auto innerResult = fillDifferentialTypeImplementation(diffTypes, structType);
+ auto innerResult = fillDifferentialTypeImplementation(ctx, diffTypes, structType);
IRBuilder builder(originalType);
builder.setInsertBefore(originalType);
@@ -2421,7 +3273,8 @@ struct AutoDiffPass : public InstPassBase
{
bool changed = false;
List<IRInst*> autoDiffWorkList;
- // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call graph.
+ // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call
+ // graph.
processAllReachableInsts(
[&](IRInst* inst)
{
@@ -2438,6 +3291,7 @@ struct AutoDiffPass : public InstPassBase
case kIROp_Func:
case kIROp_Specialize:
case kIROp_LookupWitness:
+ case kIROp_Generic:
if (auto innerFunc =
as<IRFunc>(getResolvedInstForDecorations(inst->getOperand(0))))
{
@@ -2519,8 +3373,8 @@ struct AutoDiffPass : public InstPassBase
}
// Run transcription logic to generate the body of forward/backward derivatives
- // functions. While doing so, we may discover new functions to differentiate, so we keep
- // running until the worklist goes dry.
+ // functions. While doing so, we may discover new functions to differentiate, so we
+ // keep running until the worklist goes dry.
List<IRFunc*> autodiffCleanupList;
while (autodiffContext->followUpFunctionsToTranscribe.getCount() != 0)
{
@@ -2582,10 +3436,10 @@ struct AutoDiffPass : public InstPassBase
hasChanges = true;
// We have done transcribing the functions, now it is time to demote all
- // DifferentialPair types and their operations down to DifferentialPairUserCodeType and
- // *UserCode operations so they can be treated just like normal types with no special
- // semantics in future processing, and won't be confused with the semantics of a
- // DifferentialPair type during future autodiff code gen.
+ // DifferentialPair types and their operations down to DifferentialPairUserCodeType
+ // and *UserCode operations so they can be treated just like normal types with no
+ // special semantics in future processing, and won't be confused with the semantics
+ // of a DifferentialPair type during future autodiff code gen.
rewriteDifferentialPairToUserCode(module);
hasChanges |= changed;
@@ -2693,8 +3547,8 @@ void checkAutodiffPatterns(TargetProgram* target, IRModule* module, DiagnosticSi
if (func->sourceLoc.isValid() && // Don't diagnose for synthesized functions
func->findDecoration<IRPreferRecomputeDecoration>())
{
- // If we don't have any side-effect behavior, we should warn (note: read-none is a
- // stronger guarantee than no-side-effect)
+ // If we don't have any side-effect behavior, we should warn (note: read-none is
+ // a stronger guarantee than no-side-effect)
//
if (func->findDecoration<IRNoSideEffectDecoration>() ||
func->findDecoration<IRReadNoneDecoration>())
@@ -2759,6 +3613,27 @@ void removeDetachInsts(IRModule* module)
pass.processModule();
}
+
+struct RemoveTypeAnnotationInstsPass : InstPassBase
+{
+ RemoveTypeAnnotationInstsPass(IRModule* module)
+ : InstPassBase(module)
+ {
+ }
+ void processModule()
+ {
+ processInstsOfType<IRDifferentiableTypeAnnotation>(
+ kIROp_DifferentiableTypeAnnotation,
+ [&](IRDifferentiableTypeAnnotation* annotation) { annotation->removeAndDeallocate(); });
+ }
+};
+
+void removeTypeAnnotations(IRModule* module)
+{
+ RemoveTypeAnnotationInstsPass pass(module);
+ pass.processModule();
+}
+
struct LowerNullCheckPass : InstPassBase
{
LowerNullCheckPass(IRModule* module, AutoDiffSharedContext* context)
@@ -2841,6 +3716,8 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module)
removeDetachInsts(module);
+ removeTypeAnnotations(module);
+
lowerNullCheckInsts(module, &autodiffContext);
stripNoDiffTypeAttribute(module);
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 2b03f3923..433b6093f 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -221,6 +221,8 @@ struct DifferentiableTypeConformanceContext
IRGlobalValueWithCode* parentFunc = nullptr;
OrderedDictionary<IRType*, IRInst*> differentiableTypeWitnessDictionary;
+ Dictionary<IRInst*, List<IRDifferentiableTypeAnnotation*>> annotationCache;
+
IRFunc* existentialDAddFunc = nullptr;
DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
@@ -235,6 +237,10 @@ struct DifferentiableTypeConformanceContext
void setFunc(IRGlobalValueWithCode* func);
+ List<IRDifferentiableTypeAnnotation*> getAnnotations(IRGlobalValueWithCode* inst);
+
+ List<IRDifferentiableTypeAnnotation*> getAnnotations(IRModuleInst* inst);
+
void buildGlobalWitnessDictionary();
// Lookup a witness table for the concreteType. One should exist if concreteType
@@ -445,6 +451,20 @@ struct DifferentiableTypeConformanceContext
IRBuilder* builder,
IRExtractExistentialType* extractExistentialType,
DiffConformanceKind target);
+
+ IRInst* emitDAddOfDiffInstType(
+ IRBuilder* builder,
+ IRType* primalType,
+ IRInst* op1,
+ IRInst* op2);
+
+ IRInst* emitDAddForExistentialType(
+ IRBuilder* builder,
+ IRType* primalType,
+ IRInst* op1,
+ IRInst* op2);
+
+ IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType);
};
@@ -461,9 +481,15 @@ struct DifferentialPairTypeBuilder
IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key);
- IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst);
+ IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRType* loweredPairType, IRInst* baseInst);
+
+ IRInst* emitDiffFieldAccess(IRBuilder* builder, IRType* loweredPairType, IRInst* baseInst);
- IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst);
+ IRInst* emitExistentialMakePair(
+ IRBuilder* builder,
+ IRInst* type,
+ IRInst* primalInst,
+ IRInst* diffInst);
IRStructKey* _getOrCreateDiffStructKey();
@@ -471,17 +497,52 @@ struct DifferentialPairTypeBuilder
IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType);
+ IRInst* _createDiffPairInterfaceRequirement(IRType* origBaseType, IRType* diffType);
+
IRInst* lowerDiffPairType(IRBuilder* builder, IRType* originalPairType);
+ IRInst* getOrCreateCommonDiffPairInterface(IRBuilder* builder);
+
struct PairStructKey
{
IRInst* originalType;
IRInst* diffType;
};
- // Cache from `IRDifferentialPairType` to materialized struct type.
+ // Cache from pair types to lowered type.
Dictionary<IRInst*, IRInst*> pairTypeCache;
+ // Cache from existential pair types to their lowered interface keys.
+ // We use a different cache because an interface type can have
+ // a regular pair for the pair of interface types, as well as an
+ // interface key for the associated pair types used for its implementations
+ //
+ Dictionary<IRInst*, IRInst*> existentialPairTypeCache;
+
+ // Cache for any interface requirement keys (generated for existential
+ // pair types)
+ //
+ Dictionary<IRInst*, IRStructKey*> assocPairTypeKeyMap;
+ Dictionary<IRInst*, IRStructKey*> makePairKeyMap;
+ Dictionary<IRInst*, IRStructKey*> getPrimalKeyMap;
+ Dictionary<IRInst*, IRStructKey*> getDiffKeyMap;
+
+ // More caches for easier lookups of the types associated with the
+ // keys. (avoid having to keep recomputing or performing complicated
+ // lookups)
+ //
+ Dictionary<IRInst*, IRFuncType*> makePairFuncTypeMap;
+ Dictionary<IRInst*, IRFuncType*> getPrimalFuncTypeMap;
+ Dictionary<IRInst*, IRFuncType*> getDiffFuncTypeMap;
+
+ // Even more caches for easier access to original primal/diff types
+ // (Only used for existential pair types). For regular pair types,
+ // these are easy to find right on the type itself.
+ //
+ Dictionary<IRInst*, IRType*> primalTypeMap;
+ Dictionary<IRInst*, IRType*> diffTypeMap;
+
+
IRStructKey* globalPrimalKey = nullptr;
IRStructKey* globalDiffKey = nullptr;
@@ -491,6 +552,8 @@ struct DifferentialPairTypeBuilder
List<IRInst*> generatedTypeList;
AutoDiffSharedContext* sharedContext = nullptr;
+
+ IRInterfaceType* commonDiffPairInterface = nullptr;
};
void stripAutoDiffDecorations(IRModule* module);
@@ -551,6 +614,10 @@ inline bool isRelevantDifferentialPair(IRType* type)
return false;
}
+bool isRuntimeType(IRType* type);
+
+IRInst* getExistentialBaseWitnessTable(IRBuilder* builder, IRType* type);
+
UIndex addPhiOutputArg(
IRBuilder* builder,
IRBlock* block,
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 88a9ac5e3..38e5f8869 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1282,6 +1282,9 @@ INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDicti
/* Differentiable Type Dictionary */
INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0)
+/* Differentiable Type Annotation (for run-time types)*/
+INST(DifferentiableTypeAnnotation, DifferentiableTypeAnnotation, 2, HOISTABLE)
+
INST(BeginFragmentShaderInterlock, BeginFragmentShaderInterlock, 0, 0)
INST(EndFragmentShaderInterlock, BeginFragmentShaderInterlock, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 53adce87a..a288bca97 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1334,6 +1334,18 @@ struct IRPrimalSubstitute : IRInst
IR_LEAF_ISA(PrimalSubstitute)
};
+struct IRDifferentiableTypeAnnotation : IRInst
+{
+ enum
+ {
+ kOp = kIROp_DifferentiableTypeAnnotation
+ };
+ IRInst* getBaseType() { return getOperand(0); }
+ IRInst* getWitness() { return getOperand(1); }
+
+ IR_LEAF_ISA(DifferentiableTypeAnnotation)
+};
+
struct IRDispatchKernel : IRInst
{
enum
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 28bb63a87..d60903cfc 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -1886,7 +1886,8 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
{
// We need to copy over exported symbols,
// and any global parameters if preserve-params option is set.
- if (_isHLSLExported(inst) || shouldCopyGlobalParams && as<IRGlobalParam>(inst))
+ if (_isHLSLExported(inst) || shouldCopyGlobalParams && as<IRGlobalParam>(inst) ||
+ as<IRDifferentiableTypeAnnotation>(inst))
{
auto cloned = cloneValue(context, inst);
if (!cloned->findDecorationImpl(kIROp_KeepAliveDecoration))
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 4fd162e53..b01abcbc5 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -228,8 +228,6 @@ void lowerGenerics(TargetProgram* targetProgram, IRModule* module, DiagnosticSin
checkTypeConformanceExists(&sharedContext);
- inferAnyValueSizeWhereNecessary(targetProgram, module);
-
// Replace all `makeExistential` insts with `makeExistentialWithRTTI`
// before making any other changes. This is necessary because a parameter of
// generic type will be lowered into `AnyValueType`, and after that we can no longer
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 50dfa2c6a..40cd40758 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -51,15 +51,17 @@ struct SpecializationContext
IRModule* module;
DiagnosticSink* sink;
TargetProgram* targetProgram;
+ SpecializationOptions options;
bool changed = false;
- SpecializationContext(IRModule* inModule, TargetProgram* target)
+ SpecializationContext(IRModule* inModule, TargetProgram* target, SpecializationOptions options)
: workList(*inModule->getContainerPool().getList<IRInst>())
, workListSet(*inModule->getContainerPool().getHashSet<IRInst>())
, cleanInsts(*inModule->getContainerPool().getHashSet<IRInst>())
, module(inModule)
, targetProgram(target)
+ , options(options)
{
}
~SpecializationContext()
@@ -1102,7 +1104,11 @@ struct SpecializationContext
// Now we consider lower lookupWitnessMethod insts into dynamic dispatch calls,
// which may open up more specialization opportunities.
//
- iterChanged = lowerWitnessLookup(module, sink);
+ if (options.lowerWitnessLookups)
+ {
+ iterChanged = lowerWitnessLookup(module, sink);
+ }
+
if (!iterChanged || sink->getErrorCount())
break;
}
@@ -2882,10 +2888,14 @@ struct SpecializationContext
}
};
-bool specializeModule(TargetProgram* target, IRModule* module, DiagnosticSink* sink)
+bool specializeModule(
+ TargetProgram* target,
+ IRModule* module,
+ DiagnosticSink* sink,
+ SpecializationOptions options)
{
SLANG_PROFILE;
- SpecializationContext context(module, target);
+ SpecializationContext context(module, target, options);
context.sink = sink;
context.processModule();
return context.changed;
diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h
index 734c76427..72f2c6130 100644
--- a/source/slang/slang-ir-specialize.h
+++ b/source/slang/slang-ir-specialize.h
@@ -7,8 +7,20 @@ struct IRModule;
class DiagnosticSink;
class TargetProgram;
+struct SpecializationOptions
+{
+ // Option that allows specializeModule to generate dynamic-dispatch code
+ // wherever possible to open up more specialization opportunities.
+ //
+ bool lowerWitnessLookups = false;
+};
+
/// Specialize generic and interface-based code to use concrete types.
-bool specializeModule(TargetProgram* target, IRModule* module, DiagnosticSink* sink);
+bool specializeModule(
+ TargetProgram* target,
+ IRModule* module,
+ DiagnosticSink* sink,
+ SpecializationOptions options);
void finalizeSpecialization(IRModule* module);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 011ea6bc7..e82fc03fd 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -592,11 +592,21 @@ struct IRGenContext
// The element index if we are inside an `expand` expression.
IRInst* expandIndex = nullptr;
+ // Callback function to call when after lowering a type.
+ std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> lowerTypeCallback =
+ nullptr;
+
explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder)
: shared(inShared), astBuilder(inAstBuilder), env(&inShared->globalEnv), irBuilder(nullptr)
{
}
+ void registerTypeCallback(
+ std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> callback)
+ {
+ lowerTypeCallback = callback;
+ }
+
void setGlobalValue(Decl* decl, LoweredValInfo value) { shared->setGlobalValue(decl, value); }
void setValue(Decl* decl, LoweredValInfo value) { env->mapDeclToValue[decl] = value; }
@@ -2202,7 +2212,12 @@ IRType* lowerType(IRGenContext* context, Type* type)
{
ValLoweringVisitor visitor;
visitor.context = context;
- return (IRType*)getSimpleVal(context, visitor.dispatchType(type));
+ IRType* loweredType = (IRType*)getSimpleVal(context, visitor.dispatchType(type));
+
+ if (context->lowerTypeCallback && loweredType)
+ context->lowerTypeCallback(context, type, loweredType);
+
+ return loweredType;
}
void addVarDecorations(IRGenContext* context, IRInst* inst, Decl* decl)
@@ -8105,6 +8120,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subContextStorage.thisTypeWitness = outerContext->thisTypeWitness;
subContextStorage.returnDestination = LoweredValInfo();
+ subContextStorage.lowerTypeCallback = nullptr;
}
IRBuilder* getBuilder() { return &subBuilderStorage; }
@@ -8629,7 +8645,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric);
// Add `irInterface` to decl mapping now to prevent cyclic lowering.
- context->setValue(decl, LoweredValInfo::simple(finalVal));
+ context->setGlobalValue(decl, LoweredValInfo::simple(finalVal));
subBuilder->setInsertBefore(irInterface);
@@ -8783,7 +8799,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
-
addNameHint(context, irInterface, decl);
addLinkageDecoration(context, irInterface, decl);
if (auto anyValueSizeAttr = decl->findModifier<AnyValueSizeAttribute>())
@@ -9910,6 +9925,48 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
else
outerGeneric = emitOuterGenerics(subContext, decl, decl);
+ // If our function is differentiable, register a callback so the derivative
+ // annotations for types can be lowered.
+ //
+ if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
+ {
+ auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness();
+ OrderedDictionary<DeclRefBase*, SubtypeWitness*> resolveddiffTypeWitnessMap;
+
+ // Go through each entry in the map and resolve the key.
+ for (auto& entry : diffTypeWitnessMap)
+ {
+ auto resolvedKey = as<DeclRefBase>(entry.key->resolve());
+ resolveddiffTypeWitnessMap[resolvedKey] =
+ as<SubtypeWitness>(as<Val>(entry.value)->resolve());
+ }
+
+ subContext->registerTypeCallback(
+ [=](IRGenContext* context, Type* type, IRType* irType)
+ {
+ if (!as<DeclRefType>(type))
+ return irType;
+
+ DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase();
+ if (resolveddiffTypeWitnessMap.containsKey(declRefBase))
+ {
+ auto irWitness =
+ lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val;
+ if (irWitness)
+ {
+ IRInst* args[] = {irType, irWitness};
+ context->irBuilder->emitIntrinsicInst(
+ context->irBuilder->getVoidType(),
+ kIROp_DifferentiableTypeAnnotation,
+ 2,
+ args);
+ }
+ }
+
+ return irType;
+ });
+ }
+
FuncDeclBaseTypeInfo info;
_lowerFuncDeclBaseTypeInfo(
subContext,
@@ -10220,6 +10277,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
+ subContext->registerTypeCallback(nullptr);
+
getBuilder()->addHighLevelDeclDecoration(irFunc, decl);
addSpecializedForTargetDecorations(irFunc, decl);
@@ -10467,16 +10526,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
- if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
- {
- if (decl->body)
- {
- subContext->irBuilder->setInsertInto(irFunc->getParent());
- lowerDifferentiableAttribute(subContext, irFunc, diffAttr);
- subContext->irBuilder->setInsertInto(irFunc);
- }
- }
-
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list