diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-06-13 17:30:16 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-06-13 17:30:16 -0400 |
| commit | fba316f0e7dacc7f93bee3a95fb93b2ab02bdd80 (patch) | |
| tree | 4687141e1581193de2d6990122c3190d3c2fcc9f /source/slang/slang-ir-pytorch-cpp-binding.cpp | |
| parent | f0d40ad5e1d0a0dec39fe8a141d3f81d88fc576a (diff) | |
Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets (#4364)
* Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets
* Update hlsl-torch-cross-compile.slang
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 5105b2e81..6922984d6 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -632,7 +632,6 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice()); builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice()); - builder->addHLSLExportDecoration(reflectionFunc); builder->addKeepAliveDecoration(reflectionFunc); } @@ -817,7 +816,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); - builder.addHLSLExportDecoration(reflFunc); builder.addKeepAliveDecoration(reflFunc); } @@ -899,7 +897,6 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) // Mark for host-side emit logic. builder.addCudaHostDecoration(hostFunc); // Keep alive. This method will be accessed externally. - builder.addHLSLExportDecoration(hostFunc); builder.addKeepAliveDecoration(hostFunc); } @@ -1163,6 +1160,27 @@ void handleAutoBindNames(IRModule* module) } } +void removeTorchAndCUDAEntryPoints(IRModule* module) +{ + // Go through global insts, find cuda & torch related entry points and remove the keep-alive decoration. + IRBuilder builder(module); + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(globalInst)) + { + if (func->findDecoration<IRAutoPyBindCudaDecoration>() || + func->findDecoration<IRTorchEntryPointDecoration>() || + func->findDecoration<IRCudaKernelDecoration>()) + { + if (auto keepAlive = func->findDecoration<IRKeepAliveDecoration>()) + keepAlive->removeAndDeallocate(); + if (auto hlslExport = func->findDecoration<IRHLSLExportDecoration>()) + hlslExport->removeAndDeallocate(); + } + } + } +} + void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -1237,7 +1255,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); } - builder.addHLSLExportDecoration(wrapperFunc); builder.addKeepAliveDecoration(wrapperFunc); builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc); @@ -1296,7 +1313,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); } - builder.addHLSLExportDecoration(wrapperFunc); builder.addKeepAliveDecoration(wrapperFunc); builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc); |
