summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-06 13:39:06 -0800
committerGitHub <noreply@github.com>2023-01-06 13:39:06 -0800
commit33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch)
tree318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source/slang/slang-ir-autodiff.cpp
parente70cbe76ce74769069b7384f5f05c62da1ca45ed (diff)
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func. * Fix. * Download swiftshader with github actions instead of curl on linux. * Fix github action. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp185
1 files changed, 153 insertions, 32 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 40c24d11d..d23271704 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -401,6 +401,10 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_DifferentiableTypeDictionaryDecoration:
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeIntermediateTypeDecoration:
+ case kIROp_BackwardDerivativePropagateDecoration:
+ case kIROp_BackwardDerivativePrimalDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -489,7 +493,7 @@ struct AutoDiffPass : public InstPassBase
// TODO(sai): Move this call.
forwardTranscriber.differentiableTypeConformanceContext.buildGlobalWitnessDictionary();
- IRBuilder builderStorage(this->autodiffContext->sharedBuilder);
+ IRBuilder builderStorage(&sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
// Process all ForwardDifferentiate and BackwardDifferentiate instructions by
@@ -500,6 +504,81 @@ struct AutoDiffPass : public InstPassBase
return modified;
}
+ IRInst* processIntermediateContextTypeBase(IRBuilder* builder, IRInst* base)
+ {
+ if (auto spec = as<IRSpecialize>(base))
+ {
+ List<IRInst*> args;
+ auto subBase = processIntermediateContextTypeBase(builder, spec->getBase());
+ for (UInt a = 0; a < spec->getArgCount(); a++)
+ args.add(spec->getArg(a));
+ auto actualType = builder->emitSpecializeInst(
+ builder->getTypeKind(),
+ subBase,
+ args.getCount(),
+ args.getBuffer());
+ return actualType;
+ }
+ else if (auto baseGeneric = as<IRGeneric>(base))
+ {
+ auto inner = findGenericReturnVal(baseGeneric);
+ if (auto typeDecor = inner->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
+ {
+ auto typeSpec = cast<IRSpecialize>(typeDecor->getBackwardDerivativeIntermediateType());
+ auto typeSpecBase = typeSpec->getBase();
+ return typeSpecBase;
+ }
+ }
+ else if (auto func = as<IRFunc>(base))
+ {
+ if (auto typeDecor = func->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
+ {
+ return typeDecor->getBackwardDerivativeIntermediateType();
+ }
+ }
+ else if (auto lookup = as<IRLookupWitnessMethod>(base))
+ {
+ auto key = lookup->getRequirementKey();
+ if (auto typeDecor = key->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
+ {
+ auto typeKey = typeDecor->getBackwardDerivativeIntermediateType();
+ auto typeLookup = builder->emitLookupInterfaceMethodInst(builder->getTypeKind(), lookup->getWitnessTable(), typeKey);
+ return typeLookup;
+ }
+ }
+ return nullptr;
+ }
+
+ bool lowerIntermediateContextType(IRBuilder* builder)
+ {
+ bool changed = false;
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_BackwardDiffIntermediateContextType:
+ {
+ auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst);
+
+ auto baseFunc = differentiateInst->getOperand(0);
+ IRBuilder subBuilder = *builder;
+ subBuilder.setInsertBefore(inst);
+ auto type = processIntermediateContextTypeBase(&subBuilder, baseFunc);
+ if (type)
+ {
+ inst->replaceUsesWith(type);
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ });
+ return changed;
+ }
+
// Process all differentiate calls, and recursively generate code for forward and backward
// derivative functions.
//
@@ -518,6 +597,9 @@ struct AutoDiffPass : public InstPassBase
{
case kIROp_ForwardDifferentiate:
case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ case kIROp_BackwardDiffIntermediateContextType:
// Only process now if the operand is a materialized function.
switch (inst->getOperand(0)->getOp())
{
@@ -538,29 +620,49 @@ struct AutoDiffPass : public InstPassBase
// Process collected differentiate insts and replace them with placeholders for
// differentiated functions.
- for (auto differentiateInst : autoDiffWorkList)
+ for (Index i = 0; i < autoDiffWorkList.getCount(); i++)
{
- if (auto diffInst = as<IRForwardDifferentiate>(differentiateInst))
+ auto differentiateInst = autoDiffWorkList[i];
+
+ IRInst* diffFunc = nullptr;
+ IRBuilder subBuilder(*builder);
+ subBuilder.setInsertBefore(differentiateInst);
+ switch (differentiateInst->getOp())
{
- IRBuilder subBuilder(*builder);
- subBuilder.setInsertBefore(differentiateInst);
- if (auto diffFunc = forwardTranscriber.transcribe(&subBuilder, diffInst->getBaseFn()))
+ case kIROp_ForwardDifferentiate:
{
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- changed = true;
+ auto baseFunc = as<IRForwardDifferentiate>(differentiateInst)->getBaseFn();
+ diffFunc = forwardTranscriber.transcribe(&subBuilder, baseFunc);
}
- }
- else if (auto backDiffInst = as<IRBackwardDifferentiate>(differentiateInst))
- {
- auto baseInst = backDiffInst->getBaseFn();
- if (auto diffFunc = backwardTranscriber.transcribe(builder, (IRFunc*)baseInst))
+ break;
+ case kIROp_BackwardDifferentiatePrimal:
+ {
+ auto baseFunc = differentiateInst->getOperand(0);
+ diffFunc = backwardPrimalTranscriber.transcribe(&subBuilder, baseFunc);
+ }
+ break;
+ case kIROp_BackwardDifferentiatePropagate:
{
- SLANG_ASSERT(diffFunc);
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- changed = true;
+ auto baseFunc = differentiateInst->getOperand(0);
+ diffFunc = backwardPropagateTranscriber.transcribe(&subBuilder, baseFunc);
}
+ break;
+ case kIROp_BackwardDifferentiate:
+ {
+ auto baseFunc = differentiateInst->getOperand(0);
+ diffFunc = backwardTranscriber.transcribe(&subBuilder, baseFunc);
+ }
+ break;
+ default:
+ break;
+ }
+
+ if (diffFunc)
+ {
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ changed = true;
}
}
@@ -591,8 +693,11 @@ struct AutoDiffPass : public InstPassBase
case FuncBodyTranscriptionTaskType::Forward:
forwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
break;
- case FuncBodyTranscriptionTaskType::Backward:
- backwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
+ case FuncBodyTranscriptionTaskType::BackwardPrimal:
+ // Don't need to do anything, they will be filled by `backwardPropagateTranscriber`.
+ break;
+ case FuncBodyTranscriptionTaskType::BackwardPropagate:
+ backwardPropagateTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
break;
default:
break;
@@ -616,6 +721,11 @@ struct AutoDiffPass : public InstPassBase
hasChanges |= changed;
}
+ if (lowerIntermediateContextType(builder))
+ {
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ hasChanges = true;
+ }
return hasChanges;
}
@@ -651,12 +761,28 @@ struct AutoDiffPass : public InstPassBase
AutoDiffPass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
InstPassBase(context->moduleInst->getModule()),
sink(sink),
- forwardTranscriber(context, context->sharedBuilder, sink),
- backwardTranscriber(context, context->sharedBuilder, sink),
+ forwardTranscriber(context, &sharedBuilderStorage, sink),
+ backwardPrimalTranscriber(context, &sharedBuilderStorage, sink),
+ backwardPropagateTranscriber(context, &sharedBuilderStorage, sink),
+ backwardTranscriber(context, &sharedBuilderStorage, sink),
pairBuilderStorage(context),
autodiffContext(context)
{
+
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ sharedBuilderStorage.init(module);
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+
+ context->sharedBuilder = &sharedBuilderStorage;
+
forwardTranscriber.pairBuilder = &pairBuilderStorage;
+ backwardPrimalTranscriber.pairBuilder = &pairBuilderStorage;
+ backwardPrimalTranscriber.fwdDiffTranscriber = &forwardTranscriber;
+ backwardPropagateTranscriber.pairBuilder = &pairBuilderStorage;
+ backwardPropagateTranscriber.fwdDiffTranscriber = &forwardTranscriber;
backwardTranscriber.pairBuilder = &pairBuilderStorage;
backwardTranscriber.fwdDiffTranscriber = &forwardTranscriber;
}
@@ -667,8 +793,13 @@ protected:
//
ForwardDiffTranscriber forwardTranscriber;
+ BackwardDiffPrimalTranscriber backwardPrimalTranscriber;
+
+ BackwardDiffPropagateTranscriber backwardPropagateTranscriber;
+
BackwardDiffTranscriber backwardTranscriber;
+
// Diagnostic object from the compile request for
// error messages.
DiagnosticSink* sink;
@@ -691,16 +822,6 @@ bool processAutodiffCalls(
// Create shared context for all auto-diff related passes
AutoDiffSharedContext autodiffContext(module->getModuleInst());
- // We start by initializing our shared IR building state,
- // since we will re-use that state for any code we
- // generate along the way.
- //
- SharedIRBuilder sharedBuilder;
- sharedBuilder.init(module);
- sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
-
- autodiffContext.sharedBuilder = &sharedBuilder;
-
AutoDiffPass pass(&autodiffContext, sink);
modified |= pass.processModule();