diff options
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index eb81bfd8c..971e87a6f 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -245,6 +245,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) return; } auto newParam = builder.emitParam(newParamType); + param->transferDecorationsTo(newParam); newParams.add(newParam); } @@ -361,14 +362,16 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) // Remove all [TorchEntryPoint] functions when emitting CUDA source. void removeTorchKernels(IRModule* module) { + List<IRInst*> toRemove; for (auto globalInst : module->getGlobalInsts()) { if (!as<IRFunc>(globalInst)) continue; if (globalInst->findDecoration<IRTorchEntryPointDecoration>()) - globalInst->removeAndDeallocate(); + toRemove.add(globalInst); } - + for (auto inst : toRemove) + inst->removeAndDeallocate(); } } |
