summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-30 11:41:54 -0700
committerGitHub <noreply@github.com>2023-03-30 11:41:54 -0700
commitd01e28a49b47c9fadf2b764a74f318e3f95061e5 (patch)
treefcb1dc172690d9ef83abb108d0b943408da43dee /source
parent37594df883a7b4d62b2aae80ee73f195dbfb6d77 (diff)
Fix autodiff pass duplicates exported functions. (#2759)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-util.cpp26
-rw-r--r--source/slang/slang-ir-util.h3
4 files changed, 32 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 869f8920c..df94bf69f 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1564,6 +1564,7 @@ struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
{
insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
+ removeLinkageDecorations(func);
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &differentiableTypeConformanceContext;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index f23e45be0..66c85647f 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -519,6 +519,8 @@ namespace Slang
SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func)
{
+ removeLinkageDecorations(func);
+
DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
diffTypeContext.setFunc(func);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index bff80392f..c5cebb8b5 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -582,6 +582,32 @@ void sortBlocksInFunc(IRGlobalValueWithCode* func)
block->insertAtEnd(func);
}
+void removeLinkageDecorations(IRGlobalValueWithCode* func)
+{
+ List<IRInst*> toRemove;
+ for (auto inst : func->getDecorations())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_ImportDecoration:
+ case kIROp_ExportDecoration:
+ case kIROp_ExternCppDecoration:
+ case kIROp_PublicDecoration:
+ case kIROp_KeepAliveDecoration:
+ case kIROp_DllImportDecoration:
+ case kIROp_CudaDeviceExportDecoration:
+ case kIROp_DllExportDecoration:
+ case kIROp_HLSLExportDecoration:
+ toRemove.add(inst);
+ break;
+ default:
+ break;
+ }
+ }
+ for (auto inst : toRemove)
+ inst->removeAndDeallocate();
+}
+
void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst)
{
if (as<IRParam>(inst))
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 4f1c15459..f8e53c38f 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -191,6 +191,9 @@ void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst);
IRInst* emitLoopBlocks(IRBuilder* builder, IRInst* initVal, IRInst* finalVal, IRBlock*& loopBodyBlock, IRBlock*& loopBreakBlock);
void sortBlocksInFunc(IRGlobalValueWithCode* func);
+
+// Remove all linkage decorations from func.
+void removeLinkageDecorations(IRGlobalValueWithCode* func);
}
#endif