summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-19 18:51:24 -0400
committerGitHub <noreply@github.com>2023-09-19 18:51:24 -0400
commit739c3a7b53dc6489065fcd5e9f0a04370c5f9c8f (patch)
tree593c86cbc184476479c66554cc6784b454bdec66 /source/slang/slang-emit.cpp
parent359fdc9d556b4c493c588c5b8f93df85933634f8 (diff)
Added `[AutoPyBindCUDA]` for automatic kernel binding + `[PyExport]` for exporting type information (#3209)
* Initial: add a DiffTensor impl * Auto-binding and diff tensor implementations now work * Refactored diff-tensor implementation + added py-export for struct types * Cleanup * Update slang-ir-pytorch-cpp-binding.cpp * Updated test names * Update autodiff-data-flow.slang.expected * Add more versions of load/store & default generic args for DiffTensorView. * Add diagnostic for default generic arg and more tests * Add more `[AutoPyBind]` tests
Diffstat (limited to 'source/slang/slang-emit.cpp')
-rw-r--r--source/slang/slang-emit.cpp6
1 files changed, 6 insertions, 0 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index c77f2a6ce..98c9c9803 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -45,6 +45,7 @@
#include "slang-ir-legalize-vector-types.h"
#include "slang-ir-metadata.h"
#include "slang-ir-optix-entry-point-uniforms.h"
+#include "slang-ir-pytorch-cpp-binding.h"
#include "slang-ir-restructure.h"
#include "slang-ir-restructure-scoping.h"
#include "slang-ir-sccp.h"
@@ -369,6 +370,9 @@ Result linkAndOptimizeIR(
// being passed to saturated_cooperation
fuseCallsToSaturatedCooperation(irModule);
+ // Generate any requested derivative wrappers
+ generateDerivativeWrappers(irModule, sink);
+
// Next, we need to ensure that the code we emit for
// the target doesn't contain any operations that would
// be illegal on the target platform. For example,
@@ -444,9 +448,11 @@ Result linkAndOptimizeIR(
{
case CodeGenTarget::PyTorchCppBinding:
generatePyTorchCppBinding(irModule, sink);
+ handleAutoBindNames(irModule);
break;
case CodeGenTarget::CUDASource:
removeTorchKernels(irModule);
+ handleAutoBindNames(irModule);
break;
default:
break;