summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-26 13:59:11 -0700
committerGitHub <noreply@github.com>2023-03-26 13:59:11 -0700
commitd64ee86a3130f8eeb75d09193c38c621d7565eba (patch)
treefed25a0cc2a7372d26175774f5983bed693e6b64
parent666af0962b6ab41489a3a3287db83f77c2f6461a (diff)
Add PyTorch C++ binding generation. (#2734)
* Add PyTorch C++ binding generation. * fix --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--build/visual-studio/run-generators/run-generators.vcxproj18
-rw-r--r--build/visual-studio/run-generators/run-generators.vcxproj.filters6
-rw-r--r--build/visual-studio/slang/slang.vcxproj5
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters15
-rw-r--r--prelude/slang-cpp-types-core.h561
-rw-r--r--prelude/slang-cpp-types.h554
-rw-r--r--prelude/slang-cuda-prelude.h71
-rw-r--r--prelude/slang-torch-prelude.h126
-rw-r--r--premake5.lua9
-rw-r--r--slang.h1
-rw-r--r--source/compiler-core/slang-artifact-desc-util.cpp1
-rw-r--r--source/compiler-core/slang-artifact.h1
-rw-r--r--source/core/slang-type-convert-util.cpp1
-rw-r--r--source/core/slang-type-text-util.cpp1
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/diff.meta.slang84
-rw-r--r--source/slang/hlsl.meta.slang10
-rw-r--r--source/slang/slang-ast-modifier.h5
-rw-r--r--source/slang/slang-ast-type.cpp7
-rw-r--r--source/slang/slang-ast-type.h9
-rw-r--r--source/slang/slang-compiler.cpp2
-rwxr-xr-xsource/slang/slang-compiler.h1
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-c-like.cpp1
-rw-r--r--source/slang/slang-emit-cpp.cpp25
-rw-r--r--source/slang/slang-emit-cuda.cpp5
-rw-r--r--source/slang/slang-emit-torch.cpp181
-rw-r--r--source/slang/slang-emit-torch.h28
-rw-r--r--source/slang/slang-emit.cpp90
-rw-r--r--source/slang/slang-ir-inst-defs.h49
-rw-r--r--source/slang/slang-ir-insts.h44
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp248
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.h12
-rw-r--r--source/slang/slang-ir.cpp57
-rw-r--r--source/slang/slang-ir.h25
-rw-r--r--source/slang/slang-lower-to-ir.cpp94
-rw-r--r--source/slang/slang-options.cpp3
-rw-r--r--source/slang/slang-parser.cpp4
-rw-r--r--source/slang/slang-syntax.cpp7
-rw-r--r--source/slangc/main.cpp4
-rw-r--r--tests/autodiff/cuda-kernel-export.slang38
-rw-r--r--tools/gfx/slang.slang3
-rw-r--r--tools/slang-test/slang-test-main.cpp1
43 files changed, 1733 insertions, 679 deletions
diff --git a/build/visual-studio/run-generators/run-generators.vcxproj b/build/visual-studio/run-generators/run-generators.vcxproj
index 045601edf..31037fc27 100644
--- a/build/visual-studio/run-generators/run-generators.vcxproj
+++ b/build/visual-studio/run-generators/run-generators.vcxproj
@@ -141,6 +141,7 @@
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="..\..\..\prelude\slang-cpp-scalar-intrinsics.h" />
+ <ClInclude Include="..\..\..\prelude\slang-cpp-types-core.h" />
<ClInclude Include="..\..\..\prelude\slang-cpp-types.h" />
<ClInclude Include="..\..\..\prelude\slang-llvm.h" />
</ItemGroup>
@@ -216,6 +217,23 @@
<AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release|x64'">../../../bin/windows-x64/release/slang-embed.exe</AdditionalInputs>
<AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release aarch64|ARM64'">../../../bin/windows-aarch64/release/slang-embed.exe</AdditionalInputs>
</CustomBuild>
+ <CustomBuild Include="..\..\..\prelude\slang-torch-prelude.h">
+ <FileType>Document</FileType>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">"../../../bin/windows-x86/debug/slang-embed" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">"../../../bin/windows-x64/debug/slang-embed" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Debug aarch64|ARM64'">"../../../bin/windows-aarch64/debug/slang-embed" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">"../../../bin/windows-x86/release/slang-embed" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Release|x64'">"../../../bin/windows-x64/release/slang-embed" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Release aarch64|ARM64'">"../../../bin/windows-aarch64/release/slang-embed" %(Identity)</Command>
+ <Outputs>../../../prelude/slang-torch-prelude.h.cpp</Outputs>
+ <Message>slang-embed %(Identity)</Message>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">../../../bin/windows-x86/debug/slang-embed.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">../../../bin/windows-x64/debug/slang-embed.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Debug aarch64|ARM64'">../../../bin/windows-aarch64/debug/slang-embed.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">../../../bin/windows-x86/release/slang-embed.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release|x64'">../../../bin/windows-x64/release/slang-embed.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release aarch64|ARM64'">../../../bin/windows-aarch64/release/slang-embed.exe</AdditionalInputs>
+ </CustomBuild>
<CustomBuild Include="..\..\..\source\slang\core.meta.slang">
<FileType>Document</FileType>
<Command Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">"../../../bin/windows-x86/debug/slang-generate" %(Identity)</Command>
diff --git a/build/visual-studio/run-generators/run-generators.vcxproj.filters b/build/visual-studio/run-generators/run-generators.vcxproj.filters
index d91507ae6..233899d08 100644
--- a/build/visual-studio/run-generators/run-generators.vcxproj.filters
+++ b/build/visual-studio/run-generators/run-generators.vcxproj.filters
@@ -12,6 +12,9 @@
<ClInclude Include="..\..\..\prelude\slang-cpp-scalar-intrinsics.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\prelude\slang-cpp-types-core.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\prelude\slang-cpp-types.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -37,6 +40,9 @@
<CustomBuild Include="..\..\..\prelude\slang-hlsl-prelude.h">
<Filter>Header Files</Filter>
</CustomBuild>
+ <CustomBuild Include="..\..\..\prelude\slang-torch-prelude.h">
+ <Filter>Header Files</Filter>
+ </CustomBuild>
<CustomBuild Include="..\..\..\source\slang\core.meta.slang">
<Filter>Source Files</Filter>
</CustomBuild>
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index fbc827f02..71545a5d1 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -336,6 +336,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-emit-hlsl.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-precedence.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-source-writer.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-emit-torch.h" />
<ClInclude Include="..\..\..\source\slang\slang-glsl-extension-tracker.h" />
<ClInclude Include="..\..\..\source\slang\slang-image-format-defs.h" />
<ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h" />
@@ -411,6 +412,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-optix-entry-point-uniforms.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-peephole.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-propagate-func-properties.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-pytorch-cpp-binding.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-redundancy-removal.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-restructure-scoping.h" />
@@ -489,6 +491,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\prelude\slang-cpp-prelude.h.cpp" />
<ClCompile Include="..\..\..\prelude\slang-cuda-prelude.h.cpp" />
<ClCompile Include="..\..\..\prelude\slang-hlsl-prelude.h.cpp" />
+ <ClCompile Include="..\..\..\prelude\slang-torch-prelude.h.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-api.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-artifact-output-util.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ast-builder.cpp" />
@@ -527,6 +530,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-emit-precedence.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-source-writer.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-spirv.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-emit-torch.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-glsl-extension-tracker.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp" />
@@ -599,6 +603,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-optix-entry-point-uniforms.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-peephole.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-propagate-func-properties.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-pytorch-cpp-binding.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-redundancy-removal.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-remove-unused-generic-param.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-restructure-scoping.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 87655e974..f353468e4 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -114,6 +114,9 @@
<ClInclude Include="..\..\..\source\slang\slang-emit-source-writer.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-emit-torch.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-glsl-extension-tracker.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -339,6 +342,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-propagate-func-properties.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-pytorch-cpp-binding.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-redundancy-removal.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -569,6 +575,9 @@
<ClCompile Include="..\..\..\prelude\slang-hlsl-prelude.h.cpp">
<Filter>Header Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\prelude\slang-torch-prelude.h.cpp">
+ <Filter>Header Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-api.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -683,6 +692,9 @@
<ClCompile Include="..\..\..\source\slang\slang-emit-spirv.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-emit-torch.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-emit.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -899,6 +911,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-propagate-func-properties.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-pytorch-cpp-binding.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-redundancy-removal.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/prelude/slang-cpp-types-core.h b/prelude/slang-cpp-types-core.h
new file mode 100644
index 000000000..c49ee013c
--- /dev/null
+++ b/prelude/slang-cpp-types-core.h
@@ -0,0 +1,561 @@
+#ifndef SLANG_PRELUDE_CPP_TYPES_CORE_H
+#define SLANG_PRELUDE_CPP_TYPES_CORE_H
+
+#ifndef SLANG_PRELUDE_ASSERT
+# ifdef SLANG_PRELUDE_ENABLE_ASSERT
+# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE)
+# else
+# define SLANG_PRELUDE_ASSERT(VALUE)
+# endif
+#endif
+
+// Since we are using unsigned arithmatic care is need in this comparison.
+// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0
+// Which means only a single test is needed
+
+// Asserts for bounds checking.
+// It is assumed index/count are unsigned types.
+#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count);
+#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0);
+
+// Macros to zero index if an access is out of range
+#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0;
+#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0;
+
+// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX
+// the fix macro will zero the index, if out of range
+#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX
+# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
+# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
+# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
+#else
+# define SLANG_BOUND_FIX(index, count)
+# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
+# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
+#endif
+
+#ifndef SLANG_BOUND_CHECK
+# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count)
+#endif
+
+#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS
+# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
+#endif
+
+#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY
+# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
+#endif
+
+struct TypeInfo
+{
+ size_t typeSize;
+};
+
+template <typename T, size_t SIZE>
+struct FixedArray
+{
+ const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
+ T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
+
+ T m_data[SIZE];
+};
+
+// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially
+// do bounds checking.
+template <typename T>
+struct Array
+{
+ const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; }
+ T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; }
+
+ T* data;
+ size_t count;
+};
+
+/* Constant buffers become a pointer to the contained type, so ConstantBuffer<T> becomes T* in C++ code.
+*/
+
+template <typename T, int COUNT>
+struct Vector;
+
+template <typename T>
+struct Vector<T, 1>
+{
+ T x;
+ const T& operator[](size_t /*index*/) const { return x; }
+ T& operator[](size_t /*index*/) { return x; }
+ operator T() const { return x; }
+ Vector() = default;
+ Vector(T scalar)
+ {
+ x = scalar;
+ }
+ template <typename U>
+ Vector(Vector<U, 1> other)
+ {
+ x = (T)other.x;
+ }
+ template <typename U, int otherSize>
+ Vector(Vector<U, otherSize> other)
+ {
+ int minSize = 1;
+ if (otherSize < minSize) minSize = otherSize;
+ for (int i = 0; i < minSize; i++)
+ (*this)[i] = (T)other[i];
+ }
+};
+
+template <typename T>
+struct Vector<T, 2>
+{
+ T x, y;
+ const T& operator[](size_t index) const { return index == 0 ? x : y; }
+ T& operator[](size_t index) { return index == 0 ? x : y; }
+ Vector() = default;
+ Vector(T scalar)
+ {
+ x = y = scalar;
+ }
+ Vector(T _x, T _y)
+ {
+ x = _x;
+ y = _y;
+ }
+ template <typename U>
+ Vector(Vector<U, 2> other)
+ {
+ x = (T)other.x;
+ y = (T)other.y;
+ }
+ template <typename U, int otherSize>
+ Vector(Vector<U, otherSize> other)
+ {
+ int minSize = 2;
+ if (otherSize < minSize) minSize = otherSize;
+ for (int i = 0; i < minSize; i++)
+ (*this)[i] = (T)other[i];
+ }
+};
+
+template <typename T>
+struct Vector<T, 3>
+{
+ T x, y, z;
+ const T& operator[](size_t index) const { return *((T*)(this) + index); }
+ T& operator[](size_t index) { return *((T*)(this) + index); }
+
+ Vector() = default;
+ Vector(T scalar)
+ {
+ x = y = z = scalar;
+ }
+ Vector(T _x, T _y, T _z)
+ {
+ x = _x;
+ y = _y;
+ z = _z;
+ }
+ template <typename U>
+ Vector(Vector<U, 3> other)
+ {
+ x = (T)other.x;
+ y = (T)other.y;
+ z = (T)other.z;
+ }
+ template <typename U, int otherSize>
+ Vector(Vector<U, otherSize> other)
+ {
+ int minSize = 3;
+ if (otherSize < minSize) minSize = otherSize;
+ for (int i = 0; i < minSize; i++)
+ (*this)[i] = (T)other[i];
+ }
+};
+
+template <typename T>
+struct Vector<T, 4>
+{
+ T x, y, z, w;
+
+ const T& operator[](size_t index) const { return *((T*)(this) + index); }
+ T& operator[](size_t index) { return *((T*)(this) + index); }
+ Vector() = default;
+ Vector(T scalar)
+ {
+ x = y = z = w = scalar;
+ }
+ Vector(T _x, T _y, T _z, T _w)
+ {
+ x = _x;
+ y = _y;
+ z = _z;
+ w = _w;
+ }
+ template <typename U, int otherSize>
+ Vector(Vector<U, otherSize> other)
+ {
+ int minSize = 4;
+ if (otherSize < minSize) minSize = otherSize;
+ for (int i = 0; i < minSize; i++)
+ (*this)[i] = (T)other[i];
+ }
+
+};
+
+template<typename T, int N>
+SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index)
+{
+ return x[index];
+}
+
+template<typename T, int N>
+SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector<T, N>* x, int index)
+{
+ return &((*const_cast<Vector<T,N>*>(x))[index]);
+}
+
+template<typename T, int N>
+SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector<T, N>* x, int index)
+{
+ return &((*x)[index]);
+}
+
+template<typename T, int n, typename OtherT, int m>
+SLANG_FORCE_INLINE Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other)
+{
+ Vector<T, n> result;
+ for (int i = 0; i < n; i++)
+ {
+ OtherT otherElement = T(0);
+ if (i < m)
+ otherElement = _slang_vector_get_element(other, i);
+ *_slang_vector_get_element_ptr(&result, i) = (T)otherElement;
+ }
+ return result;
+}
+
+typedef uint32_t uint;
+
+#define SLANG_VECTOR_BINARY_OP(T, op) \
+ template<int n> \
+ SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
+ { \
+ Vector<T, n> result;\
+ for (int i = 0; i < n; i++) \
+ result[i] = thisVal[i] op other[i]; \
+ return result;\
+ }
+#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \
+ template<int n> \
+ SLANG_FORCE_INLINE Vector<bool, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
+ { \
+ Vector<bool, n> result;\
+ for (int i = 0; i < n; i++) \
+ result[i] = thisVal[i] op other[i]; \
+ return result;\
+ }
+
+#define SLANG_VECTOR_UNARY_OP(T, op) \
+ template<int n> \
+ SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal) \
+ { \
+ Vector<T, n> result;\
+ for (int i = 0; i < n; i++) \
+ result[i] = op thisVal[i]; \
+ return result;\
+ }
+#define SLANG_INT_VECTOR_OPS(T) \
+ SLANG_VECTOR_BINARY_OP(T, +)\
+ SLANG_VECTOR_BINARY_OP(T, -)\
+ SLANG_VECTOR_BINARY_OP(T, *)\
+ SLANG_VECTOR_BINARY_OP(T, / )\
+ SLANG_VECTOR_BINARY_OP(T, &)\
+ SLANG_VECTOR_BINARY_OP(T, |)\
+ SLANG_VECTOR_BINARY_OP(T, &&)\
+ SLANG_VECTOR_BINARY_OP(T, ||)\
+ SLANG_VECTOR_BINARY_OP(T, ^)\
+ SLANG_VECTOR_BINARY_OP(T, %)\
+ SLANG_VECTOR_BINARY_OP(T, >>)\
+ SLANG_VECTOR_BINARY_OP(T, <<)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\
+ SLANG_VECTOR_UNARY_OP(T, !)\
+ SLANG_VECTOR_UNARY_OP(T, ~)
+#define SLANG_FLOAT_VECTOR_OPS(T) \
+ SLANG_VECTOR_BINARY_OP(T, +)\
+ SLANG_VECTOR_BINARY_OP(T, -)\
+ SLANG_VECTOR_BINARY_OP(T, *)\
+ SLANG_VECTOR_BINARY_OP(T, /)\
+ SLANG_VECTOR_UNARY_OP(T, -)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)
+
+SLANG_INT_VECTOR_OPS(bool)
+SLANG_INT_VECTOR_OPS(int)
+SLANG_INT_VECTOR_OPS(int8_t)
+SLANG_INT_VECTOR_OPS(int16_t)
+SLANG_INT_VECTOR_OPS(int64_t)
+SLANG_INT_VECTOR_OPS(uint)
+SLANG_INT_VECTOR_OPS(uint8_t)
+SLANG_INT_VECTOR_OPS(uint16_t)
+SLANG_INT_VECTOR_OPS(uint64_t)
+
+SLANG_FLOAT_VECTOR_OPS(float)
+SLANG_FLOAT_VECTOR_OPS(double)
+
+#define SLANG_VECTOR_INT_NEG_OP(T) \
+ template<int N>\
+ Vector<T, N> operator-(const Vector<T, N>& thisVal) \
+ { \
+ Vector<T, N> result;\
+ for (int i = 0; i < N; i++) \
+ result[i] = 0 - thisVal[i]; \
+ return result;\
+ }
+SLANG_VECTOR_INT_NEG_OP(int)
+SLANG_VECTOR_INT_NEG_OP(int8_t)
+SLANG_VECTOR_INT_NEG_OP(int16_t)
+SLANG_VECTOR_INT_NEG_OP(int64_t)
+SLANG_VECTOR_INT_NEG_OP(uint)
+SLANG_VECTOR_INT_NEG_OP(uint8_t)
+SLANG_VECTOR_INT_NEG_OP(uint16_t)
+SLANG_VECTOR_INT_NEG_OP(uint64_t)
+
+#define SLANG_FLOAT_VECTOR_MOD(T)\
+ template<int N> \
+ Vector<T, N> operator%(const Vector<T, N>& left, const Vector<T, N>& right) \
+ {\
+ Vector<T, N> result;\
+ for (int i = 0; i < N; i++) \
+ result[i] = _slang_fmod(left[i], right[i]); \
+ return result;\
+ }
+
+SLANG_FLOAT_VECTOR_MOD(float)
+SLANG_FLOAT_VECTOR_MOD(double)
+#undef SLANG_FLOAT_VECTOR_MOD
+#undef SLANG_VECTOR_BINARY_OP
+#undef SLANG_VECTOR_UNARY_OP
+#undef SLANG_INT_VECTOR_OPS
+#undef SLANG_FLOAT_VECTOR_OPS
+#undef SLANG_VECTOR_INT_NEG_OP
+#undef SLANG_FLOAT_VECTOR_MOD
+
+template <typename T, int ROWS, int COLS>
+struct Matrix
+{
+ Vector<T, COLS> rows[ROWS];
+ Vector<T, COLS>& operator[](size_t index) { return rows[index]; }
+ Matrix() = default;
+ Matrix(T scalar)
+ {
+ for (int i = 0; i < ROWS; i++)
+ rows[i] = Vector<T, COLS>(scalar);
+ }
+ Matrix(const Vector<T, COLS>& row0)
+ {
+ rows[0] = row0;
+ }
+ Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1)
+ {
+ rows[0] = row0;
+ rows[1] = row1;
+ }
+ Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2)
+ {
+ rows[0] = row0;
+ rows[1] = row1;
+ rows[2] = row2;
+ }
+ Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3)
+ {
+ rows[0] = row0;
+ rows[1] = row1;
+ rows[2] = row2;
+ rows[3] = row3;
+ }
+ template<typename U, int otherRow, int otherCol>
+ Matrix(const Matrix<U, otherRow, otherCol>& other)
+ {
+ int minRow = ROWS;
+ int minCol = COLS;
+ if (minRow > otherRow) minRow = otherRow;
+ if (minCol > otherCol) minCol = otherCol;
+ for (int i = 0; i < minRow; i++)
+ for (int j = 0; j < minCol; j++)
+ rows[i][j] = (T)other.rows[i][j];
+ }
+ Matrix(T v0, T v1, T v2, T v3)
+ {
+ rows[0][0] = v0; rows[0][1] = v1;
+ rows[1][0] = v2; rows[1][1] = v3;
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5)
+ {
+ if (COLS == 3)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
+ rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
+ }
+ else
+ {
+ rows[0][0] = v0; rows[0][1] = v1;
+ rows[1][0] = v2; rows[1][1] = v3;
+ rows[2][0] = v4; rows[2][1] = v5;
+ }
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7)
+ {
+ if (COLS == 4)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
+ rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
+ }
+ else
+ {
+ rows[0][0] = v0; rows[0][1] = v1;
+ rows[1][0] = v2; rows[1][1] = v3;
+ rows[2][0] = v4; rows[2][1] = v5;
+ rows[3][0] = v6; rows[3][1] = v7;
+ }
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
+ rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
+ rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11)
+ {
+ if (COLS == 4)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
+ rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
+ rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
+ }
+ else
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
+ rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
+ rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
+ rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11;
+ }
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
+ rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
+ rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
+ rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15;
+ }
+};
+
+#define SLANG_MATRIX_BINARY_OP(T, op) \
+ template<int R, int C> \
+ Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \
+ { \
+ Matrix<T, R, C> result;\
+ for (int i = 0; i < R; i++) \
+ for (int j = 0; j < C; j++) \
+ result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \
+ return result;\
+ }
+
+#define SLANG_MATRIX_UNARY_OP(T, op) \
+ template<int R, int C> \
+ Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \
+ { \
+ Matrix<T, R, C> result;\
+ for (int i = 0; i < R; i++) \
+ for (int j = 0; j < C; j++) \
+ result[i].rows[i][j] = op thisVal.rows[i][j]; \
+ return result;\
+ }
+#define SLANG_INT_MATRIX_OPS(T) \
+ SLANG_MATRIX_BINARY_OP(T, +)\
+ SLANG_MATRIX_BINARY_OP(T, -)\
+ SLANG_MATRIX_BINARY_OP(T, *)\
+ SLANG_MATRIX_BINARY_OP(T, / )\
+ SLANG_MATRIX_BINARY_OP(T, &)\
+ SLANG_MATRIX_BINARY_OP(T, |)\
+ SLANG_MATRIX_BINARY_OP(T, &&)\
+ SLANG_MATRIX_BINARY_OP(T, ||)\
+ SLANG_MATRIX_BINARY_OP(T, ^)\
+ SLANG_MATRIX_BINARY_OP(T, %)\
+ SLANG_MATRIX_UNARY_OP(T, !)\
+ SLANG_MATRIX_UNARY_OP(T, ~)
+#define SLANG_FLOAT_MATRIX_OPS(T) \
+ SLANG_MATRIX_BINARY_OP(T, +)\
+ SLANG_MATRIX_BINARY_OP(T, -)\
+ SLANG_MATRIX_BINARY_OP(T, *)\
+ SLANG_MATRIX_BINARY_OP(T, /)\
+ SLANG_MATRIX_UNARY_OP(T, -)
+SLANG_INT_MATRIX_OPS(int)
+SLANG_INT_MATRIX_OPS(int8_t)
+SLANG_INT_MATRIX_OPS(int16_t)
+SLANG_INT_MATRIX_OPS(int64_t)
+SLANG_INT_MATRIX_OPS(uint)
+SLANG_INT_MATRIX_OPS(uint8_t)
+SLANG_INT_MATRIX_OPS(uint16_t)
+SLANG_INT_MATRIX_OPS(uint64_t)
+
+SLANG_FLOAT_MATRIX_OPS(float)
+SLANG_FLOAT_MATRIX_OPS(double)
+
+#define SLANG_MATRIX_INT_NEG_OP(T) \
+ template<int R, int C>\
+ SLANG_FORCE_INLINE Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \
+ { \
+ Matrix<T, R, C> result;\
+ for (int i = 0; i < R; i++) \
+ for (int j = 0; j < C; j++) \
+ result.rows[i][j] = 0 - thisVal.rows[i][j]; \
+ return result;\
+ }
+ SLANG_MATRIX_INT_NEG_OP(int)
+ SLANG_MATRIX_INT_NEG_OP(int8_t)
+ SLANG_MATRIX_INT_NEG_OP(int16_t)
+ SLANG_MATRIX_INT_NEG_OP(int64_t)
+ SLANG_MATRIX_INT_NEG_OP(uint)
+ SLANG_MATRIX_INT_NEG_OP(uint8_t)
+ SLANG_MATRIX_INT_NEG_OP(uint16_t)
+ SLANG_MATRIX_INT_NEG_OP(uint64_t)
+
+#define SLANG_FLOAT_MATRIX_MOD(T)\
+ template<int R, int C> \
+ SLANG_FORCE_INLINE Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \
+ {\
+ Matrix<T, R, C> result;\
+ for (int i = 0; i < R; i++) \
+ for (int j = 0; j < C; j++) \
+ result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \
+ return result;\
+ }
+
+ SLANG_FLOAT_MATRIX_MOD(float)
+ SLANG_FLOAT_MATRIX_MOD(double)
+#undef SLANG_FLOAT_MATRIX_MOD
+#undef SLANG_MATRIX_BINARY_OP
+#undef SLANG_MATRIX_UNARY_OP
+#undef SLANG_INT_MATRIX_OPS
+#undef SLANG_FLOAT_MATRIX_OPS
+#undef SLANG_MATRIX_INT_NEG_OP
+#undef SLANG_FLOAT_MATRIX_MOD
+
+template<typename TResult, typename TInput>
+TResult slang_bit_cast(TInput val)
+{
+ return *(TResult*)(&val);
+}
+
+#endif
+
+
diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h
index 28fe3dd8d..ac66ad9f3 100644
--- a/prelude/slang-cpp-types.h
+++ b/prelude/slang-cpp-types.h
@@ -1,244 +1,12 @@
#ifndef SLANG_PRELUDE_CPP_TYPES_H
#define SLANG_PRELUDE_CPP_TYPES_H
-#ifndef SLANG_PRELUDE_ASSERT
-# ifdef SLANG_PRELUDE_ENABLE_ASSERT
-# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE)
-# else
-# define SLANG_PRELUDE_ASSERT(VALUE)
-# endif
-#endif
-
-// Since we are using unsigned arithmatic care is need in this comparison.
-// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0
-// Which means only a single test is needed
-
-// Asserts for bounds checking.
-// It is assumed index/count are unsigned types.
-#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count);
-#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0);
-
-// Macros to zero index if an access is out of range
-#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0;
-#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0;
-
-// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX
-// the fix macro will zero the index, if out of range
-#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX
-# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
-# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
-# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
-#else
-# define SLANG_BOUND_FIX(index, count)
-# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
-# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
-#endif
-
-#ifndef SLANG_BOUND_CHECK
-# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count)
-#endif
-
-#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS
-# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
-#endif
-
-#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY
-# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
-#endif
-
#ifdef SLANG_PRELUDE_NAMESPACE
namespace SLANG_PRELUDE_NAMESPACE {
#endif
-struct TypeInfo
-{
- size_t typeSize;
-};
-
-template <typename T, size_t SIZE>
-struct FixedArray
-{
- const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
- T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
-
- T m_data[SIZE];
-};
-
-// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially
-// do bounds checking.
-template <typename T>
-struct Array
-{
- const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; }
- T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; }
-
- T* data;
- size_t count;
-};
-
-/* Constant buffers become a pointer to the contained type, so ConstantBuffer<T> becomes T* in C++ code.
-*/
-
-template <typename T, int COUNT>
-struct Vector;
-
-template <typename T>
-struct Vector<T, 1>
-{
- T x;
- const T& operator[](size_t /*index*/) const { return x; }
- T& operator[](size_t /*index*/) { return x; }
- operator T() const { return x; }
- Vector() = default;
- Vector(T scalar)
- {
- x = scalar;
- }
- template <typename U>
- Vector(Vector<U, 1> other)
- {
- x = (T)other.x;
- }
- template <typename U, int otherSize>
- Vector(Vector<U, otherSize> other)
- {
- int minSize = 1;
- if (otherSize < minSize) minSize = otherSize;
- for (int i = 0; i < minSize; i++)
- (*this)[i] = (T)other[i];
- }
-};
-
-template <typename T>
-struct Vector<T, 2>
-{
- T x, y;
- const T& operator[](size_t index) const { return index == 0 ? x : y; }
- T& operator[](size_t index) { return index == 0 ? x : y; }
- Vector() = default;
- Vector(T scalar)
- {
- x = y = scalar;
- }
- Vector(T _x, T _y)
- {
- x = _x;
- y = _y;
- }
- template <typename U>
- Vector(Vector<U, 2> other)
- {
- x = (T)other.x;
- y = (T)other.y;
- }
- template <typename U, int otherSize>
- Vector(Vector<U, otherSize> other)
- {
- int minSize = 2;
- if (otherSize < minSize) minSize = otherSize;
- for (int i = 0; i < minSize; i++)
- (*this)[i] = (T)other[i];
- }
-};
-
-template <typename T>
-struct Vector<T, 3>
-{
- T x, y, z;
- const T& operator[](size_t index) const { return *((T*)(this) + index); }
- T& operator[](size_t index) { return *((T*)(this) + index); }
-
- Vector() = default;
- Vector(T scalar)
- {
- x = y = z = scalar;
- }
- Vector(T _x, T _y, T _z)
- {
- x = _x;
- y = _y;
- z = _z;
- }
- template <typename U>
- Vector(Vector<U, 3> other)
- {
- x = (T)other.x;
- y = (T)other.y;
- z = (T)other.z;
- }
- template <typename U, int otherSize>
- Vector(Vector<U, otherSize> other)
- {
- int minSize = 3;
- if (otherSize < minSize) minSize = otherSize;
- for (int i = 0; i < minSize; i++)
- (*this)[i] = (T)other[i];
- }
-};
-template <typename T>
-struct Vector<T, 4>
-{
- T x, y, z, w;
-
- const T& operator[](size_t index) const { return *((T*)(this) + index); }
- T& operator[](size_t index) { return *((T*)(this) + index); }
- Vector() = default;
- Vector(T scalar)
- {
- x = y = z = w = scalar;
- }
- Vector(T _x, T _y, T _z, T _w)
- {
- x = _x;
- y = _y;
- z = _z;
- w = _w;
- }
- template <typename U, int otherSize>
- Vector(Vector<U, otherSize> other)
- {
- int minSize = 4;
- if (otherSize < minSize) minSize = otherSize;
- for (int i = 0; i < minSize; i++)
- (*this)[i] = (T)other[i];
- }
-
-};
-
-template<typename T, int N>
-SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index)
-{
- return x[index];
-}
-
-template<typename T, int N>
-SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector<T, N>* x, int index)
-{
- return &((*const_cast<Vector<T,N>*>(x))[index]);
-}
-
-template<typename T, int N>
-SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector<T, N>* x, int index)
-{
- return &((*x)[index]);
-}
-
-template<typename T, int n, typename OtherT, int m>
-SLANG_FORCE_INLINE Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other)
-{
- Vector<T, n> result;
- for (int i = 0; i < n; i++)
- {
- OtherT otherElement = T(0);
- if (i < m)
- otherElement = _slang_vector_get_element(other, i);
- *_slang_vector_get_element_ptr(&result, i) = (T)otherElement;
- }
- return result;
-}
-
-typedef uint32_t uint;
+#include "slang-cpp-types-core.h"
typedef Vector<float, 2> float2;
typedef Vector<float, 3> float3;
@@ -252,320 +20,6 @@ typedef Vector<uint32_t, 2> uint2;
typedef Vector<uint32_t, 3> uint3;
typedef Vector<uint32_t, 4> uint4;
-#define SLANG_VECTOR_BINARY_OP(T, op) \
- template<int n> \
- SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
- { \
- Vector<T, n> result;\
- for (int i = 0; i < n; i++) \
- result[i] = thisVal[i] op other[i]; \
- return result;\
- }
-#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \
- template<int n> \
- SLANG_FORCE_INLINE Vector<bool, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
- { \
- Vector<bool, n> result;\
- for (int i = 0; i < n; i++) \
- result[i] = thisVal[i] op other[i]; \
- return result;\
- }
-
-#define SLANG_VECTOR_UNARY_OP(T, op) \
- template<int n> \
- SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal) \
- { \
- Vector<T, n> result;\
- for (int i = 0; i < n; i++) \
- result[i] = op thisVal[i]; \
- return result;\
- }
-#define SLANG_INT_VECTOR_OPS(T) \
- SLANG_VECTOR_BINARY_OP(T, +)\
- SLANG_VECTOR_BINARY_OP(T, -)\
- SLANG_VECTOR_BINARY_OP(T, *)\
- SLANG_VECTOR_BINARY_OP(T, / )\
- SLANG_VECTOR_BINARY_OP(T, &)\
- SLANG_VECTOR_BINARY_OP(T, |)\
- SLANG_VECTOR_BINARY_OP(T, &&)\
- SLANG_VECTOR_BINARY_OP(T, ||)\
- SLANG_VECTOR_BINARY_OP(T, ^)\
- SLANG_VECTOR_BINARY_OP(T, %)\
- SLANG_VECTOR_BINARY_OP(T, >>)\
- SLANG_VECTOR_BINARY_OP(T, <<)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\
- SLANG_VECTOR_UNARY_OP(T, !)\
- SLANG_VECTOR_UNARY_OP(T, ~)
-#define SLANG_FLOAT_VECTOR_OPS(T) \
- SLANG_VECTOR_BINARY_OP(T, +)\
- SLANG_VECTOR_BINARY_OP(T, -)\
- SLANG_VECTOR_BINARY_OP(T, *)\
- SLANG_VECTOR_BINARY_OP(T, /)\
- SLANG_VECTOR_UNARY_OP(T, -)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)
-
-SLANG_INT_VECTOR_OPS(bool)
-SLANG_INT_VECTOR_OPS(int)
-SLANG_INT_VECTOR_OPS(int8_t)
-SLANG_INT_VECTOR_OPS(int16_t)
-SLANG_INT_VECTOR_OPS(int64_t)
-SLANG_INT_VECTOR_OPS(uint)
-SLANG_INT_VECTOR_OPS(uint8_t)
-SLANG_INT_VECTOR_OPS(uint16_t)
-SLANG_INT_VECTOR_OPS(uint64_t)
-
-SLANG_FLOAT_VECTOR_OPS(float)
-SLANG_FLOAT_VECTOR_OPS(double)
-
-#define SLANG_VECTOR_INT_NEG_OP(T) \
- template<int N>\
- Vector<T, N> operator-(const Vector<T, N>& thisVal) \
- { \
- Vector<T, N> result;\
- for (int i = 0; i < N; i++) \
- result[i] = 0 - thisVal[i]; \
- return result;\
- }
-SLANG_VECTOR_INT_NEG_OP(int)
-SLANG_VECTOR_INT_NEG_OP(int8_t)
-SLANG_VECTOR_INT_NEG_OP(int16_t)
-SLANG_VECTOR_INT_NEG_OP(int64_t)
-SLANG_VECTOR_INT_NEG_OP(uint)
-SLANG_VECTOR_INT_NEG_OP(uint8_t)
-SLANG_VECTOR_INT_NEG_OP(uint16_t)
-SLANG_VECTOR_INT_NEG_OP(uint64_t)
-
-#define SLANG_FLOAT_VECTOR_MOD(T)\
- template<int N> \
- Vector<T, N> operator%(const Vector<T, N>& left, const Vector<T, N>& right) \
- {\
- Vector<T, N> result;\
- for (int i = 0; i < N; i++) \
- result[i] = _slang_fmod(left[i], right[i]); \
- return result;\
- }
-
-SLANG_FLOAT_VECTOR_MOD(float)
-SLANG_FLOAT_VECTOR_MOD(double)
-#undef SLANG_FLOAT_VECTOR_MOD
-#undef SLANG_VECTOR_BINARY_OP
-#undef SLANG_VECTOR_UNARY_OP
-#undef SLANG_INT_VECTOR_OPS
-#undef SLANG_FLOAT_VECTOR_OPS
-#undef SLANG_VECTOR_INT_NEG_OP
-#undef SLANG_FLOAT_VECTOR_MOD
-
-template <typename T, int ROWS, int COLS>
-struct Matrix
-{
- Vector<T, COLS> rows[ROWS];
- Vector<T, COLS>& operator[](size_t index) { return rows[index]; }
- Matrix() = default;
- Matrix(T scalar)
- {
- for (int i = 0; i < ROWS; i++)
- rows[i] = Vector<T, COLS>(scalar);
- }
- Matrix(const Vector<T, COLS>& row0)
- {
- rows[0] = row0;
- }
- Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1)
- {
- rows[0] = row0;
- rows[1] = row1;
- }
- Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2)
- {
- rows[0] = row0;
- rows[1] = row1;
- rows[2] = row2;
- }
- Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3)
- {
- rows[0] = row0;
- rows[1] = row1;
- rows[2] = row2;
- rows[3] = row3;
- }
- template<typename U, int otherRow, int otherCol>
- Matrix(const Matrix<U, otherRow, otherCol>& other)
- {
- int minRow = ROWS;
- int minCol = COLS;
- if (minRow > otherRow) minRow = otherRow;
- if (minCol > otherCol) minCol = otherCol;
- for (int i = 0; i < minRow; i++)
- for (int j = 0; j < minCol; j++)
- rows[i][j] = (T)other.rows[i][j];
- }
- Matrix(T v0, T v1, T v2, T v3)
- {
- rows[0][0] = v0; rows[0][1] = v1;
- rows[1][0] = v2; rows[1][1] = v3;
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5)
- {
- if (COLS == 3)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
- rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
- }
- else
- {
- rows[0][0] = v0; rows[0][1] = v1;
- rows[1][0] = v2; rows[1][1] = v3;
- rows[2][0] = v4; rows[2][1] = v5;
- }
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7)
- {
- if (COLS == 4)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
- rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
- }
- else
- {
- rows[0][0] = v0; rows[0][1] = v1;
- rows[1][0] = v2; rows[1][1] = v3;
- rows[2][0] = v4; rows[2][1] = v5;
- rows[3][0] = v6; rows[3][1] = v7;
- }
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
- rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
- rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11)
- {
- if (COLS == 4)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
- rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
- rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
- }
- else
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
- rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
- rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
- rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11;
- }
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
- rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
- rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
- rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15;
- }
-};
-
-#define SLANG_MATRIX_BINARY_OP(T, op) \
- template<int R, int C> \
- Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \
- { \
- Matrix<T, R, C> result;\
- for (int i = 0; i < R; i++) \
- for (int j = 0; j < C; j++) \
- result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \
- return result;\
- }
-
-#define SLANG_MATRIX_UNARY_OP(T, op) \
- template<int R, int C> \
- Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \
- { \
- Matrix<T, R, C> result;\
- for (int i = 0; i < R; i++) \
- for (int j = 0; j < C; j++) \
- result[i].rows[i][j] = op thisVal.rows[i][j]; \
- return result;\
- }
-#define SLANG_INT_MATRIX_OPS(T) \
- SLANG_MATRIX_BINARY_OP(T, +)\
- SLANG_MATRIX_BINARY_OP(T, -)\
- SLANG_MATRIX_BINARY_OP(T, *)\
- SLANG_MATRIX_BINARY_OP(T, / )\
- SLANG_MATRIX_BINARY_OP(T, &)\
- SLANG_MATRIX_BINARY_OP(T, |)\
- SLANG_MATRIX_BINARY_OP(T, &&)\
- SLANG_MATRIX_BINARY_OP(T, ||)\
- SLANG_MATRIX_BINARY_OP(T, ^)\
- SLANG_MATRIX_BINARY_OP(T, %)\
- SLANG_MATRIX_UNARY_OP(T, !)\
- SLANG_MATRIX_UNARY_OP(T, ~)
-#define SLANG_FLOAT_MATRIX_OPS(T) \
- SLANG_MATRIX_BINARY_OP(T, +)\
- SLANG_MATRIX_BINARY_OP(T, -)\
- SLANG_MATRIX_BINARY_OP(T, *)\
- SLANG_MATRIX_BINARY_OP(T, /)\
- SLANG_MATRIX_UNARY_OP(T, -)
-SLANG_INT_MATRIX_OPS(int)
-SLANG_INT_MATRIX_OPS(int8_t)
-SLANG_INT_MATRIX_OPS(int16_t)
-SLANG_INT_MATRIX_OPS(int64_t)
-SLANG_INT_MATRIX_OPS(uint)
-SLANG_INT_MATRIX_OPS(uint8_t)
-SLANG_INT_MATRIX_OPS(uint16_t)
-SLANG_INT_MATRIX_OPS(uint64_t)
-
-SLANG_FLOAT_MATRIX_OPS(float)
-SLANG_FLOAT_MATRIX_OPS(double)
-
-#define SLANG_MATRIX_INT_NEG_OP(T) \
- template<int R, int C>\
- SLANG_FORCE_INLINE Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \
- { \
- Matrix<T, R, C> result;\
- for (int i = 0; i < R; i++) \
- for (int j = 0; j < C; j++) \
- result.rows[i][j] = 0 - thisVal.rows[i][j]; \
- return result;\
- }
- SLANG_MATRIX_INT_NEG_OP(int)
- SLANG_MATRIX_INT_NEG_OP(int8_t)
- SLANG_MATRIX_INT_NEG_OP(int16_t)
- SLANG_MATRIX_INT_NEG_OP(int64_t)
- SLANG_MATRIX_INT_NEG_OP(uint)
- SLANG_MATRIX_INT_NEG_OP(uint8_t)
- SLANG_MATRIX_INT_NEG_OP(uint16_t)
- SLANG_MATRIX_INT_NEG_OP(uint64_t)
-
-#define SLANG_FLOAT_MATRIX_MOD(T)\
- template<int R, int C> \
- SLANG_FORCE_INLINE Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \
- {\
- Matrix<T, R, C> result;\
- for (int i = 0; i < R; i++) \
- for (int j = 0; j < C; j++) \
- result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \
- return result;\
- }
-
- SLANG_FLOAT_MATRIX_MOD(float)
- SLANG_FLOAT_MATRIX_MOD(double)
-#undef SLANG_FLOAT_MATRIX_MOD
-#undef SLANG_MATRIX_BINARY_OP
-#undef SLANG_MATRIX_UNARY_OP
-#undef SLANG_INT_MATRIX_OPS
-#undef SLANG_FLOAT_MATRIX_OPS
-#undef SLANG_MATRIX_INT_NEG_OP
-#undef SLANG_FLOAT_MATRIX_MOD
-
// We can just map `NonUniformResourceIndex` type directly to the index type on CPU, as CPU does not require
// any special handling around such accesses.
typedef size_t NonUniformResourceIndex;
@@ -1484,12 +938,6 @@ struct ComputeVaryingInput
typedef void(*ComputeThreadFunc)(ComputeThreadVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState);
typedef void(*ComputeFunc)(ComputeVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState);
-template<typename TResult, typename TInput>
-TResult slang_bit_cast(TInput val)
-{
- return *(TResult*)(&val);
-}
-
#ifdef SLANG_PRELUDE_NAMESPACE
}
#endif
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 7a4c5a918..9a55aed57 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -6,6 +6,8 @@
#define SLANG_CUDA_RTC 0
#endif
+#include <stdio.h>
+
// Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support.
// For this to work NVRTC needs to have the path to the CUDA SDK.
//
@@ -2080,4 +2082,73 @@ __forceinline__ __device__ void *traceOptiXRay(
r0, r1
);
}
+
#endif
+
+
+// TensorView
+struct TensorView
+{
+ uint8_t* data;
+ uint32_t* strides;
+ uint32_t* sizes;
+ uint32_t dimensionCount;
+
+ template<typename T>
+ __device__ T* data_ptr()
+ {
+ return reinterpret_cast<T*>(data);
+ }
+
+ template<typename T>
+ __device__ T load(uint32_t x)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * x);
+ }
+ template<typename T>
+ __device__ T load(uint32_t x, uint32_t y)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y);
+ }
+ template<typename T>
+ __device__ T load(uint32_t x, uint32_t y, uint32_t z)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z);
+ }
+ template<typename T>
+ __device__ T load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w);
+ }
+ template<typename T>
+ __device__ T load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4);
+ }
+ template<typename T>
+ __device__ void store(uint32_t x, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * x) = val;
+ }
+ template<typename T>
+ __device__ void store(uint32_t x, uint32_t y, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y) = val;
+ }
+ template<typename T>
+ __device__ void store(uint32_t x, uint32_t y, uint32_t z, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z) = val;
+ }
+ template<typename T>
+ __device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val)
+ {
+ *reinterpret_cast<T*>(
+ data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w) = val;
+ }
+ template<typename T>
+ __device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4) = val;
+ }
+};
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
new file mode 100644
index 000000000..f2accc149
--- /dev/null
+++ b/prelude/slang-torch-prelude.h
@@ -0,0 +1,126 @@
+// Prelude for PyTorch cpp binding.
+
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAUtils.h>
+#include <vector>
+
+#ifndef SLANG_NO_THROW
+# define SLANG_NO_THROW
+#endif
+
+#ifndef SLANG_STDCALL
+# define SLANG_STDCALL
+#endif
+#ifndef SLANG_MCALL
+# define SLANG_MCALL SLANG_STDCALL
+#endif
+#ifndef SLANG_FORCE_INLINE
+# define SLANG_FORCE_INLINE inline
+#endif
+
+#ifdef SLANG_LLVM
+#include "slang-llvm.h"
+#else // SLANG_LLVM
+# if SLANG_GCC_FAMILY && __GNUC__ < 6
+# include <cmath>
+# define SLANG_PRELUDE_STD std::
+# else
+# include <math.h>
+# define SLANG_PRELUDE_STD
+# endif
+
+# include <assert.h>
+# include <stdlib.h>
+# include <string.h>
+# include <stdint.h>
+#endif // SLANG_LLVM
+
+#include "slang-cpp-types-core.h"
+#include "slang-cpp-scalar-intrinsics.h"
+
+struct TensorView
+{
+ uint8_t* data;
+ uint32_t* strides;
+ uint32_t* sizes;
+ uint32_t dimensionCount;
+};
+
+struct CudaTaskMemoryAllocator
+{
+ std::vector<void*> allocations;
+
+ uint32_t* allocUIntArray(uint32_t size)
+ {
+ void* ptr = nullptr;
+ cudaMallocManaged(&ptr, size * sizeof(uint32_t));
+ AT_CUDA_CHECK(cudaGetLastError());
+ return (uint32_t*)ptr;
+ }
+
+ ~CudaTaskMemoryAllocator()
+ {
+ for (auto ptr : allocations)
+ cudaFree(ptr);
+ }
+};
+
+TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val)
+{
+ val = val.to(torch::kCUDA);
+ TensorView res = {};
+ res.dimensionCount = val.dim();
+ res.strides = allocator->allocUIntArray(val.dim());
+ res.sizes = allocator->allocUIntArray(val.dim());
+ res.data = nullptr;
+ size_t elementSize = 4;
+ switch (val.scalar_type())
+ {
+ case torch::kInt8:
+ case torch::kUInt8:
+ elementSize = 1;
+ res.data = (uint8_t*)val.data_ptr<uint8_t>();
+ break;
+ case torch::kBFloat16:
+ elementSize = 2;
+ res.data = (uint8_t*)val.data_ptr<torch::BFloat16>();
+ break;
+ case torch::kInt16:
+ elementSize = 2;
+ res.data = (uint8_t*)val.data_ptr<int16_t>();
+ break;
+ case torch::kFloat32:
+ elementSize = 4;
+ res.data = (uint8_t*)val.data_ptr<float>();
+ break;
+ case torch::kInt32:
+ elementSize = 4;
+ res.data = (uint8_t*)val.data_ptr<int32_t>();
+ break;
+ case torch::kFloat64:
+ elementSize = 8;
+ res.data = (uint8_t*)val.data_ptr<double>();
+ break;
+ case torch::kInt64:
+ elementSize = 8;
+ res.data = (uint8_t*)val.data_ptr<int64_t>();
+ break;
+ }
+ for (int i = 0; i < val.dim(); ++i)
+ {
+ res.strides[i] = val.stride(i) * elementSize;
+ res.sizes[i] = val.size(i);
+ }
+ return res;
+}
+
+size_t slangGetCudaKernelSharedMemSize(const void* func)
+{
+ cudaFuncAttributes attr = {};
+ cudaFuncGetAttributes(&attr, func);
+ AT_CUDA_CHECK(cudaGetLastError());
+ return attr.sharedSizeBytes;
+}
+
+#define SLANG_PRELUDE_EXPORT
diff --git a/premake5.lua b/premake5.lua
index add7544f2..093b4659d 100644
--- a/premake5.lua
+++ b/premake5.lua
@@ -1350,7 +1350,8 @@ if enableEmbedStdLib then
"prelude/slang-cuda-prelude.h.cpp",
"prelude/slang-hlsl-prelude.h.cpp",
"prelude/slang-cpp-prelude.h.cpp",
- "prelude/slang-cpp-host-prelude.h.cpp"
+ "prelude/slang-cpp-host-prelude.h.cpp",
+ "prelude/slang-torch-prelude.h.cpp"
}
end
@@ -1460,7 +1461,8 @@ standardProject("slang", "source/slang")
"prelude/slang-cuda-prelude.h.cpp",
"prelude/slang-hlsl-prelude.h.cpp",
"prelude/slang-cpp-prelude.h.cpp",
- "prelude/slang-cpp-host-prelude.h.cpp"
+ "prelude/slang-cpp-host-prelude.h.cpp",
+ "prelude/slang-torch-prelude.h.cpp"
}
-- Similarly for any generated lookup tables
@@ -1553,7 +1555,8 @@ if enableProfile then
"prelude/slang-cuda-prelude.h.cpp",
"prelude/slang-hlsl-prelude.h.cpp",
"prelude/slang-cpp-prelude.h.cpp",
- "prelude/slang-cpp-host-prelude.h.cpp"
+ "prelude/slang-cpp-host-prelude.h.cpp",
+ "prelude/slang-torch-prelude.h.cpp"
}
-- Add the slang source
diff --git a/slang.h b/slang.h
index 1a72c4348..ea165fbf7 100644
--- a/slang.h
+++ b/slang.h
@@ -566,6 +566,7 @@ extern "C"
SLANG_DXIL_ASM,
SLANG_C_SOURCE, ///< The C language
SLANG_CPP_SOURCE, ///< C++ code for shader kernels.
+ SLANG_CPP_PYTORCH_BINDING, ///< C++ PyTorch binding code.
SLANG_HOST_EXECUTABLE, ///< Standalone binary executable (for hosting CPU/OS)
SLANG_SHADER_SHARED_LIBRARY, ///< A shared library/Dll for shader kernels (for hosting CPU/OS)
SLANG_SHADER_HOST_CALLABLE, ///< A CPU target that makes the compiled shader code available to be run immediately
diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp
index ca7dcb70f..3be0448c4 100644
--- a/source/compiler-core/slang-artifact-desc-util.cpp
+++ b/source/compiler-core/slang-artifact-desc-util.cpp
@@ -273,6 +273,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL
case SLANG_C_SOURCE: return Desc::make(Kind::Source, Payload::C, Style::Kernel, 0);
case SLANG_CPP_SOURCE: return Desc::make(Kind::Source, Payload::Cpp, Style::Kernel, 0);
case SLANG_HOST_CPP_SOURCE: return Desc::make(Kind::Source, Payload::Cpp, Style::Host, 0);
+ case SLANG_CPP_PYTORCH_BINDING: return Desc::make(Kind::Source, Payload::Cpp, Style::Host, 0);
case SLANG_HOST_EXECUTABLE: return Desc::make(Kind::Executable, Payload::HostCPU, Style::Host, 0);
case SLANG_SHADER_SHARED_LIBRARY: return Desc::make(Kind::SharedLibrary, Payload::HostCPU, Style::Kernel, 0);
case SLANG_SHADER_HOST_CALLABLE: return Desc::make(Kind::HostCallable, Payload::HostCPU, Style::Kernel, 0);
diff --git a/source/compiler-core/slang-artifact.h b/source/compiler-core/slang-artifact.h
index cc4d3e9fd..65d4d1bf9 100644
--- a/source/compiler-core/slang-artifact.h
+++ b/source/compiler-core/slang-artifact.h
@@ -208,7 +208,6 @@ enum class ArtifactStyle : uint8_t
Kernel, ///< Compiled as `GPU kernel` style.
Host, ///< Compiled in `host` style
-
Obfuscated, ///< Holds something specific to obfuscation, such as an obfuscated source map
CountOf,
diff --git a/source/core/slang-type-convert-util.cpp b/source/core/slang-type-convert-util.cpp
index 6e6598357..c763b2835 100644
--- a/source/core/slang-type-convert-util.cpp
+++ b/source/core/slang-type-convert-util.cpp
@@ -17,6 +17,7 @@ namespace Slang
case SLANG_HLSL: return SLANG_SOURCE_LANGUAGE_HLSL;
case SLANG_C_SOURCE: return SLANG_SOURCE_LANGUAGE_C;
case SLANG_CPP_SOURCE: return SLANG_SOURCE_LANGUAGE_CPP;
+ case SLANG_CPP_PYTORCH_BINDING:return SLANG_SOURCE_LANGUAGE_CPP;
case SLANG_HOST_CPP_SOURCE: return SLANG_SOURCE_LANGUAGE_CPP;
case SLANG_CUDA_SOURCE: return SLANG_SOURCE_LANGUAGE_CUDA;
default: break;
diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp
index d37051e47..9d2f93ba3 100644
--- a/source/core/slang-type-text-util.cpp
+++ b/source/core/slang-type-text-util.cpp
@@ -78,6 +78,7 @@ static const CompileTargetInfo s_compileTargetInfos[] =
{ SLANG_SPIRV_ASM, "spv-asm", "spirv-asm,spirv-assembly" },
{ SLANG_C_SOURCE, "c", "c" },
{ SLANG_CPP_SOURCE, "cpp,c++,cxx", "cpp,c++,cxx" },
+ { SLANG_CPP_PYTORCH_BINDING, "cpp,c++,cxx", "torch,torch-binding,torch-cpp,torch-cpp-binding" },
{ SLANG_HOST_CPP_SOURCE, "cpp,c++,cxx", "host-cpp,host-c++,host-cxx"},
{ SLANG_HOST_EXECUTABLE,"exe", "exe,executable" },
{ SLANG_SHADER_SHARED_LIBRARY, "dll,so", "sharedlib,sharedlibrary,dll" },
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 82a60a612..c45ad5bd6 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -3081,6 +3081,9 @@ __attributeTarget(FuncDecl)
attribute_syntax [DllExport] : DllExportAttribute;
__attributeTarget(FuncDecl)
+attribute_syntax [TorchEntryPoint] : TorchEntryPointAttribute;
+
+__attributeTarget(FuncDecl)
attribute_syntax [CudaDeviceExport] : CudaDeviceExportAttribute;
__attributeTarget(FuncDecl)
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index bbe94dbc2..d5b70bbb3 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -9,7 +9,6 @@ attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute;
-
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
@@ -26,6 +25,89 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
+__generic<T>
+__magic_type(TensorViewType)
+__intrinsic_type($(kIROp_TensorViewType))
+struct TensorView
+{
+ __target_intrinsic(cuda, "$0.data_ptr<$G0>()")
+ Ptr<T> data_ptr();
+
+ __implicit_conversion($(kConversionCost_ImplicitDereference))
+ __intrinsic_op($(kIROp_TorchTensorGetView))
+ __init(TorchTensor<T> t);
+
+ __target_intrinsic(cuda, "$0.load<$G0>($1)")
+ T load(uint x);
+ __target_intrinsic(cuda, "$0.load<$G0>($1, $2)")
+ T load(uint x, uint y);
+ __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3)")
+ T load(uint x, uint y, uint z);
+ __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4)")
+ T load(uint x, uint y, uint z, uint w);
+ __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4, $5)")
+ T load(uint i0, uint i1, uint i2, uint i3, uint i4);
+
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2)")
+ void store(uint x, T val);
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3)")
+ void store(uint x, uint y, T val);
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4)")
+ void store(uint x, uint y, uint z, T val);
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5)")
+ void store(uint x, uint y, uint z, uint w, T val);
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5, $6)")
+ void store(uint i0, uint i1, uint i2, uint i3, uint i4, T val);
+
+ __target_intrinsic(cuda, "$0.dimensionCount")
+ uint dims();
+
+ __target_intrinsic(cuda, "$0.sizes[$1]")
+ uint size(uint i);
+
+ __target_intrinsic(cuda, "$0.strides[$1]")
+ uint stride(uint i);
+}
+
+__generic<T>
+__intrinsic_type($(kIROp_TorchTensorType))
+struct TorchTensor
+{
+ __intrinsic_op($(kIROp_TorchTensorGetView))
+ TensorView<T> getView();
+
+ __target_intrinsic(cuda, "$0.dims()")
+ __target_intrinsic(cpp, "$0.dims()")
+ uint dims();
+
+ __target_intrinsic(cuda, "$0.size($1)")
+ __target_intrinsic(cpp, "$0.size($1)")
+ uint size(uint i);
+
+ __target_intrinsic(cuda, "$0.stride($1)")
+ __target_intrinsic(cpp, "$0.stride($1)")
+ uint stride(uint i);
+
+ __target_intrinsic(cuda, "$0.data_ptr<$G0>()")
+ __target_intrinsic(cpp, "$0.data_ptr<$G0>()")
+ Ptr<T> data_ptr();
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint x);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint x, uint y);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint x, uint y, uint z);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint x, uint y, uint z, uint w);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4);
+}
+
__generic<T: IDifferentiable>
__intrinsic_op($(kIROp_MakeDifferentialPairUserCode))
DifferentialPair<T> diffPair(T primal, T.Differential diff);
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index d417e3b7c..4c4aaa4f0 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -6834,3 +6834,13 @@ void debugBreak();
__specialized_for_target(glsl)
[[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]]
void debugBreak();
+
+
+__target_intrinsic(cuda, "(threadIdx)")
+uint3 cudaThreadIdx();
+
+__target_intrinsic(cuda, "(blockIdx)")
+uint3 cudaBlockIdx();
+
+__target_intrinsic(cuda, "(blockDim)")
+uint3 cudaBlockDim();
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 0d2e27e5f..00a6570ef 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1055,6 +1055,11 @@ class DllExportAttribute : public Attribute
SLANG_AST_CLASS(DllExportAttribute)
};
+class TorchEntryPointAttribute : public Attribute
+{
+ SLANG_AST_CLASS(TorchEntryPointAttribute)
+};
+
class CudaDeviceExportAttribute : public Attribute
{
SLANG_AST_CLASS(CudaDeviceExportAttribute)
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index fdbd56377..1fed2d52a 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -275,6 +275,13 @@ BasicExpressionType* BasicExpressionType::_getScalarTypeOverride()
return this;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Type* TensorViewType::getElementType()
+{
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]);
+}
+
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void VectorExpressionType::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 47608405a..cb3fde9f9 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -524,6 +524,15 @@ private:
MatrixExpressionType(Type*, IntVal*, IntVal*) {}
};
+class TensorViewType : public BuiltinType
+{
+ SLANG_AST_CLASS(TensorViewType)
+
+ Type* getElementType();
+private:
+ TensorViewType(Type*) {}
+};
+
// Base class for built in string types
class StringTypeBase : public BuiltinType
{
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 8cbe12ef0..ff38bbbde 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -470,6 +470,7 @@ namespace Slang
case CodeGenTarget::CUDASource:
case CodeGenTarget::CPPSource:
case CodeGenTarget::HostCPPSource:
+ case CodeGenTarget::PyTorchCppBinding:
case CodeGenTarget::CSource:
{
return PassThroughMode::None;
@@ -1570,6 +1571,7 @@ namespace Slang
case CodeGenTarget::CUDASource:
case CodeGenTarget::CPPSource:
case CodeGenTarget::HostCPPSource:
+ case CodeGenTarget::PyTorchCppBinding:
case CodeGenTarget::CSource:
{
RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target);
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index b287b21eb..b49ce4fc3 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -75,6 +75,7 @@ namespace Slang
DXILAssembly = SLANG_DXIL_ASM,
CSource = SLANG_C_SOURCE,
CPPSource = SLANG_CPP_SOURCE,
+ PyTorchCppBinding = SLANG_CPP_PYTORCH_BINDING,
HostCPPSource = SLANG_HOST_CPP_SOURCE,
HostExecutable = SLANG_HOST_EXECUTABLE,
ShaderSharedLibrary = SLANG_SHADER_SHARED_LIBRARY,
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index c3e0adbca..39ceb6678 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -664,6 +664,8 @@ DIAGNOSTIC(54002, Error, meshOutputMustBeArray, "HLSL style mesh shader outputs
DIAGNOSTIC(54003, Error, meshOutputArrayMustHaveSize, "HLSL style mesh shader output arrays must have a length specified")
DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL style mesh shader output modifier")
+DIAGNOSTIC(55101, Error, invalidTorchKernelReturnType, "'$0' is not a valid return type for a pytorch kernel function.")
+DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid parameter type for a pytorch kernel function.")
//
// 8xxxx - Issues specific to a particular library/technology/platform/etc.
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 356f1c7ce..ebc312560 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -129,6 +129,7 @@ Index LocationTracker::getValue(Kind kind, IRInst* inst, IRDecoration* decoratio
}
case CodeGenTarget::CPPSource:
case CodeGenTarget::HostCPPSource:
+ case CodeGenTarget::PyTorchCppBinding:
{
return SourceLanguage::CPP;
}
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 2a2ae06c6..346926712 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -304,6 +304,18 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
out << ">";
return SLANG_OK;
}
+ case kIROp_TargetTupleType:
+ {
+ out << "std::tuple<";
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ if (i > 0) out << ", ";
+ auto elementType = (IRType*)type->getOperand(i);
+ SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out));
+ }
+ out << ">";
+ return SLANG_OK;
+ }
default:
{
if (isNominalOp(type->getOp()))
@@ -1187,6 +1199,19 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
return true;
}
+ case kIROp_MakeTargetTuple:
+ {
+ m_writer->emit("std::make_tuple(");
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ auto arg = inst->getOperand(i);
+ emitOperand(arg, getInfo(EmitOp::General));
+ }
+ m_writer->emit(")");
+ return true;
+ }
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
case kIROp_FloatCast:
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index 846b3b1f2..d2fa892ba 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -146,6 +146,11 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target,
out << prefix << vecCount;
return SLANG_OK;
}
+ case kIROp_TensorViewType:
+ {
+ out << "TensorView";
+ return SLANG_OK;
+ }
default:
{
if (isNominalOp(type->getOp()))
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
new file mode 100644
index 000000000..ef67c520a
--- /dev/null
+++ b/source/slang/slang-emit-torch.cpp
@@ -0,0 +1,181 @@
+// slang-emit-torch.cpp
+#include "slang-emit-torch.h"
+
+#include "../core/slang-writer.h"
+
+#include "slang-emit-source-writer.h"
+#include "slang-mangled-lexer.h"
+
+#include <assert.h>
+
+namespace Slang
+{
+bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
+{
+ switch (inst->getOp())
+ {
+ default:
+ {
+ return Super::tryEmitInstExprImpl(inst, inOuterPrec);
+ }
+ case kIROp_MakeTensorView:
+ {
+ m_writer->emit("make_tensor_view(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ case kIROp_CudaKernelLaunch:
+ {
+ m_writer->emit("cudaLaunchKernel(");
+ // func
+ m_writer->emit("(const void*)(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // gridDim
+ m_writer->emit("slang_bit_cast<dim3>(");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // blockDim
+ m_writer->emit("slang_bit_cast<dim3>(");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+
+ // args
+ emitOperand(inst->getOperand(3), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+
+ // shared mem
+ m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")), ");
+
+ // stream
+ m_writer->emit("((cudaStream_t)");
+ emitOperand(inst->getOperand(4), getInfo(EmitOp::General));
+ m_writer->emit("))");
+ return true;
+ }
+ case kIROp_TorchGetCudaStream:
+ {
+ m_writer->emit("at::cuda::getCurrentCUDAStream()");
+ return true;
+ }
+ case kIROp_AllocateTorchTensor:
+ {
+ /*
+ Emit something like:
+ ```
+ torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... },
+ torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
+ ```
+ */
+ m_writer->emit("torch::empty({ ");
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ auto arg = inst->getOperand(i);
+ emitOperand(arg, getInfo(EmitOp::General));
+ }
+ m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::");
+ switch (inst->getDataType()->getOperand(0)->getOp())
+ {
+ case kIROp_FloatType:
+ m_writer->emit("kFloat32");
+ break;
+ case kIROp_HalfType:
+ m_writer->emit("kFloat16");
+ break;
+ case kIROp_DoubleType:
+ m_writer->emit("kFloat64");
+ break;
+ case kIROp_UInt8Type:
+ m_writer->emit("kUInt8");
+ break;
+ case kIROp_UInt16Type:
+ m_writer->emit("kUInt16");
+ break;
+ case kIROp_UIntType:
+ m_writer->emit("kUInt32");
+ break;
+ case kIROp_UInt64Type:
+ m_writer->emit("kUInt64");
+ break;
+ case kIROp_Int8Type:
+ m_writer->emit("kInt8");
+ break;
+ case kIROp_Int16Type:
+ m_writer->emit("kInt16");
+ break;
+ case kIROp_IntType:
+ m_writer->emit("kInt32");
+ break;
+ case kIROp_Int64Type:
+ m_writer->emit("kInt64");
+ break;
+ default:
+ SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor");
+ break;
+ }
+ m_writer->emit("))");
+ return true;
+ }
+ }
+}
+
+SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out)
+{
+ switch (type->getOp())
+ {
+ default:
+ return Super::calcTypeName(type, target, out);
+ case kIROp_TensorViewType:
+ {
+ out << "TensorView";
+ return SLANG_OK;
+ }
+ case kIROp_TorchTensorType:
+ {
+ out << "torch::Tensor";
+ return SLANG_OK;
+ }
+ case kIROp_TorchKernelMemoryAllocatorType:
+ {
+ out << "CudaTaskMemoryAllocator";
+ return SLANG_OK;
+ }
+ }
+}
+
+void TorchCppSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink)
+{
+ Super::emitModuleImpl(module, sink);
+
+ // Emit PyBind declarations.
+ m_writer->emit("PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n");
+ m_writer->indent();
+ for (auto inst : module->getGlobalInsts())
+ {
+ auto func = as<IRFunc>(inst);
+ if (!func) continue;
+ auto decor = func->findDecoration<IRTorchEntryPointDecoration>();
+ if (!decor) continue;
+ m_writer->emit("m.def(");
+ emitStringLiteral(decor->getFunctionName());
+ m_writer->emit(", &");
+ m_writer->emit(decor->getFunctionName());
+ m_writer->emit(", ");
+ emitStringLiteral(decor->getFunctionName());
+ m_writer->emit(");\n");
+ }
+ m_writer->dedent();
+ m_writer->emit("}\n");
+
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-emit-torch.h b/source/slang/slang-emit-torch.h
new file mode 100644
index 000000000..84ce42331
--- /dev/null
+++ b/source/slang/slang-emit-torch.h
@@ -0,0 +1,28 @@
+// slang-emit-torch.h
+#ifndef SLANG_EMIT_TORCH_H
+#define SLANG_EMIT_TORCH_H
+
+#include "slang-emit-cpp.h"
+
+namespace Slang
+{
+
+class TorchCppSourceEmitter : public CPPSourceEmitter
+{
+public:
+ typedef CPPSourceEmitter Super;
+
+ TorchCppSourceEmitter(const Desc& desc) :
+ Super(desc)
+ {
+ }
+
+protected:
+ // CPPSourceEmitter overrides
+ virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) override;
+ virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) override;
+ virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink) override;
+};
+
+}
+#endif
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 1b4eed8fd..fe72efcc7 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -56,6 +56,7 @@
#include "slang-ir-glsl-liveness.h"
#include "slang-ir-string-hash.h"
#include "slang-ir-simplify-for-emit.h"
+#include "slang-ir-pytorch-cpp-binding.h"
#include "slang-legalize-types.h"
#include "slang-lower-to-ir.h"
#include "slang-mangle.h"
@@ -74,6 +75,7 @@
#include "slang-emit-hlsl.h"
#include "slang-emit-cpp.h"
#include "slang-emit-cuda.h"
+#include "slang-emit-torch.h"
#include "../compiler-core/slang-artifact-desc-util.h"
#include "../compiler-core/slang-artifact-util.h"
@@ -83,6 +85,7 @@
#include <assert.h>
Slang::String get_slang_cpp_host_prelude();
+Slang::String get_slang_torch_prelude();
namespace Slang {
@@ -402,6 +405,18 @@ Result linkAndOptimizeIR(
finalizeSpecialization(irModule);
+ switch (target)
+ {
+ case CodeGenTarget::PyTorchCppBinding:
+ generatePyTorchCppBinding(irModule, sink);
+ break;
+ case CodeGenTarget::CUDASource:
+ removeTorchKernels(irModule);
+ break;
+ default:
+ break;
+ }
+
// If we have a target that is GPU like we use the string hashing mechanism
// but for that to work we need to inline such that calls (or returns) of strings
// boil down into getStringHash(stringLiteral)
@@ -969,31 +984,39 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
LinkedIR linkedIR;
RefPtr<CLikeSourceEmitter> sourceEmitter;
-
SourceLanguage sourceLanguage = CLikeSourceEmitter::getSourceLanguage(target);
- switch (sourceLanguage)
+
+ switch (target)
{
- case SourceLanguage::CPP:
- {
- sourceEmitter = new CPPSourceEmitter(desc);
- break;
- }
- case SourceLanguage::GLSL:
- {
- sourceEmitter = new GLSLSourceEmitter(desc);
- break;
- }
- case SourceLanguage::HLSL:
- {
- sourceEmitter = new HLSLSourceEmitter(desc);
- break;
- }
- case SourceLanguage::CUDA:
+ default:
+ switch (sourceLanguage)
{
- sourceEmitter = new CUDASourceEmitter(desc);
- break;
+ case SourceLanguage::CPP:
+ {
+ sourceEmitter = new CPPSourceEmitter(desc);
+ break;
+ }
+ case SourceLanguage::GLSL:
+ {
+ sourceEmitter = new GLSLSourceEmitter(desc);
+ break;
+ }
+ case SourceLanguage::HLSL:
+ {
+ sourceEmitter = new HLSLSourceEmitter(desc);
+ break;
+ }
+ case SourceLanguage::CUDA:
+ {
+ sourceEmitter = new CUDASourceEmitter(desc);
+ break;
+ }
+ default: break;
}
- default: break;
+ break;
+ case CodeGenTarget::PyTorchCppBinding:
+ sourceEmitter = new TorchCppSourceEmitter(desc);
+ break;
}
if (!sourceEmitter)
@@ -1072,16 +1095,23 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
// Emit any front matter
sourceEmitter->emitFrontMatter(targetRequest);
- // If heterogeneous we output the prelude before everything else
- if (isHeterogeneousTarget(target))
- {
- sourceWriter.emit(get_slang_cpp_host_prelude());
- }
- else
+ switch (target)
{
- // Get the prelude
- String prelude = session->getPreludeForLanguage(sourceLanguage);
- sourceWriter.emit(prelude);
+ case CodeGenTarget::PyTorchCppBinding:
+ sourceWriter.emit(get_slang_torch_prelude());
+ break;
+ default:
+ if (isHeterogeneousTarget(target))
+ {
+ sourceWriter.emit(get_slang_cpp_host_prelude());
+ }
+ else
+ {
+ // Get the prelude
+ String prelude = session->getPreludeForLanguage(sourceLanguage);
+ sourceWriter.emit(prelude);
+ }
+ break;
}
// Emit anything that goes before the contents of the code generated for the module
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 04e08293f..68f1a28e6 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -67,6 +67,11 @@ INST(Nop, nop, 0, 0)
INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE)
+ INST(TensorViewType, TensorView, 1, HOISTABLE)
+ INST(TorchTensorType, TorchTensor, 0, HOISTABLE)
+ INST(TorchKernelMemoryAllocatorType, TorchMemAllocatorType, 0, HOISTABLE)
+ INST(ArrayListType, ArrayListVector, 1, HOISTABLE)
+
/* BindExistentialsTypeBase */
// A `BindExistentials<B, T0,w0, T1,w1, ...>` represents
@@ -220,6 +225,7 @@ INST(ThisType, this_type, 0, HOISTABLE)
INST(RTTIType, rtti_type, 0, HOISTABLE)
INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE)
INST(TupleType, tuple_type, 0, HOISTABLE)
+INST(TargetTupleType, TargetTuple, 0, HOISTABLE)
// A type that identifies it's contained type as being emittable as `spirv_literal.
INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE)
@@ -308,6 +314,7 @@ INST(MakeArray, makeArray, 0, 0)
INST(MakeArrayFromElement, makeArrayFromElement, 1, 0)
INST(MakeStruct, makeStruct, 0, 0)
INST(MakeTuple, makeTuple, 0, 0)
+INST(MakeTargetTuple, makeTuple, 0, 0)
INST(GetTupleElement, getTupleElement, 2, 0)
INST(MakeResultValue, makeResultValue, 1, 0)
INST(MakeResultError, makeResultError, 1, 0)
@@ -509,24 +516,24 @@ INST(SwizzledStore, swizzledStore, 2, 0)
/* IRConditionalbranch */
// conditionalBranch <condition> <trueBlock> <falseBlock>
- INST(conditionalBranch, conditionalBranch, 3, 0)
+INST(conditionalBranch, conditionalBranch, 3, 0)
- // ifElse <condition> <trueBlock> <falseBlock> <mergeBlock>
- INST(ifElse, ifElse, 4, 0)
- INST_RANGE(ConditionalBranch, conditionalBranch, ifElse)
+// ifElse <condition> <trueBlock> <falseBlock> <mergeBlock>
+INST(ifElse, ifElse, 4, 0)
+INST_RANGE(ConditionalBranch, conditionalBranch, ifElse)
- INST(Throw, throw, 1, 0)
- // tryCall <successBlock> <failBlock> <callee> <args>...
- INST(TryCall, tryCall, 3, 0)
- // switch <val> <break> <default> <caseVal1> <caseBlock1> ...
- INST(Switch, switch, 3, 0)
+INST(Throw, throw, 1, 0)
+// tryCall <successBlock> <failBlock> <callee> <args>...
+INST(TryCall, tryCall, 3, 0)
+// switch <val> <break> <default> <caseVal1> <caseBlock1> ...
+INST(Switch, switch, 3, 0)
- INST(discard, discard, 0, 0)
+INST(discard, discard, 0, 0)
- /* IRUnreachable */
- INST(MissingReturn, missingReturn, 0, 0)
- INST(Unreachable, unreachable, 0, 0)
- INST_RANGE(Unreachable, MissingReturn, Unreachable)
+/* IRUnreachable */
+INST(MissingReturn, missingReturn, 0, 0)
+INST(Unreachable, unreachable, 0, 0)
+INST_RANGE(Unreachable, MissingReturn, Unreachable)
INST_RANGE(TerminatorInst, Return, Unreachable)
@@ -575,10 +582,10 @@ INST(GetStringHash, getStringHash, 1, 0)
INST(WaveGetActiveMask, waveGetActiveMask, 0, 0)
- /// trueMask = waveMaskBallot(mask, condition)
+/// trueMask = waveMaskBallot(mask, condition)
INST(WaveMaskBallot, waveMaskBallot, 2, 0)
- /// matchMask = waveMaskBallot(mask, value)
+/// matchMask = waveMaskBallot(mask, value)
INST(WaveMaskMatch, waveMaskMatch, 2, 0)
// Texture sampling operation of the form `t.Sample(s,u)`
@@ -604,6 +611,12 @@ INST(GetOptiXHitAttribute, getOptiXHitAttribute, 2, 0)
// using a pointer.
INST(GetOptiXSbtDataPtr, getOptiXSbtDataPointer, 0, 0)
+INST(MakeArrayList, makeArrayList, 0, 0)
+INST(MakeTensorView, makeTensorView, 0, 0)
+INST(AllocateTorchTensor, allocTorchTensor , 0, 0)
+INST(TorchGetCudaStream, TorchGetCudaStream, 0, 0)
+INST(TorchTensorGetView, TorchTensorGetView, 0, 0)
+
/* Decoration */
INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
@@ -669,7 +682,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(CudaKernelDecoration, CudaKernel, 0, 0)
INST(CudaHostDecoration, CudaHost, 0, 0)
-
+ INST(TorchEntryPointDecoration, TorchEntryPoint, 0, 0)
+
/// Used to mark parameters that are moved from entry point parameters to global params as coming from the entry point.
INST(EntryPointParamDecoration, entryPointParam, 0, 0)
@@ -908,6 +922,7 @@ INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0)
INST(PrimalSubstitute, PrimalSubstitute, 1, 0)
INST(DispatchKernel, DispatchKernel, 3, 0)
+INST(CudaKernelLaunch, CudaKernelLaunch, 6, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9bb66823b..4cdf6c749 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -518,6 +518,18 @@ struct IRDllExportDecoration : IRDecoration
UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); }
};
+struct IRTorchEntryPointDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_TorchEntryPointDecoration
+ };
+ IR_LEAF_ISA(TorchEntryPointDecoration)
+
+ IRStringLit* getFunctionNameOperand() { return cast<IRStringLit>(getOperand(0)); }
+ UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); }
+};
+
struct IRFormatDecoration : IRDecoration
{
enum { kOp = kIROp_FormatDecoration };
@@ -936,6 +948,15 @@ struct IRDispatchKernel : IRInst
IR_LEAF_ISA(DispatchKernel)
};
+struct IRTorchTensorGetView : IRInst
+{
+ enum
+ {
+ kOp = kIROp_TorchTensorGetView
+ };
+ IR_LEAF_ISA(TorchTensorGetView)
+};
+
// Dictionary item mapping a type with a corresponding
// IDifferentiable witness table
//
@@ -2720,6 +2741,8 @@ public:
IRAnyValueType* getAnyValueType(IRInst* size);
IRDynamicType* getDynamicType();
+ IRTargetTupleType* getTargetTupleType(UInt count, IRType* const* types);
+
IRTupleType* getTupleType(UInt count, IRType* const* types);
IRTupleType* getTupleType(List<IRType*> const& types)
{
@@ -2775,6 +2798,10 @@ public:
IRInst* rowCount,
IRInst* columnCount);
+ IRArrayListType* getArrayListType(IRType* elementType);
+ IRTensorViewType* getTensorViewType(IRType* elementType);
+ IRTorchTensorType* getTorchTensorType();
+
IRDifferentialPairType* getDifferentialPairType(
IRType* valueType,
IRInst* witnessTable);
@@ -2896,7 +2923,10 @@ public:
IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn);
IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn);
IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn);
+
IRInst* emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs);
+ IRInst* emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream);
+ IRInst* emitGetTorchCudaStream();
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential);
@@ -2999,6 +3029,8 @@ public:
// Creates an RTTI object. Result is of `IRRTTIType`.
IRInst* emitMakeRTTIObject(IRInst* typeInst);
+ IRInst* emitMakeTargetTuple(IRType* type, UInt count, IRInst* const* args);
+
IRInst* emitMakeTuple(IRType* type, UInt count, IRInst* const* args);
IRInst* emitMakeTuple(UInt count, IRInst* const* args);
@@ -3067,6 +3099,11 @@ public:
UInt argCount,
IRInst* const* args);
+ IRInst* emitMakeArrayList(
+ IRType* type,
+ UInt argCount,
+ IRInst* const* args);
+
IRInst* emitMakeArrayFromElement(
IRType* type,
IRInst* element);
@@ -3083,6 +3120,8 @@ public:
return emitMakeStruct(type, args.getCount(), args.getBuffer());
}
+ IRInst* emitMakeTensorView(IRType* type, IRInst* allocator, IRInst* val);
+
IRInst* emitMakeExistential(
IRType* type,
IRInst* value,
@@ -3785,6 +3824,11 @@ public:
addDecoration(value, kIROp_DllExportDecoration, getStringValue(functionName));
}
+ void addTorchEntryPointDecoration(IRInst* value, UnownedStringSlice const& functionName)
+ {
+ addDecoration(value, kIROp_TorchEntryPointDecoration, getStringValue(functionName));
+ }
+
void addCudaDeviceExportDecoration(IRInst* value, UnownedStringSlice const& functionName)
{
addDecoration(value, kIROp_CudaDeviceExportDecoration, getStringValue(functionName));
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
new file mode 100644
index 000000000..e33adec1d
--- /dev/null
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -0,0 +1,248 @@
+#include "slang-ir-pytorch-cpp-binding.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-diagnostics.h"
+
+namespace Slang
+{
+static bool getHostReturnTypeImpl(List<IRType*>& elementTypes, IRBuilder& builder, IRType* type)
+{
+ bool isValid = true;
+ if (as<IRVoidType>(type))
+ return true;
+ if (as<IRBasicType>(type))
+ elementTypes.add(type);
+ else if (as<IRTorchTensorType>(type))
+ elementTypes.add(type);
+ else if (auto vectorType = as<IRVectorType>(type))
+ {
+ auto count = as<IRIntLit>(vectorType->getElementCount());
+ if (!count)
+ {
+ return false;
+ }
+ for (IRIntegerValue i = 0; i < count->getValue(); i++)
+ {
+ elementTypes.addRange(vectorType->getElementType());
+ }
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ auto arraySize = as<IRIntLit>(arrayType->getElementCount());
+ if (!arraySize)
+ {
+ return false;
+ }
+ List<IRType*> subElementTypes;
+ isValid &= getHostReturnTypeImpl(subElementTypes, builder, arrayType->getElementType());
+ for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
+ {
+ elementTypes.addRange(subElementTypes);
+ }
+ }
+ else if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ {
+ isValid &= getHostReturnTypeImpl(elementTypes, builder, field->getFieldType());
+ }
+ }
+ else
+ {
+ return false;
+ }
+ return isValid;
+}
+
+static IRType* getHostReturnType(IRBuilder& builder, IRType* type)
+{
+ List<IRType*> types;
+ bool isValid = getHostReturnTypeImpl(types, builder, type);
+ if (isValid)
+ return builder.getTargetTupleType((UInt)types.getCount(), types.getBuffer());
+ return nullptr;
+}
+
+static void flattenToTupleImpl(List<IRInst*>& result, IRBuilder& builder, IRInst* val)
+{
+ auto type = val->getDataType();
+ if (as<IRVoidType>(type))
+ return;
+ if (as<IRBasicType>(type))
+ result.add(val);
+ else if (as<IRTorchTensorType>(type))
+ result.add(val);
+ else if (auto vectorType = as<IRVectorType>(type))
+ {
+ auto count = as<IRIntLit>(vectorType->getElementCount());
+ if (!count)
+ {
+ return;
+ }
+ for (IRIntegerValue i = 0; i < count->getValue(); i++)
+ {
+ result.add(builder.emitElementExtract(vectorType->getElementType(), builder.getIntValue(builder.getIntType(), i)));
+ }
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ auto arraySize = as<IRIntLit>(arrayType->getElementCount());
+ if (!arraySize)
+ {
+ return;
+ }
+ for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
+ {
+ auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
+ flattenToTupleImpl(result, builder, elementVal);
+ }
+ }
+ else if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ {
+ auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey());
+ flattenToTupleImpl(result, builder, elementVal);
+ }
+ }
+}
+
+static IRInst* flattenToHostReturnTuple(IRBuilder& builder, IRType* type, IRInst* val)
+{
+ List<IRInst*> vals;
+ flattenToTupleImpl(vals, builder, val);
+ return builder.emitMakeTargetTuple(type, (UInt)vals.getCount(), vals.getBuffer());
+}
+
+static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
+{
+ IRBuilder builder(func);
+
+ builder.setInsertBefore(func);
+ auto hostReturnType = getHostReturnType(builder, func->getResultType());
+ if (!hostReturnType)
+ {
+ sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType());
+ return;
+ }
+ List<IRType*> hostParamTypes;
+ auto funcType = as<IRFuncType>(func->getDataType());
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ hostParamTypes.add(funcType->getParamType(i));
+ }
+ auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType);
+ func->setFullType(bindingFuncType);
+
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType));
+
+ List<IRInst*> instsToRemove;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (auto kernelDispatch = as<IRDispatchKernel>(inst))
+ {
+ builder.setInsertBefore(kernelDispatch);
+ List<IRInst*> kernelArgs;
+ auto kernelArgCount = kernelDispatch->getArgCount();
+ auto argArrayType = builder.getArrayType(builder.getPtrType(builder.getVoidType()),
+ builder.getIntValue(builder.getIntType(), kernelArgCount));
+ auto argArrayVar = builder.emitVar(argArrayType);
+ for (UInt i = 0; i < kernelArgCount; i++)
+ {
+ auto arg = kernelDispatch->getArg(i);
+ auto argVar = builder.emitVar(arg->getFullType());
+ builder.emitStore(argVar, arg);
+ auto addr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), i));
+ builder.emitStore(addr, argVar);
+ }
+ auto argArrayPtr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), 0));
+ builder.emitCudaKernelLaunch(
+ kernelDispatch->getBaseFn(),
+ kernelDispatch->getDispatchSize(),
+ kernelDispatch->getThreadGroupSize(),
+ argArrayPtr,
+ builder.emitGetTorchCudaStream());
+ instsToRemove.add(inst);
+ }
+ else if (auto getView = as<IRTorchTensorGetView>(inst))
+ {
+ builder.setInsertBefore(getView);
+ auto makeView = builder.emitMakeTensorView(getView->getFullType(), allocator, inst->getOperand(0));
+ getView->replaceUsesWith(makeView);
+ instsToRemove.add(getView);
+ }
+ else if (auto ret = as<IRReturn>(inst))
+ {
+ builder.setInsertBefore(ret);
+ auto retVal = flattenToHostReturnTuple(builder, hostReturnType, ret->getVal());
+ ret->setOperand(0, retVal);
+ }
+ }
+ }
+
+ for (auto inst : instsToRemove)
+ inst->removeAndDeallocate();
+}
+
+void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
+{
+ List<IRFunc*> workList;
+ List<IRFunc*> cudaKernels;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ auto func = as<IRFunc>(globalInst);
+ if (!func)
+ continue;
+ if (func->findDecoration<IRTorchEntryPointDecoration>())
+ {
+ workList.add(func);
+ }
+ else if (func->findDecoration<IRCudaKernelDecoration>())
+ {
+ cudaKernels.add(func);
+ }
+ else
+ {
+ // Remove all other export decorations if this is not a cuda host func.
+ if (auto decor = func->findDecoration<IRPublicDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRHLSLExportDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRKeepAliveDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRDllExportDecoration>())
+ decor->removeAndDeallocate();
+ }
+ }
+
+ for (auto func : workList)
+ generateCppBindingForFunc(func, sink);
+
+ for (auto func : cudaKernels)
+ {
+ for (auto block = func->getFirstBlock(); block;)
+ {
+ auto nextBlock = block->getNextBlock();
+ block->removeAndDeallocate();
+ block = nextBlock;
+ }
+ }
+}
+
+// Remove all [TorchEntryPoint] functions when emitting CUDA source.
+void removeTorchKernels(IRModule* module)
+{
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (!as<IRFunc>(globalInst))
+ continue;
+ if (globalInst->findDecoration<IRTorchEntryPointDecoration>())
+ globalInst->removeAndDeallocate();
+ }
+
+}
+
+}
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h
new file mode 100644
index 000000000..c35b6a8eb
--- /dev/null
+++ b/source/slang/slang-ir-pytorch-cpp-binding.h
@@ -0,0 +1,12 @@
+#pragma once
+
+namespace Slang
+{
+struct IRModule;
+class DiagnosticSink;
+
+void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink);
+void removeTorchKernels(IRModule* module);
+
+}
+
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 69870c128..6ce54a948 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2594,6 +2594,11 @@ namespace Slang
IRDynamicType* IRBuilder::getDynamicType() { return (IRDynamicType*)getType(kIROp_DynamicType); }
+ IRTargetTupleType* IRBuilder::getTargetTupleType(UInt count, IRType* const* types)
+ {
+ return (IRTargetTupleType*)getType(kIROp_TargetTupleType, count, (IRInst* const*)types);
+ }
+
IRAssociatedType* IRBuilder::getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes)
{
return (IRAssociatedType*)getType(kIROp_AssociatedType,
@@ -2788,6 +2793,27 @@ namespace Slang
operands);
}
+ IRArrayListType* IRBuilder::getArrayListType(IRType* elementType)
+ {
+ return (IRArrayListType*)getType(
+ kIROp_ArrayListType,
+ 1,
+ (IRInst**)&elementType);
+ }
+
+ IRTensorViewType* IRBuilder::getTensorViewType(IRType* elementType)
+ {
+ return (IRTensorViewType*)getType(
+ kIROp_TensorViewType,
+ 1,
+ (IRInst**)&elementType);
+ }
+
+ IRTorchTensorType* IRBuilder::getTorchTensorType()
+ {
+ return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 0, nullptr);
+ }
+
IRDifferentialPairType* IRBuilder::getDifferentialPairType(
IRType* valueType,
IRInst* witnessTable)
@@ -3173,6 +3199,21 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream)
+ {
+ IRInst* args[5] = {baseFn, gridDim, blockDim, argsArray, cudaStream};
+ return emitIntrinsicInst(
+ getVoidType(),
+ kIROp_CudaKernelLaunch,
+ 5,
+ args);
+ }
+
+ IRInst* IRBuilder::emitGetTorchCudaStream()
+ {
+ return emitIntrinsicInst(getPtrType(getVoidType()), kIROp_TorchGetCudaStream, 0, nullptr);
+ }
+
IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn)
{
auto inst = createInst<IRBackwardDifferentiatePrimal>(
@@ -3659,6 +3700,11 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeTuple, count, args);
}
+ IRInst* IRBuilder::emitMakeTargetTuple(IRType* type, UInt count, IRInst* const* args)
+ {
+ return emitIntrinsicInst(type, kIROp_MakeTargetTuple, count, args);
+ }
+
IRInst* IRBuilder::emitMakeTuple(UInt count, IRInst* const* args)
{
List<IRType*> types;
@@ -3851,6 +3897,11 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeArray, argCount, args);
}
+ IRInst* IRBuilder::emitMakeArrayList(IRType* type, UInt argCount, IRInst* const* args)
+ {
+ return emitIntrinsicInst(type, kIROp_MakeArrayList, argCount, args);
+ }
+
IRInst* IRBuilder::emitMakeArrayFromElement(
IRType* type,
IRInst* element)
@@ -3866,6 +3917,12 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeStruct, argCount, args);
}
+ IRInst* IRBuilder::emitMakeTensorView(IRType* type, IRInst* allocator, IRInst* val)
+ {
+ IRInst* args[2] = { allocator, val };
+ return emitIntrinsicInst(type, kIROp_MakeTensorView, 2, args);
+ }
+
IRInst* IRBuilder::emitMakeExistential(
IRType* type,
IRInst* value,
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 025812f83..d74a679d3 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1490,6 +1490,25 @@ struct IRMatrixType : IRType
IR_LEAF_ISA(MatrixType)
};
+struct IRArrayListType : IRType
+{
+ IRType* getElementType() { return (IRType*)getOperand(0); }
+
+ IR_LEAF_ISA(ArrayListType)
+};
+
+struct IRTensorViewType : IRType
+{
+ IRType* getElementType() { return (IRType*)getOperand(0); }
+
+ IR_LEAF_ISA(TensorViewType)
+};
+
+struct IRTorchTensorType : IRType
+{
+ IR_LEAF_ISA(TorchTensorType)
+};
+
struct IRSPIRVLiteralType : IRType
{
IR_LEAF_ISA(SPIRVLiteralType)
@@ -1699,6 +1718,12 @@ struct IRTupleType : IRType
IR_LEAF_ISA(TupleType)
};
+/// Represents a tuple in target language. TargetTupleType will not be lowered to structs.
+struct IRTargetTupleType : IRType
+{
+ IR_LEAF_ISA(TargetTupleType)
+};
+
/// Represents an `Result<T,E>`, used by functions that throws error codes.
struct IRResultType : IRType
{
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 7144b3450..9d424d1e8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1150,51 +1150,67 @@ static void addLinkageDecoration(
{
builder->addExportDecoration(inst, mangledName);
}
- if (decl->findModifier<PublicModifier>())
+ for (auto modifier : decl->modifiers)
{
- builder->addPublicDecoration(inst);
- builder->addKeepAliveDecoration(inst);
- }
- if (decl->findModifier<HLSLExportModifier>())
- {
- builder->addHLSLExportDecoration(inst);
- builder->addKeepAliveDecoration(inst);
- }
- if (decl->findModifier<ExternCppModifier>())
- {
- builder->addExternCppDecoration(inst, mangledName);
+ if (as<PublicModifier>(modifier))
+ {
+ builder->addPublicDecoration(inst);
+ builder->addKeepAliveDecoration(inst);
+ }
+ else if (as<HLSLExportModifier>(modifier))
+ {
+ builder->addHLSLExportDecoration(inst);
+ builder->addKeepAliveDecoration(inst);
+ }
+ else if (as<ExternCppModifier>(modifier))
+ {
+ builder->addExternCppDecoration(inst, mangledName);
+ }
+ else if (auto dllImportModifier = as<DllImportAttribute>(modifier))
+ {
+ auto libraryName = dllImportModifier->modulePath;
+ auto functionName = dllImportModifier->functionName.getLength()
+ ? dllImportModifier->functionName.getUnownedSlice()
+ : decl->getName()->text.getUnownedSlice();
+ builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), functionName);
+ }
+ else if (as<DllExportAttribute>(modifier))
+ {
+ builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice());
+ builder->addPublicDecoration(inst);
+ }
+ else if (as<CudaDeviceExportAttribute>(modifier))
+ {
+ builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice());
+ builder->addPublicDecoration(inst);
+ builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
+ }
+ else if (as<CudaHostAttribute>(modifier))
+ {
+ builder->addCudaHostDecoration(inst);
+ builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
+ }
+ else if (as<CudaKernelAttribute>(modifier))
+ {
+ builder->addCudaKernelDecoration(inst);
+ builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
+ builder->addPublicDecoration(inst);
+ builder->addKeepAliveDecoration(inst);
+ }
+ else if (as<TorchEntryPointAttribute>(modifier))
+ {
+ builder->addTorchEntryPointDecoration(inst, decl->getName()->text.getUnownedSlice());
+ builder->addCudaHostDecoration(inst);
+ builder->addPublicDecoration(inst);
+ builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
+ }
}
if (as<InterfaceDecl>(decl->parentDecl) &&
- decl->parentDecl->hasModifier<ComInterfaceAttribute>())
+ decl->parentDecl->hasModifier<ComInterfaceAttribute>() &&
+ !inst->findDecoration<IRExternCppDecoration>())
{
builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
}
- if (auto dllImportModifier = decl->findModifier<DllImportAttribute>())
- {
- auto libraryName = dllImportModifier->modulePath;
- auto functionName = dllImportModifier->functionName.getLength()
- ? dllImportModifier->functionName.getUnownedSlice()
- : decl->getName()->text.getUnownedSlice();
- builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), functionName);
- }
- if (decl->findModifier<DllExportAttribute>())
- {
- builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice());
- builder->addPublicDecoration(inst);
- }
- if (decl->findModifier<CudaDeviceExportAttribute>())
- {
- builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice());
- builder->addPublicDecoration(inst);
- }
- if (decl->findModifier<CudaHostAttribute>())
- {
- builder->addCudaHostDecoration(inst);
- }
- if (decl->findModifier<CudaKernelAttribute>())
- {
- builder->addCudaKernelDecoration(inst);
- }
}
static void addLinkageDecoration(
diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp
index 714e2c99d..d30c02484 100644
--- a/source/slang/slang-options.cpp
+++ b/source/slang/slang-options.cpp
@@ -2193,6 +2193,7 @@ struct OptionsParser
if (rawOutputs.getCount() == 0 &&
rawTargets.getCount() == 1 &&
(rawTargets[0].format == CodeGenTarget::HostCPPSource ||
+ rawTargets[0].format == CodeGenTarget::PyTorchCppBinding ||
rawTargets[0].format == CodeGenTarget::CUDASource ||
ArtifactDescUtil::makeDescForCompileTarget(asExternal(rawTargets[0].format)).kind == ArtifactKind::HostCallable))
{
@@ -2258,7 +2259,7 @@ struct OptionsParser
case CodeGenTarget::ShaderHostCallable:
case CodeGenTarget::HostExecutable:
case CodeGenTarget::ShaderSharedLibrary:
-
+ case CodeGenTarget::PyTorchCppBinding:
case CodeGenTarget::DXIL:
rawOutput.isWholeProgram = true;
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index cdeb0b259..45f4be477 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2166,9 +2166,9 @@ namespace Slang
dispatchExpr->baseFunction = parser->ParseArgExpr();
parser->ReadToken(TokenType::Comma);
- dispatchExpr->threadGroupSize = parser->ParseArgExpr();
- parser->ReadToken(TokenType::Comma);
dispatchExpr->dispatchSize = parser->ParseArgExpr();
+ parser->ReadToken(TokenType::Comma);
+ dispatchExpr->threadGroupSize = parser->ParseArgExpr();
parser->ReadToken(TokenType::RParent);
return dispatchExpr;
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 27aba435f..1c2726551 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -530,6 +530,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
matType->declRef = declRef;
return matType;
}
+ else if (magicMod->magicName == "TensorViewType")
+ {
+ SLANG_ASSERT(subst && subst->getArgs().getCount() == 1);
+ auto vecType = astBuilder->getOrCreate<TensorViewType>(ExtractGenericArgType(subst->getArgs()[0]));
+ vecType->declRef = declRef;
+ return vecType;
+ }
else if (magicMod->magicName == "Texture")
{
SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1);
diff --git a/source/slangc/main.cpp b/source/slangc/main.cpp
index 2870a5a3c..2fe9d19a1 100644
--- a/source/slangc/main.cpp
+++ b/source/slangc/main.cpp
@@ -79,15 +79,15 @@ SLANG_TEST_TOOL_API SlangResult innerMain(StdWriters* stdWriters, slang::IGlobal
if (TestToolUtil::hasDeferredStdLib(Index(argc - 1), argv + 1))
{
SLANG_RETURN_ON_FAIL(slang_createGlobalSessionWithoutStdLib(SLANG_API_VERSION, session.writeRef()));
- TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session);
}
else if (!session)
{
// Just create the global session in the regular way if there isn't one set
SLANG_RETURN_ON_FAIL(slang_createGlobalSession(SLANG_API_VERSION, session.writeRef()));
- TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session);
}
+ TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session);
+
SlangCompileRequest* compileRequest = spCreateCompileRequest(session);
compileRequest->addSearchPath(Path::getParentDirectory(Path::getExecutablePath()).getBuffer());
SlangResult res = _compile(compileRequest, argc, argv);
diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang
index 54442498b..2700fb054 100644
--- a/tests/autodiff/cuda-kernel-export.slang
+++ b/tests/autodiff/cuda-kernel-export.slang
@@ -3,39 +3,19 @@
// Verify that we can output a cuda device function with [CudaDeviceExport].
// Disabled until we have FileCheck.
-struct MixedType : IDifferentiable
-{
- no_diff float noDiffField;
- float field;
-}
-
-[BackwardDifferentiable]
-float f1(MixedType m)
-{
- return 2.0 * m.field;
-}
-
-[BackwardDifferentiable]
-float f(MixedType m)
-{
- MixedType m1 = { m.noDiffField, m.field };
- return f1(m1);
-}
-
-[CudaDeviceExport]
-void diffF(inout DifferentialPair<MixedType> m, float dout)
-{
- __bwd_diff(f)(m, dout);
-}
[CudaKernel]
-void myKernel(float* inValues, float* outValues)
+void myKernel(TensorView<float> inValues, TensorView<float> outValues)
{
- outValues[0] = sin(inValues[0]);
+ if (cudaThreadIdx().x > 0)
+ return;
+ outValues.store(cudaThreadIdx().x, sin(inValues.load(cudaThreadIdx().x)));
}
-[CudaHost]
-public __extern_cpp void runCompute(float *inValues, float *outValues, uint3 dispathcSize)
+[TorchEntryPoint]
+public __extern_cpp TorchTensor<float> runCompute(TorchTensor<float> inValues)
{
- __dispatch_kernel(myKernel, uint3(128, 1, 1), dispathcSize)(inValues, outValues);
+ var outValues = TorchTensor<float>.alloc(1);
+ __dispatch_kernel(myKernel, uint3(1, 1, 1), uint3(32, 1, 1))(inValues, outValues);
+ return outValues;
} \ No newline at end of file
diff --git a/tools/gfx/slang.slang b/tools/gfx/slang.slang
index 4250cb62e..1fd06560f 100644
--- a/tools/gfx/slang.slang
+++ b/tools/gfx/slang.slang
@@ -53,6 +53,7 @@ enum SlangCompileTarget
SLANG_DXIL_ASM,
SLANG_C_SOURCE, ///< The C language
SLANG_CPP_SOURCE, ///< C++ code for shader kernels.
+ SLANG_CPP_PYTORCH_BINDING,
SLANG_HOST_EXECUTABLE, ///< Standalone binary executable (for hosting CPU/OS)
SLANG_SHADER_SHARED_LIBRARY, ///< A shared library/Dll for shader kernels (for hosting CPU/OS)
SLANG_SHADER_HOST_CALLABLE, ///< A CPU target that makes the compiled shader code available to be run immediately
@@ -448,4 +449,4 @@ struct SpecializationArg
TypeReflection *type;
}
-} \ No newline at end of file
+}
diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp
index d2023a0f0..71f5c04bd 100644
--- a/tools/slang-test/slang-test-main.cpp
+++ b/tools/slang-test/slang-test-main.cpp
@@ -765,6 +765,7 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target)
case SLANG_GLSL:
case SLANG_C_SOURCE:
case SLANG_CPP_SOURCE:
+ case SLANG_CPP_PYTORCH_BINDING:
case SLANG_HOST_CPP_SOURCE:
case SLANG_CUDA_SOURCE:
{