summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-06-13 17:30:16 -0400
committerGitHub <noreply@github.com>2024-06-13 17:30:16 -0400
commitfba316f0e7dacc7f93bee3a95fb93b2ab02bdd80 (patch)
tree4687141e1581193de2d6990122c3190d3c2fcc9f
parentf0d40ad5e1d0a0dec39fe8a141d3f81d88fc576a (diff)
Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets (#4364)
* Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets * Update hlsl-torch-cross-compile.slang
-rw-r--r--source/slang/slang-emit.cpp28
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp26
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp12
-rw-r--r--tests/autodiff/hlsl-torch-cross-compile.slang55
5 files changed, 110 insertions, 12 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 38f066c6c..0f53f74cd 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -557,6 +557,17 @@ Result linkAndOptimizeIR(
switch (target)
{
+ case CodeGenTarget::CUDASource:
+ case CodeGenTarget::PyTorchCppBinding:
+ break;
+
+ default:
+ removeTorchAndCUDAEntryPoints(irModule);
+ break;
+ }
+
+ switch (target)
+ {
case CodeGenTarget::CPPSource:
case CodeGenTarget::HostCPPSource:
{
@@ -605,10 +616,19 @@ Result linkAndOptimizeIR(
if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations())
fuseCallsToSaturatedCooperation(irModule);
- // Generate any requested derivative wrappers
- if (requiredLoweringPassSet.derivativePyBindWrapper)
- generateDerivativeWrappers(irModule, sink);
-
+ switch (target)
+ {
+ case CodeGenTarget::CUDASource:
+ case CodeGenTarget::PyTorchCppBinding:
+ {
+ // Generate any requested derivative wrappers
+ if (requiredLoweringPassSet.derivativePyBindWrapper)
+ generateDerivativeWrappers(irModule, sink);
+ break;
+ }
+ default:
+ break;
+ }
// Next, we need to ensure that the code we emit for
// the target doesn't contain any operations that would
// be illegal on the target platform. For example,
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 5105b2e81..6922984d6 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -632,7 +632,6 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
- builder->addHLSLExportDecoration(reflectionFunc);
builder->addKeepAliveDecoration(reflectionFunc);
}
@@ -817,7 +816,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
- builder.addHLSLExportDecoration(reflFunc);
builder.addKeepAliveDecoration(reflFunc);
}
@@ -899,7 +897,6 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
// Mark for host-side emit logic.
builder.addCudaHostDecoration(hostFunc);
// Keep alive. This method will be accessed externally.
- builder.addHLSLExportDecoration(hostFunc);
builder.addKeepAliveDecoration(hostFunc);
}
@@ -1163,6 +1160,27 @@ void handleAutoBindNames(IRModule* module)
}
}
+void removeTorchAndCUDAEntryPoints(IRModule* module)
+{
+ // Go through global insts, find cuda & torch related entry points and remove the keep-alive decoration.
+ IRBuilder builder(module);
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(globalInst))
+ {
+ if (func->findDecoration<IRAutoPyBindCudaDecoration>() ||
+ func->findDecoration<IRTorchEntryPointDecoration>() ||
+ func->findDecoration<IRCudaKernelDecoration>())
+ {
+ if (auto keepAlive = func->findDecoration<IRKeepAliveDecoration>())
+ keepAlive->removeAndDeallocate();
+ if (auto hlslExport = func->findDecoration<IRHLSLExportDecoration>())
+ hlslExport->removeAndDeallocate();
+ }
+ }
+ }
+}
+
void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
@@ -1237,7 +1255,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
}
- builder.addHLSLExportDecoration(wrapperFunc);
builder.addKeepAliveDecoration(wrapperFunc);
builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc);
@@ -1296,7 +1313,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
}
- builder.addHLSLExportDecoration(wrapperFunc);
builder.addKeepAliveDecoration(wrapperFunc);
builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc);
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h
index a761dbc03..6d022db7b 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.h
+++ b/source/slang/slang-ir-pytorch-cpp-binding.h
@@ -11,6 +11,7 @@ void removeTorchKernels(IRModule* module);
void handleAutoBindNames(IRModule* module);
void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink);
void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink* sink);
+void removeTorchAndCUDAEntryPoints(IRModule* module);
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index f8faf7c07..cc4704fe9 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1385,21 +1385,27 @@ static void addLinkageDecoration(
{
builder->addCudaKernelDecoration(inst);
builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
- builder->addHLSLExportDecoration(inst);
+
+ // Temp decorations to get this function through the linker.
builder->addKeepAliveDecoration(inst);
+ builder->addHLSLExportDecoration(inst);
}
else if (as<TorchEntryPointAttribute>(modifier))
{
builder->addTorchEntryPointDecoration(inst, decl->getName()->text.getUnownedSlice());
builder->addCudaHostDecoration(inst);
- builder->addHLSLExportDecoration(inst);
- builder->addKeepAliveDecoration(inst);
builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
+
+ // Temp decorations to get this function through the linker.
+ builder->addKeepAliveDecoration(inst);
+ builder->addHLSLExportDecoration(inst);
}
else if (as<AutoPyBindCudaAttribute>(modifier))
{
builder->addAutoPyBindCudaDecoration(inst, decl->getName()->text.getUnownedSlice());
builder->addAutoPyBindExportInfoDecoration(inst);
+
+ // Temp decorations to get this function through the linker.
builder->addKeepAliveDecoration(inst);
builder->addHLSLExportDecoration(inst);
}
diff --git a/tests/autodiff/hlsl-torch-cross-compile.slang b/tests/autodiff/hlsl-torch-cross-compile.slang
new file mode 100644
index 000000000..5568f26c5
--- /dev/null
+++ b/tests/autodiff/hlsl-torch-cross-compile.slang
@@ -0,0 +1,55 @@
+//TEST:SIMPLE(filecheck=HLSL): -target hlsl -line-directive-mode none -entry computeMain -stage compute
+//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none
+//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[Differentiable]
+float func1(float x)
+{
+ return x * 4;
+}
+
+[AutoPyBindCUDA]
+[CUDAKernel]
+void torchMain(TensorView<float> v)
+{
+ v[0] = func1(v[0]);
+ v[1] = func1(v[1]);
+}
+
+// Shouldn't see torchMain (or its transformations) anywhere in the HLSL output
+// HLSL-NOT:torchMain
+// HLSL:func1
+// HLSL-NOT:torchMain
+// HLSL:computeMain
+// HLSL-NOT:torchMain
+
+[Differentiable]
+float func2(float a)
+{
+ return a;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+ dpfloat dpb = dpfloat(1.5, 1.0);
+
+ outputBuffer[0] = fwd_diff(func1)(dpa).d; // Expect: 1
+ outputBuffer[1] = fwd_diff(func2)(dpfloat(dpa.p, 0.0)).d; // Expect: 0
+ }
+}
+
+// Ensure that the generated CUDA and Torch kernels do have torchMain & its transformations
+
+// TORCH: {{^SLANG_PRELUDE_EXPORT$}}
+// TORCH-NEXT: void __kernel__torchMain(TensorView {{[[:alnum:]_]+}});
+
+// CUDA: __global__ void __kernel__torchMain(TensorView {{[[:alnum:]_]+}})