diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-06 13:39:06 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-06 13:39:06 -0800 |
| commit | 33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch) | |
| tree | 318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source/slang/slang-ir-autodiff.cpp | |
| parent | e70cbe76ce74769069b7384f5f05c62da1ca45ed (diff) | |
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func.
* Fix.
* Download swiftshader with github actions instead of curl on linux.
* Fix github action.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 185 |
1 files changed, 153 insertions, 32 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 40c24d11d..d23271704 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -401,6 +401,10 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_DifferentiableTypeDictionaryDecoration: case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDerivativeIntermediateTypeDecoration: + case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_BackwardDerivativePrimalDecoration: decor->removeAndDeallocate(); break; default: @@ -489,7 +493,7 @@ struct AutoDiffPass : public InstPassBase // TODO(sai): Move this call. forwardTranscriber.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); - IRBuilder builderStorage(this->autodiffContext->sharedBuilder); + IRBuilder builderStorage(&sharedBuilderStorage); IRBuilder* builder = &builderStorage; // Process all ForwardDifferentiate and BackwardDifferentiate instructions by @@ -500,6 +504,81 @@ struct AutoDiffPass : public InstPassBase return modified; } + IRInst* processIntermediateContextTypeBase(IRBuilder* builder, IRInst* base) + { + if (auto spec = as<IRSpecialize>(base)) + { + List<IRInst*> args; + auto subBase = processIntermediateContextTypeBase(builder, spec->getBase()); + for (UInt a = 0; a < spec->getArgCount(); a++) + args.add(spec->getArg(a)); + auto actualType = builder->emitSpecializeInst( + builder->getTypeKind(), + subBase, + args.getCount(), + args.getBuffer()); + return actualType; + } + else if (auto baseGeneric = as<IRGeneric>(base)) + { + auto inner = findGenericReturnVal(baseGeneric); + if (auto typeDecor = inner->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) + { + auto typeSpec = cast<IRSpecialize>(typeDecor->getBackwardDerivativeIntermediateType()); + auto typeSpecBase = typeSpec->getBase(); + return typeSpecBase; + } + } + else if (auto func = as<IRFunc>(base)) + { + if (auto typeDecor = func->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) + { + return typeDecor->getBackwardDerivativeIntermediateType(); + } + } + else if (auto lookup = as<IRLookupWitnessMethod>(base)) + { + auto key = lookup->getRequirementKey(); + if (auto typeDecor = key->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) + { + auto typeKey = typeDecor->getBackwardDerivativeIntermediateType(); + auto typeLookup = builder->emitLookupInterfaceMethodInst(builder->getTypeKind(), lookup->getWitnessTable(), typeKey); + return typeLookup; + } + } + return nullptr; + } + + bool lowerIntermediateContextType(IRBuilder* builder) + { + bool changed = false; + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_BackwardDiffIntermediateContextType: + { + auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst); + + auto baseFunc = differentiateInst->getOperand(0); + IRBuilder subBuilder = *builder; + subBuilder.setInsertBefore(inst); + auto type = processIntermediateContextTypeBase(&subBuilder, baseFunc); + if (type) + { + inst->replaceUsesWith(type); + inst->removeAndDeallocate(); + changed = true; + } + } + break; + default: + break; + } + }); + return changed; + } + // Process all differentiate calls, and recursively generate code for forward and backward // derivative functions. // @@ -518,6 +597,9 @@ struct AutoDiffPass : public InstPassBase { case kIROp_ForwardDifferentiate: case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + case kIROp_BackwardDiffIntermediateContextType: // Only process now if the operand is a materialized function. switch (inst->getOperand(0)->getOp()) { @@ -538,29 +620,49 @@ struct AutoDiffPass : public InstPassBase // Process collected differentiate insts and replace them with placeholders for // differentiated functions. - for (auto differentiateInst : autoDiffWorkList) + for (Index i = 0; i < autoDiffWorkList.getCount(); i++) { - if (auto diffInst = as<IRForwardDifferentiate>(differentiateInst)) + auto differentiateInst = autoDiffWorkList[i]; + + IRInst* diffFunc = nullptr; + IRBuilder subBuilder(*builder); + subBuilder.setInsertBefore(differentiateInst); + switch (differentiateInst->getOp()) { - IRBuilder subBuilder(*builder); - subBuilder.setInsertBefore(differentiateInst); - if (auto diffFunc = forwardTranscriber.transcribe(&subBuilder, diffInst->getBaseFn())) + case kIROp_ForwardDifferentiate: { - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - changed = true; + auto baseFunc = as<IRForwardDifferentiate>(differentiateInst)->getBaseFn(); + diffFunc = forwardTranscriber.transcribe(&subBuilder, baseFunc); } - } - else if (auto backDiffInst = as<IRBackwardDifferentiate>(differentiateInst)) - { - auto baseInst = backDiffInst->getBaseFn(); - if (auto diffFunc = backwardTranscriber.transcribe(builder, (IRFunc*)baseInst)) + break; + case kIROp_BackwardDifferentiatePrimal: + { + auto baseFunc = differentiateInst->getOperand(0); + diffFunc = backwardPrimalTranscriber.transcribe(&subBuilder, baseFunc); + } + break; + case kIROp_BackwardDifferentiatePropagate: { - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - changed = true; + auto baseFunc = differentiateInst->getOperand(0); + diffFunc = backwardPropagateTranscriber.transcribe(&subBuilder, baseFunc); } + break; + case kIROp_BackwardDifferentiate: + { + auto baseFunc = differentiateInst->getOperand(0); + diffFunc = backwardTranscriber.transcribe(&subBuilder, baseFunc); + } + break; + default: + break; + } + + if (diffFunc) + { + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + changed = true; } } @@ -591,8 +693,11 @@ struct AutoDiffPass : public InstPassBase case FuncBodyTranscriptionTaskType::Forward: forwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc); break; - case FuncBodyTranscriptionTaskType::Backward: - backwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc); + case FuncBodyTranscriptionTaskType::BackwardPrimal: + // Don't need to do anything, they will be filled by `backwardPropagateTranscriber`. + break; + case FuncBodyTranscriptionTaskType::BackwardPropagate: + backwardPropagateTranscriber.transcribeFunc(builder, primalFunc, diffFunc); break; default: break; @@ -616,6 +721,11 @@ struct AutoDiffPass : public InstPassBase hasChanges |= changed; } + if (lowerIntermediateContextType(builder)) + { + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + hasChanges = true; + } return hasChanges; } @@ -651,12 +761,28 @@ struct AutoDiffPass : public InstPassBase AutoDiffPass(AutoDiffSharedContext* context, DiagnosticSink* sink) : InstPassBase(context->moduleInst->getModule()), sink(sink), - forwardTranscriber(context, context->sharedBuilder, sink), - backwardTranscriber(context, context->sharedBuilder, sink), + forwardTranscriber(context, &sharedBuilderStorage, sink), + backwardPrimalTranscriber(context, &sharedBuilderStorage, sink), + backwardPropagateTranscriber(context, &sharedBuilderStorage, sink), + backwardTranscriber(context, &sharedBuilderStorage, sink), pairBuilderStorage(context), autodiffContext(context) { + + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + sharedBuilderStorage.init(module); + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + + context->sharedBuilder = &sharedBuilderStorage; + forwardTranscriber.pairBuilder = &pairBuilderStorage; + backwardPrimalTranscriber.pairBuilder = &pairBuilderStorage; + backwardPrimalTranscriber.fwdDiffTranscriber = &forwardTranscriber; + backwardPropagateTranscriber.pairBuilder = &pairBuilderStorage; + backwardPropagateTranscriber.fwdDiffTranscriber = &forwardTranscriber; backwardTranscriber.pairBuilder = &pairBuilderStorage; backwardTranscriber.fwdDiffTranscriber = &forwardTranscriber; } @@ -667,8 +793,13 @@ protected: // ForwardDiffTranscriber forwardTranscriber; + BackwardDiffPrimalTranscriber backwardPrimalTranscriber; + + BackwardDiffPropagateTranscriber backwardPropagateTranscriber; + BackwardDiffTranscriber backwardTranscriber; + // Diagnostic object from the compile request for // error messages. DiagnosticSink* sink; @@ -691,16 +822,6 @@ bool processAutodiffCalls( // Create shared context for all auto-diff related passes AutoDiffSharedContext autodiffContext(module->getModuleInst()); - // We start by initializing our shared IR building state, - // since we will re-use that state for any code we - // generate along the way. - // - SharedIRBuilder sharedBuilder; - sharedBuilder.init(module); - sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); - - autodiffContext.sharedBuilder = &sharedBuilder; - AutoDiffPass pass(&autodiffContext, sink); modified |= pass.processModule(); |
