From d64ee86a3130f8eeb75d09193c38c621d7565eba Mon Sep 17 00:00:00 2001 From: Yong He Date: Sun, 26 Mar 2023 13:59:11 -0700 Subject: Add PyTorch C++ binding generation. (#2734) * Add PyTorch C++ binding generation. * fix --------- Co-authored-by: Yong He --- source/slang/slang-emit.cpp | 90 ++++++++++++++++++++++++++++++--------------- 1 file changed, 60 insertions(+), 30 deletions(-) (limited to 'source/slang/slang-emit.cpp') diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 1b4eed8fd..fe72efcc7 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -56,6 +56,7 @@ #include "slang-ir-glsl-liveness.h" #include "slang-ir-string-hash.h" #include "slang-ir-simplify-for-emit.h" +#include "slang-ir-pytorch-cpp-binding.h" #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" @@ -74,6 +75,7 @@ #include "slang-emit-hlsl.h" #include "slang-emit-cpp.h" #include "slang-emit-cuda.h" +#include "slang-emit-torch.h" #include "../compiler-core/slang-artifact-desc-util.h" #include "../compiler-core/slang-artifact-util.h" @@ -83,6 +85,7 @@ #include Slang::String get_slang_cpp_host_prelude(); +Slang::String get_slang_torch_prelude(); namespace Slang { @@ -402,6 +405,18 @@ Result linkAndOptimizeIR( finalizeSpecialization(irModule); + switch (target) + { + case CodeGenTarget::PyTorchCppBinding: + generatePyTorchCppBinding(irModule, sink); + break; + case CodeGenTarget::CUDASource: + removeTorchKernels(irModule); + break; + default: + break; + } + // If we have a target that is GPU like we use the string hashing mechanism // but for that to work we need to inline such that calls (or returns) of strings // boil down into getStringHash(stringLiteral) @@ -969,31 +984,39 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr LinkedIR linkedIR; RefPtr sourceEmitter; - SourceLanguage sourceLanguage = CLikeSourceEmitter::getSourceLanguage(target); - switch (sourceLanguage) + + switch (target) { - case SourceLanguage::CPP: - { - sourceEmitter = new CPPSourceEmitter(desc); - break; - } - case SourceLanguage::GLSL: - { - sourceEmitter = new GLSLSourceEmitter(desc); - break; - } - case SourceLanguage::HLSL: - { - sourceEmitter = new HLSLSourceEmitter(desc); - break; - } - case SourceLanguage::CUDA: + default: + switch (sourceLanguage) { - sourceEmitter = new CUDASourceEmitter(desc); - break; + case SourceLanguage::CPP: + { + sourceEmitter = new CPPSourceEmitter(desc); + break; + } + case SourceLanguage::GLSL: + { + sourceEmitter = new GLSLSourceEmitter(desc); + break; + } + case SourceLanguage::HLSL: + { + sourceEmitter = new HLSLSourceEmitter(desc); + break; + } + case SourceLanguage::CUDA: + { + sourceEmitter = new CUDASourceEmitter(desc); + break; + } + default: break; } - default: break; + break; + case CodeGenTarget::PyTorchCppBinding: + sourceEmitter = new TorchCppSourceEmitter(desc); + break; } if (!sourceEmitter) @@ -1072,16 +1095,23 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr // Emit any front matter sourceEmitter->emitFrontMatter(targetRequest); - // If heterogeneous we output the prelude before everything else - if (isHeterogeneousTarget(target)) - { - sourceWriter.emit(get_slang_cpp_host_prelude()); - } - else + switch (target) { - // Get the prelude - String prelude = session->getPreludeForLanguage(sourceLanguage); - sourceWriter.emit(prelude); + case CodeGenTarget::PyTorchCppBinding: + sourceWriter.emit(get_slang_torch_prelude()); + break; + default: + if (isHeterogeneousTarget(target)) + { + sourceWriter.emit(get_slang_cpp_host_prelude()); + } + else + { + // Get the prelude + String prelude = session->getPreludeForLanguage(sourceLanguage); + sourceWriter.emit(prelude); + } + break; } // Emit anything that goes before the contents of the code generated for the module -- cgit v1.2.3