summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-26 13:59:11 -0700
committerGitHub <noreply@github.com>2023-03-26 13:59:11 -0700
commitd64ee86a3130f8eeb75d09193c38c621d7565eba (patch)
treefed25a0cc2a7372d26175774f5983bed693e6b64 /source/slang/slang-emit.cpp
parent666af0962b6ab41489a3a3287db83f77c2f6461a (diff)
Add PyTorch C++ binding generation. (#2734)
* Add PyTorch C++ binding generation. * fix --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit.cpp')
-rw-r--r--source/slang/slang-emit.cpp90
1 files changed, 60 insertions, 30 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 1b4eed8fd..fe72efcc7 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -56,6 +56,7 @@
#include "slang-ir-glsl-liveness.h"
#include "slang-ir-string-hash.h"
#include "slang-ir-simplify-for-emit.h"
+#include "slang-ir-pytorch-cpp-binding.h"
#include "slang-legalize-types.h"
#include "slang-lower-to-ir.h"
#include "slang-mangle.h"
@@ -74,6 +75,7 @@
#include "slang-emit-hlsl.h"
#include "slang-emit-cpp.h"
#include "slang-emit-cuda.h"
+#include "slang-emit-torch.h"
#include "../compiler-core/slang-artifact-desc-util.h"
#include "../compiler-core/slang-artifact-util.h"
@@ -83,6 +85,7 @@
#include <assert.h>
Slang::String get_slang_cpp_host_prelude();
+Slang::String get_slang_torch_prelude();
namespace Slang {
@@ -402,6 +405,18 @@ Result linkAndOptimizeIR(
finalizeSpecialization(irModule);
+ switch (target)
+ {
+ case CodeGenTarget::PyTorchCppBinding:
+ generatePyTorchCppBinding(irModule, sink);
+ break;
+ case CodeGenTarget::CUDASource:
+ removeTorchKernels(irModule);
+ break;
+ default:
+ break;
+ }
+
// If we have a target that is GPU like we use the string hashing mechanism
// but for that to work we need to inline such that calls (or returns) of strings
// boil down into getStringHash(stringLiteral)
@@ -969,31 +984,39 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
LinkedIR linkedIR;
RefPtr<CLikeSourceEmitter> sourceEmitter;
-
SourceLanguage sourceLanguage = CLikeSourceEmitter::getSourceLanguage(target);
- switch (sourceLanguage)
+
+ switch (target)
{
- case SourceLanguage::CPP:
- {
- sourceEmitter = new CPPSourceEmitter(desc);
- break;
- }
- case SourceLanguage::GLSL:
- {
- sourceEmitter = new GLSLSourceEmitter(desc);
- break;
- }
- case SourceLanguage::HLSL:
- {
- sourceEmitter = new HLSLSourceEmitter(desc);
- break;
- }
- case SourceLanguage::CUDA:
+ default:
+ switch (sourceLanguage)
{
- sourceEmitter = new CUDASourceEmitter(desc);
- break;
+ case SourceLanguage::CPP:
+ {
+ sourceEmitter = new CPPSourceEmitter(desc);
+ break;
+ }
+ case SourceLanguage::GLSL:
+ {
+ sourceEmitter = new GLSLSourceEmitter(desc);
+ break;
+ }
+ case SourceLanguage::HLSL:
+ {
+ sourceEmitter = new HLSLSourceEmitter(desc);
+ break;
+ }
+ case SourceLanguage::CUDA:
+ {
+ sourceEmitter = new CUDASourceEmitter(desc);
+ break;
+ }
+ default: break;
}
- default: break;
+ break;
+ case CodeGenTarget::PyTorchCppBinding:
+ sourceEmitter = new TorchCppSourceEmitter(desc);
+ break;
}
if (!sourceEmitter)
@@ -1072,16 +1095,23 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
// Emit any front matter
sourceEmitter->emitFrontMatter(targetRequest);
- // If heterogeneous we output the prelude before everything else
- if (isHeterogeneousTarget(target))
- {
- sourceWriter.emit(get_slang_cpp_host_prelude());
- }
- else
+ switch (target)
{
- // Get the prelude
- String prelude = session->getPreludeForLanguage(sourceLanguage);
- sourceWriter.emit(prelude);
+ case CodeGenTarget::PyTorchCppBinding:
+ sourceWriter.emit(get_slang_torch_prelude());
+ break;
+ default:
+ if (isHeterogeneousTarget(target))
+ {
+ sourceWriter.emit(get_slang_cpp_host_prelude());
+ }
+ else
+ {
+ // Get the prelude
+ String prelude = session->getPreludeForLanguage(sourceLanguage);
+ sourceWriter.emit(prelude);
+ }
+ break;
}
// Emit anything that goes before the contents of the code generated for the module