summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-10 12:42:55 -0800
committerGitHub <noreply@github.com>2023-01-10 12:42:55 -0800
commit2f422087ed04940f6b6b351605e61d48ce1989ce (patch)
tree522f8027173732d903a906081238b12863d73fb8 /source
parenteb813fbd8750ed1ab66d73f5fa29ae8f2407e8af (diff)
Nested bwd-diff func call context save/restore. (#2584)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp7
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp48
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h3
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp177
-rw-r--r--source/slang/slang-ir-autodiff-rev.h33
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h28
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp171
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h31
-rw-r--r--source/slang/slang-ir-autodiff.cpp16
-rw-r--r--source/slang/slang-ir-autodiff.h12
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h34
-rw-r--r--source/slang/slang-ir-legalize-types.cpp6
-rw-r--r--source/slang/slang-ir-specialize.cpp55
-rw-r--r--source/slang/slang-ir-specialize.h2
-rw-r--r--source/slang/slang-ir.cpp11
19 files changed, 407 insertions, 239 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7c8e320c4..b8732a67f 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5672,7 +5672,8 @@ namespace Slang
{
// Requirement for backward derivative.
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
- auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)));
+ auto originalFuncType = getFuncType(m_astBuilder, declRef);
+ auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType));
{
auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
cloneModifiers(reqDecl, decl);
@@ -5704,8 +5705,8 @@ namespace Slang
auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>();
cloneModifiers(reqDecl, decl);
FuncType* primalFuncType = m_astBuilder->create<FuncType>();
- primalFuncType->resultType = diffFuncType->resultType;
- primalFuncType->paramTypes.addRange(diffFuncType->paramTypes);
+ primalFuncType->resultType = originalFuncType->resultType;
+ primalFuncType->paramTypes.addRange(originalFuncType->paramTypes);
auto outType = m_astBuilder->getOutType(intermediateType);
primalFuncType->paramTypes.add(outType);
setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 00db6bd96..d50cc45a3 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -387,6 +387,8 @@ Result linkAndOptimizeIR(
finalizeAutoDiffPass(irModule);
+ finalizeSpecialization(irModule);
+
// If we have a target that is GPU like we use the string hashing mechanism
// but for that to work we need to inline such that calls (or returns) of strings
// boil down into getStringHash(stringLiteral)
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 19678f402..e37415446 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -864,7 +864,7 @@ InstPair ForwardDiffTranscriber::transcribeSingleOperandInst(IRBuilder* builder,
IRInst* diffResult = nullptr;
- if (auto diffType = differentiateType(builder, primalType))
+ if (auto diffType = differentiateType(builder, origInst->getDataType()))
{
if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
{
@@ -930,7 +930,33 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
{
if (auto bwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>())
return InstPair(origFunc, bwdDecor->getForwardDerivativeFunc());
+
+ auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
+
+ if (auto outerGen = findOuterGeneric(diffFunc))
+ {
+ IRBuilder subBuilder = *inBuilder;
+ subBuilder.setInsertBefore(origFunc);
+ auto specialized =
+ specializeWithGeneric(subBuilder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc)));
+ subBuilder.addForwardDerivativeDecoration(origFunc, specialized);
+ }
+ else
+ {
+ inBuilder->addForwardDerivativeDecoration(origFunc, diffFunc);
+ }
+
+ FuncBodyTranscriptionTask task;
+ task.type = FuncBodyTranscriptionTaskType::Forward;
+ task.originalFunc = origFunc;
+ task.resultFunc = diffFunc;
+ autoDiffSharedContext->followUpFunctionsToTranscribe.add(task);
+
+ return InstPair(origFunc, diffFunc);
+}
+IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc)
+{
IRBuilder builder = *inBuilder;
IRFunc* primalFunc = origFunc;
@@ -955,17 +981,6 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
newNameSb << "s_fwd_" << originalName;
builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
}
-
- if (auto outerGen = findOuterGeneric(diffFunc))
- {
- auto specialized =
- specializeWithGeneric(builder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc)));
- builder.addForwardDerivativeDecoration(origFunc, specialized);
- }
- else
- {
- builder.addForwardDerivativeDecoration(origFunc, diffFunc);
- }
// Mark the generated derivative function itself as differentiable.
builder.addForwardDifferentiableDecoration(diffFunc);
@@ -975,14 +990,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
{
cloneDecoration(dictDecor, diffFunc);
}
-
- FuncBodyTranscriptionTask task;
- task.type = FuncBodyTranscriptionTaskType::Forward;
- task.originalFunc = primalFunc;
- task.resultFunc = diffFunc;
- autoDiffSharedContext->followUpFunctionsToTranscribe.add(task);
-
- return InstPair(primalFunc, diffFunc);
+ return diffFunc;
}
// Transcribe a function definition.
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 869b25ffd..828916c01 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -81,6 +81,9 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
// Transcribe a generic definition
InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric);
+ // Transcribe a function without marking the result as a decoration.
+ IRFunc* transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);
+
// Create an empty func to represent the transcribed func of `origFunc`.
virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index b6704011c..817534065 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -60,7 +60,15 @@ namespace Slang
{
auto intermediateType = builder->getBackwardDiffIntermediateContextType(func);
auto outType = builder->getOutType(intermediateType);
- return differentiateFunctionTypeImpl(builder, funcType, outType);
+ List<IRType*> paramTypes;
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ paramTypes.add(funcType->getParamType(i));
+ }
+ paramTypes.add(outType);
+ IRFuncType* primalFuncType = builder->getFuncType(
+ paramTypes, funcType->getResultType());
+ return primalFuncType;
}
InstPair BackwardDiffPrimalTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
@@ -210,8 +218,6 @@ namespace Slang
differentiableTypeConformanceContext.setFunc(origFunc);
- primalFunc = origFunc;
-
auto diffFunc = builder.createFunc();
SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
@@ -278,27 +284,65 @@ namespace Slang
builder.setInsertInto(header.differential);
builder.emitBlock();
auto funcType = as<IRFuncType>(header.differential->getDataType());
- List<IRInst*> args;
+ List<IRInst*> primalArgs, propagateArgs;
+ List<IRType*> primalTypes, propagateTypes;
for (UInt i = 0; i < funcType->getParamCount(); i++)
{
auto paramType = funcType->getParamType(i);
- args.add(builder.emitParam(paramType));
+ auto param = builder.emitParam(paramType);
+ if (i != funcType->getParamCount() - 1)
+ {
+ primalArgs.add(param);
+ }
+ propagateArgs.add(param);
+ propagateTypes.add(paramType);
+ }
+
+ // Fetch primal values to use as arguments in primal func call.
+ for (auto& arg : primalArgs)
+ {
+ IRInst* valueType = arg->getDataType();
+ auto inoutType = as<IRPtrTypeBase>(arg->getDataType());
+ if (inoutType)
+ {
+ valueType = inoutType->getValueType();
+ arg = builder.emitLoad(arg);
+ }
+ auto diffPairType = as<IRDifferentialPairType>(valueType);
+ if (!diffPairType) continue;
+ arg = builder.emitDifferentialPairGetPrimal(arg);
}
+
+ for (auto& arg : primalArgs)
+ {
+ primalTypes.add(arg->getFullType());
+ }
+
auto outerGeneric = findOuterGeneric(origFunc);
IRInst* specializedOriginalFunc = origFunc;
if (outerGeneric)
{
specializedOriginalFunc = maybeSpecializeWithGeneric(builder, outerGeneric, findOuterGeneric(header.differential));
}
+
auto intermediateType = builder.getBackwardDiffIntermediateContextType(specializedOriginalFunc);
auto intermediateVar = builder.emitVar(intermediateType);
- auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(builder.getTypeKind(), specializedOriginalFunc);
- auto propagateFunc = builder.emitBackwardDifferentiatePropagateInst(builder.getTypeKind(), specializedOriginalFunc);
- args.add(intermediateVar);
- builder.emitCallInst(builder.getVoidType(), primalFunc, args);
- args.removeLast();
- args.add(builder.emitLoad(intermediateVar));
- builder.emitCallInst(builder.getVoidType(), propagateFunc, args);
+
+ auto origFuncType = as<IRFuncType>(origFunc->getDataType());
+ auto primalFuncType = builder.getFuncType(
+ primalTypes,
+ origFuncType->getResultType());
+ primalArgs.add(intermediateVar);
+ primalTypes.add(builder.getOutType(intermediateType));
+ auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(primalFuncType, specializedOriginalFunc);
+ builder.emitCallInst(origFuncType->getResultType(), primalFunc, primalArgs);
+
+ propagateTypes.add(intermediateType);
+ propagateArgs.add(builder.emitLoad(intermediateVar));
+ auto propagateFuncType = builder.getFuncType(propagateTypes, builder.getVoidType());
+ auto propagateFunc = builder.emitBackwardDifferentiatePropagateInst(propagateFuncType, specializedOriginalFunc);
+ builder.emitCallInst(builder.getVoidType(), propagateFunc, propagateArgs);
+
builder.emitReturn();
return header;
}
@@ -339,98 +383,6 @@ namespace Slang
builder.emitBranch(firstBlock);
}
- void BackwardDiffTranscriberBase::cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
- {
- IRStructType* structType = as<IRStructType>(intermediateType);
- if (!structType)
- {
- auto genType = as<IRGeneric>(intermediateType);
- structType = as<IRStructType>(findGenericReturnVal(genType));
- SLANG_RELEASE_ASSERT(structType);
- }
-
- // Collect fields that are never fetched by reverse func.
- OrderedHashSet<IRStructKey*> fieldsToCleanup;
- for (auto children : structType->getChildren())
- {
- if (auto field = as<IRStructField>(children))
- {
- auto structKey = field->getKey();
- bool usedByRevFunc = false;
- for (auto use = structKey->firstUse; use; use = use->nextUse)
- {
- if (isChildInstOf(use->getUser(), func))
- {
- usedByRevFunc = true;
- break;
- }
- }
- if (!usedByRevFunc)
- {
- List<IRInst*> users;
- for (auto use = structKey->firstUse; use; use = use->nextUse)
- {
- users.add(use->getUser());
- }
- for (auto user : users)
- {
- if (!isChildInstOf(user, primalFunc))
- continue;
- if (auto addr = as<IRFieldAddress>(user))
- {
- if (addr->hasMoreThanOneUse())
- continue;
- if (addr->firstUse)
- {
- if (addr->firstUse->getUser()->getOp() == kIROp_Store)
- {
- addr->firstUse->getUser()->removeAndDeallocate();
- }
- addr->removeAndDeallocate();
- }
- }
- }
-
- bool hasNonTrivialUse = false;
- for (auto use = structKey->firstUse; use; use = use->nextUse)
- {
- switch (use->getUser()->getOp())
- {
- case kIROp_PrimalValueStructKeyDecoration:
- case kIROp_StructField:
- continue;
- default:
- hasNonTrivialUse = true;
- break;
- }
- }
- if (!hasNonTrivialUse)
- {
- fieldsToCleanup.Add(structKey);
- }
- }
- }
- }
-
- // Actually remove fields from struct.
- for (auto children : structType->getChildren())
- {
- if (auto field = as<IRStructField>(children))
- {
- if (fieldsToCleanup.Contains(field->getKey()))
- {
- auto key = field->getKey();
- List<IRInst*> keyUsers;
- for (auto use = key->firstUse; use; use = use->nextUse)
- keyUsers.add(use->getUser());
- for (auto keyUser : keyUsers)
- keyUser->removeAndDeallocate();
- key->removeAndDeallocate();
- }
- }
- }
- }
-
// Transcribe a function definition.
void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc)
{
@@ -442,11 +394,11 @@ namespace Slang
// Generate a temporary forward derivative function as an intermediate step.
IRBuilder tempBuilder = *builder;
tempBuilder.setInsertBefore(diffPropagateFunc);
- IRFunc* fwdDiffFunc = as<IRFunc>(
- fwdDiffTranscriber->transcribeFuncHeader(&tempBuilder, primalFunc).differential);
+ ForwardDiffTranscriber* fwdTranscriber = static_cast<ForwardDiffTranscriber*>(autoDiffSharedContext->transcriberSet.forwardTranscriber);
+ IRFunc* fwdDiffFunc = as<IRFunc>(fwdTranscriber->transcribeFuncHeaderImpl(&tempBuilder, primalFunc));
SLANG_ASSERT(fwdDiffFunc);
- fwdDiffTranscriber->transcribeFunc(&tempBuilder, primalFunc, fwdDiffFunc);
+ fwdTranscriber->transcribeFunc(&tempBuilder, primalFunc, fwdDiffFunc);
// Split first block into a paramter block.
this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));
@@ -473,7 +425,8 @@ namespace Slang
// for that.
//
builder->setInsertInto(diffPropagateFunc->getParent());
- auto tempDiffFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, unzippedFwdDiffFunc));
+ IRCloneEnv subCloneEnv;
+ auto tempDiffFunc = as<IRFunc>(cloneInst(&subCloneEnv, builder, unzippedFwdDiffFunc));
// Move blocks to the diffFunc shell.
{
@@ -496,18 +449,18 @@ namespace Slang
DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr};
diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
+ eliminateDeadCode(diffPropagateFunc);
+
// Extracts the primal computations into its own func, and replace the primal insts
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
- auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffPropagateFunc, unzippedFwdDiffFunc, intermediateType);
+ auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
+ diffPropagateFunc, primalFunc, intermediateType);
// Clean up by deallocating intermediate versions.
tempDiffFunc->removeAndDeallocate();
unzippedFwdDiffFunc->removeAndDeallocate();
fwdDiffFunc->removeAndDeallocate();
-
- eliminateDeadCode(diffPropagateFunc);
- cleanUpUnusedPrimalIntermediate(diffPropagateFunc, extractedPrimalFunc, intermediateType);
// If primal function is nested in a generic, we want to create separate generics for all the associated things
// we have just created.
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 378300789..decbdf150 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -30,7 +30,6 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
Dictionary<IRInst*, IRInst*> orginalToTranscribed;
// References to other passes that for reverse-mode transcription.
- ForwardDiffTranscriber* fwdDiffTranscriber;
DiffTransposePass* diffTransposePass;
DiffPropagationPass* diffPropagationPass;
DiffUnzipPass* diffUnzipPass;
@@ -40,7 +39,11 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
DiffPropagationPass diffPropagationPassStorage;
DiffUnzipPass diffUnzipPassStorage;
- BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType taskType, AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ BackwardDiffTranscriberBase(
+ FuncBodyTranscriptionTaskType taskType,
+ AutoDiffSharedContext* shared,
+ SharedIRBuilder* inSharedBuilder,
+ DiagnosticSink* inSink)
: AutoDiffTranscriberBase(shared, inSharedBuilder, inSink)
, diffTaskType(taskType)
, diffTransposePassStorage(shared)
@@ -49,7 +52,7 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
, diffTransposePass(&diffTransposePassStorage)
, diffPropagationPass(&diffPropagationPassStorage)
, diffUnzipPass(&diffUnzipPassStorage)
- { }
+ {}
// Returns "dp<var-name>" to use as a name hint for parameters.
// If no primal name is available, returns a blank string.
@@ -63,8 +66,6 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
// Puts parameters into their own block.
void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func);
- void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType);
-
// Transcribe a function definition.
virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0;
@@ -103,8 +104,12 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase
{
- BackwardDiffPrimalTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
- : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::BackwardPrimal, shared, inSharedBuilder, inSink)
+ BackwardDiffPrimalTranscriber(
+ AutoDiffSharedContext* shared,
+ SharedIRBuilder* inSharedBuilder,
+ DiagnosticSink* inSink)
+ : BackwardDiffTranscriberBase(
+ FuncBodyTranscriptionTaskType::BackwardPrimal, shared, inSharedBuilder, inSink)
{ }
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
@@ -125,8 +130,15 @@ struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase
struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase
{
- BackwardDiffPropagateTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
- : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::BackwardPropagate, shared, inSharedBuilder, inSink)
+ BackwardDiffPropagateTranscriber(
+ AutoDiffSharedContext* shared,
+ SharedIRBuilder* inSharedBuilder,
+ DiagnosticSink* inSink)
+ : BackwardDiffTranscriberBase(
+ FuncBodyTranscriptionTaskType::BackwardPropagate,
+ shared,
+ inSharedBuilder,
+ inSink)
{ }
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
@@ -153,7 +165,8 @@ struct BackwardDiffTranscriber : BackwardDiffTranscriberBase
AutoDiffSharedContext* shared,
SharedIRBuilder* inSharedBuilder,
DiagnosticSink* inSink)
- : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::Backward, shared, inSharedBuilder, inSink)
+ : BackwardDiffTranscriberBase(
+ FuncBodyTranscriptionTaskType::Backward, shared, inSharedBuilder, inSink)
{ }
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 4aab0f835..c0404e036 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -390,7 +390,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
if (lookupKeyPath.getCount())
{
// `interfaceType` does conform to `IDifferentiable`.
- outWitnessTable = builder->emitExtractExistentialWitnessTable(origType->getOperand(0));
+ outWitnessTable = builder->emitExtractExistentialWitnessTable(lookupPrimalInstIfExists(origType->getOperand(0)));
for (auto node : lookupKeyPath)
{
outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey());
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index 4c3bbe05f..a6b832856 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -67,6 +67,8 @@ struct AutoDiffTranscriberBase
IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst);
+ IRInst* lookupPrimalInstIfExists(IRInst* origInst) { return lookupPrimalInst(origInst, origInst); }
+
bool hasPrimalInst(IRInst* origInst);
IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 436a17a7f..fa9f4ffb2 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -155,12 +155,15 @@ struct DiffTransposePass
//
firstRevDiffBlockMap[revDiffFunc] = revBlockMap[workList[0]];
+ IRInst* retVal = nullptr;
+
for (auto block : workList)
{
// Set dOutParameter as the transpose gradient for the return inst, if any.
if (auto returnInst = as<IRReturn>(block->getTerminator()))
{
this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr));
+ retVal = returnInst->getVal();
}
IRBlock* revBlock = revBlockMap[block];
@@ -187,7 +190,18 @@ struct DiffTransposePass
// There should be no parameters in the first reverse-mode block.
SLANG_ASSERT(terminalRevBlock->getFirstParam() == nullptr);
- subBuilder.emitBranch(terminalRevBlock);
+ auto branch = subBuilder.emitBranch(terminalRevBlock);
+
+ if (!retVal)
+ {
+ retVal = subBuilder.getVoidValue();
+ }
+ else
+ {
+ auto makePair = cast<IRMakeDifferentialPair>(retVal);
+ retVal = makePair->getPrimalValue();
+ }
+ subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal);
}
// Remove fwd-mode blocks.
@@ -498,6 +512,10 @@ struct DiffTransposePass
}
}
+ // The call must have been decorated with the continuation context after splitting.
+ auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>();
+ SLANG_RELEASE_ASSERT(primalContextDecor);
+
auto baseFn = fwdDiffCallee->getBaseFn();
List<IRInst*> args;
@@ -543,8 +561,14 @@ struct DiffTransposePass
args.add(revValue);
argTypes.add(revValue->getDataType());
+ args.add(primalContextDecor->getBackwardDerivativePrimalContextVar());
+ argTypes.add(builder->getOutType(
+ as<IRPtrTypeBase>(
+ primalContextDecor->getBackwardDerivativePrimalContextVar()->getDataType())
+ ->getValueType()));
+
auto revFnType = builder->getFuncType(argTypes, builder->getVoidType());
- auto revCallee = builder->emitBackwardDifferentiateInst(
+ auto revCallee = builder->emitBackwardDifferentiatePropagateInst(
revFnType,
baseFn);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 1496ae60f..43b48aa13 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -7,10 +7,12 @@ namespace Slang
struct ExtractPrimalFuncContext
{
SharedIRBuilder* sharedBuilder;
+ AutoDiffTranscriberBase* backwardPrimalTranscriber;
- void init(SharedIRBuilder* inSharedBuilder)
+ void init(SharedIRBuilder* inSharedBuilder, AutoDiffTranscriberBase* transcriber)
{
sharedBuilder = inSharedBuilder;
+ backwardPrimalTranscriber = transcriber;
}
IRInst* cloneGenericHeader(IRBuilder& builder, IRCloneEnv& cloneEnv, IRGeneric* gen)
@@ -65,14 +67,14 @@ struct ExtractPrimalFuncContext
}
IRInst* generatePrimalFuncType(
- IRGlobalValueWithCode* destFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType)
+ IRGlobalValueWithCode* destFunc, IRGlobalValueWithCode* originalFunc, IRInst*& outIntermediateType)
{
IRBuilder builder(sharedBuilder);
builder.setInsertBefore(destFunc);
IRFuncType* originalFuncType = nullptr;
outIntermediateType = createIntermediateType(destFunc);
- originalFuncType = as<IRFuncType>(fwdFunc->getDataType());
+ originalFuncType = as<IRFuncType>(originalFunc->getDataType());
SLANG_RELEASE_ASSERT(originalFuncType);
List<IRType*> paramTypes;
@@ -231,56 +233,46 @@ struct ExtractPrimalFuncContext
return true;
}
- void storeInst(
- IRBuilder& builder,
- IRInst* inst,
- IRInst* intermediateOutput)
+ IRStructField* addIntermediateContextField(IRInst* type, IRInst* intermediateOutput)
{
IRBuilder genTypeBuilder(sharedBuilder);
- auto ptrStructType = as<IRPtrTypeBase>(intermediateOutput->getDataType() );
+ auto ptrStructType = as<IRPtrTypeBase>(intermediateOutput->getDataType());
SLANG_RELEASE_ASSERT(ptrStructType);
auto structType = as<IRStructType>(ptrStructType->getValueType());
genTypeBuilder.setInsertBefore(structType);
- auto fieldType = inst->getDataType();
+ auto fieldType = type;
SLANG_RELEASE_ASSERT(structType);
auto structKey = genTypeBuilder.createStructKey();
- if (auto nameHint = inst->findDecoration<IRNameHintDecoration>())
- cloneDecoration(nameHint, structKey);
genTypeBuilder.setInsertInto(structType);
- genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType);
- builder.addPrimalValueStructKeyDecoration(inst, structKey);
+ return genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType);
+ }
+
+ void storeInst(
+ IRBuilder& builder,
+ IRInst* inst,
+ IRInst* intermediateOutput)
+ {
+ auto field = addIntermediateContextField(inst->getDataType(), intermediateOutput);
+ auto key = field->getKey();
+ if (auto nameHint = inst->findDecoration<IRNameHintDecoration>())
+ cloneDecoration(nameHint, key);
+ builder.addPrimalValueStructKeyDecoration(inst, key);
builder.emitStore(
builder.emitFieldAddress(
- builder.getPtrType(inst->getFullType()), intermediateOutput, structKey),
+ builder.getPtrType(inst->getFullType()), intermediateOutput, key),
inst);
}
- IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* fwdFunc, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, IRInst*& outIntermediateType)
{
- // Note: this transformation assumes the original func has only one return.
-
IRBuilder builder(sharedBuilder);
IRFunc* func = unzippedFunc;
IRInst* intermediateType = nullptr;
- auto newFuncType = generatePrimalFuncType(unzippedFunc, fwdFunc, intermediateType);
+ auto newFuncType = generatePrimalFuncType(unzippedFunc, originalFunc, intermediateType);
outIntermediateType = intermediateType;
func->setFullType((IRType*)newFuncType);
- // Go through all the insts and preserve the primal blocks.
- // Create a return block to replace all branches into a non-primal block.
- builder.setInsertInto(func);
- auto returnBlock = builder.emitBlock();
- for (auto block : func->getBlocks())
- {
- auto term = block->getTerminator();
- if (auto ret = as<IRReturn>(term))
- {
- insertIntoReturnBlock(builder, ret);
- break;
- }
- }
-
auto paramBlock = func->getFirstBlock();
builder.setInsertInto(paramBlock);
auto oldIntermediateParam = func->getLastParam();
@@ -317,53 +309,76 @@ struct ExtractPrimalFuncContext
builder.setInsertAfter(inst);
storeInst(builder, inst, outIntermediary);
}
+ else if (inst->getOp() == kIROp_Var)
+ {
+ // Always store intermediate context var.
+ if (inst->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
+ {
+ auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary);
+ builder.setInsertBefore(inst);
+ auto fieldAddr = builder.emitFieldAddress(
+ inst->getFullType(), outIntermediary, field->getKey());
+ inst->replaceUsesWith(fieldAddr);
+ builder.addPrimalValueStructKeyDecoration(inst, field->getKey());
+ }
+ }
}
}
- // Go over differential blocks and complete
- for (auto block : diffBlocksList)
+ for (auto block : primalBlocksList)
{
-
- if (block->getFirstParam() == nullptr)
- {
- // If the block does not have any PHI nodes, just remove it and
- // replace all its uses with returnBlock.
-
- // TODO: This invalides the next block in the chain. Make a list first.
- block->replaceUsesWith(returnBlock);
- block->removeAndDeallocate();
- }
- else
+ auto term = block->getTerminator();
+ builder.setInsertBefore(term);
+ if (auto decor = term->findDecoration<IRBackwardDerivativePrimalReturnDecoration>())
{
- // If the block has Phi nodes, we can't directly replace it with
- // `returnBlock`, but we can turn the block into a trivial branch
- // into `returnBlock` to safely preserve the invariants of Phi nodes.
- auto inst = block->getLastParam()->getNextInst();
- for (; inst;)
- {
- auto nextInst = inst->getNextInst();
- inst->removeAndDeallocate();
- inst = nextInst;
- }
-
- builder.setInsertInto(block);
- builder.emitBranch(returnBlock);
+ builder.emitReturn(decor->getBackwardDerivativePrimalReturnValue());
+ term->removeAndDeallocate();
}
}
-
+
List<IRBlock*> unusedBlocks;
for (auto block : func->getBlocks())
{
- if (!block->hasUses() && isDiffInst(block))
+ if (isDiffInst(block))
unusedBlocks.add(block);
}
-
for (auto block : unusedBlocks)
block->removeAndDeallocate();
builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType);
builder.emitStore(outIntermediary, defVal);
+
+ // The primal func will not have the result derivative param (second to last param), so we remove it.
+ auto resultDerivativeParam = func->getLastParam()->getPrevParam();
+ SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses());
+ resultDerivativeParam->removeAndDeallocate();
+
+ // Finally, go through parameters and turn DifferentiablePair<T> back to T.
+ for (auto param : func->getParams())
+ {
+ IRInst* valueType = param->getDataType();
+ auto inoutType = as<IRPtrTypeBase>(param->getDataType());
+ if (inoutType) valueType = inoutType->getValueType();
+ auto diffPairType = as<IRDifferentialPairType>(valueType);
+ if (!diffPairType) continue;
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+
+ auto originalValueType = diffPairType->getValueType();
+
+ // Create a local var to act as the old param.
+ auto tempVar = builder.emitVar(diffPairType);
+ param->replaceUsesWith(tempVar);
+ auto pairValue = builder.emitMakeDifferentialPair(
+ diffPairType,
+ param,
+ backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType));
+ builder.emitStore(tempVar, pairValue);
+
+ // Change the param type to original type.
+ param->setFullType(originalValueType);
+ }
+
return unzippedFunc;
}
};
@@ -386,7 +401,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
}
IRFunc* DiffUnzipPass::extractPrimalFunc(
- IRFunc* func, IRFunc* fwdFunc, IRInst*& intermediateType)
+ IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType)
{
IRBuilder builder(this->autodiffContext->sharedBuilder);
builder.setInsertBefore(func);
@@ -397,15 +412,19 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func));
ExtractPrimalFuncContext context;
- context.init(autodiffContext->sharedBuilder);
+ context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber);
intermediateType = nullptr;
- auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, fwdFunc, intermediateType);
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, intermediateType);
if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>())
{
- auto primalName = String(nameHint->getName()) + "_primal";
- nameHint->setOperand(0, builder.getStringValue(primalName.getUnownedSlice()));
+ nameHint->removeAndDeallocate();
+ }
+ if (auto originalNameHint = originalFunc->findDecoration<IRNameHintDecoration>())
+ {
+ auto primalName = String("s_bwd_primal_") + UnownedStringSlice(originalNameHint->getName());
+ builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice()));
}
// Copy PrimalValueStructKey decorations from primal func.
@@ -429,10 +448,26 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
builder.getPtrType(inst->getDataType()),
intermediateVar,
structKeyDecor->getStructKey());
- auto val = builder.emitLoad(addr);
- inst->replaceUsesWith(val);
+ if (inst->getOp() == kIROp_Var)
+ {
+ // This is a var for intermediate context.
+ inst->replaceUsesWith(addr);
+ }
+ else
+ {
+ // Orindary value.
+ auto val = builder.emitLoad(addr);
+ inst->replaceUsesWith(val);
+ }
instsToRemove.add(inst);
}
+ else if (auto primalCtx = inst->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
+ {
+ if (inst->getOp() == kIROp_Call)
+ {
+ builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst);
+ }
+ }
}
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index f2ce3dc62..ba1e425db 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -8,6 +8,7 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-propagate.h"
+#include "slang-ir-autodiff-transcriber-base.h"
namespace Slang
{
@@ -31,10 +32,10 @@ struct DiffUnzipPass
// might run into an issue here?
IRBlock* firstDiffBlock;
- // Dictionary<IRBlock*, List<IRBlock*>>
-
- DiffUnzipPass(AutoDiffSharedContext* autodiffContext) :
- autodiffContext(autodiffContext), diffTypeContext(autodiffContext)
+ DiffUnzipPass(
+ AutoDiffSharedContext* autodiffContext)
+ : autodiffContext(autodiffContext)
+ , diffTypeContext(autodiffContext)
{ }
IRInst* lookupPrimalInst(IRInst* inst)
@@ -71,9 +72,6 @@ struct DiffUnzipPass
SLANG_ASSERT(unzippedFunc->getFirstBlock() != nullptr);
SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock() != nullptr);
- // Ignore the first block (this is reserved for parameters), start
- // at the second block. (For now, we work with only a single block of insts)
- // TODO: expand to handle multi-block functions later.
IRBlock* firstBlock = unzippedFunc->getFirstBlock()->getNextBlock();
List<IRBlock*> mixedBlocks;
@@ -132,7 +130,7 @@ struct DiffUnzipPass
return unzippedFunc;
}
- IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* fwdFunc, IRInst*& intermediateType);
+ IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType);
bool isRelevantDifferentialPair(IRType* type)
{
@@ -160,6 +158,14 @@ struct DiffUnzipPass
auto fwdCalleeType = as<IRFuncType>(fwdCallee->getDataType());
auto baseFn = fwdCallee->getBaseFn();
+ auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType(
+ primalBuilder, baseFn, as<IRFuncType>(baseFn->getDataType()));
+
+ auto intermediateVar = primalBuilder->emitVar(primalBuilder->getBackwardDiffIntermediateContextType(baseFn));
+ primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar);
+
+ auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn);
+
List<IRInst*> primalArgs;
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
{
@@ -174,6 +180,7 @@ struct DiffUnzipPass
primalArgs.add(arg);
}
}
+ primalArgs.add(intermediateVar);
auto mixedDecoration = mixedCall->findDecoration<IRMixedDifferentialInstDecoration>();
SLANG_ASSERT(mixedDecoration);
@@ -184,8 +191,9 @@ struct DiffUnzipPass
auto primalType = fwdPairResultType->getValueType();
auto diffType = (IRType*) diffTypeContext.getDifferentialForType(&globalBuilder, primalType);
- auto primalVal = primalBuilder->emitCallInst(primalType, baseFn, primalArgs);
-
+ auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs);
+ primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar);
+
List<IRInst*> diffArgs;
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
{
@@ -215,6 +223,7 @@ struct DiffUnzipPass
}
auto newFwdCallee = diffBuilder->emitForwardDifferentiateInst(fwdCalleeType, baseFn);
+
diffBuilder->markInstAsDifferential(newFwdCallee);
auto diffPairVal = diffBuilder->emitCallInst(
@@ -222,6 +231,7 @@ struct DiffUnzipPass
newFwdCallee,
diffArgs);
diffBuilder->markInstAsDifferential(diffPairVal, primalType);
+ diffBuilder->addBackwardDerivativePrimalContextDecoration(diffPairVal, intermediateVar);
auto diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, diffPairVal);
diffBuilder->markInstAsDifferential(diffVal, primalType);
@@ -272,7 +282,6 @@ struct DiffUnzipPass
// Check that we have an unambiguous 'first' differential block.
SLANG_ASSERT(firstDiffBlock);
auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
-
auto pairVal = diffBuilder->emitMakeDifferentialPair(
pairType,
lookupPrimalInst(mixedReturn->getVal()),
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index d23271704..53e2ed0be 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -2,6 +2,7 @@
#include "slang-ir-autodiff-rev.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-pairs.h"
+#include "slang-ir-validate.h"
namespace Slang
{
@@ -405,6 +406,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_BackwardDerivativeIntermediateTypeDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
case kIROp_BackwardDerivativePrimalDecoration:
+ case kIROp_BackwardDerivativePrimalContextDecoration:
+ case kIROp_BackwardDerivativePrimalReturnDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -716,6 +719,10 @@ struct AutoDiffPass : public InstPassBase
autodiffCleanupList.clear();
+#if _DEBUG
+ validateIRModule(module, sink);
+#endif
+
if (!changed)
break;
hasChanges |= changed;
@@ -780,11 +787,14 @@ struct AutoDiffPass : public InstPassBase
forwardTranscriber.pairBuilder = &pairBuilderStorage;
backwardPrimalTranscriber.pairBuilder = &pairBuilderStorage;
- backwardPrimalTranscriber.fwdDiffTranscriber = &forwardTranscriber;
backwardPropagateTranscriber.pairBuilder = &pairBuilderStorage;
- backwardPropagateTranscriber.fwdDiffTranscriber = &forwardTranscriber;
backwardTranscriber.pairBuilder = &pairBuilderStorage;
- backwardTranscriber.fwdDiffTranscriber = &forwardTranscriber;
+
+ // Make the transcribers available to all sub passes via shared context.
+ context->transcriberSet.primalTranscriber = &backwardPrimalTranscriber;
+ context->transcriberSet.propagateTranscriber = &backwardPropagateTranscriber;
+ context->transcriberSet.forwardTranscriber = &forwardTranscriber;
+ context->transcriberSet.backwardTranscriber = &backwardTranscriber;
}
protected:
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 1415618e1..b4b97751f 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -47,6 +47,16 @@ struct FuncBodyTranscriptionTask
IRFunc* resultFunc;
};
+struct AutoDiffTranscriberBase;
+
+struct DiffTranscriberSet
+{
+ AutoDiffTranscriberBase* forwardTranscriber = nullptr;
+ AutoDiffTranscriberBase* primalTranscriber = nullptr;
+ AutoDiffTranscriberBase* propagateTranscriber = nullptr;
+ AutoDiffTranscriberBase* backwardTranscriber = nullptr;
+};
+
struct AutoDiffSharedContext
{
IRModuleInst* moduleInst = nullptr;
@@ -93,6 +103,8 @@ struct AutoDiffSharedContext
List<FuncBodyTranscriptionTask> followUpFunctionsToTranscribe;
+ DiffTranscriberSet transcriberSet;
+
AutoDiffSharedContext(IRModuleInst* inModuleInst);
private:
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index b721f4225..06f8b0e5d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -645,6 +645,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// A `[keepAlive]` decoration marks an instruction that should not be eliminated.
INST(KeepAliveDecoration, keepAlive, 0, 0)
+ /// A `[NoSideEffect]` decoration marks a callee to be side-effect free.
+ INST(NoSideEffectDecoration, noSideEffect, 0, 0)
+
INST(BindExistentialSlotsDecoration, bindExistentialSlots, 0, 0)
/// A `[format(f)]` decoration specifies that the format of an image should be `f`
@@ -737,6 +740,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(BackwardDerivativeIntermediateTypeDecoration, backwardDiffIntermediateTypeReference, 1, 0)
INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0)
+ INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0)
+ INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)
+
/// Used by the auto-diff pass to mark insts that compute
/// a differential value.
INST(DifferentialInstDecoration, diffInstDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index d2a4c7ae3..1ff61a774 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -308,6 +308,7 @@ struct IRRequireGLSLExtensionDecoration : IRDecoration
};
IR_SIMPLE_DECORATION(ReadNoneDecoration)
+IR_SIMPLE_DECORATION(NoSideEffectDecoration)
IR_SIMPLE_DECORATION(EarlyDepthStencilDecoration)
IR_SIMPLE_DECORATION(GloballyCoherentDecoration)
IR_SIMPLE_DECORATION(PreciseDecoration)
@@ -607,6 +608,29 @@ struct IRBackwardDerivativePrimalDecoration : IRDecoration
IRInst* getBackwardDerivativePrimalFunc() { return getOperand(0); }
};
+// Used to associate the restore context var to use in a call to splitted backward propgate function.
+struct IRBackwardDerivativePrimalContextDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_BackwardDerivativePrimalContextDecoration
+ };
+ IR_LEAF_ISA(BackwardDerivativePrimalContextDecoration)
+
+ IRInst* getBackwardDerivativePrimalContextVar() { return getOperand(0); }
+};
+
+struct IRBackwardDerivativePrimalReturnDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_BackwardDerivativePrimalReturnDecoration
+ };
+ IR_LEAF_ISA(BackwardDerivativePrimalReturnDecoration)
+
+ IRInst* getBackwardDerivativePrimalReturnValue() { return getOperand(0); }
+};
+
struct IRBackwardDerivativePropagateDecoration : IRDecoration
{
enum
@@ -3478,6 +3502,11 @@ public:
addDecoration(value, kIROp_BackwardDerivativePrimalDecoration, jvpFn);
}
+ void addBackwardDerivativePrimalReturnDecoration(IRInst* value, IRInst* retVal)
+ {
+ addDecoration(value, kIROp_BackwardDerivativePrimalReturnDecoration, retVal);
+ }
+
void addBackwardDerivativePropagateDecoration(IRInst* value, IRInst* jvpFn)
{
addDecoration(value, kIROp_BackwardDerivativePropagateDecoration, jvpFn);
@@ -3493,6 +3522,11 @@ public:
addDecoration(value, kIROp_BackwardDerivativeIntermediateTypeDecoration, jvpFn);
}
+ void addBackwardDerivativePrimalContextDecoration(IRInst* value, IRInst* ctx)
+ {
+ addDecoration(value, kIROp_BackwardDerivativePrimalContextDecoration, ctx);
+ }
+
void markInstAsDifferential(IRInst* value)
{
addDecoration(value, kIROp_DifferentialInstDecoration, nullptr);
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 5b0afdd12..38503155d 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -1511,6 +1511,10 @@ static LegalVal legalizeMakeStruct(
List<IRInst*> args;
for(UInt aa = 0; aa < argCount; ++aa)
{
+ // Ignore none values.
+ if (legalArgs[aa].flavor == LegalVal::Flavor::none)
+ continue;
+
// Note: we assume that all the arguments
// must be simple here, because otherwise
// the `struct` type with them as fields
@@ -1521,7 +1525,7 @@ static LegalVal legalizeMakeStruct(
return LegalVal::simple(
builder->emitMakeStruct(
legalType.getSimple(),
- argCount,
+ args.getCount(),
args.getBuffer()));
}
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index de970fbca..cbb5ccf09 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -720,14 +720,27 @@ struct SpecializationContext
if (!item) continue;
IRSimpleSpecializationKey key;
bool shouldSkip = false;
- for (UInt i = 1; i < item->getOperandCount(); i++)
+ for (UInt i = 0; i < item->getOperandCount(); i++)
{
if (item->getOperand(i) == nullptr)
{
shouldSkip = true;
break;
}
- key.vals.add(item->getOperand(i));
+ if (item->getOperand(i)->getParent() == nullptr)
+ {
+ shouldSkip = true;
+ break;
+ }
+ if (item->getOperand(i)->getOp() == kIROp_undefined)
+ {
+ shouldSkip = true;
+ break;
+ }
+ if (i > 0)
+ {
+ key.vals.add(item->getOperand(i));
+ }
}
if (shouldSkip)
continue;
@@ -768,10 +781,19 @@ struct SpecializationContext
builder.setInsertInto(dictInst);
for (auto kv : dict)
{
- List<IRInst*> args;
- args.add(kv.Value);
- args.addRange(kv.Key.vals);
- builder.emitIntrinsicInst(nullptr, kIROp_SpecializationDictionaryItem, args.getCount(), args.getBuffer());
+ if (!kv.Value->parent)
+ continue;
+ for (auto keyVal : kv.Key.vals)
+ {
+ if (!keyVal->parent) goto next;
+ }
+ {
+ List<IRInst*> args;
+ args.add(kv.Value);
+ args.addRange(kv.Key.vals);
+ builder.emitIntrinsicInst(nullptr, kIROp_SpecializationDictionaryItem, args.getCount(), args.getBuffer());
+ }
+ next:;
}
}
void writeSpecializationDictionaries()
@@ -2312,6 +2334,27 @@ bool specializeModule(
return context.changed;
}
+void finalizeSpecialization(IRModule* module)
+{
+ for (auto inst : module->getModuleInst()->getChildren())
+ {
+ for (auto decor = inst->getFirstDecoration(); decor; )
+ {
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
+ {
+ case kIROp_ExistentialFuncSpecializationDictionary:
+ case kIROp_ExistentialTypeSpecializationDictionary:
+ case kIROp_GenericSpecializationDictionary:
+ decor->removeAndDeallocate();
+ break;
+ default:
+ break;
+ }
+ decor = next;
+ }
+ }
+}
IRInst* specializeGenericImpl(
IRGeneric* genericVal,
diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h
index 1503c238e..20d65cb67 100644
--- a/source/slang/slang-ir-specialize.h
+++ b/source/slang/slang-ir-specialize.h
@@ -9,4 +9,6 @@ struct IRModule;
bool specializeModule(
IRModule* module);
+void finalizeSpecialization(IRModule* module);
+
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 9e0e328bd..f37a7a1a0 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -6526,7 +6526,7 @@ namespace Slang
// By default, assume that we might have side effects,
// to safely cover all the instructions we haven't had time to think about.
default:
- return true;
+ break;
case kIROp_Call:
{
@@ -6553,7 +6553,7 @@ namespace Slang
return false;
}
}
- return true;
+ break;
// All of the cases for "global values" are side-effect-free.
case kIROp_StructType:
@@ -6665,6 +6665,13 @@ namespace Slang
case kIROp_BackwardDifferentiate:
return false;
}
+
+ // Check if the calle has been marked with a catch-all no-side-effect decoration.
+ if (findDecoration<IRNoSideEffectDecoration>())
+ {
+ return false;
+ }
+ return true;
}
IRModule* IRInst::getModule()