diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-06-13 17:30:16 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-06-13 17:30:16 -0400 |
| commit | fba316f0e7dacc7f93bee3a95fb93b2ab02bdd80 (patch) | |
| tree | 4687141e1581193de2d6990122c3190d3c2fcc9f | |
| parent | f0d40ad5e1d0a0dec39fe8a141d3f81d88fc576a (diff) | |
Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets (#4364)
* Remove `IRHLSLExportDecoration` and `IRKeepAliveDecoration` for non-CUDA/Torch targets
* Update hlsl-torch-cross-compile.slang
| -rw-r--r-- | source/slang/slang-emit.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 26 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 12 | ||||
| -rw-r--r-- | tests/autodiff/hlsl-torch-cross-compile.slang | 55 |
5 files changed, 110 insertions, 12 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 38f066c6c..0f53f74cd 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -557,6 +557,17 @@ Result linkAndOptimizeIR( switch (target) { + case CodeGenTarget::CUDASource: + case CodeGenTarget::PyTorchCppBinding: + break; + + default: + removeTorchAndCUDAEntryPoints(irModule); + break; + } + + switch (target) + { case CodeGenTarget::CPPSource: case CodeGenTarget::HostCPPSource: { @@ -605,10 +616,19 @@ Result linkAndOptimizeIR( if (!targetProgram->getOptionSet().shouldPerformMinimumOptimizations()) fuseCallsToSaturatedCooperation(irModule); - // Generate any requested derivative wrappers - if (requiredLoweringPassSet.derivativePyBindWrapper) - generateDerivativeWrappers(irModule, sink); - + switch (target) + { + case CodeGenTarget::CUDASource: + case CodeGenTarget::PyTorchCppBinding: + { + // Generate any requested derivative wrappers + if (requiredLoweringPassSet.derivativePyBindWrapper) + generateDerivativeWrappers(irModule, sink); + break; + } + default: + break; + } // Next, we need to ensure that the code we emit for // the target doesn't contain any operations that would // be illegal on the target platform. For example, diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 5105b2e81..6922984d6 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -632,7 +632,6 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice()); builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice()); - builder->addHLSLExportDecoration(reflectionFunc); builder->addKeepAliveDecoration(reflectionFunc); } @@ -817,7 +816,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); - builder.addHLSLExportDecoration(reflFunc); builder.addKeepAliveDecoration(reflFunc); } @@ -899,7 +897,6 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) // Mark for host-side emit logic. builder.addCudaHostDecoration(hostFunc); // Keep alive. This method will be accessed externally. - builder.addHLSLExportDecoration(hostFunc); builder.addKeepAliveDecoration(hostFunc); } @@ -1163,6 +1160,27 @@ void handleAutoBindNames(IRModule* module) } } +void removeTorchAndCUDAEntryPoints(IRModule* module) +{ + // Go through global insts, find cuda & torch related entry points and remove the keep-alive decoration. + IRBuilder builder(module); + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(globalInst)) + { + if (func->findDecoration<IRAutoPyBindCudaDecoration>() || + func->findDecoration<IRTorchEntryPointDecoration>() || + func->findDecoration<IRCudaKernelDecoration>()) + { + if (auto keepAlive = func->findDecoration<IRKeepAliveDecoration>()) + keepAlive->removeAndDeallocate(); + if (auto hlslExport = func->findDecoration<IRHLSLExportDecoration>()) + hlslExport->removeAndDeallocate(); + } + } + } +} + void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -1237,7 +1255,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); } - builder.addHLSLExportDecoration(wrapperFunc); builder.addKeepAliveDecoration(wrapperFunc); builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc); @@ -1296,7 +1313,6 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); } - builder.addHLSLExportDecoration(wrapperFunc); builder.addKeepAliveDecoration(wrapperFunc); builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc); diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h index a761dbc03..6d022db7b 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.h +++ b/source/slang/slang-ir-pytorch-cpp-binding.h @@ -11,6 +11,7 @@ void removeTorchKernels(IRModule* module); void handleAutoBindNames(IRModule* module); void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink); void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink* sink); +void removeTorchAndCUDAEntryPoints(IRModule* module); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f8faf7c07..cc4704fe9 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1385,21 +1385,27 @@ static void addLinkageDecoration( { builder->addCudaKernelDecoration(inst); builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); - builder->addHLSLExportDecoration(inst); + + // Temp decorations to get this function through the linker. builder->addKeepAliveDecoration(inst); + builder->addHLSLExportDecoration(inst); } else if (as<TorchEntryPointAttribute>(modifier)) { builder->addTorchEntryPointDecoration(inst, decl->getName()->text.getUnownedSlice()); builder->addCudaHostDecoration(inst); - builder->addHLSLExportDecoration(inst); - builder->addKeepAliveDecoration(inst); builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + + // Temp decorations to get this function through the linker. + builder->addKeepAliveDecoration(inst); + builder->addHLSLExportDecoration(inst); } else if (as<AutoPyBindCudaAttribute>(modifier)) { builder->addAutoPyBindCudaDecoration(inst, decl->getName()->text.getUnownedSlice()); builder->addAutoPyBindExportInfoDecoration(inst); + + // Temp decorations to get this function through the linker. builder->addKeepAliveDecoration(inst); builder->addHLSLExportDecoration(inst); } diff --git a/tests/autodiff/hlsl-torch-cross-compile.slang b/tests/autodiff/hlsl-torch-cross-compile.slang new file mode 100644 index 000000000..5568f26c5 --- /dev/null +++ b/tests/autodiff/hlsl-torch-cross-compile.slang @@ -0,0 +1,55 @@ +//TEST:SIMPLE(filecheck=HLSL): -target hlsl -line-directive-mode none -entry computeMain -stage compute +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none +//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[Differentiable] +float func1(float x) +{ + return x * 4; +} + +[AutoPyBindCUDA] +[CUDAKernel] +void torchMain(TensorView<float> v) +{ + v[0] = func1(v[0]); + v[1] = func1(v[1]); +} + +// Shouldn't see torchMain (or its transformations) anywhere in the HLSL output +// HLSL-NOT:torchMain +// HLSL:func1 +// HLSL-NOT:torchMain +// HLSL:computeMain +// HLSL-NOT:torchMain + +[Differentiable] +float func2(float a) +{ + return a; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + dpfloat dpb = dpfloat(1.5, 1.0); + + outputBuffer[0] = fwd_diff(func1)(dpa).d; // Expect: 1 + outputBuffer[1] = fwd_diff(func2)(dpfloat(dpa.p, 0.0)).d; // Expect: 0 + } +} + +// Ensure that the generated CUDA and Torch kernels do have torchMain & its transformations + +// TORCH: {{^SLANG_PRELUDE_EXPORT$}} +// TORCH-NEXT: void __kernel__torchMain(TensorView {{[[:alnum:]_]+}}); + +// CUDA: __global__ void __kernel__torchMain(TensorView {{[[:alnum:]_]+}}) |
