summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index eb81bfd8c..971e87a6f 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -245,6 +245,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
return;
}
auto newParam = builder.emitParam(newParamType);
+ param->transferDecorationsTo(newParam);
newParams.add(newParam);
}
@@ -361,14 +362,16 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
// Remove all [TorchEntryPoint] functions when emitting CUDA source.
void removeTorchKernels(IRModule* module)
{
+ List<IRInst*> toRemove;
for (auto globalInst : module->getGlobalInsts())
{
if (!as<IRFunc>(globalInst))
continue;
if (globalInst->findDecoration<IRTorchEntryPointDecoration>())
- globalInst->removeAndDeallocate();
+ toRemove.add(globalInst);
}
-
+ for (auto inst : toRemove)
+ inst->removeAndDeallocate();
}
}