summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-28 15:19:03 -0700
committerGitHub <noreply@github.com>2023-03-28 15:19:03 -0700
commita61f089fbc4b944d058e6417d8a0d22d57ca5c92 (patch)
tree4fa1a0c6370b8d34262d297653239f48aa004c71 /source/slang/slang-ir-pytorch-cpp-binding.cpp
parent8f03af5e5b580170fab3fd2fe6144f92038c7701 (diff)
Add slangpy doc, fix cuda prelude. (#2748)
* Add slangpy doc, fix cuda prelude. * more bug fix. * fix. * fix. * More fix. * fix. * f * fix prelude. * update prelude. * update doc * Update prelude. * add zeros_like * update doc. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index eb81bfd8c..971e87a6f 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -245,6 +245,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
return;
}
auto newParam = builder.emitParam(newParamType);
+ param->transferDecorationsTo(newParam);
newParams.add(newParam);
}
@@ -361,14 +362,16 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
// Remove all [TorchEntryPoint] functions when emitting CUDA source.
void removeTorchKernels(IRModule* module)
{
+ List<IRInst*> toRemove;
for (auto globalInst : module->getGlobalInsts())
{
if (!as<IRFunc>(globalInst))
continue;
if (globalInst->findDecoration<IRTorchEntryPointDecoration>())
- globalInst->removeAndDeallocate();
+ toRemove.add(globalInst);
}
-
+ for (auto inst : toRemove)
+ inst->removeAndDeallocate();
}
}