summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-28 23:52:04 -0700
committerGitHub <noreply@github.com>2024-08-28 23:52:04 -0700
commitaaf3f5e97aaa3a256f4ca938251d011c125b9491 (patch)
tree139f4dc52d3c6c84ae4fb0c148fa309cdfc0ed3c /source
parente9f52a694710d793c7032bbb6175a452618f1b23 (diff)
Make sure `NullDifferential` and its witness are removed after autodiff. (#4958)
* Make sure `NullDifferential` and its witness are removed after autodiff. * Fix. * Add a test.
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/diff.meta.slang1
-rw-r--r--source/slang/slang-ast-modifier.h6
-rw-r--r--source/slang/slang-ir-autodiff.cpp28
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h5
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
7 files changed, 43 insertions, 5 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 67fb201fe..7d5f4087c 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2930,6 +2930,9 @@ __attributeTarget(DeclBase)
attribute_syntax [builtin] : BuiltinAttribute;
__attributeTarget(DeclBase)
+attribute_syntax[__AutoDiffBuiltin] : AutoDiffBuiltinAttribute;
+
+__attributeTarget(DeclBase)
attribute_syntax [__requiresNVAPI] : RequiresNVAPIAttribute;
__attributeTarget(AggTypeDecl)
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 80aca230a..b9cc0b103 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -26,6 +26,7 @@ __attributeTarget(FunctionDeclBase)
attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
// A 'none-type' that acts as a run-time sentinel for zero differentials.
+[__AutoDiffBuiltin]
export struct NullDifferential : IDifferentiable
{
// for now, we'll use at least one field to make sure the type is non-empty
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 8c9cb484f..a7af4a249 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1171,6 +1171,12 @@ class BuiltinAttribute : public Attribute
{
SLANG_AST_CLASS(BuiltinAttribute)
};
+
+ /// An attribute that marks a decl as a compiler built-in object for the autodiff system.
+class AutoDiffBuiltinAttribute : public Attribute
+{
+ SLANG_AST_CLASS(AutoDiffBuiltinAttribute)
+};
/// An attribute that defines the size of `AnyValue` type to represent a polymoprhic value that conforms to
/// the decorated interface type.
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 8ca7dbe76..0979c097c 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1179,6 +1179,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
for (auto inst : parent->getChildren())
{
+ bool shouldRemoveKeepAliveDecorations = false;
for (auto decor = inst->getFirstDecoration(); decor; )
{
auto next = decor->getNextDecoration();
@@ -1204,12 +1205,34 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_IntermediateContextFieldDifferentialTypeDecoration:
decor->removeAndDeallocate();
break;
+ case kIROp_AutoDiffBuiltinDecoration:
+ // Remove the builtin decoration, and also remove any export/keep-alive
+ // decorations.
+ shouldRemoveKeepAliveDecorations = true;
+ decor->removeAndDeallocate();
default:
break;
}
decor = next;
}
+ if (shouldRemoveKeepAliveDecorations)
+ {
+ for (auto decor = inst->getFirstDecoration(); decor; )
+ {
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
+ {
+ case kIROp_ExportDecoration:
+ case kIROp_HLSLExportDecoration:
+ case kIROp_KeepAliveDecoration:
+ decor->removeAndDeallocate();
+ break;
+ }
+ decor = next;
+ }
+ }
+
if (inst->getFirstChild() != nullptr)
{
stripAutoDiffDecorationsFromChildren(inst);
@@ -2274,11 +2297,6 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module)
stripNoDiffTypeAttribute(module);
- // Remove keep-alive decorations from null-differential type
- // so it can be DCE'd if unused.
- //
- releaseNullDifferentialType(&autodiffContext);
-
return modified;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 9fc1ab22c..eb4e88c41 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -968,6 +968,9 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
/// Decorates a auto-diff transcribed value with the original value that the inst is transcribed from.
INST(AutoDiffOriginalValueDecoration, AutoDiffOriginalValueDecoration, 1, 0)
+ /// Decorates a type as auto-diff builtin type.
+ INST(AutoDiffBuiltinDecoration, AutoDiffBuiltinDecoration, 0, 0)
+
/// Used by the auto-diff pass to hold a reference to the
/// generated derivative function.
INST(ForwardDerivativeDecoration, fwdDerivative, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index db1571e50..f8836219e 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4992,6 +4992,11 @@ public:
addDecoration(value, kIROp_DynamicUniformDecoration);
}
+ void addAutoDiffBuiltinDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_AutoDiffBuiltinDecoration);
+ }
+
/// Add a decoration that indicates that the given `inst` depends on the given `dependency`.
///
/// This decoration can be used to ensure that a value that an instruction
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 87199734a..ab9e2a540 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8883,6 +8883,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
if (as<NonCopyableTypeAttribute>(modifier))
subBuilder->addNonCopyableTypeDecoration(irAggType);
+ else if (as<AutoDiffBuiltinAttribute>(modifier))
+ subBuilder->addAutoDiffBuiltinDecoration(irAggType);
}