summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-06-13 17:30:16 -0400
committerGitHub <noreply@github.com>2024-06-13 17:30:16 -0400
commitfba316f0e7dacc7f93bee3a95fb93b2ab02bdd80 (patch)
tree4687141e1581193de2d6990122c3190d3c2fcc9f /source/slang/slang-ir-pytorch-cpp-binding.cpp
parentf0d40ad5e1d0a0dec39fe8a141d3f81d88fc576a (diff)
Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets (#4364)
* Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets * Update hlsl-torch-cross-compile.slang
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp26
1 files changed, 21 insertions, 5 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 5105b2e81..6922984d6 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -632,7 +632,6 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
- builder->addHLSLExportDecoration(reflectionFunc);
builder->addKeepAliveDecoration(reflectionFunc);
}
@@ -817,7 +816,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
- builder.addHLSLExportDecoration(reflFunc);
builder.addKeepAliveDecoration(reflFunc);
}
@@ -899,7 +897,6 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
// Mark for host-side emit logic.
builder.addCudaHostDecoration(hostFunc);
// Keep alive. This method will be accessed externally.
- builder.addHLSLExportDecoration(hostFunc);
builder.addKeepAliveDecoration(hostFunc);
}
@@ -1163,6 +1160,27 @@ void handleAutoBindNames(IRModule* module)
}
}
+void removeTorchAndCUDAEntryPoints(IRModule* module)
+{
+ // Go through global insts, find cuda & torch related entry points and remove the keep-alive decoration.
+ IRBuilder builder(module);
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(globalInst))
+ {
+ if (func->findDecoration<IRAutoPyBindCudaDecoration>() ||
+ func->findDecoration<IRTorchEntryPointDecoration>() ||
+ func->findDecoration<IRCudaKernelDecoration>())
+ {
+ if (auto keepAlive = func->findDecoration<IRKeepAliveDecoration>())
+ keepAlive->removeAndDeallocate();
+ if (auto hlslExport = func->findDecoration<IRHLSLExportDecoration>())
+ hlslExport->removeAndDeallocate();
+ }
+ }
+ }
+}
+
void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
@@ -1237,7 +1255,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
}
- builder.addHLSLExportDecoration(wrapperFunc);
builder.addKeepAliveDecoration(wrapperFunc);
builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc);
@@ -1296,7 +1313,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
}
- builder.addHLSLExportDecoration(wrapperFunc);
builder.addKeepAliveDecoration(wrapperFunc);
builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc);