summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-06 13:39:06 -0800
committerGitHub <noreply@github.com>2023-01-06 13:39:06 -0800
commit33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch)
tree318b1669a0e52aabd11f8694de1278ef7dbc0e3b
parente70cbe76ce74769069b7384f5f05c62da1ca45ed (diff)
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func. * Fix. * Download swiftshader with github actions instead of curl on linux. * Fix github action. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--.github/workflows/linux.yml5
-rw-r--r--github_test.sh6
-rw-r--r--source/slang/slang-ast-decl.h15
-rw-r--r--source/slang/slang-ast-val.h15
-rw-r--r--source/slang/slang-check-decl.cpp94
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp30
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp278
-rw-r--r--source/slang/slang-ir-autodiff-rev.h106
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp160
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h2
-rw-r--r--source/slang/slang-ir-autodiff.cpp185
-rw-r--r--source/slang/slang-ir-autodiff.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h23
-rw-r--r--source/slang/slang-ir-insts.h91
-rw-r--r--source/slang/slang-ir-specialize.cpp2
-rw-r--r--source/slang/slang-ir-util.cpp74
-rw-r--r--source/slang/slang-ir-util.h34
-rw-r--r--source/slang/slang-ir.cpp48
-rw-r--r--source/slang/slang-ir.h9
-rw-r--r--source/slang/slang-lower-to-ir.cpp53
-rw-r--r--source/slang/slang-mangle.cpp6
24 files changed, 935 insertions, 311 deletions
diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml
index 9dd8ff672..bf951ed23 100644
--- a/.github/workflows/linux.yml
+++ b/.github/workflows/linux.yml
@@ -20,6 +20,11 @@ jobs:
with:
submodules: 'true'
fetch-depth: '0'
+ - uses: robinraju/release-downloader@v1.7
+ with:
+ latest: true
+ repository: "shader-slang/swiftshader"
+ fileName: "vk_swiftshader_linux_${{matrix.platform}}.zip"
- name: build
run: |
CC=${{matrix.compiler}}
diff --git a/github_test.sh b/github_test.sh
index dc07c6250..fb94c95ac 100644
--- a/github_test.sh
+++ b/github_test.sh
@@ -30,11 +30,7 @@ TARGET=${PLATFORM}-${ARCHITECTURE}
OUTPUTDIR=bin/${TARGET}/${CONFIGURATION}/
if [ "${ARCHITECTURE}" == "x64" -a "${PLATFORM}" != "macosx" ]; then
- LOCATION=$(curl -s https://api.github.com/repos/shader-slang/swiftshader/releases/latest \
- | grep "tag_name" \
- | awk '{print "https://github.com/shader-slang/swiftshader/releases/download/" substr($2, 2, length($2)-3) "/vk_swiftshader_linux_x64.zip"}')
- curl -L -o libswiftshader.zip $LOCATION
- unzip libswiftshader.zip -d $OUTPUTDIR
+ unzip vk_swiftshader_linux_x64.zip -d $OUTPUTDIR
fi
SLANG_TEST=${OUTPUTDIR}slang-test
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index ccbac0286..81a6e3f7d 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -546,6 +546,21 @@ class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl
SLANG_AST_CLASS(BackwardDerivativeRequirementDecl)
};
+class BackwardDerivativePrimalRequirementDecl : public DerivativeRequirementDecl
+{
+ SLANG_AST_CLASS(BackwardDerivativePrimalRequirementDecl)
+};
+
+class BackwardDerivativePropagateRequirementDecl : public DerivativeRequirementDecl
+{
+ SLANG_AST_CLASS(BackwardDerivativePropagateRequirementDecl)
+};
+
+class BackwardDerivativeIntermediateTypeRequirementDecl : public DerivativeRequirementDecl
+{
+ SLANG_AST_CLASS(BackwardDerivativeIntermediateTypeRequirementDecl)
+};
+
bool isInterfaceRequirement(Decl* decl);
InterfaceDecl* findParentInterfaceDecl(Decl* decl);
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 503d63a76..8e5192536 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -521,4 +521,19 @@ class BackwardDifferentiateVal : public DifferentiateVal
SLANG_AST_CLASS(BackwardDifferentiateVal)
};
+class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal
+{
+ SLANG_AST_CLASS(BackwardDifferentiateIntermediateTypeVal)
+};
+
+class BackwardDifferentiatePrimalVal : public DifferentiateVal
+{
+ SLANG_AST_CLASS(BackwardDifferentiatePrimalVal)
+};
+
+class BackwardDifferentiatePropagateVal : public DifferentiateVal
+{
+ SLANG_AST_CLASS(BackwardDifferentiatePropagateVal)
+};
+
} // namespace Slang
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 80bf74e53..7c8e320c4 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2665,10 +2665,28 @@ namespace Slang
}
else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
{
- BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
+ DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
val->func = satisfyingMemberDeclRef;
witnessTable->add(bwdReq, RequirementWitness(val));
}
+ else if (auto primalReq = as<BackwardDerivativePrimalRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePrimalVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(primalReq, RequirementWitness(val));
+ }
+ else if (auto propReq = as<BackwardDerivativePropagateRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePropagateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(propReq, RequirementWitness(val));
+ }
+ else if (auto itypeReq = as<BackwardDerivativeIntermediateTypeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateIntermediateTypeVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(itypeReq, RequirementWitness(val));
+ }
}
witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef));
}
@@ -5652,18 +5670,70 @@ namespace Slang
}
if (decl->hasModifier<BackwardDifferentiableAttribute>())
{
- auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
- cloneModifiers(reqDecl, decl);
+ // Requirement for backward derivative.
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
- auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef));
- setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
- interfaceDecl->members.add(reqDecl);
- reqDecl->parentDecl = interfaceDecl;
-
- auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
- reqRef->referencedDecl = reqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
+ auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)));
+ {
+ auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType);
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
+ }
+ // Requirement for backward derivative intermediate type.
+ auto intermediateTypeReqDecl = m_astBuilder->create<BackwardDerivativeIntermediateTypeRequirementDecl>();
+ auto intermediateType = m_astBuilder->getOrCreateDeclRefType(
+ intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl));
+ {
+ cloneModifiers(intermediateTypeReqDecl, decl);
+ interfaceDecl->members.add(intermediateTypeReqDecl);
+ intermediateTypeReqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = intermediateTypeReqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
+ }
+ // Requirement for backward derivative primal func.
+ {
+ auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ FuncType* primalFuncType = m_astBuilder->create<FuncType>();
+ primalFuncType->resultType = diffFuncType->resultType;
+ primalFuncType->paramTypes.addRange(diffFuncType->paramTypes);
+ auto outType = m_astBuilder->getOutType(intermediateType);
+ primalFuncType->paramTypes.add(outType);
+ setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType);
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
+ }
+ // Requirement for backward derivative propagate func.
+ {
+ auto reqDecl = m_astBuilder->create<BackwardDerivativePropagateRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+ FuncType* propagateFuncType = m_astBuilder->create<FuncType>();
+ propagateFuncType->resultType = diffFuncType->resultType;
+ propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes);
+ propagateFuncType->paramTypes.add(intermediateType);
+ setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType);
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
+ }
+
isDiffFunc = true;
}
if (isDiffFunc)
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index c245701df..19678f402 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -11,8 +11,10 @@
namespace Slang
{
-IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
{
+ SLANG_UNUSED(func);
+
List<IRType*> newParameterTypes;
IRType* diffReturnType;
@@ -330,7 +332,8 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
// If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
// to generate the implementation.
diffCallee = builder->emitForwardDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
+ differentiateFunctionType(
+ builder, primalCallee, as<IRFuncType>(primalCallee->getFullType())),
primalCallee);
}
@@ -615,8 +618,16 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
{
args.add(primalSpecialize->getArg(i));
}
+
+ // A `ForwardDerivative` decoration on an inner func of a generic should always be a `specialize`.
+ auto diffBaseSpecialize = as<IRSpecialize>(diffBase);
+ SLANG_RELEASE_ASSERT(diffBaseSpecialize);
+
+ // Note: this assumes that the generic arguments to specialize the derivative is the same as the
+ // generic args to specialize the primal function. This is true for all of our stdlib functions,
+ // but we may need to rely on more general substitution logic here.
auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>())
@@ -933,6 +944,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
IRType* diffFuncType = this->differentiateFunctionType(
&builder,
+ origFunc,
as<IRFuncType>(origFunc->getFullType()));
diffFunc->setFullType(diffFuncType);
@@ -943,7 +955,17 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
newNameSb << "s_fwd_" << originalName;
builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
}
- builder.addForwardDerivativeDecoration(origFunc, diffFunc);
+
+ if (auto outerGen = findOuterGeneric(diffFunc))
+ {
+ auto specialized =
+ specializeWithGeneric(builder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc)));
+ builder.addForwardDerivativeDecoration(origFunc, specialized);
+ }
+ else
+ {
+ builder.addForwardDerivativeDecoration(origFunc, diffFunc);
+ }
// Mark the generated derivative function itself as differentiable.
builder.addForwardDifferentiableDecoration(diffFunc);
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 22ebf9d95..869b25ffd 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -73,7 +73,7 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeWrapExistential(IRBuilder* builder, IRInst* origInst);
- virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) override;
+ virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
// Transcribe a function definition.
InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index ae9b69f61..b6704011c 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -11,7 +11,7 @@
namespace Slang
{
- IRFuncType* BackwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType)
{
List<IRType*> newParameterTypes;
IRType* diffReturnType;
@@ -46,12 +46,53 @@ namespace Slang
newParameterTypes.add(differentiateType(builder, funcType->getResultType()));
+ if (intermeidateType)
+ {
+ newParameterTypes.add((IRType*)intermeidateType);
+ }
+
diffReturnType = builder->getVoidType();
return builder->getFuncType(newParameterTypes, diffReturnType);
}
+
+ IRFuncType* BackwardDiffPrimalTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
+ {
+ auto intermediateType = builder->getBackwardDiffIntermediateContextType(func);
+ auto outType = builder->getOutType(intermediateType);
+ return differentiateFunctionTypeImpl(builder, funcType, outType);
+ }
+
+ InstPair BackwardDiffPrimalTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ SLANG_UNUSED(builder);
+ SLANG_UNUSED(diffFunc);
+ auto intermediateTypeDecor = primalFunc->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>();
+ SLANG_RELEASE_ASSERT(intermediateTypeDecor);
+ auto primalDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>();
+ return InstPair(primalFunc, primalDecor->getBackwardDerivativePrimalFunc());
+ }
+
+ IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
+ {
+ auto intermediateType = builder->getBackwardDiffIntermediateContextType(func);
+ return differentiateFunctionTypeImpl(builder, funcType, intermediateType);
+ }
+
+ IRFuncType* BackwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
+ {
+ SLANG_UNUSED(func);
+ return differentiateFunctionTypeImpl(builder, funcType, nullptr);
+ }
+
+ InstPair BackwardDiffPropagateTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ IRGlobalValueWithCode* diffPrimalFunc = nullptr;
+ transcribeFuncImpl(builder, primalFunc, diffFunc, diffPrimalFunc);
+ return InstPair(primalFunc, diffFunc);
+ }
- InstPair BackwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst)
+ InstPair BackwardDiffTranscriberBase::transcribeInstImpl(IRBuilder* builder, IRInst* origInst)
{
switch (origInst->getOp())
{
@@ -90,7 +131,7 @@ namespace Slang
// Returns "dp<var-name>" to use as a name hint for parameters.
// If no primal name is available, returns a blank string.
//
- String BackwardDiffTranscriber::makeDiffPairName(IRInst* origVar)
+ String BackwardDiffTranscriberBase::makeDiffPairName(IRInst* origVar)
{
if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
{
@@ -100,47 +141,7 @@ namespace Slang
return String("");
}
-
- // In differential computation, the 'default' differential value is always zero.
- // This is a consequence of differential computing being inherently linear. As a
- // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
- // The current implementation requires that types are defined linearly.
- //
- IRInst* BackwardDiffTranscriber::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
- {
- if (auto diffType = differentiateType(builder, primalType))
- {
- switch (diffType->getOp())
- {
- case kIROp_DifferentialPairType:
- return builder->emitMakeDifferentialPair(
- diffType,
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
- }
- // Since primalType has a corresponding differential type, we can lookup the
- // definition for zero().
- auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
- SLANG_ASSERT(zeroMethod);
-
- auto emptyArgList = List<IRInst*>();
- return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
- }
- else
- {
- if (isScalarIntegerType(primalType))
- {
- return builder->getIntValue(primalType, 0);
- }
-
- getSink()->diagnose(primalType->sourceLoc,
- Diagnostics::internalCompilerError,
- "could not generate zero value for given type");
- return nullptr;
- }
- }
-
- InstPair BackwardDiffTranscriber::transposeBlock(IRBuilder* builder, IRBlock* origBlock)
+ InstPair BackwardDiffTranscriberBase::transposeBlock(IRBuilder* builder, IRBlock* origBlock)
{
IRBuilder subBuilder(builder->getSharedBuilder());
subBuilder.setInsertLoc(builder->getInsertLoc());
@@ -194,10 +195,10 @@ namespace Slang
}
// Create an empty func to represent the transcribed func of `origFunc`.
- InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc)
{
- if (auto bwdDecor = origFunc->findDecoration<IRBackwardDerivativeDecoration>())
- return InstPair(origFunc, bwdDecor->getBackwardDerivativeFunc());
+ if (auto bwdDiffFunc = findExistingDiffFunc(origFunc))
+ return InstPair(origFunc, bwdDiffFunc);
if (!isMarkedForBackwardDifferentiation(origFunc))
return InstPair(nullptr, nullptr);
@@ -216,6 +217,7 @@ namespace Slang
SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
IRType* diffFuncType = this->differentiateFunctionType(
&builder,
+ origFunc,
as<IRFuncType>(origFunc->getFullType()));
diffFunc->setFullType(diffFuncType);
@@ -226,7 +228,18 @@ namespace Slang
newNameSb << "s_bwd_" << originalName;
builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
}
- builder.addBackwardDerivativeDecoration(origFunc, diffFunc);
+
+ if (auto outerGen = findOuterGeneric(diffFunc))
+ {
+ builder.setInsertBefore(origFunc);
+ auto specialized =
+ specializeWithGeneric(builder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc)));
+ addExistingDiffFuncDecor(&builder, origFunc, specialized);
+ }
+ else
+ {
+ addExistingDiffFuncDecor(&builder, origFunc, diffFunc);
+ }
// Mark the generated derivative function itself as differentiable.
builder.addBackwardDifferentiableDecoration(diffFunc);
@@ -237,17 +250,61 @@ namespace Slang
cloneDecoration(dictDecor, diffFunc);
}
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ {
+ auto result = transcribeFuncHeaderImpl(inBuilder, origFunc);
+
FuncBodyTranscriptionTask task;
- task.originalFunc = primalFunc;
- task.resultFunc = diffFunc;
- task.type = FuncBodyTranscriptionTaskType::Backward;
- autoDiffSharedContext->followUpFunctionsToTranscribe.add(task);
+ task.originalFunc = as<IRFunc>(result.primal);
+ task.resultFunc = as<IRFunc>(result.differential);
+ task.type = diffTaskType;
+ if (task.resultFunc)
+ {
+ autoDiffSharedContext->followUpFunctionsToTranscribe.add(task);
+ }
+ return result;
+ }
- return InstPair(primalFunc, diffFunc);
+ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ {
+ auto header = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ if (!header.differential)
+ return header;
+
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertInto(header.differential);
+ builder.emitBlock();
+ auto funcType = as<IRFuncType>(header.differential->getDataType());
+ List<IRInst*> args;
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto paramType = funcType->getParamType(i);
+ args.add(builder.emitParam(paramType));
+ }
+ auto outerGeneric = findOuterGeneric(origFunc);
+ IRInst* specializedOriginalFunc = origFunc;
+ if (outerGeneric)
+ {
+ specializedOriginalFunc = maybeSpecializeWithGeneric(builder, outerGeneric, findOuterGeneric(header.differential));
+ }
+ auto intermediateType = builder.getBackwardDiffIntermediateContextType(specializedOriginalFunc);
+ auto intermediateVar = builder.emitVar(intermediateType);
+ auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(builder.getTypeKind(), specializedOriginalFunc);
+ auto propagateFunc = builder.emitBackwardDifferentiatePropagateInst(builder.getTypeKind(), specializedOriginalFunc);
+ args.add(intermediateVar);
+ builder.emitCallInst(builder.getVoidType(), primalFunc, args);
+ args.removeLast();
+ args.add(builder.emitLoad(intermediateVar));
+ builder.emitCallInst(builder.getVoidType(), propagateFunc, args);
+ builder.emitReturn();
+ return header;
}
// Puts parameters into their own block.
- void BackwardDiffTranscriber::makeParameterBlock(IRBuilder* inBuilder, IRFunc* func)
+ void BackwardDiffTranscriberBase::makeParameterBlock(IRBuilder* inBuilder, IRFunc* func)
{
IRBuilder builder(inBuilder->getSharedBuilder());
@@ -282,7 +339,7 @@ namespace Slang
builder.emitBranch(firstBlock);
}
- void BackwardDiffTranscriber::cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
+ void BackwardDiffTranscriberBase::cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
{
IRStructType* structType = as<IRStructType>(intermediateType);
if (!structType)
@@ -375,22 +432,21 @@ namespace Slang
}
// Transcribe a function definition.
- InstPair BackwardDiffTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
+ void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc)
{
SLANG_ASSERT(primalFunc);
- SLANG_ASSERT(diffFunc);
+ SLANG_ASSERT(diffPropagateFunc);
// Reverse-mode transcription uses 4 separate steps:
// TODO(sai): Fill in documentation.
// Generate a temporary forward derivative function as an intermediate step.
IRBuilder tempBuilder = *builder;
- tempBuilder.setInsertBefore(diffFunc);
- IRFunc* fwdDiffFunc = as<IRFunc>(fwdDiffTranscriber->transcribeFuncHeader(&tempBuilder, (IRFunc*)primalFunc).differential);
+ tempBuilder.setInsertBefore(diffPropagateFunc);
+ IRFunc* fwdDiffFunc = as<IRFunc>(
+ fwdDiffTranscriber->transcribeFuncHeader(&tempBuilder, primalFunc).differential);
SLANG_ASSERT(fwdDiffFunc);
- // Transcribe the body of the primal function into it's linear (fwd-diff) form.
- // TODO(sai): Handle the case when we already have a user-defined fwd-derivative function.
- fwdDiffTranscriber->transcribeFunc(&tempBuilder, primalFunc, as<IRFunc>(fwdDiffFunc));
+ fwdDiffTranscriber->transcribeFunc(&tempBuilder, primalFunc, fwdDiffFunc);
// Split first block into a paramter block.
this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));
@@ -416,7 +472,7 @@ namespace Slang
// only blocks, and right now there's no provision in slang-ir-clone.h
// for that.
//
- builder->setInsertInto(diffFunc->getParent());
+ builder->setInsertInto(diffPropagateFunc->getParent());
auto tempDiffFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, unzippedFwdDiffFunc));
// Move blocks to the diffFunc shell.
@@ -426,37 +482,63 @@ namespace Slang
workList.add(block);
for (auto block : workList)
- block->insertAtEnd(diffFunc);
+ block->insertAtEnd(diffPropagateFunc);
}
// Transpose the first block (parameter block)
- transposeParameterBlock(builder, diffFunc);
+ transposeParameterBlock(builder, diffPropagateFunc);
- builder->setInsertInto(diffFunc);
+ builder->setInsertInto(diffPropagateFunc);
- auto dOutParameter = diffFunc->getLastParam();
+ auto dOutParameter = diffPropagateFunc->getLastParam()->getPrevParam();
// Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr};
- diffTransposePass->transposeDiffBlocksInFunc(diffFunc, info);
+ diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
// Extracts the primal computations into its own func, and replace the primal insts
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
- auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType);
+ auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffPropagateFunc, unzippedFwdDiffFunc, intermediateType);
// Clean up by deallocating intermediate versions.
tempDiffFunc->removeAndDeallocate();
unzippedFwdDiffFunc->removeAndDeallocate();
fwdDiffFunc->removeAndDeallocate();
- eliminateDeadCode(diffFunc);
- cleanUpUnusedPrimalIntermediate(diffFunc, extractedPrimalFunc, intermediateType);
-
- return InstPair(primalFunc, diffFunc);
+ eliminateDeadCode(diffPropagateFunc);
+ cleanUpUnusedPrimalIntermediate(diffPropagateFunc, extractedPrimalFunc, intermediateType);
+
+ // If primal function is nested in a generic, we want to create separate generics for all the associated things
+ // we have just created.
+ auto primalOuterGeneric = findOuterGeneric(primalFunc);
+ IRInst* specializedFunc = nullptr;
+ auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc);
+ auto specializedIntermeidateType = maybeSpecializeWithGeneric(*builder, intermediateTypeGeneric, primalOuterGeneric);
+ builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, specializedIntermeidateType);
+
+ auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc);
+ builder->setInsertBefore(primalFunc);
+
+ if (auto existingDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>())
+ {
+ // If we already created a header for primal func, move the body into the existing primal func header.
+ auto existingPrimalHeader = existingDecor->getBackwardDerivativePrimalFunc();
+ if (auto spec = as<IRSpecialize>(existingPrimalHeader))
+ existingPrimalHeader = spec->getBase();
+ moveInstChildren(existingPrimalHeader, primalFuncGeneric);
+ primalFuncGeneric->replaceUsesWith(existingPrimalHeader);
+ primalFuncGeneric->removeAndDeallocate();
+ }
+ else
+ {
+ auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric);
+ builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc);
+ }
+ diffPrimalFunc = as<IRGlobalValueWithCode>(primalOuterGeneric);
}
- void BackwardDiffTranscriber::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc)
+ void BackwardDiffTranscriberBase::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc)
{
IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock();
@@ -499,16 +581,19 @@ namespace Slang
auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
// 2. Add a parameter for 'derivative of the output' (d_out).
- // The type is the last parameter type of the function.
+ // The type is the second last parameter type of the function.
//
- auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1);
+ auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2);
SLANG_ASSERT(dOutParamType);
builder->emitParam(dOutParamType);
+
+ // Add a parameter for intermediate val.
+ builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
}
- IRInst* BackwardDiffTranscriber::copyParam(IRBuilder* builder, IRParam* origParam)
+ IRInst* BackwardDiffTranscriberBase::copyParam(IRBuilder* builder, IRParam* origParam)
{
auto primalDataType = origParam->getDataType();
@@ -533,7 +618,7 @@ namespace Slang
return cloneInst(&cloneEnv, builder, origParam);
}
- InstPair BackwardDiffTranscriber::copyBinaryArith(IRBuilder* builder, IRInst* origArith)
+ InstPair BackwardDiffTranscriberBase::copyBinaryArith(IRBuilder* builder, IRInst* origArith)
{
SLANG_ASSERT(origArith->getOperandCount() == 2);
@@ -577,7 +662,7 @@ namespace Slang
return InstPair(newInst, nullptr);
}
- IRInst* BackwardDiffTranscriber::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
+ IRInst* BackwardDiffTranscriberBase::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
{
SLANG_ASSERT(origArith->getOperandCount() == 2);
@@ -645,7 +730,7 @@ namespace Slang
return nullptr;
}
- InstPair BackwardDiffTranscriber::copyInst(IRBuilder* builder, IRInst* origInst)
+ InstPair BackwardDiffTranscriberBase::copyInst(IRBuilder* builder, IRInst* origInst)
{
// Handle common SSA-style operations
switch (origInst->getOp())
@@ -670,7 +755,7 @@ namespace Slang
return InstPair(nullptr, nullptr);
}
- IRInst* BackwardDiffTranscriber::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
+ IRInst* BackwardDiffTranscriberBase::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
{
IRInOutType* inoutParam = as<IRInOutType>(param->getDataType());
auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType());
@@ -687,7 +772,7 @@ namespace Slang
return store;
}
- IRInst* BackwardDiffTranscriber::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
+ IRInst* BackwardDiffTranscriberBase::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
{
// Handle common SSA-style operations
switch (origInst->getOp())
@@ -727,7 +812,7 @@ namespace Slang
return nullptr;
}
- InstPair BackwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+ InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
{
auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
List<IRInst*> primalArgs;
@@ -739,8 +824,7 @@ namespace Slang
auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst(
(IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer());
- IRInst* diffBase = nullptr;
- if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase))
+ if (auto diffBase = instMapD.TryGetValue(origSpecialize->getBase()))
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
@@ -748,7 +832,7 @@ namespace Slang
args.add(primalSpecialize->getArg(i));
}
auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ builder->getTypeKind(), *diffBase, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
@@ -757,25 +841,31 @@ namespace Slang
// (Normally, this would be on the inner IRFunc, but in this case only the JVP func
// can be specialized, so we put a decoration on the IRSpecialize)
//
- if (auto backDecor = origSpecialize->findDecoration<IRBackwardDerivativeDecoration>())
+ if (auto derivativeFunc = findExistingDiffFunc(origSpecialize))
{
- auto derivativeFunc = backDecor->getBackwardDerivativeFunc();
-
// Make sure this isn't itself a specialize .
SLANG_RELEASE_ASSERT(!as<IRSpecialize>(derivativeFunc));
return InstPair(primalSpecialize, derivativeFunc);
}
- else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRBackwardDerivativeDecoration>())
+ else if (auto diffBase = findExistingDiffFunc(genericInnerVal))
{
- diffBase = derivativeDecoration->getBackwardDerivativeFunc();
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
{
args.add(primalSpecialize->getArg(i));
}
+
+ // A `BackwardDerivative` decoration on an inner func of a generic should always be a `specialize`.
+ auto diffBaseSpecialize = as<IRSpecialize>(diffBase);
+ SLANG_RELEASE_ASSERT(diffBaseSpecialize);
+
+ // Note: this assumes that the generic arguments to specialize the derivative is the same as the
+ // generic args to specialize the primal function. This is true for all of our stdlib functions,
+ // but we may need to rely on more general substitution logic here.
auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer());
+
return InstPair(primalSpecialize, diffSpecialize);
}
else if (auto diffDecor = genericInnerVal->findDecoration<IRBackwardDifferentiableDecoration>())
@@ -785,9 +875,9 @@ namespace Slang
{
args.add(primalSpecialize->getArg(i));
}
- diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase());
+ auto diffCallee = findOrTranscribeDiffInst(builder, origSpecialize->getBase());
auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ builder->getTypeKind(), diffCallee, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
else
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index f9ca6110c..378300789 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -20,8 +20,10 @@ struct IRReverseDerivativePassOptions
// Nothing for now..
};
-struct BackwardDiffTranscriber : AutoDiffTranscriberBase
+struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
{
+ FuncBodyTranscriptionTaskType diffTaskType;
+
// Map that stores the upper gradient given an IRInst*
Dictionary<IRInst*, List<IRInst*>> upperGradients;
Dictionary<IRInst*, IRInst*> primalToDiffPair;
@@ -38,8 +40,9 @@ struct BackwardDiffTranscriber : AutoDiffTranscriberBase
DiffPropagationPass diffPropagationPassStorage;
DiffUnzipPass diffUnzipPassStorage;
- BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType taskType, AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
: AutoDiffTranscriberBase(shared, inSharedBuilder, inSink)
+ , diffTaskType(taskType)
, diffTransposePassStorage(shared)
, diffPropagationPassStorage(shared)
, diffUnzipPassStorage(shared)
@@ -52,13 +55,8 @@ struct BackwardDiffTranscriber : AutoDiffTranscriberBase
// If no primal name is available, returns a blank string.
//
String makeDiffPairName(IRInst* origVar);
-
- // In differential computation, the 'default' differential value is always zero.
- // This is a consequence of differential computing being inherently linear. As a
- // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
- // The current implementation requires that types are defined linearly.
- //
- IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
+
+ IRFuncType* differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermediateType);
InstPair transposeBlock(IRBuilder* builder, IRBlock* origBlock);
@@ -68,7 +66,7 @@ struct BackwardDiffTranscriber : AutoDiffTranscriberBase
void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType);
// Transcribe a function definition.
- InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc);
+ virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0;
void transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc);
@@ -86,18 +84,98 @@ struct BackwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
- // Create an empty func to represent the transcribed func of `origFunc`.
- virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
+ void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc);
- virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) override;
+ InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);
+
+ virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override;
+ virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) = 0;
+ virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) = 0;
+
virtual IROp getDifferentiableMethodDictionaryItemOp() override
{
- return kIROp_ForwardDifferentiableMethodRequirementDictionaryItem;
+ return kIROp_BackwardDifferentiableMethodRequirementDictionaryItem;
+ }
+};
+
+struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase
+{
+ BackwardDiffPrimalTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::BackwardPrimal, shared, inSharedBuilder, inSink)
+ { }
+
+ virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
+ virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override;
+ virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override
+ {
+ if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>())
+ {
+ return backDecor->getBackwardDerivativePrimalFunc();
+ }
+ return nullptr;
+ }
+ virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override
+ {
+ builder->addBackwardDerivativePrimalDecoration(inst, diffFunc);
+ }
+};
+
+struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase
+{
+ BackwardDiffPropagateTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::BackwardPropagate, shared, inSharedBuilder, inSink)
+ { }
+
+ virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
+ virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override;
+ virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override
+ {
+ if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativePropagateDecoration>())
+ {
+ return backDecor->getBackwardDerivativePropagateFunc();
+ }
+ return nullptr;
}
+ virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override
+ {
+ builder->addBackwardDerivativePropagateDecoration(inst, diffFunc);
+ }
+};
+
+// A backward derivative function combines both primal + propagate functions and accepts no
+// intermediate value input.
+struct BackwardDiffTranscriber : BackwardDiffTranscriberBase
+{
+ BackwardDiffTranscriber(
+ AutoDiffSharedContext* shared,
+ SharedIRBuilder* inSharedBuilder,
+ DiagnosticSink* inSink)
+ : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::Backward, shared, inSharedBuilder, inSink)
+ { }
+ virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
+ virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
+ virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override
+ {
+ SLANG_UNUSED(builder);
+ // Don't need to do anything here, the body is generated in transcribeFuncHeader.
+ return InstPair(primalFunc, diffFunc);
+ }
+ virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override
+ {
+ if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativeDecoration>())
+ {
+ return backDecor->getBackwardDerivativeFunc();
+ }
+ return nullptr;
+ }
+ virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override
+ {
+ builder->addBackwardDerivativeDecoration(inst, diffFunc);
+ }
};
}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 69cef941c..4aab0f835 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -259,7 +259,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
}
case kIROp_FuncType:
- return differentiateFunctionType(builder, as<IRFuncType>(primalType));
+ return differentiateFunctionType(builder, nullptr, as<IRFuncType>(primalType));
case kIROp_OutType:
if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
@@ -436,7 +436,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o
{
auto primalDataType = findOrTranscribePrimalInst(builder, origParam->getDataType());
// Do not differentiate generic type (and witness table) parameters
- if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
+ if (isGenericParam(origParam))
{
return InstPair(
cloneInst(&cloneEnv, builder, origParam),
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index 8e4b7a901..4c3bbe05f 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -116,7 +116,7 @@ struct AutoDiffTranscriberBase
IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType);
- virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) = 0;
+ virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) = 0;
// Create an empty func to represent the transcribed func of `origFunc`.
virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) = 0;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 2fd53dbd0..1496ae60f 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -72,23 +72,11 @@ struct ExtractPrimalFuncContext
IRFuncType* originalFuncType = nullptr;
outIntermediateType = createIntermediateType(destFunc);
- if (auto gen = as<IRGeneric>(destFunc))
- {
- auto func = findGenericReturnVal(gen);
- builder.setInsertBefore(func);
- outIntermediateType =
- specializeWithGeneric(builder, outIntermediateType, gen);
- SLANG_RELEASE_ASSERT(func);
- originalFuncType = as<IRFuncType>(as<IRGeneric>(fwdFunc)->getDataType());
- }
- else
- {
- originalFuncType = as<IRFuncType>(fwdFunc->getDataType());
- }
+ originalFuncType = as<IRFuncType>(fwdFunc->getDataType());
SLANG_RELEASE_ASSERT(originalFuncType);
List<IRType*> paramTypes;
- for (UInt i = 0; i < originalFuncType->getParamCount(); i++)
+ for (UInt i = 0; i < originalFuncType->getParamCount() - 1; i++)
paramTypes.add(originalFuncType->getParamType(i));
paramTypes.add(builder.getInOutType((IRType*)outIntermediateType));
auto newFuncType = builder.getFuncType(paramTypes, builder.getVoidType());
@@ -243,75 +231,9 @@ struct ExtractPrimalFuncContext
return true;
}
- // Given a `genericA<Param1, Param1,...> { instX(Param1, Param2) }`,
- // and a clone of it `genericB<ParamB_1, ParamB_2,...> { }`.
- // `GenericChildrenMigrationContext(genericA, genericB)::getCorrespondingInst(instX)`
- // returns a clone of `instX` in `genericB` that references the new generic params
- // as `instX_clone` in `genericB<ParamB_1, ParamB_2,...> { instX_clone(ParamB_1, ParamB_2) }`.
- struct GenericChildrenMigrationContext
- {
- IRCloneEnv cloneEnv;
- IRGeneric* oldGeneric = nullptr;
- IRGeneric* newGeneric = nullptr;
- IRInst* newGenericRetVal = nullptr;
-
- void init(IRGeneric* oldGen, IRGeneric* newGen)
- {
- oldGeneric = oldGen;
- newGeneric = newGen;
- newGenericRetVal = findGenericReturnVal(newGen);
-
- IRInst* oldParam = oldGen->getFirstParam();
- IRInst* newParam = newGen->getFirstParam();
- while (oldParam)
- {
- oldParam = as<IRParam>(oldParam->getNextInst());
- newParam = as<IRParam>(newParam->getNextInst());
- if (!oldParam)
- {
- SLANG_RELEASE_ASSERT(!newParam);
- break;
- }
- SLANG_RELEASE_ASSERT(newParam);
- cloneEnv.mapOldValToNew[oldParam] = newParam;
- }
- }
- IRInst* getCorrespondingInst(IRBuilder& builder, IRInst* oldChild)
- {
- if (!oldGeneric)
- return oldChild;
- auto parent = oldChild->getParent();
- bool found = false;
- while (parent)
- {
- if (parent == oldGeneric)
- {
- found = true;
- break;
- }
- parent = parent->getParent();
- }
- if (!found)
- return oldChild;
- for (UInt i = 0; i < oldChild->getOperandCount(); i++)
- {
- auto operand = oldChild->getOperand(i);
- if (cloneEnv.mapOldValToNew.ContainsKey(operand))
- {}
- else
- {
- getCorrespondingInst(builder, operand);
- }
- }
- auto cloned = cloneInst(&cloneEnv, &builder, oldChild);
- return cloned;
- }
- };
-
void storeInst(
IRBuilder& builder,
IRInst* inst,
- GenericChildrenMigrationContext& genericContext,
IRInst* intermediateOutput)
{
IRBuilder genTypeBuilder(sharedBuilder);
@@ -319,7 +241,7 @@ struct ExtractPrimalFuncContext
SLANG_RELEASE_ASSERT(ptrStructType);
auto structType = as<IRStructType>(ptrStructType->getValueType());
genTypeBuilder.setInsertBefore(structType);
- auto fieldType = genericContext.getCorrespondingInst(genTypeBuilder, inst->getDataType());
+ auto fieldType = inst->getDataType();
SLANG_RELEASE_ASSERT(structType);
auto structKey = genTypeBuilder.createStructKey();
if (auto nameHint = inst->findDecoration<IRNameHintDecoration>())
@@ -333,30 +255,16 @@ struct ExtractPrimalFuncContext
inst);
}
- IRGlobalValueWithCode* turnUnzippedFuncIntoPrimalFunc(IRGlobalValueWithCode* unzippedFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* fwdFunc, IRInst*& outIntermediateType)
{
// Note: this transformation assumes the original func has only one return.
IRBuilder builder(sharedBuilder);
- IRFunc* func = nullptr;
+ IRFunc* func = unzippedFunc;
IRInst* intermediateType = nullptr;
auto newFuncType = generatePrimalFuncType(unzippedFunc, fwdFunc, intermediateType);
- if (auto gen = as<IRGeneric>(unzippedFunc))
- {
- func = as<IRFunc>(findGenericReturnVal(gen));
- SLANG_RELEASE_ASSERT(func);
- builder.setInsertBefore(func);
- auto spec = as<IRSpecialize>(intermediateType);
- SLANG_RELEASE_ASSERT(spec);
- outIntermediateType = spec->getBase();
- }
- else
- {
- func = as<IRFunc>(unzippedFunc);
- SLANG_RELEASE_ASSERT(func);
- outIntermediateType = intermediateType;
- }
+ outIntermediateType = intermediateType;
func->setFullType((IRType*)newFuncType);
// Go through all the insts and preserve the primal blocks.
@@ -375,19 +283,14 @@ struct ExtractPrimalFuncContext
auto paramBlock = func->getFirstBlock();
builder.setInsertInto(paramBlock);
+ auto oldIntermediateParam = func->getLastParam();
auto outIntermediary =
builder.emitParam(builder.getInOutType((IRType*)intermediateType));
+ oldIntermediateParam->replaceUsesWith(outIntermediary);
+ oldIntermediateParam->removeAndDeallocate();
auto firstBlock = *(paramBlock->getSuccessors().begin());
- GenericChildrenMigrationContext genericMigrationContext;
- if (auto gen = as<IRGeneric>(unzippedFunc))
- {
- auto spec = as<IRSpecialize>(intermediateType);
- SLANG_RELEASE_ASSERT(spec);
- genericMigrationContext.init(gen, as<IRGeneric>(spec->getBase()));
- }
-
List<IRBlock*> diffBlocksList;
List<IRBlock*> primalBlocksList;
@@ -412,7 +315,7 @@ struct ExtractPrimalFuncContext
if (shouldStoreInst(inst))
{
builder.setInsertAfter(inst);
- storeInst(builder, inst, genericMigrationContext, outIntermediary);
+ storeInst(builder, inst, outIntermediary);
}
}
}
@@ -482,8 +385,8 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
}
}
-IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc(
- IRGlobalValueWithCode* func, IRGlobalValueWithCode* fwdFunc, IRInst*& intermediateType)
+IRFunc* DiffUnzipPass::extractPrimalFunc(
+ IRFunc* func, IRFunc* fwdFunc, IRInst*& intermediateType)
{
IRBuilder builder(this->autodiffContext->sharedBuilder);
builder.setInsertBefore(func);
@@ -491,46 +394,31 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc(
IRCloneEnv subEnv;
subEnv.squashChildrenMapping = true;
subEnv.parent = &cloneEnv;
- auto clonedFunc = as<IRGlobalValueWithCode>(cloneInst(&subEnv, &builder, func));
+ auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func));
ExtractPrimalFuncContext context;
context.init(autodiffContext->sharedBuilder);
intermediateType = nullptr;
auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, fwdFunc, intermediateType);
- IRInst* specializedPrimalFunc = primalFunc;
-
- // Copy PrimalValueStructKey decorations from primal func.
- copyPrimalValueStructKeyDecorations(func, subEnv);
-
- IRInst* specializedIntermediateType = intermediateType;
- auto innerFunc = as<IRFunc>(func);
- if (auto genFunc = as<IRGeneric>(func))
+ if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>())
{
- innerFunc = as<IRFunc>(findGenericReturnVal(genFunc));
- builder.setInsertBefore(innerFunc);
- specializedIntermediateType = specializeWithGeneric(builder, intermediateType, genFunc);
- specializedPrimalFunc = specializeWithGeneric(builder, primalFunc, genFunc);
+ auto primalName = String(nameHint->getName()) + "_primal";
+ nameHint->setOperand(0, builder.getStringValue(primalName.getUnownedSlice()));
}
- SLANG_RELEASE_ASSERT(innerFunc);
- // Insert a call to primal func at start of the function.
- auto paramBlock = innerFunc->getFirstBlock();
+ // Copy PrimalValueStructKey decorations from primal func.
+ copyPrimalValueStructKeyDecorations(func, subEnv);
+
+ auto paramBlock = func->getFirstBlock();
auto firstBlock = *(paramBlock->getSuccessors().begin());
builder.setInsertBefore(firstBlock->getFirstInst());
- auto intermediateVar = builder.emitVar((IRType*)specializedIntermediateType);
- List<IRInst*> args;
- for (auto param : paramBlock->getParams())
- {
- args.add(param);
- }
- args.add(intermediateVar);
- builder.emitCallInst(innerFunc->getResultType(), specializedPrimalFunc, args);
+ auto intermediateVar = func->getLastParam();
// Replace all insts that has intermediate results with a load of the intermediate.
List<IRInst*> instsToRemove;
- for (auto block : innerFunc->getBlocks())
+ for (auto block : func->getBlocks())
{
for (auto inst : block->getOrdinaryInsts())
{
@@ -554,8 +442,8 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc(
}
// Run simplification to DCE unnecessary insts.
- eliminateDeadCode(innerFunc);
- eliminateDeadCode(specializedPrimalFunc);
+ eliminateDeadCode(func);
+ eliminateDeadCode(primalFunc);
return primalFunc;
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 2c55b390b..f2ce3dc62 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -132,7 +132,7 @@ struct DiffUnzipPass
return unzippedFunc;
}
- IRGlobalValueWithCode* extractPrimalFunc(IRGlobalValueWithCode* func, IRGlobalValueWithCode* fwdFunc, IRInst*& intermediateType);
+ IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* fwdFunc, IRInst*& intermediateType);
bool isRelevantDifferentialPair(IRType* type)
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 40c24d11d..d23271704 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -401,6 +401,10 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_DifferentiableTypeDictionaryDecoration:
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeIntermediateTypeDecoration:
+ case kIROp_BackwardDerivativePropagateDecoration:
+ case kIROp_BackwardDerivativePrimalDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -489,7 +493,7 @@ struct AutoDiffPass : public InstPassBase
// TODO(sai): Move this call.
forwardTranscriber.differentiableTypeConformanceContext.buildGlobalWitnessDictionary();
- IRBuilder builderStorage(this->autodiffContext->sharedBuilder);
+ IRBuilder builderStorage(&sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
// Process all ForwardDifferentiate and BackwardDifferentiate instructions by
@@ -500,6 +504,81 @@ struct AutoDiffPass : public InstPassBase
return modified;
}
+ IRInst* processIntermediateContextTypeBase(IRBuilder* builder, IRInst* base)
+ {
+ if (auto spec = as<IRSpecialize>(base))
+ {
+ List<IRInst*> args;
+ auto subBase = processIntermediateContextTypeBase(builder, spec->getBase());
+ for (UInt a = 0; a < spec->getArgCount(); a++)
+ args.add(spec->getArg(a));
+ auto actualType = builder->emitSpecializeInst(
+ builder->getTypeKind(),
+ subBase,
+ args.getCount(),
+ args.getBuffer());
+ return actualType;
+ }
+ else if (auto baseGeneric = as<IRGeneric>(base))
+ {
+ auto inner = findGenericReturnVal(baseGeneric);
+ if (auto typeDecor = inner->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
+ {
+ auto typeSpec = cast<IRSpecialize>(typeDecor->getBackwardDerivativeIntermediateType());
+ auto typeSpecBase = typeSpec->getBase();
+ return typeSpecBase;
+ }
+ }
+ else if (auto func = as<IRFunc>(base))
+ {
+ if (auto typeDecor = func->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
+ {
+ return typeDecor->getBackwardDerivativeIntermediateType();
+ }
+ }
+ else if (auto lookup = as<IRLookupWitnessMethod>(base))
+ {
+ auto key = lookup->getRequirementKey();
+ if (auto typeDecor = key->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
+ {
+ auto typeKey = typeDecor->getBackwardDerivativeIntermediateType();
+ auto typeLookup = builder->emitLookupInterfaceMethodInst(builder->getTypeKind(), lookup->getWitnessTable(), typeKey);
+ return typeLookup;
+ }
+ }
+ return nullptr;
+ }
+
+ bool lowerIntermediateContextType(IRBuilder* builder)
+ {
+ bool changed = false;
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_BackwardDiffIntermediateContextType:
+ {
+ auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst);
+
+ auto baseFunc = differentiateInst->getOperand(0);
+ IRBuilder subBuilder = *builder;
+ subBuilder.setInsertBefore(inst);
+ auto type = processIntermediateContextTypeBase(&subBuilder, baseFunc);
+ if (type)
+ {
+ inst->replaceUsesWith(type);
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ });
+ return changed;
+ }
+
// Process all differentiate calls, and recursively generate code for forward and backward
// derivative functions.
//
@@ -518,6 +597,9 @@ struct AutoDiffPass : public InstPassBase
{
case kIROp_ForwardDifferentiate:
case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ case kIROp_BackwardDiffIntermediateContextType:
// Only process now if the operand is a materialized function.
switch (inst->getOperand(0)->getOp())
{
@@ -538,29 +620,49 @@ struct AutoDiffPass : public InstPassBase
// Process collected differentiate insts and replace them with placeholders for
// differentiated functions.
- for (auto differentiateInst : autoDiffWorkList)
+ for (Index i = 0; i < autoDiffWorkList.getCount(); i++)
{
- if (auto diffInst = as<IRForwardDifferentiate>(differentiateInst))
+ auto differentiateInst = autoDiffWorkList[i];
+
+ IRInst* diffFunc = nullptr;
+ IRBuilder subBuilder(*builder);
+ subBuilder.setInsertBefore(differentiateInst);
+ switch (differentiateInst->getOp())
{
- IRBuilder subBuilder(*builder);
- subBuilder.setInsertBefore(differentiateInst);
- if (auto diffFunc = forwardTranscriber.transcribe(&subBuilder, diffInst->getBaseFn()))
+ case kIROp_ForwardDifferentiate:
{
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- changed = true;
+ auto baseFunc = as<IRForwardDifferentiate>(differentiateInst)->getBaseFn();
+ diffFunc = forwardTranscriber.transcribe(&subBuilder, baseFunc);
}
- }
- else if (auto backDiffInst = as<IRBackwardDifferentiate>(differentiateInst))
- {
- auto baseInst = backDiffInst->getBaseFn();
- if (auto diffFunc = backwardTranscriber.transcribe(builder, (IRFunc*)baseInst))
+ break;
+ case kIROp_BackwardDifferentiatePrimal:
+ {
+ auto baseFunc = differentiateInst->getOperand(0);
+ diffFunc = backwardPrimalTranscriber.transcribe(&subBuilder, baseFunc);
+ }
+ break;
+ case kIROp_BackwardDifferentiatePropagate:
{
- SLANG_ASSERT(diffFunc);
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- changed = true;
+ auto baseFunc = differentiateInst->getOperand(0);
+ diffFunc = backwardPropagateTranscriber.transcribe(&subBuilder, baseFunc);
}
+ break;
+ case kIROp_BackwardDifferentiate:
+ {
+ auto baseFunc = differentiateInst->getOperand(0);
+ diffFunc = backwardTranscriber.transcribe(&subBuilder, baseFunc);
+ }
+ break;
+ default:
+ break;
+ }
+
+ if (diffFunc)
+ {
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ changed = true;
}
}
@@ -591,8 +693,11 @@ struct AutoDiffPass : public InstPassBase
case FuncBodyTranscriptionTaskType::Forward:
forwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
break;
- case FuncBodyTranscriptionTaskType::Backward:
- backwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
+ case FuncBodyTranscriptionTaskType::BackwardPrimal:
+ // Don't need to do anything, they will be filled by `backwardPropagateTranscriber`.
+ break;
+ case FuncBodyTranscriptionTaskType::BackwardPropagate:
+ backwardPropagateTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
break;
default:
break;
@@ -616,6 +721,11 @@ struct AutoDiffPass : public InstPassBase
hasChanges |= changed;
}
+ if (lowerIntermediateContextType(builder))
+ {
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ hasChanges = true;
+ }
return hasChanges;
}
@@ -651,12 +761,28 @@ struct AutoDiffPass : public InstPassBase
AutoDiffPass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
InstPassBase(context->moduleInst->getModule()),
sink(sink),
- forwardTranscriber(context, context->sharedBuilder, sink),
- backwardTranscriber(context, context->sharedBuilder, sink),
+ forwardTranscriber(context, &sharedBuilderStorage, sink),
+ backwardPrimalTranscriber(context, &sharedBuilderStorage, sink),
+ backwardPropagateTranscriber(context, &sharedBuilderStorage, sink),
+ backwardTranscriber(context, &sharedBuilderStorage, sink),
pairBuilderStorage(context),
autodiffContext(context)
{
+
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ sharedBuilderStorage.init(module);
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+
+ context->sharedBuilder = &sharedBuilderStorage;
+
forwardTranscriber.pairBuilder = &pairBuilderStorage;
+ backwardPrimalTranscriber.pairBuilder = &pairBuilderStorage;
+ backwardPrimalTranscriber.fwdDiffTranscriber = &forwardTranscriber;
+ backwardPropagateTranscriber.pairBuilder = &pairBuilderStorage;
+ backwardPropagateTranscriber.fwdDiffTranscriber = &forwardTranscriber;
backwardTranscriber.pairBuilder = &pairBuilderStorage;
backwardTranscriber.fwdDiffTranscriber = &forwardTranscriber;
}
@@ -667,8 +793,13 @@ protected:
//
ForwardDiffTranscriber forwardTranscriber;
+ BackwardDiffPrimalTranscriber backwardPrimalTranscriber;
+
+ BackwardDiffPropagateTranscriber backwardPropagateTranscriber;
+
BackwardDiffTranscriber backwardTranscriber;
+
// Diagnostic object from the compile request for
// error messages.
DiagnosticSink* sink;
@@ -691,16 +822,6 @@ bool processAutodiffCalls(
// Create shared context for all auto-diff related passes
AutoDiffSharedContext autodiffContext(module->getModuleInst());
- // We start by initializing our shared IR building state,
- // since we will re-use that state for any code we
- // generate along the way.
- //
- SharedIRBuilder sharedBuilder;
- sharedBuilder.init(module);
- sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
-
- autodiffContext.sharedBuilder = &sharedBuilder;
-
AutoDiffPass pass(&autodiffContext, sink);
modified |= pass.processModule();
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index e0508cef7..1415618e1 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -37,7 +37,7 @@ typedef DiffInstPair<IRInst*, IRInst*> InstPair;
enum class FuncBodyTranscriptionTaskType
{
- Forward, Backward, Primal
+ Forward, BackwardPrimal, BackwardPropagate, Backward
};
struct FuncBodyTranscriptionTask
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 8440f4181..b721f4225 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -60,6 +60,7 @@ INST(Nop, nop, 0, 0)
INST(OptionalType, Optional, 1, 0)
INST(DifferentialPairType, DiffPair, 1, 0)
+ INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, 0)
/* BindExistentialsTypeBase */
@@ -731,6 +732,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0)
/// Decorated function is marked for the reverse-mode differentiation pass.
+ INST(BackwardDerivativePrimalDecoration, backwardDiffPrimalReference, 1, 0)
+ INST(BackwardDerivativePropagateDecoration, backwardDiffPropagateReference, 1, 0)
+ INST(BackwardDerivativeIntermediateTypeDecoration, backwardDiffIntermediateTypeReference, 1, 0)
INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0)
/// Used by the auto-diff pass to mark insts that compute
@@ -815,8 +819,18 @@ INST(CastToVoid, castToVoid, 1, 0)
INST(IsType, IsType, 3, 0)
INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0)
-INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0)
-INST(DifferentialEqualityTypeCast, DifferentialEqualityTypeCast, 1, 0)
+
+// Produces the primal computation of backward derivatives, will return an intermediate context for
+// backward derivative func.
+INST(BackwardDifferentiatePrimal, BackwardDifferentiatePrimal, 1, 0)
+
+// Produces the actual backward derivative propagate function, using the intermediate context returned by the
+// primal func produced from `BackwardDifferentiatePrimal`.
+INST(BackwardDifferentiatePropagate, BackwardDifferentiatePropagate, 1, 0)
+
+// Represents the conceptual backward derivative function. Only produced by lower-to-ir and will be
+// replaced with `BackwardDifferentiatePrimal` and `BackwardDifferentiatePropagate`.
+INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
@@ -875,6 +889,11 @@ INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0)
/* DifferentiableMethodRequirementDictionaryItem */
INST(ForwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0)
INST(BackwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0)
+ INST(BackwardDifferentiablePrimalMethodRequirementDictionaryItem, DifferentiablePrimalMethodRequirementDictionaryItem, 0, 0)
+ INST(BackwardDifferentiablePropagateMethodRequirementDictionaryItem, DifferentiablePropagateMethodRequirementDictionaryItem, 0, 0)
+ INST(BackwardDifferentiableIntermediateTypeRequirementDictionaryItem, DifferentiableIntermediateTypeRequirementDictionaryItem, 0, 0)
+
+
INST_RANGE(DifferentiableMethodRequirementDictionaryItem, ForwardDifferentiableMethodRequirementDictionaryItem, BackwardDifferentiableMethodRequirementDictionaryItem)
#undef PARENT
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 03a3fb063..d2a4c7ae3 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -585,6 +585,38 @@ struct IRForwardDerivativeDecoration : IRDecoration
IRInst* getForwardDerivativeFunc() { return getOperand(0); }
};
+struct IRBackwardDerivativeIntermediateTypeDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_BackwardDerivativeIntermediateTypeDecoration
+ };
+ IR_LEAF_ISA(BackwardDerivativeIntermediateTypeDecoration)
+
+ IRInst* getBackwardDerivativeIntermediateType() { return getOperand(0); }
+};
+
+struct IRBackwardDerivativePrimalDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_BackwardDerivativePrimalDecoration
+ };
+ IR_LEAF_ISA(BackwardDerivativePrimalDecoration)
+
+ IRInst* getBackwardDerivativePrimalFunc() { return getOperand(0); }
+};
+
+struct IRBackwardDerivativePropagateDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_BackwardDerivativePropagateDecoration
+ };
+ IR_LEAF_ISA(BackwardDerivativePropagateDecoration)
+
+ IRInst* getBackwardDerivativePropagateFunc() { return getOperand(0); }
+};
struct IRBackwardDerivativeDecoration : IRDecoration
{
@@ -681,7 +713,45 @@ struct IRForwardDifferentiate : IRInst
};
// An instruction that replaces the function symbol
-// with it's derivative function.
+// with its backward derivative primal function.
+// A backward derivative primal function is the first pass
+// of backward derivative computation. It performs the primal
+// computations and returns the intermediates that will be used
+// by the actual backward derivative function.
+struct IRBackwardDifferentiatePrimal : IRInst
+{
+ enum
+ {
+ kOp = kIROp_BackwardDifferentiatePrimal
+ };
+ // The base function for the call.
+ IRUse base;
+ IRInst* getBaseFn() { return getOperand(0); }
+
+ IR_LEAF_ISA(BackwardDifferentiatePrimal)
+};
+
+// An instruction that replaces the function symbol with its backward derivative propagate function.
+// A backward derivative propagate function is the second pass of backward derivative computation. It uses the
+// intermediates computed in the bacward derivative primal function to perform the actual backward
+// derivative propagation.
+struct IRBackwardDifferentiatePropagate : IRInst
+{
+ enum
+ {
+ kOp = kIROp_BackwardDifferentiatePropagate
+ };
+ // The base function for the call.
+ IRUse base;
+ IRInst* getBaseFn() { return getOperand(0); }
+
+ IR_LEAF_ISA(BackwardDifferentiatePropagate)
+};
+
+// An instruction that replaces the function symbol with its backward derivative function.
+// A backward derivative function is a concept that combines both passes of backward derivative
+// computation. This inst should only be produced by lower-to-ir, and will be replaced with calls to
+// the primal function followed by the propagate function in the auto-diff pass.
struct IRBackwardDifferentiate : IRInst
{
enum
@@ -2556,6 +2626,8 @@ public:
IRType* valueType,
IRInst* witnessTable);
+ IRBackwardDiffIntermediateContextType* getBackwardDiffIntermediateContextType(IRInst* func);
+
IRFuncType* getFuncType(
UInt paramCount,
IRType* const* paramTypes,
@@ -2664,6 +2736,8 @@ public:
IRInst* emitForwardDifferentiateInst(IRType* type, IRInst* baseFn);
IRInst* emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn);
+ IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn);
+ IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn);
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
@@ -3399,11 +3473,26 @@ public:
addDecoration(value, kIROp_ForwardDerivativeDecoration, fwdFunc);
}
+ void addBackwardDerivativePrimalDecoration(IRInst* value, IRInst* jvpFn)
+ {
+ addDecoration(value, kIROp_BackwardDerivativePrimalDecoration, jvpFn);
+ }
+
+ void addBackwardDerivativePropagateDecoration(IRInst* value, IRInst* jvpFn)
+ {
+ addDecoration(value, kIROp_BackwardDerivativePropagateDecoration, jvpFn);
+ }
+
void addBackwardDerivativeDecoration(IRInst* value, IRInst* jvpFn)
{
addDecoration(value, kIROp_BackwardDerivativeDecoration, jvpFn);
}
+ void addBackwardDerivativeIntermediateTypeDecoration(IRInst* value, IRInst* jvpFn)
+ {
+ addDecoration(value, kIROp_BackwardDerivativeIntermediateTypeDecoration, jvpFn);
+ }
+
void markInstAsDifferential(IRInst* value)
{
addDecoration(value, kIROp_DifferentialInstDecoration, nullptr);
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 46c7b3363..de970fbca 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -403,7 +403,7 @@ struct SpecializationContext
// If the base is specialized, the JVP version must be also be a specialized
// generic.
//
- SLANG_ASSERT(specDiffFunc);
+ SLANG_RELEASE_ASSERT(specDiffFunc);
// Build specialization arguments from specInst.
// Note that if we've reached this point, we can safely assume
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 81b5d636a..8e3e879ad 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1,5 +1,6 @@
#include "slang-ir-util.h"
#include "slang-ir-insts.h"
+#include "slang-ir-clone.h"
namespace Slang
{
@@ -143,4 +144,77 @@ IRInst* specializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecialize, I
genArgs.getBuffer());
}
+IRInst* maybeSpecializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecailize, IRInst* userGeneric)
+{
+ if (auto gen = as<IRGeneric>(userGeneric))
+ {
+ if (auto toSpecialize = as<IRGeneric>(genericToSpecailize))
+ {
+ return specializeWithGeneric(builder, toSpecialize, gen);
+ }
+ }
+ return genericToSpecailize;
+}
+
+IRInst* hoistValueFromGeneric(IRBuilder& builder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue)
+{
+ auto outerGeneric = as<IRGeneric>(findOuterGeneric(value));
+ if (!outerGeneric) return value;
+
+ builder.setInsertBefore(outerGeneric);
+ auto newGeneric = builder.emitGeneric();
+ builder.setInsertInto(newGeneric);
+ builder.emitBlock();
+ IRInst* newResultVal = nullptr;
+
+ // Clone insts in outerGeneric up until `value`.
+ IRCloneEnv cloneEnv;
+ for (auto inst : outerGeneric->getFirstBlock()->getChildren())
+ {
+ auto newInst = cloneInst(&cloneEnv, &builder, inst);
+ if (inst == value)
+ {
+ builder.emitReturn(newInst);
+ newResultVal = newInst;
+ break;
+ }
+ }
+ SLANG_RELEASE_ASSERT(newResultVal);
+ if (newResultVal->getOp() == kIROp_Func)
+ {
+ IRBuilder subBuilder = builder;
+ IRInst* subOutSpecialized = nullptr;
+ auto genericFuncType = hoistValueFromGeneric(subBuilder, newResultVal->getFullType(), subOutSpecialized, false);
+ newGeneric->setFullType((IRType*)genericFuncType);
+ }
+ else
+ {
+ newGeneric->setFullType(builder.getTypeKind());
+ }
+ if (replaceExistingValue)
+ {
+ builder.setInsertBefore(value);
+ outSpecializedVal = specializeWithGeneric(builder, newGeneric, outerGeneric);
+ value->replaceUsesWith(outSpecializedVal);
+ value->removeAndDeallocate();
+ }
+ return newGeneric;
+}
+
+void moveInstChildren(IRInst* dest, IRInst* src)
+{
+ for (auto child = dest->getFirstDecorationOrChild(); child; )
+ {
+ auto next = child->getNextInst();
+ child->removeAndDeallocate();
+ child = next;
+ }
+ for (auto child = src->getFirstDecorationOrChild(); child; )
+ {
+ auto next = child->getNextInst();
+ child->insertAtEnd(dest);
+ child = next;
+ }
+}
+
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 2087ee4a7..49f46d0e3 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -61,6 +61,40 @@ inline bool isChildInstOf(IRInst* inst, IRInst* parent)
IRInst* specializeWithGeneric(
IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric);
+IRInst* maybeSpecializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecailize, IRInst* userGeneric);
+
+ // For a value inside a generic, create a standalone generic wrapping just the value, and replace the use of
+ // the original value with a specialization of the new generic using the current generic arguments if
+ // `replaceExistingValue` is true.
+ // For example, if we have
+ // ```
+ // generic G { param T; v = x(T); f = y(v); return f; }
+ // ```
+ // hoistValueFromGeneric(G, v) turns the code into:
+ // ```
+ // generic G1 { param T1; v1 = x(T); return v1; }
+ // generic G { param T; v = specialize(G1, T); f = y(v); return f; }
+ // ```
+ // This function returns newly created generic inst.
+ // if `value` is not inside any generic, this function makes no change to IR, and returns `value`.
+IRInst* hoistValueFromGeneric(
+ IRBuilder& builder,
+ IRInst* value,
+ IRInst*& outSpecializedVal,
+ bool replaceExistingValue = false);
+
+// Clear dest and move all chidlren from src to dest.
+void moveInstChildren(IRInst* dest, IRInst* src);
+
+inline bool isGenericParam(IRInst* param)
+{
+ auto parent = param->getParent();
+ if (auto block = as<IRBlock>(parent))
+ parent = block->getParent();
+ if (as<IRGeneric>(parent))
+ return true;
+ return false;
+}
inline IRInst* unwrapAttributedType(IRInst* type)
{
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index d8a8fb7c4..9e0e328bd 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -300,6 +300,11 @@ namespace Slang
return as<IRParam>(getNextInst());
}
+ IRParam* IRParam::getPrevParam()
+ {
+ return as<IRParam>(getPrevInst());
+ }
+
// IRArrayTypeBase
IRInst* IRArrayTypeBase::getElementCount()
@@ -2802,6 +2807,15 @@ namespace Slang
operands);
}
+ IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType(
+ IRInst* func)
+ {
+ return (IRBackwardDiffIntermediateContextType*)getType(
+ kIROp_BackwardDiffIntermediateContextType,
+ 1,
+ &func);
+ }
+
IRFuncType* IRBuilder::getFuncType(
UInt paramCount,
IRType* const* paramTypes,
@@ -3129,6 +3143,28 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn)
+ {
+ auto inst = createInst<IRBackwardDifferentiatePrimal>(
+ this,
+ kIROp_BackwardDifferentiatePrimal,
+ type,
+ baseFn);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn)
+ {
+ auto inst = createInst<IRBackwardDifferentiatePropagate>(
+ this,
+ kIROp_BackwardDifferentiatePropagate,
+ type,
+ baseFn);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type));
@@ -6622,6 +6658,7 @@ namespace Slang
case kIROp_UnpackAnyValue:
case kIROp_Reinterpret:
case kIROp_GetNativePtr:
+ case kIROp_BackwardDiffIntermediateContextType:
return false;
case kIROp_ForwardDifferentiate:
@@ -6904,6 +6941,16 @@ namespace Slang
}
return nullptr;
}
+
+ IRInst* getGenericReturnVal(IRInst* inst)
+ {
+ if (auto gen = as<IRGeneric>(inst))
+ {
+ return findGenericReturnVal(gen);
+ }
+ return inst;
+ }
+
} // namespace Slang
#if SLANG_VC
@@ -6917,4 +6964,3 @@ SLANG_API const int SlangDebug__IROpStringLit = Slang::kIROp_StringLit;
SLANG_API const int SlangDebug__IROpIntLit = Slang::kIROp_IntLit;
#endif
#endif
-
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 56a33c02b..b4a657545 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1347,6 +1347,12 @@ struct IRDifferentialPairType : IRType
IR_LEAF_ISA(DifferentialPairType)
};
+struct IRBackwardDiffIntermediateContextType : IRType
+{
+ IRInst* getFunc() { return getOperand(0); }
+ IR_LEAF_ISA(BackwardDiffIntermediateContextType)
+};
+
struct IRVectorType : IRType
{
IRType* getElementType() { return (IRType*)getOperand(0); }
@@ -1743,6 +1749,9 @@ IRInst* findGenericReturnVal(IRGeneric* generic);
// Recursively find the inner most generic return value.
IRInst* findInnerMostGenericReturnVal(IRGeneric* generic);
+// Returns the generic return val if `inst` is a generic, otherwise returns `inst`.
+IRInst* getGenericReturnVal(IRInst* inst);
+
// Find the generic container, if any, that this inst is contained in
// Returns nullptr if there is no outer container.
IRInst* findOuterGeneric(IRInst* inst);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index a84cf9b8d..6803e1cb4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1407,6 +1407,33 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(diff);
}
+ LoweredValInfo visitBackwardDifferentiatePropagateVal(BackwardDifferentiatePropagateVal* val)
+ {
+ auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ auto diff = getBuilder()->emitBackwardDifferentiatePropagateInst(getBuilder()->getTypeKind(), funcVal.val);
+ return LoweredValInfo::simple(diff);
+ }
+
+ LoweredValInfo visitBackwardDifferentiatePrimalVal(BackwardDifferentiatePrimalVal* val)
+ {
+ auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ auto diff = getBuilder()->emitBackwardDifferentiatePrimalInst(getBuilder()->getTypeKind(), funcVal.val);
+ return LoweredValInfo::simple(diff);
+ }
+
+ LoweredValInfo visitBackwardDifferentiateIntermediateTypeVal(BackwardDifferentiateIntermediateTypeVal* val)
+ {
+ auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ auto diff = getBuilder()->getBackwardDiffIntermediateContextType(funcVal.val);
+ return LoweredValInfo::simple(diff);
+ }
+
LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*)
{
return LoweredValInfo();
@@ -6816,9 +6843,23 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
context->irBuilder->addDecoration(
interfaceType, kIROp_DifferentiableMethodRequirementDictionaryDecoration);
}
- auto op = as<ForwardDerivativeRequirementDecl>(requirementDecl)
- ? kIROp_ForwardDifferentiableMethodRequirementDictionaryItem
- : kIROp_BackwardDifferentiableMethodRequirementDictionaryItem;
+ IROp op = kIROp_ForwardDifferentiableMethodRequirementDictionaryItem;
+ if (as<BackwardDerivativeRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_BackwardDifferentiableMethodRequirementDictionaryItem;
+ }
+ else if (as<BackwardDerivativePropagateRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_BackwardDifferentiablePropagateMethodRequirementDictionaryItem;
+ }
+ else if (as<BackwardDerivativePrimalRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_BackwardDifferentiablePrimalMethodRequirementDictionaryItem;
+ }
+ else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(requirementDecl))
+ {
+ op = kIROp_BackwardDifferentiableIntermediateTypeRequirementDictionaryItem;
+ }
IRInst* args[] = {originalKey, associatedKey};
auto assoc = context->irBuilder->emitIntrinsicInst(nullptr, op, 2, args);
assoc->insertAtEnd(decor);
@@ -8405,6 +8446,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
+ LoweredValInfo visitBackwardDerivativeIntermediateTypeRequirementDecl(BackwardDerivativeIntermediateTypeRequirementDecl* decl)
+ {
+ SLANG_UNUSED(decl);
+ return LoweredValInfo(getBuilder()->getTypeKind());
+ }
+
LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl)
{
// A function declaration may have multiple, target-specific
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index 6ea9ea01e..58d6aaae3 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -521,6 +521,12 @@ namespace Slang
emitRaw(context, "FwdReq_");
else if (as<BackwardDerivativeRequirementDecl>(decl))
emitRaw(context, "BwdReq_");
+ else if (as<BackwardDerivativePropagateRequirementDecl>(decl))
+ emitRaw(context, "BwdReq_Prop_");
+ else if (as<BackwardDerivativePrimalRequirementDecl>(decl))
+ emitRaw(context, "BwdReq_Primal_");
+ else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(decl))
+ emitRaw(context, "BwdReq_CtxType_");
else
{
// TODO: handle other cases