summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-23 06:59:25 -0800
committerGitHub <noreply@github.com>2023-01-23 06:59:25 -0800
commit46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch)
treec89f3a1c416330f859887d00f896b18bcc7488a5 /source/slang/slang-ir-check-differentiability.cpp
parent263ca18ea516cfce43fda703c0a411aaf1938e42 (diff)
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp82
1 files changed, 45 insertions, 37 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index cb7290036..67b7e92b0 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -5,12 +5,45 @@
namespace Slang
{
+bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
+{
+ HashSet<IRInst*> processedSet;
+ while (auto ptrType = as<IRPtrTypeBase>(typeInst))
+ {
+ typeInst = ptrType->getValueType();
+ if (!processedSet.Add(typeInst))
+ return false;
+ }
+ if (!typeInst)
+ return false;
+ switch (typeInst->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_DifferentialPairType:
+ return true;
+ default:
+ break;
+ }
+ if (context.lookUpConformanceForType(typeInst))
+ return true;
+ // Look for equivalent types.
+ for (auto type : context.differentiableWitnessDictionary)
+ {
+ if (isTypeEqual(type.Key, (IRType*)typeInst))
+ {
+ context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
+ return true;
+ }
+ }
+ return false;
+}
struct CheckDifferentiabilityPassContext : public InstPassBase
{
public:
DiagnosticSink* sink;
AutoDiffSharedContext sharedContext;
+ SharedIRBuilder* sharedBuilder;
enum DifferentiableLevel
{
@@ -18,8 +51,8 @@ public:
};
Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions;
- CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
- : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
+ CheckDifferentiabilityPassContext(SharedIRBuilder* inSharedBuilder, IRModule* inModule, DiagnosticSink* inSink)
+ : InstPassBase(inModule), sharedBuilder(inSharedBuilder), sink(inSink), sharedContext(inModule->getModuleInst())
{}
IRInst* getSpecializedVal(IRInst* inst)
@@ -161,39 +194,6 @@ public:
return false;
}
- bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
- {
- HashSet<IRInst*> processedSet;
- while (auto ptrType = as<IRPtrTypeBase>(typeInst))
- {
- typeInst = ptrType->getValueType();
- if (!processedSet.Add(typeInst))
- return false;
- }
- if (!typeInst)
- return false;
- switch (typeInst->getOp())
- {
- case kIROp_FloatType:
- case kIROp_DifferentialPairType:
- return true;
- default:
- break;
- }
- if (context.lookUpConformanceForType(typeInst))
- return true;
- // Look for equivalent types.
- for (auto type : context.differentiableWitnessDictionary)
- {
- if (isTypeEqual(type.Key, (IRType*)typeInst))
- {
- context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
- return true;
- }
- }
- return false;
- }
-
int getParamIndexInBlock(IRParam* paramInst)
{
auto block = as<IRBlock>(paramInst->getParent());
@@ -228,6 +228,14 @@ public:
DifferentiableTypeConformanceContext diffTypeContext(&sharedContext);
diffTypeContext.setFunc(funcInst);
+ if (isBackwardDifferentiableFunc(funcInst) && !funcInst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ {
+ if (auto func = as<IRFunc>(funcInst))
+ {
+ if (SLANG_FAILED(eliminateAddressInsts(sharedBuilder, diffTypeContext, func, sink)))
+ return;
+ }
+ }
HashSet<IRInst*> produceDiffSet;
HashSet<IRInst*> expectDiffSet;
@@ -468,9 +476,9 @@ public:
}
};
-void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink)
+void checkAutoDiffUsages(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink)
{
- CheckDifferentiabilityPassContext context(module, sink);
+ CheckDifferentiabilityPassContext context(sharedBuilder, module, sink);
context.processModule();
}