From aaf3f5e97aaa3a256f4ca938251d011c125b9491 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 28 Aug 2024 23:52:04 -0700 Subject: Make sure `NullDifferential` and its witness are removed after autodiff. (#4958) * Make sure `NullDifferential` and its witness are removed after autodiff. * Fix. * Add a test. --- source/slang/core.meta.slang | 3 +++ source/slang/diff.meta.slang | 1 + source/slang/slang-ast-modifier.h | 6 ++++++ source/slang/slang-ir-autodiff.cpp | 28 +++++++++++++++++++++++----- source/slang/slang-ir-inst-defs.h | 3 +++ source/slang/slang-ir-insts.h | 5 +++++ source/slang/slang-lower-to-ir.cpp | 2 ++ 7 files changed, 43 insertions(+), 5 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 67fb201fe..7d5f4087c 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2929,6 +2929,9 @@ attribute_syntax [Specialize] : SpecializeAttribute; __attributeTarget(DeclBase) attribute_syntax [builtin] : BuiltinAttribute; +__attributeTarget(DeclBase) +attribute_syntax[__AutoDiffBuiltin] : AutoDiffBuiltinAttribute; + __attributeTarget(DeclBase) attribute_syntax [__requiresNVAPI] : RequiresNVAPIAttribute; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 80aca230a..b9cc0b103 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -26,6 +26,7 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; // A 'none-type' that acts as a run-time sentinel for zero differentials. +[__AutoDiffBuiltin] export struct NullDifferential : IDifferentiable { // for now, we'll use at least one field to make sure the type is non-empty diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 8c9cb484f..a7af4a249 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1171,6 +1171,12 @@ class BuiltinAttribute : public Attribute { SLANG_AST_CLASS(BuiltinAttribute) }; + + /// An attribute that marks a decl as a compiler built-in object for the autodiff system. +class AutoDiffBuiltinAttribute : public Attribute +{ + SLANG_AST_CLASS(AutoDiffBuiltinAttribute) +}; /// An attribute that defines the size of `AnyValue` type to represent a polymoprhic value that conforms to /// the decorated interface type. diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 8ca7dbe76..0979c097c 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1179,6 +1179,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) { for (auto inst : parent->getChildren()) { + bool shouldRemoveKeepAliveDecorations = false; for (auto decor = inst->getFirstDecoration(); decor; ) { auto next = decor->getNextDecoration(); @@ -1204,12 +1205,34 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_IntermediateContextFieldDifferentialTypeDecoration: decor->removeAndDeallocate(); break; + case kIROp_AutoDiffBuiltinDecoration: + // Remove the builtin decoration, and also remove any export/keep-alive + // decorations. + shouldRemoveKeepAliveDecorations = true; + decor->removeAndDeallocate(); default: break; } decor = next; } + if (shouldRemoveKeepAliveDecorations) + { + for (auto decor = inst->getFirstDecoration(); decor; ) + { + auto next = decor->getNextDecoration(); + switch (decor->getOp()) + { + case kIROp_ExportDecoration: + case kIROp_HLSLExportDecoration: + case kIROp_KeepAliveDecoration: + decor->removeAndDeallocate(); + break; + } + decor = next; + } + } + if (inst->getFirstChild() != nullptr) { stripAutoDiffDecorationsFromChildren(inst); @@ -2274,11 +2297,6 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module) stripNoDiffTypeAttribute(module); - // Remove keep-alive decorations from null-differential type - // so it can be DCE'd if unused. - // - releaseNullDifferentialType(&autodiffContext); - return modified; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 9fc1ab22c..eb4e88c41 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -968,6 +968,9 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) /// Decorates a auto-diff transcribed value with the original value that the inst is transcribed from. INST(AutoDiffOriginalValueDecoration, AutoDiffOriginalValueDecoration, 1, 0) + /// Decorates a type as auto-diff builtin type. + INST(AutoDiffBuiltinDecoration, AutoDiffBuiltinDecoration, 0, 0) + /// Used by the auto-diff pass to hold a reference to the /// generated derivative function. INST(ForwardDerivativeDecoration, fwdDerivative, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index db1571e50..f8836219e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4992,6 +4992,11 @@ public: addDecoration(value, kIROp_DynamicUniformDecoration); } + void addAutoDiffBuiltinDecoration(IRInst* value) + { + addDecoration(value, kIROp_AutoDiffBuiltinDecoration); + } + /// Add a decoration that indicates that the given `inst` depends on the given `dependency`. /// /// This decoration can be used to ensure that a value that an instruction diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 87199734a..ab9e2a540 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8883,6 +8883,8 @@ struct DeclLoweringVisitor : DeclVisitor { if (as(modifier)) subBuilder->addNonCopyableTypeDecoration(irAggType); + else if (as(modifier)) + subBuilder->addAutoDiffBuiltinDecoration(irAggType); } -- cgit v1.2.3