summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-29 18:17:33 -0800
committerGitHub <noreply@github.com>2022-11-29 18:17:33 -0800
commitf52b4de3b29ee27213b7d60fb620a0d5d50b49f9 (patch)
treed4570c53045bca8e9411e884b0905d9384430a58 /source/slang/slang-ir-autodiff.cpp
parentf5581786a1891cedb165adb1afe71fe34f26e030 (diff)
Allow `no_diff` modifier on parameters (#2538)
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp35
1 files changed, 34 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index b0dbf62fa..4373cf44b 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -413,6 +413,36 @@ void stripAutoDiffDecorations(IRModule* module)
stripAutoDiffDecorationsFromChildren(module->getModuleInst());
}
+struct StripNoDiffTypeAttributePass : InstPassBase
+{
+ StripNoDiffTypeAttributePass(IRModule* module) :
+ InstPassBase(module)
+ {
+ }
+ void processModule()
+ {
+ processInstsOfType<IRAttributedType>(kIROp_AttributedType, [&](IRAttributedType* attrType)
+ {
+ if (attrType->getAllAttrs().getCount() == 1)
+ {
+ if (attrType->findAttr<IRNoDiffAttr>())
+ {
+ attrType->replaceUsesWith(attrType->getBaseType());
+ attrType->removeAndDeallocate();
+ }
+ }
+ });
+ sharedBuilderStorage.init(module);
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ }
+};
+
+void stripNoDiffTypeAttribute(IRModule* module)
+{
+ StripNoDiffTypeAttributePass pass(module);
+ pass.processModule();
+}
+
bool processAutodiffCalls(
IRModule* module,
DiagnosticSink* sink,
@@ -452,11 +482,14 @@ bool processAutodiffCalls(
//
modified |= processPairTypes(&autodiffContext);
+ stripNoDiffTypeAttribute(module);
+
// Remove auto-diff related decorations.
stripAutoDiffDecorations(module);
+
return modified;
}
-} \ No newline at end of file
+}