diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-23 06:59:25 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-23 06:59:25 -0800 |
| commit | 46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch) | |
| tree | c89f3a1c416330f859887d00f896b18bcc7488a5 /source/slang/slang-ir-check-differentiability.cpp | |
| parent | 263ca18ea516cfce43fda703c0a411aaf1938e42 (diff) | |
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 82 |
1 files changed, 45 insertions, 37 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index cb7290036..67b7e92b0 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -5,12 +5,45 @@ namespace Slang { +bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) +{ + HashSet<IRInst*> processedSet; + while (auto ptrType = as<IRPtrTypeBase>(typeInst)) + { + typeInst = ptrType->getValueType(); + if (!processedSet.Add(typeInst)) + return false; + } + if (!typeInst) + return false; + switch (typeInst->getOp()) + { + case kIROp_FloatType: + case kIROp_DifferentialPairType: + return true; + default: + break; + } + if (context.lookUpConformanceForType(typeInst)) + return true; + // Look for equivalent types. + for (auto type : context.differentiableWitnessDictionary) + { + if (isTypeEqual(type.Key, (IRType*)typeInst)) + { + context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value; + return true; + } + } + return false; +} struct CheckDifferentiabilityPassContext : public InstPassBase { public: DiagnosticSink* sink; AutoDiffSharedContext sharedContext; + SharedIRBuilder* sharedBuilder; enum DifferentiableLevel { @@ -18,8 +51,8 @@ public: }; Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions; - CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink) - : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst()) + CheckDifferentiabilityPassContext(SharedIRBuilder* inSharedBuilder, IRModule* inModule, DiagnosticSink* inSink) + : InstPassBase(inModule), sharedBuilder(inSharedBuilder), sink(inSink), sharedContext(inModule->getModuleInst()) {} IRInst* getSpecializedVal(IRInst* inst) @@ -161,39 +194,6 @@ public: return false; } - bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) - { - HashSet<IRInst*> processedSet; - while (auto ptrType = as<IRPtrTypeBase>(typeInst)) - { - typeInst = ptrType->getValueType(); - if (!processedSet.Add(typeInst)) - return false; - } - if (!typeInst) - return false; - switch (typeInst->getOp()) - { - case kIROp_FloatType: - case kIROp_DifferentialPairType: - return true; - default: - break; - } - if (context.lookUpConformanceForType(typeInst)) - return true; - // Look for equivalent types. - for (auto type : context.differentiableWitnessDictionary) - { - if (isTypeEqual(type.Key, (IRType*)typeInst)) - { - context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value; - return true; - } - } - return false; - } - int getParamIndexInBlock(IRParam* paramInst) { auto block = as<IRBlock>(paramInst->getParent()); @@ -228,6 +228,14 @@ public: DifferentiableTypeConformanceContext diffTypeContext(&sharedContext); diffTypeContext.setFunc(funcInst); + if (isBackwardDifferentiableFunc(funcInst) && !funcInst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + { + if (auto func = as<IRFunc>(funcInst)) + { + if (SLANG_FAILED(eliminateAddressInsts(sharedBuilder, diffTypeContext, func, sink))) + return; + } + } HashSet<IRInst*> produceDiffSet; HashSet<IRInst*> expectDiffSet; @@ -468,9 +476,9 @@ public: } }; -void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink) +void checkAutoDiffUsages(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink) { - CheckDifferentiabilityPassContext context(module, sink); + CheckDifferentiabilityPassContext context(sharedBuilder, module, sink); context.processModule(); } |
