summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-20 14:42:50 -0800
committerGitHub <noreply@github.com>2023-02-20 14:42:50 -0800
commit47715e625337d489f3c0131bbc2b849378b48a5a (patch)
treebc737c8f03ef537b2ac39860bbb922c7600edc43 /source/slang/slang-ir-autodiff.cpp
parent8b05df4187117d61491f2fdbeb7d744146ad73f7 (diff)
Miscellaneous backward autodiff fixes. (#2665)
* Fix differentiable type registration * Fix use of non-differentiable return value in a differentiable func. * Fix use of primal inst that does not dominate the diff block. * Fix primal inst hoisting, and add missing type legalization logic. * Make `detach` defined on all differentiable T. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 97cdb644e..b630b798d 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -563,6 +563,30 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst*
return false;
}
+bool canInstBeStored(IRInst* inst)
+{
+ if (as<IRBasicType>(inst->getDataType()))
+ return true;
+
+ switch (inst->getDataType()->getOp())
+ {
+ case kIROp_StructType:
+ case kIROp_OptionalType:
+ case kIROp_TupleType:
+ case kIROp_ArrayType:
+ case kIROp_DifferentialPairType:
+ case kIROp_InterfaceType:
+ case kIROp_AnyValueType:
+ case kIROp_ClassType:
+ case kIROp_FloatType:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ return true;
+ default:
+ return false;
+ }
+}
+
struct AutoDiffPass : public InstPassBase
{
DiagnosticSink* getSink()