summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-26 13:59:11 -0700
committerGitHub <noreply@github.com>2023-03-26 13:59:11 -0700
commitd64ee86a3130f8eeb75d09193c38c621d7565eba (patch)
treefed25a0cc2a7372d26175774f5983bed693e6b64 /prelude
parent666af0962b6ab41489a3a3287db83f77c2f6461a (diff)
Add PyTorch C++ binding generation. (#2734)
* Add PyTorch C++ binding generation. * fix --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cpp-types-core.h561
-rw-r--r--prelude/slang-cpp-types.h554
-rw-r--r--prelude/slang-cuda-prelude.h71
-rw-r--r--prelude/slang-torch-prelude.h126
4 files changed, 759 insertions, 553 deletions
diff --git a/prelude/slang-cpp-types-core.h b/prelude/slang-cpp-types-core.h
new file mode 100644
index 000000000..c49ee013c
--- /dev/null
+++ b/prelude/slang-cpp-types-core.h
@@ -0,0 +1,561 @@
+#ifndef SLANG_PRELUDE_CPP_TYPES_CORE_H
+#define SLANG_PRELUDE_CPP_TYPES_CORE_H
+
+#ifndef SLANG_PRELUDE_ASSERT
+# ifdef SLANG_PRELUDE_ENABLE_ASSERT
+# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE)
+# else
+# define SLANG_PRELUDE_ASSERT(VALUE)
+# endif
+#endif
+
+// Since we are using unsigned arithmatic care is need in this comparison.
+// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0
+// Which means only a single test is needed
+
+// Asserts for bounds checking.
+// It is assumed index/count are unsigned types.
+#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count);
+#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0);
+
+// Macros to zero index if an access is out of range
+#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0;
+#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0;
+
+// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX
+// the fix macro will zero the index, if out of range
+#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX
+# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
+# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
+# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
+#else
+# define SLANG_BOUND_FIX(index, count)
+# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
+# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
+#endif
+
+#ifndef SLANG_BOUND_CHECK
+# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count)
+#endif
+
+#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS
+# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
+#endif
+
+#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY
+# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
+#endif
+
+struct TypeInfo
+{
+ size_t typeSize;
+};
+
+template <typename T, size_t SIZE>
+struct FixedArray
+{
+ const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
+ T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
+
+ T m_data[SIZE];
+};
+
+// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially
+// do bounds checking.
+template <typename T>
+struct Array
+{
+ const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; }
+ T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; }
+
+ T* data;
+ size_t count;
+};
+
+/* Constant buffers become a pointer to the contained type, so ConstantBuffer<T> becomes T* in C++ code.
+*/
+
+template <typename T, int COUNT>
+struct Vector;
+
+template <typename T>
+struct Vector<T, 1>
+{
+ T x;
+ const T& operator[](size_t /*index*/) const { return x; }
+ T& operator[](size_t /*index*/) { return x; }
+ operator T() const { return x; }
+ Vector() = default;
+ Vector(T scalar)
+ {
+ x = scalar;
+ }
+ template <typename U>
+ Vector(Vector<U, 1> other)
+ {
+ x = (T)other.x;
+ }
+ template <typename U, int otherSize>
+ Vector(Vector<U, otherSize> other)
+ {
+ int minSize = 1;
+ if (otherSize < minSize) minSize = otherSize;
+ for (int i = 0; i < minSize; i++)
+ (*this)[i] = (T)other[i];
+ }
+};
+
+template <typename T>
+struct Vector<T, 2>
+{
+ T x, y;
+ const T& operator[](size_t index) const { return index == 0 ? x : y; }
+ T& operator[](size_t index) { return index == 0 ? x : y; }
+ Vector() = default;
+ Vector(T scalar)
+ {
+ x = y = scalar;
+ }
+ Vector(T _x, T _y)
+ {
+ x = _x;
+ y = _y;
+ }
+ template <typename U>
+ Vector(Vector<U, 2> other)
+ {
+ x = (T)other.x;
+ y = (T)other.y;
+ }
+ template <typename U, int otherSize>
+ Vector(Vector<U, otherSize> other)
+ {
+ int minSize = 2;
+ if (otherSize < minSize) minSize = otherSize;
+ for (int i = 0; i < minSize; i++)
+ (*this)[i] = (T)other[i];
+ }
+};
+
+template <typename T>
+struct Vector<T, 3>
+{
+ T x, y, z;
+ const T& operator[](size_t index) const { return *((T*)(this) + index); }
+ T& operator[](size_t index) { return *((T*)(this) + index); }
+
+ Vector() = default;
+ Vector(T scalar)
+ {
+ x = y = z = scalar;
+ }
+ Vector(T _x, T _y, T _z)
+ {
+ x = _x;
+ y = _y;
+ z = _z;
+ }
+ template <typename U>
+ Vector(Vector<U, 3> other)
+ {
+ x = (T)other.x;
+ y = (T)other.y;
+ z = (T)other.z;
+ }
+ template <typename U, int otherSize>
+ Vector(Vector<U, otherSize> other)
+ {
+ int minSize = 3;
+ if (otherSize < minSize) minSize = otherSize;
+ for (int i = 0; i < minSize; i++)
+ (*this)[i] = (T)other[i];
+ }
+};
+
+template <typename T>
+struct Vector<T, 4>
+{
+ T x, y, z, w;
+
+ const T& operator[](size_t index) const { return *((T*)(this) + index); }
+ T& operator[](size_t index) { return *((T*)(this) + index); }
+ Vector() = default;
+ Vector(T scalar)
+ {
+ x = y = z = w = scalar;
+ }
+ Vector(T _x, T _y, T _z, T _w)
+ {
+ x = _x;
+ y = _y;
+ z = _z;
+ w = _w;
+ }
+ template <typename U, int otherSize>
+ Vector(Vector<U, otherSize> other)
+ {
+ int minSize = 4;
+ if (otherSize < minSize) minSize = otherSize;
+ for (int i = 0; i < minSize; i++)
+ (*this)[i] = (T)other[i];
+ }
+
+};
+
+template<typename T, int N>
+SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index)
+{
+ return x[index];
+}
+
+template<typename T, int N>
+SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector<T, N>* x, int index)
+{
+ return &((*const_cast<Vector<T,N>*>(x))[index]);
+}
+
+template<typename T, int N>
+SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector<T, N>* x, int index)
+{
+ return &((*x)[index]);
+}
+
+template<typename T, int n, typename OtherT, int m>
+SLANG_FORCE_INLINE Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other)
+{
+ Vector<T, n> result;
+ for (int i = 0; i < n; i++)
+ {
+ OtherT otherElement = T(0);
+ if (i < m)
+ otherElement = _slang_vector_get_element(other, i);
+ *_slang_vector_get_element_ptr(&result, i) = (T)otherElement;
+ }
+ return result;
+}
+
+typedef uint32_t uint;
+
+#define SLANG_VECTOR_BINARY_OP(T, op) \
+ template<int n> \
+ SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
+ { \
+ Vector<T, n> result;\
+ for (int i = 0; i < n; i++) \
+ result[i] = thisVal[i] op other[i]; \
+ return result;\
+ }
+#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \
+ template<int n> \
+ SLANG_FORCE_INLINE Vector<bool, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
+ { \
+ Vector<bool, n> result;\
+ for (int i = 0; i < n; i++) \
+ result[i] = thisVal[i] op other[i]; \
+ return result;\
+ }
+
+#define SLANG_VECTOR_UNARY_OP(T, op) \
+ template<int n> \
+ SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal) \
+ { \
+ Vector<T, n> result;\
+ for (int i = 0; i < n; i++) \
+ result[i] = op thisVal[i]; \
+ return result;\
+ }
+#define SLANG_INT_VECTOR_OPS(T) \
+ SLANG_VECTOR_BINARY_OP(T, +)\
+ SLANG_VECTOR_BINARY_OP(T, -)\
+ SLANG_VECTOR_BINARY_OP(T, *)\
+ SLANG_VECTOR_BINARY_OP(T, / )\
+ SLANG_VECTOR_BINARY_OP(T, &)\
+ SLANG_VECTOR_BINARY_OP(T, |)\
+ SLANG_VECTOR_BINARY_OP(T, &&)\
+ SLANG_VECTOR_BINARY_OP(T, ||)\
+ SLANG_VECTOR_BINARY_OP(T, ^)\
+ SLANG_VECTOR_BINARY_OP(T, %)\
+ SLANG_VECTOR_BINARY_OP(T, >>)\
+ SLANG_VECTOR_BINARY_OP(T, <<)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\
+ SLANG_VECTOR_UNARY_OP(T, !)\
+ SLANG_VECTOR_UNARY_OP(T, ~)
+#define SLANG_FLOAT_VECTOR_OPS(T) \
+ SLANG_VECTOR_BINARY_OP(T, +)\
+ SLANG_VECTOR_BINARY_OP(T, -)\
+ SLANG_VECTOR_BINARY_OP(T, *)\
+ SLANG_VECTOR_BINARY_OP(T, /)\
+ SLANG_VECTOR_UNARY_OP(T, -)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
+ SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)
+
+SLANG_INT_VECTOR_OPS(bool)
+SLANG_INT_VECTOR_OPS(int)
+SLANG_INT_VECTOR_OPS(int8_t)
+SLANG_INT_VECTOR_OPS(int16_t)
+SLANG_INT_VECTOR_OPS(int64_t)
+SLANG_INT_VECTOR_OPS(uint)
+SLANG_INT_VECTOR_OPS(uint8_t)
+SLANG_INT_VECTOR_OPS(uint16_t)
+SLANG_INT_VECTOR_OPS(uint64_t)
+
+SLANG_FLOAT_VECTOR_OPS(float)
+SLANG_FLOAT_VECTOR_OPS(double)
+
+#define SLANG_VECTOR_INT_NEG_OP(T) \
+ template<int N>\
+ Vector<T, N> operator-(const Vector<T, N>& thisVal) \
+ { \
+ Vector<T, N> result;\
+ for (int i = 0; i < N; i++) \
+ result[i] = 0 - thisVal[i]; \
+ return result;\
+ }
+SLANG_VECTOR_INT_NEG_OP(int)
+SLANG_VECTOR_INT_NEG_OP(int8_t)
+SLANG_VECTOR_INT_NEG_OP(int16_t)
+SLANG_VECTOR_INT_NEG_OP(int64_t)
+SLANG_VECTOR_INT_NEG_OP(uint)
+SLANG_VECTOR_INT_NEG_OP(uint8_t)
+SLANG_VECTOR_INT_NEG_OP(uint16_t)
+SLANG_VECTOR_INT_NEG_OP(uint64_t)
+
+#define SLANG_FLOAT_VECTOR_MOD(T)\
+ template<int N> \
+ Vector<T, N> operator%(const Vector<T, N>& left, const Vector<T, N>& right) \
+ {\
+ Vector<T, N> result;\
+ for (int i = 0; i < N; i++) \
+ result[i] = _slang_fmod(left[i], right[i]); \
+ return result;\
+ }
+
+SLANG_FLOAT_VECTOR_MOD(float)
+SLANG_FLOAT_VECTOR_MOD(double)
+#undef SLANG_FLOAT_VECTOR_MOD
+#undef SLANG_VECTOR_BINARY_OP
+#undef SLANG_VECTOR_UNARY_OP
+#undef SLANG_INT_VECTOR_OPS
+#undef SLANG_FLOAT_VECTOR_OPS
+#undef SLANG_VECTOR_INT_NEG_OP
+#undef SLANG_FLOAT_VECTOR_MOD
+
+template <typename T, int ROWS, int COLS>
+struct Matrix
+{
+ Vector<T, COLS> rows[ROWS];
+ Vector<T, COLS>& operator[](size_t index) { return rows[index]; }
+ Matrix() = default;
+ Matrix(T scalar)
+ {
+ for (int i = 0; i < ROWS; i++)
+ rows[i] = Vector<T, COLS>(scalar);
+ }
+ Matrix(const Vector<T, COLS>& row0)
+ {
+ rows[0] = row0;
+ }
+ Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1)
+ {
+ rows[0] = row0;
+ rows[1] = row1;
+ }
+ Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2)
+ {
+ rows[0] = row0;
+ rows[1] = row1;
+ rows[2] = row2;
+ }
+ Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3)
+ {
+ rows[0] = row0;
+ rows[1] = row1;
+ rows[2] = row2;
+ rows[3] = row3;
+ }
+ template<typename U, int otherRow, int otherCol>
+ Matrix(const Matrix<U, otherRow, otherCol>& other)
+ {
+ int minRow = ROWS;
+ int minCol = COLS;
+ if (minRow > otherRow) minRow = otherRow;
+ if (minCol > otherCol) minCol = otherCol;
+ for (int i = 0; i < minRow; i++)
+ for (int j = 0; j < minCol; j++)
+ rows[i][j] = (T)other.rows[i][j];
+ }
+ Matrix(T v0, T v1, T v2, T v3)
+ {
+ rows[0][0] = v0; rows[0][1] = v1;
+ rows[1][0] = v2; rows[1][1] = v3;
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5)
+ {
+ if (COLS == 3)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
+ rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
+ }
+ else
+ {
+ rows[0][0] = v0; rows[0][1] = v1;
+ rows[1][0] = v2; rows[1][1] = v3;
+ rows[2][0] = v4; rows[2][1] = v5;
+ }
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7)
+ {
+ if (COLS == 4)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
+ rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
+ }
+ else
+ {
+ rows[0][0] = v0; rows[0][1] = v1;
+ rows[1][0] = v2; rows[1][1] = v3;
+ rows[2][0] = v4; rows[2][1] = v5;
+ rows[3][0] = v6; rows[3][1] = v7;
+ }
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
+ rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
+ rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11)
+ {
+ if (COLS == 4)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
+ rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
+ rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
+ }
+ else
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
+ rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
+ rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
+ rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11;
+ }
+ }
+ Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15)
+ {
+ rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
+ rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
+ rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
+ rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15;
+ }
+};
+
+#define SLANG_MATRIX_BINARY_OP(T, op) \
+ template<int R, int C> \
+ Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \
+ { \
+ Matrix<T, R, C> result;\
+ for (int i = 0; i < R; i++) \
+ for (int j = 0; j < C; j++) \
+ result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \
+ return result;\
+ }
+
+#define SLANG_MATRIX_UNARY_OP(T, op) \
+ template<int R, int C> \
+ Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \
+ { \
+ Matrix<T, R, C> result;\
+ for (int i = 0; i < R; i++) \
+ for (int j = 0; j < C; j++) \
+ result[i].rows[i][j] = op thisVal.rows[i][j]; \
+ return result;\
+ }
+#define SLANG_INT_MATRIX_OPS(T) \
+ SLANG_MATRIX_BINARY_OP(T, +)\
+ SLANG_MATRIX_BINARY_OP(T, -)\
+ SLANG_MATRIX_BINARY_OP(T, *)\
+ SLANG_MATRIX_BINARY_OP(T, / )\
+ SLANG_MATRIX_BINARY_OP(T, &)\
+ SLANG_MATRIX_BINARY_OP(T, |)\
+ SLANG_MATRIX_BINARY_OP(T, &&)\
+ SLANG_MATRIX_BINARY_OP(T, ||)\
+ SLANG_MATRIX_BINARY_OP(T, ^)\
+ SLANG_MATRIX_BINARY_OP(T, %)\
+ SLANG_MATRIX_UNARY_OP(T, !)\
+ SLANG_MATRIX_UNARY_OP(T, ~)
+#define SLANG_FLOAT_MATRIX_OPS(T) \
+ SLANG_MATRIX_BINARY_OP(T, +)\
+ SLANG_MATRIX_BINARY_OP(T, -)\
+ SLANG_MATRIX_BINARY_OP(T, *)\
+ SLANG_MATRIX_BINARY_OP(T, /)\
+ SLANG_MATRIX_UNARY_OP(T, -)
+SLANG_INT_MATRIX_OPS(int)
+SLANG_INT_MATRIX_OPS(int8_t)
+SLANG_INT_MATRIX_OPS(int16_t)
+SLANG_INT_MATRIX_OPS(int64_t)
+SLANG_INT_MATRIX_OPS(uint)
+SLANG_INT_MATRIX_OPS(uint8_t)
+SLANG_INT_MATRIX_OPS(uint16_t)
+SLANG_INT_MATRIX_OPS(uint64_t)
+
+SLANG_FLOAT_MATRIX_OPS(float)
+SLANG_FLOAT_MATRIX_OPS(double)
+
+#define SLANG_MATRIX_INT_NEG_OP(T) \
+ template<int R, int C>\
+ SLANG_FORCE_INLINE Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \
+ { \
+ Matrix<T, R, C> result;\
+ for (int i = 0; i < R; i++) \
+ for (int j = 0; j < C; j++) \
+ result.rows[i][j] = 0 - thisVal.rows[i][j]; \
+ return result;\
+ }
+ SLANG_MATRIX_INT_NEG_OP(int)
+ SLANG_MATRIX_INT_NEG_OP(int8_t)
+ SLANG_MATRIX_INT_NEG_OP(int16_t)
+ SLANG_MATRIX_INT_NEG_OP(int64_t)
+ SLANG_MATRIX_INT_NEG_OP(uint)
+ SLANG_MATRIX_INT_NEG_OP(uint8_t)
+ SLANG_MATRIX_INT_NEG_OP(uint16_t)
+ SLANG_MATRIX_INT_NEG_OP(uint64_t)
+
+#define SLANG_FLOAT_MATRIX_MOD(T)\
+ template<int R, int C> \
+ SLANG_FORCE_INLINE Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \
+ {\
+ Matrix<T, R, C> result;\
+ for (int i = 0; i < R; i++) \
+ for (int j = 0; j < C; j++) \
+ result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \
+ return result;\
+ }
+
+ SLANG_FLOAT_MATRIX_MOD(float)
+ SLANG_FLOAT_MATRIX_MOD(double)
+#undef SLANG_FLOAT_MATRIX_MOD
+#undef SLANG_MATRIX_BINARY_OP
+#undef SLANG_MATRIX_UNARY_OP
+#undef SLANG_INT_MATRIX_OPS
+#undef SLANG_FLOAT_MATRIX_OPS
+#undef SLANG_MATRIX_INT_NEG_OP
+#undef SLANG_FLOAT_MATRIX_MOD
+
+template<typename TResult, typename TInput>
+TResult slang_bit_cast(TInput val)
+{
+ return *(TResult*)(&val);
+}
+
+#endif
+
+
diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h
index 28fe3dd8d..ac66ad9f3 100644
--- a/prelude/slang-cpp-types.h
+++ b/prelude/slang-cpp-types.h
@@ -1,244 +1,12 @@
#ifndef SLANG_PRELUDE_CPP_TYPES_H
#define SLANG_PRELUDE_CPP_TYPES_H
-#ifndef SLANG_PRELUDE_ASSERT
-# ifdef SLANG_PRELUDE_ENABLE_ASSERT
-# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE)
-# else
-# define SLANG_PRELUDE_ASSERT(VALUE)
-# endif
-#endif
-
-// Since we are using unsigned arithmatic care is need in this comparison.
-// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0
-// Which means only a single test is needed
-
-// Asserts for bounds checking.
-// It is assumed index/count are unsigned types.
-#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count);
-#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0);
-
-// Macros to zero index if an access is out of range
-#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0;
-#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0;
-
-// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX
-// the fix macro will zero the index, if out of range
-#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX
-# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
-# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
-# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
-#else
-# define SLANG_BOUND_FIX(index, count)
-# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
-# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
-#endif
-
-#ifndef SLANG_BOUND_CHECK
-# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count)
-#endif
-
-#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS
-# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
-#endif
-
-#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY
-# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
-#endif
-
#ifdef SLANG_PRELUDE_NAMESPACE
namespace SLANG_PRELUDE_NAMESPACE {
#endif
-struct TypeInfo
-{
- size_t typeSize;
-};
-
-template <typename T, size_t SIZE>
-struct FixedArray
-{
- const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
- T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; }
-
- T m_data[SIZE];
-};
-
-// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially
-// do bounds checking.
-template <typename T>
-struct Array
-{
- const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; }
- T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; }
-
- T* data;
- size_t count;
-};
-
-/* Constant buffers become a pointer to the contained type, so ConstantBuffer<T> becomes T* in C++ code.
-*/
-
-template <typename T, int COUNT>
-struct Vector;
-
-template <typename T>
-struct Vector<T, 1>
-{
- T x;
- const T& operator[](size_t /*index*/) const { return x; }
- T& operator[](size_t /*index*/) { return x; }
- operator T() const { return x; }
- Vector() = default;
- Vector(T scalar)
- {
- x = scalar;
- }
- template <typename U>
- Vector(Vector<U, 1> other)
- {
- x = (T)other.x;
- }
- template <typename U, int otherSize>
- Vector(Vector<U, otherSize> other)
- {
- int minSize = 1;
- if (otherSize < minSize) minSize = otherSize;
- for (int i = 0; i < minSize; i++)
- (*this)[i] = (T)other[i];
- }
-};
-
-template <typename T>
-struct Vector<T, 2>
-{
- T x, y;
- const T& operator[](size_t index) const { return index == 0 ? x : y; }
- T& operator[](size_t index) { return index == 0 ? x : y; }
- Vector() = default;
- Vector(T scalar)
- {
- x = y = scalar;
- }
- Vector(T _x, T _y)
- {
- x = _x;
- y = _y;
- }
- template <typename U>
- Vector(Vector<U, 2> other)
- {
- x = (T)other.x;
- y = (T)other.y;
- }
- template <typename U, int otherSize>
- Vector(Vector<U, otherSize> other)
- {
- int minSize = 2;
- if (otherSize < minSize) minSize = otherSize;
- for (int i = 0; i < minSize; i++)
- (*this)[i] = (T)other[i];
- }
-};
-
-template <typename T>
-struct Vector<T, 3>
-{
- T x, y, z;
- const T& operator[](size_t index) const { return *((T*)(this) + index); }
- T& operator[](size_t index) { return *((T*)(this) + index); }
-
- Vector() = default;
- Vector(T scalar)
- {
- x = y = z = scalar;
- }
- Vector(T _x, T _y, T _z)
- {
- x = _x;
- y = _y;
- z = _z;
- }
- template <typename U>
- Vector(Vector<U, 3> other)
- {
- x = (T)other.x;
- y = (T)other.y;
- z = (T)other.z;
- }
- template <typename U, int otherSize>
- Vector(Vector<U, otherSize> other)
- {
- int minSize = 3;
- if (otherSize < minSize) minSize = otherSize;
- for (int i = 0; i < minSize; i++)
- (*this)[i] = (T)other[i];
- }
-};
-template <typename T>
-struct Vector<T, 4>
-{
- T x, y, z, w;
-
- const T& operator[](size_t index) const { return *((T*)(this) + index); }
- T& operator[](size_t index) { return *((T*)(this) + index); }
- Vector() = default;
- Vector(T scalar)
- {
- x = y = z = w = scalar;
- }
- Vector(T _x, T _y, T _z, T _w)
- {
- x = _x;
- y = _y;
- z = _z;
- w = _w;
- }
- template <typename U, int otherSize>
- Vector(Vector<U, otherSize> other)
- {
- int minSize = 4;
- if (otherSize < minSize) minSize = otherSize;
- for (int i = 0; i < minSize; i++)
- (*this)[i] = (T)other[i];
- }
-
-};
-
-template<typename T, int N>
-SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index)
-{
- return x[index];
-}
-
-template<typename T, int N>
-SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector<T, N>* x, int index)
-{
- return &((*const_cast<Vector<T,N>*>(x))[index]);
-}
-
-template<typename T, int N>
-SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector<T, N>* x, int index)
-{
- return &((*x)[index]);
-}
-
-template<typename T, int n, typename OtherT, int m>
-SLANG_FORCE_INLINE Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other)
-{
- Vector<T, n> result;
- for (int i = 0; i < n; i++)
- {
- OtherT otherElement = T(0);
- if (i < m)
- otherElement = _slang_vector_get_element(other, i);
- *_slang_vector_get_element_ptr(&result, i) = (T)otherElement;
- }
- return result;
-}
-
-typedef uint32_t uint;
+#include "slang-cpp-types-core.h"
typedef Vector<float, 2> float2;
typedef Vector<float, 3> float3;
@@ -252,320 +20,6 @@ typedef Vector<uint32_t, 2> uint2;
typedef Vector<uint32_t, 3> uint3;
typedef Vector<uint32_t, 4> uint4;
-#define SLANG_VECTOR_BINARY_OP(T, op) \
- template<int n> \
- SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
- { \
- Vector<T, n> result;\
- for (int i = 0; i < n; i++) \
- result[i] = thisVal[i] op other[i]; \
- return result;\
- }
-#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \
- template<int n> \
- SLANG_FORCE_INLINE Vector<bool, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \
- { \
- Vector<bool, n> result;\
- for (int i = 0; i < n; i++) \
- result[i] = thisVal[i] op other[i]; \
- return result;\
- }
-
-#define SLANG_VECTOR_UNARY_OP(T, op) \
- template<int n> \
- SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal) \
- { \
- Vector<T, n> result;\
- for (int i = 0; i < n; i++) \
- result[i] = op thisVal[i]; \
- return result;\
- }
-#define SLANG_INT_VECTOR_OPS(T) \
- SLANG_VECTOR_BINARY_OP(T, +)\
- SLANG_VECTOR_BINARY_OP(T, -)\
- SLANG_VECTOR_BINARY_OP(T, *)\
- SLANG_VECTOR_BINARY_OP(T, / )\
- SLANG_VECTOR_BINARY_OP(T, &)\
- SLANG_VECTOR_BINARY_OP(T, |)\
- SLANG_VECTOR_BINARY_OP(T, &&)\
- SLANG_VECTOR_BINARY_OP(T, ||)\
- SLANG_VECTOR_BINARY_OP(T, ^)\
- SLANG_VECTOR_BINARY_OP(T, %)\
- SLANG_VECTOR_BINARY_OP(T, >>)\
- SLANG_VECTOR_BINARY_OP(T, <<)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\
- SLANG_VECTOR_UNARY_OP(T, !)\
- SLANG_VECTOR_UNARY_OP(T, ~)
-#define SLANG_FLOAT_VECTOR_OPS(T) \
- SLANG_VECTOR_BINARY_OP(T, +)\
- SLANG_VECTOR_BINARY_OP(T, -)\
- SLANG_VECTOR_BINARY_OP(T, *)\
- SLANG_VECTOR_BINARY_OP(T, /)\
- SLANG_VECTOR_UNARY_OP(T, -)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\
- SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)
-
-SLANG_INT_VECTOR_OPS(bool)
-SLANG_INT_VECTOR_OPS(int)
-SLANG_INT_VECTOR_OPS(int8_t)
-SLANG_INT_VECTOR_OPS(int16_t)
-SLANG_INT_VECTOR_OPS(int64_t)
-SLANG_INT_VECTOR_OPS(uint)
-SLANG_INT_VECTOR_OPS(uint8_t)
-SLANG_INT_VECTOR_OPS(uint16_t)
-SLANG_INT_VECTOR_OPS(uint64_t)
-
-SLANG_FLOAT_VECTOR_OPS(float)
-SLANG_FLOAT_VECTOR_OPS(double)
-
-#define SLANG_VECTOR_INT_NEG_OP(T) \
- template<int N>\
- Vector<T, N> operator-(const Vector<T, N>& thisVal) \
- { \
- Vector<T, N> result;\
- for (int i = 0; i < N; i++) \
- result[i] = 0 - thisVal[i]; \
- return result;\
- }
-SLANG_VECTOR_INT_NEG_OP(int)
-SLANG_VECTOR_INT_NEG_OP(int8_t)
-SLANG_VECTOR_INT_NEG_OP(int16_t)
-SLANG_VECTOR_INT_NEG_OP(int64_t)
-SLANG_VECTOR_INT_NEG_OP(uint)
-SLANG_VECTOR_INT_NEG_OP(uint8_t)
-SLANG_VECTOR_INT_NEG_OP(uint16_t)
-SLANG_VECTOR_INT_NEG_OP(uint64_t)
-
-#define SLANG_FLOAT_VECTOR_MOD(T)\
- template<int N> \
- Vector<T, N> operator%(const Vector<T, N>& left, const Vector<T, N>& right) \
- {\
- Vector<T, N> result;\
- for (int i = 0; i < N; i++) \
- result[i] = _slang_fmod(left[i], right[i]); \
- return result;\
- }
-
-SLANG_FLOAT_VECTOR_MOD(float)
-SLANG_FLOAT_VECTOR_MOD(double)
-#undef SLANG_FLOAT_VECTOR_MOD
-#undef SLANG_VECTOR_BINARY_OP
-#undef SLANG_VECTOR_UNARY_OP
-#undef SLANG_INT_VECTOR_OPS
-#undef SLANG_FLOAT_VECTOR_OPS
-#undef SLANG_VECTOR_INT_NEG_OP
-#undef SLANG_FLOAT_VECTOR_MOD
-
-template <typename T, int ROWS, int COLS>
-struct Matrix
-{
- Vector<T, COLS> rows[ROWS];
- Vector<T, COLS>& operator[](size_t index) { return rows[index]; }
- Matrix() = default;
- Matrix(T scalar)
- {
- for (int i = 0; i < ROWS; i++)
- rows[i] = Vector<T, COLS>(scalar);
- }
- Matrix(const Vector<T, COLS>& row0)
- {
- rows[0] = row0;
- }
- Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1)
- {
- rows[0] = row0;
- rows[1] = row1;
- }
- Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2)
- {
- rows[0] = row0;
- rows[1] = row1;
- rows[2] = row2;
- }
- Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3)
- {
- rows[0] = row0;
- rows[1] = row1;
- rows[2] = row2;
- rows[3] = row3;
- }
- template<typename U, int otherRow, int otherCol>
- Matrix(const Matrix<U, otherRow, otherCol>& other)
- {
- int minRow = ROWS;
- int minCol = COLS;
- if (minRow > otherRow) minRow = otherRow;
- if (minCol > otherCol) minCol = otherCol;
- for (int i = 0; i < minRow; i++)
- for (int j = 0; j < minCol; j++)
- rows[i][j] = (T)other.rows[i][j];
- }
- Matrix(T v0, T v1, T v2, T v3)
- {
- rows[0][0] = v0; rows[0][1] = v1;
- rows[1][0] = v2; rows[1][1] = v3;
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5)
- {
- if (COLS == 3)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
- rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
- }
- else
- {
- rows[0][0] = v0; rows[0][1] = v1;
- rows[1][0] = v2; rows[1][1] = v3;
- rows[2][0] = v4; rows[2][1] = v5;
- }
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7)
- {
- if (COLS == 4)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
- rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
- }
- else
- {
- rows[0][0] = v0; rows[0][1] = v1;
- rows[1][0] = v2; rows[1][1] = v3;
- rows[2][0] = v4; rows[2][1] = v5;
- rows[3][0] = v6; rows[3][1] = v7;
- }
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
- rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
- rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11)
- {
- if (COLS == 4)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
- rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
- rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
- }
- else
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2;
- rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5;
- rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8;
- rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11;
- }
- }
- Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15)
- {
- rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3;
- rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7;
- rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11;
- rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15;
- }
-};
-
-#define SLANG_MATRIX_BINARY_OP(T, op) \
- template<int R, int C> \
- Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \
- { \
- Matrix<T, R, C> result;\
- for (int i = 0; i < R; i++) \
- for (int j = 0; j < C; j++) \
- result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \
- return result;\
- }
-
-#define SLANG_MATRIX_UNARY_OP(T, op) \
- template<int R, int C> \
- Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \
- { \
- Matrix<T, R, C> result;\
- for (int i = 0; i < R; i++) \
- for (int j = 0; j < C; j++) \
- result[i].rows[i][j] = op thisVal.rows[i][j]; \
- return result;\
- }
-#define SLANG_INT_MATRIX_OPS(T) \
- SLANG_MATRIX_BINARY_OP(T, +)\
- SLANG_MATRIX_BINARY_OP(T, -)\
- SLANG_MATRIX_BINARY_OP(T, *)\
- SLANG_MATRIX_BINARY_OP(T, / )\
- SLANG_MATRIX_BINARY_OP(T, &)\
- SLANG_MATRIX_BINARY_OP(T, |)\
- SLANG_MATRIX_BINARY_OP(T, &&)\
- SLANG_MATRIX_BINARY_OP(T, ||)\
- SLANG_MATRIX_BINARY_OP(T, ^)\
- SLANG_MATRIX_BINARY_OP(T, %)\
- SLANG_MATRIX_UNARY_OP(T, !)\
- SLANG_MATRIX_UNARY_OP(T, ~)
-#define SLANG_FLOAT_MATRIX_OPS(T) \
- SLANG_MATRIX_BINARY_OP(T, +)\
- SLANG_MATRIX_BINARY_OP(T, -)\
- SLANG_MATRIX_BINARY_OP(T, *)\
- SLANG_MATRIX_BINARY_OP(T, /)\
- SLANG_MATRIX_UNARY_OP(T, -)
-SLANG_INT_MATRIX_OPS(int)
-SLANG_INT_MATRIX_OPS(int8_t)
-SLANG_INT_MATRIX_OPS(int16_t)
-SLANG_INT_MATRIX_OPS(int64_t)
-SLANG_INT_MATRIX_OPS(uint)
-SLANG_INT_MATRIX_OPS(uint8_t)
-SLANG_INT_MATRIX_OPS(uint16_t)
-SLANG_INT_MATRIX_OPS(uint64_t)
-
-SLANG_FLOAT_MATRIX_OPS(float)
-SLANG_FLOAT_MATRIX_OPS(double)
-
-#define SLANG_MATRIX_INT_NEG_OP(T) \
- template<int R, int C>\
- SLANG_FORCE_INLINE Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \
- { \
- Matrix<T, R, C> result;\
- for (int i = 0; i < R; i++) \
- for (int j = 0; j < C; j++) \
- result.rows[i][j] = 0 - thisVal.rows[i][j]; \
- return result;\
- }
- SLANG_MATRIX_INT_NEG_OP(int)
- SLANG_MATRIX_INT_NEG_OP(int8_t)
- SLANG_MATRIX_INT_NEG_OP(int16_t)
- SLANG_MATRIX_INT_NEG_OP(int64_t)
- SLANG_MATRIX_INT_NEG_OP(uint)
- SLANG_MATRIX_INT_NEG_OP(uint8_t)
- SLANG_MATRIX_INT_NEG_OP(uint16_t)
- SLANG_MATRIX_INT_NEG_OP(uint64_t)
-
-#define SLANG_FLOAT_MATRIX_MOD(T)\
- template<int R, int C> \
- SLANG_FORCE_INLINE Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \
- {\
- Matrix<T, R, C> result;\
- for (int i = 0; i < R; i++) \
- for (int j = 0; j < C; j++) \
- result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \
- return result;\
- }
-
- SLANG_FLOAT_MATRIX_MOD(float)
- SLANG_FLOAT_MATRIX_MOD(double)
-#undef SLANG_FLOAT_MATRIX_MOD
-#undef SLANG_MATRIX_BINARY_OP
-#undef SLANG_MATRIX_UNARY_OP
-#undef SLANG_INT_MATRIX_OPS
-#undef SLANG_FLOAT_MATRIX_OPS
-#undef SLANG_MATRIX_INT_NEG_OP
-#undef SLANG_FLOAT_MATRIX_MOD
-
// We can just map `NonUniformResourceIndex` type directly to the index type on CPU, as CPU does not require
// any special handling around such accesses.
typedef size_t NonUniformResourceIndex;
@@ -1484,12 +938,6 @@ struct ComputeVaryingInput
typedef void(*ComputeThreadFunc)(ComputeThreadVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState);
typedef void(*ComputeFunc)(ComputeVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState);
-template<typename TResult, typename TInput>
-TResult slang_bit_cast(TInput val)
-{
- return *(TResult*)(&val);
-}
-
#ifdef SLANG_PRELUDE_NAMESPACE
}
#endif
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 7a4c5a918..9a55aed57 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -6,6 +6,8 @@
#define SLANG_CUDA_RTC 0
#endif
+#include <stdio.h>
+
// Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support.
// For this to work NVRTC needs to have the path to the CUDA SDK.
//
@@ -2080,4 +2082,73 @@ __forceinline__ __device__ void *traceOptiXRay(
r0, r1
);
}
+
#endif
+
+
+// TensorView
+struct TensorView
+{
+ uint8_t* data;
+ uint32_t* strides;
+ uint32_t* sizes;
+ uint32_t dimensionCount;
+
+ template<typename T>
+ __device__ T* data_ptr()
+ {
+ return reinterpret_cast<T*>(data);
+ }
+
+ template<typename T>
+ __device__ T load(uint32_t x)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * x);
+ }
+ template<typename T>
+ __device__ T load(uint32_t x, uint32_t y)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y);
+ }
+ template<typename T>
+ __device__ T load(uint32_t x, uint32_t y, uint32_t z)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z);
+ }
+ template<typename T>
+ __device__ T load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w);
+ }
+ template<typename T>
+ __device__ T load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
+ {
+ return *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4);
+ }
+ template<typename T>
+ __device__ void store(uint32_t x, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * x) = val;
+ }
+ template<typename T>
+ __device__ void store(uint32_t x, uint32_t y, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y) = val;
+ }
+ template<typename T>
+ __device__ void store(uint32_t x, uint32_t y, uint32_t z, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z) = val;
+ }
+ template<typename T>
+ __device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val)
+ {
+ *reinterpret_cast<T*>(
+ data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w) = val;
+ }
+ template<typename T>
+ __device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val)
+ {
+ *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4) = val;
+ }
+};
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
new file mode 100644
index 000000000..f2accc149
--- /dev/null
+++ b/prelude/slang-torch-prelude.h
@@ -0,0 +1,126 @@
+// Prelude for PyTorch cpp binding.
+
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAUtils.h>
+#include <vector>
+
+#ifndef SLANG_NO_THROW
+# define SLANG_NO_THROW
+#endif
+
+#ifndef SLANG_STDCALL
+# define SLANG_STDCALL
+#endif
+#ifndef SLANG_MCALL
+# define SLANG_MCALL SLANG_STDCALL
+#endif
+#ifndef SLANG_FORCE_INLINE
+# define SLANG_FORCE_INLINE inline
+#endif
+
+#ifdef SLANG_LLVM
+#include "slang-llvm.h"
+#else // SLANG_LLVM
+# if SLANG_GCC_FAMILY && __GNUC__ < 6
+# include <cmath>
+# define SLANG_PRELUDE_STD std::
+# else
+# include <math.h>
+# define SLANG_PRELUDE_STD
+# endif
+
+# include <assert.h>
+# include <stdlib.h>
+# include <string.h>
+# include <stdint.h>
+#endif // SLANG_LLVM
+
+#include "slang-cpp-types-core.h"
+#include "slang-cpp-scalar-intrinsics.h"
+
+struct TensorView
+{
+ uint8_t* data;
+ uint32_t* strides;
+ uint32_t* sizes;
+ uint32_t dimensionCount;
+};
+
+struct CudaTaskMemoryAllocator
+{
+ std::vector<void*> allocations;
+
+ uint32_t* allocUIntArray(uint32_t size)
+ {
+ void* ptr = nullptr;
+ cudaMallocManaged(&ptr, size * sizeof(uint32_t));
+ AT_CUDA_CHECK(cudaGetLastError());
+ return (uint32_t*)ptr;
+ }
+
+ ~CudaTaskMemoryAllocator()
+ {
+ for (auto ptr : allocations)
+ cudaFree(ptr);
+ }
+};
+
+TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val)
+{
+ val = val.to(torch::kCUDA);
+ TensorView res = {};
+ res.dimensionCount = val.dim();
+ res.strides = allocator->allocUIntArray(val.dim());
+ res.sizes = allocator->allocUIntArray(val.dim());
+ res.data = nullptr;
+ size_t elementSize = 4;
+ switch (val.scalar_type())
+ {
+ case torch::kInt8:
+ case torch::kUInt8:
+ elementSize = 1;
+ res.data = (uint8_t*)val.data_ptr<uint8_t>();
+ break;
+ case torch::kBFloat16:
+ elementSize = 2;
+ res.data = (uint8_t*)val.data_ptr<torch::BFloat16>();
+ break;
+ case torch::kInt16:
+ elementSize = 2;
+ res.data = (uint8_t*)val.data_ptr<int16_t>();
+ break;
+ case torch::kFloat32:
+ elementSize = 4;
+ res.data = (uint8_t*)val.data_ptr<float>();
+ break;
+ case torch::kInt32:
+ elementSize = 4;
+ res.data = (uint8_t*)val.data_ptr<int32_t>();
+ break;
+ case torch::kFloat64:
+ elementSize = 8;
+ res.data = (uint8_t*)val.data_ptr<double>();
+ break;
+ case torch::kInt64:
+ elementSize = 8;
+ res.data = (uint8_t*)val.data_ptr<int64_t>();
+ break;
+ }
+ for (int i = 0; i < val.dim(); ++i)
+ {
+ res.strides[i] = val.stride(i) * elementSize;
+ res.sizes[i] = val.size(i);
+ }
+ return res;
+}
+
+size_t slangGetCudaKernelSharedMemSize(const void* func)
+{
+ cudaFuncAttributes attr = {};
+ cudaFuncGetAttributes(&attr, func);
+ AT_CUDA_CHECK(cudaGetLastError());
+ return attr.sharedSizeBytes;
+}
+
+#define SLANG_PRELUDE_EXPORT