From d64ee86a3130f8eeb75d09193c38c621d7565eba Mon Sep 17 00:00:00 2001 From: Yong He Date: Sun, 26 Mar 2023 13:59:11 -0700 Subject: Add PyTorch C++ binding generation. (#2734) * Add PyTorch C++ binding generation. * fix --------- Co-authored-by: Yong He --- prelude/slang-cpp-types.h | 554 +--------------------------------------------- 1 file changed, 1 insertion(+), 553 deletions(-) (limited to 'prelude/slang-cpp-types.h') diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h index 28fe3dd8d..ac66ad9f3 100644 --- a/prelude/slang-cpp-types.h +++ b/prelude/slang-cpp-types.h @@ -1,244 +1,12 @@ #ifndef SLANG_PRELUDE_CPP_TYPES_H #define SLANG_PRELUDE_CPP_TYPES_H -#ifndef SLANG_PRELUDE_ASSERT -# ifdef SLANG_PRELUDE_ENABLE_ASSERT -# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE) -# else -# define SLANG_PRELUDE_ASSERT(VALUE) -# endif -#endif - -// Since we are using unsigned arithmatic care is need in this comparison. -// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0 -// Which means only a single test is needed - -// Asserts for bounds checking. -// It is assumed index/count are unsigned types. -#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count); -#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0); - -// Macros to zero index if an access is out of range -#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0; -#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0; - -// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX -// the fix macro will zero the index, if out of range -#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX -# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count) -# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count) -#else -# define SLANG_BOUND_FIX(index, count) -# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) -#endif - -#ifndef SLANG_BOUND_CHECK -# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count) -#endif - -#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS -# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -#endif - -#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY -# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count) -#endif - #ifdef SLANG_PRELUDE_NAMESPACE namespace SLANG_PRELUDE_NAMESPACE { #endif -struct TypeInfo -{ - size_t typeSize; -}; - -template -struct FixedArray -{ - const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } - T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } - - T m_data[SIZE]; -}; - -// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially -// do bounds checking. -template -struct Array -{ - const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; } - T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; } - - T* data; - size_t count; -}; - -/* Constant buffers become a pointer to the contained type, so ConstantBuffer becomes T* in C++ code. -*/ - -template -struct Vector; - -template -struct Vector -{ - T x; - const T& operator[](size_t /*index*/) const { return x; } - T& operator[](size_t /*index*/) { return x; } - operator T() const { return x; } - Vector() = default; - Vector(T scalar) - { - x = scalar; - } - template - Vector(Vector other) - { - x = (T)other.x; - } - template - Vector(Vector other) - { - int minSize = 1; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; - -template -struct Vector -{ - T x, y; - const T& operator[](size_t index) const { return index == 0 ? x : y; } - T& operator[](size_t index) { return index == 0 ? x : y; } - Vector() = default; - Vector(T scalar) - { - x = y = scalar; - } - Vector(T _x, T _y) - { - x = _x; - y = _y; - } - template - Vector(Vector other) - { - x = (T)other.x; - y = (T)other.y; - } - template - Vector(Vector other) - { - int minSize = 2; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; - -template -struct Vector -{ - T x, y, z; - const T& operator[](size_t index) const { return *((T*)(this) + index); } - T& operator[](size_t index) { return *((T*)(this) + index); } - - Vector() = default; - Vector(T scalar) - { - x = y = z = scalar; - } - Vector(T _x, T _y, T _z) - { - x = _x; - y = _y; - z = _z; - } - template - Vector(Vector other) - { - x = (T)other.x; - y = (T)other.y; - z = (T)other.z; - } - template - Vector(Vector other) - { - int minSize = 3; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; -template -struct Vector -{ - T x, y, z, w; - - const T& operator[](size_t index) const { return *((T*)(this) + index); } - T& operator[](size_t index) { return *((T*)(this) + index); } - Vector() = default; - Vector(T scalar) - { - x = y = z = w = scalar; - } - Vector(T _x, T _y, T _z, T _w) - { - x = _x; - y = _y; - z = _z; - w = _w; - } - template - Vector(Vector other) - { - int minSize = 4; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } - -}; - -template -SLANG_FORCE_INLINE T _slang_vector_get_element(Vector x, int index) -{ - return x[index]; -} - -template -SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector* x, int index) -{ - return &((*const_cast*>(x))[index]); -} - -template -SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector* x, int index) -{ - return &((*x)[index]); -} - -template -SLANG_FORCE_INLINE Vector _slang_vector_reshape(const Vector other) -{ - Vector result; - for (int i = 0; i < n; i++) - { - OtherT otherElement = T(0); - if (i < m) - otherElement = _slang_vector_get_element(other, i); - *_slang_vector_get_element_ptr(&result, i) = (T)otherElement; - } - return result; -} - -typedef uint32_t uint; +#include "slang-cpp-types-core.h" typedef Vector float2; typedef Vector float3; @@ -252,320 +20,6 @@ typedef Vector uint2; typedef Vector uint3; typedef Vector uint4; -#define SLANG_VECTOR_BINARY_OP(T, op) \ - template \ - SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal, const Vector& other) \ - { \ - Vector result;\ - for (int i = 0; i < n; i++) \ - result[i] = thisVal[i] op other[i]; \ - return result;\ - } -#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \ - template \ - SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal, const Vector& other) \ - { \ - Vector result;\ - for (int i = 0; i < n; i++) \ - result[i] = thisVal[i] op other[i]; \ - return result;\ - } - -#define SLANG_VECTOR_UNARY_OP(T, op) \ - template \ - SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal) \ - { \ - Vector result;\ - for (int i = 0; i < n; i++) \ - result[i] = op thisVal[i]; \ - return result;\ - } -#define SLANG_INT_VECTOR_OPS(T) \ - SLANG_VECTOR_BINARY_OP(T, +)\ - SLANG_VECTOR_BINARY_OP(T, -)\ - SLANG_VECTOR_BINARY_OP(T, *)\ - SLANG_VECTOR_BINARY_OP(T, / )\ - SLANG_VECTOR_BINARY_OP(T, &)\ - SLANG_VECTOR_BINARY_OP(T, |)\ - SLANG_VECTOR_BINARY_OP(T, &&)\ - SLANG_VECTOR_BINARY_OP(T, ||)\ - SLANG_VECTOR_BINARY_OP(T, ^)\ - SLANG_VECTOR_BINARY_OP(T, %)\ - SLANG_VECTOR_BINARY_OP(T, >>)\ - SLANG_VECTOR_BINARY_OP(T, <<)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\ - SLANG_VECTOR_UNARY_OP(T, !)\ - SLANG_VECTOR_UNARY_OP(T, ~) -#define SLANG_FLOAT_VECTOR_OPS(T) \ - SLANG_VECTOR_BINARY_OP(T, +)\ - SLANG_VECTOR_BINARY_OP(T, -)\ - SLANG_VECTOR_BINARY_OP(T, *)\ - SLANG_VECTOR_BINARY_OP(T, /)\ - SLANG_VECTOR_UNARY_OP(T, -)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, !=) - -SLANG_INT_VECTOR_OPS(bool) -SLANG_INT_VECTOR_OPS(int) -SLANG_INT_VECTOR_OPS(int8_t) -SLANG_INT_VECTOR_OPS(int16_t) -SLANG_INT_VECTOR_OPS(int64_t) -SLANG_INT_VECTOR_OPS(uint) -SLANG_INT_VECTOR_OPS(uint8_t) -SLANG_INT_VECTOR_OPS(uint16_t) -SLANG_INT_VECTOR_OPS(uint64_t) - -SLANG_FLOAT_VECTOR_OPS(float) -SLANG_FLOAT_VECTOR_OPS(double) - -#define SLANG_VECTOR_INT_NEG_OP(T) \ - template\ - Vector operator-(const Vector& thisVal) \ - { \ - Vector result;\ - for (int i = 0; i < N; i++) \ - result[i] = 0 - thisVal[i]; \ - return result;\ - } -SLANG_VECTOR_INT_NEG_OP(int) -SLANG_VECTOR_INT_NEG_OP(int8_t) -SLANG_VECTOR_INT_NEG_OP(int16_t) -SLANG_VECTOR_INT_NEG_OP(int64_t) -SLANG_VECTOR_INT_NEG_OP(uint) -SLANG_VECTOR_INT_NEG_OP(uint8_t) -SLANG_VECTOR_INT_NEG_OP(uint16_t) -SLANG_VECTOR_INT_NEG_OP(uint64_t) - -#define SLANG_FLOAT_VECTOR_MOD(T)\ - template \ - Vector operator%(const Vector& left, const Vector& right) \ - {\ - Vector result;\ - for (int i = 0; i < N; i++) \ - result[i] = _slang_fmod(left[i], right[i]); \ - return result;\ - } - -SLANG_FLOAT_VECTOR_MOD(float) -SLANG_FLOAT_VECTOR_MOD(double) -#undef SLANG_FLOAT_VECTOR_MOD -#undef SLANG_VECTOR_BINARY_OP -#undef SLANG_VECTOR_UNARY_OP -#undef SLANG_INT_VECTOR_OPS -#undef SLANG_FLOAT_VECTOR_OPS -#undef SLANG_VECTOR_INT_NEG_OP -#undef SLANG_FLOAT_VECTOR_MOD - -template -struct Matrix -{ - Vector rows[ROWS]; - Vector& operator[](size_t index) { return rows[index]; } - Matrix() = default; - Matrix(T scalar) - { - for (int i = 0; i < ROWS; i++) - rows[i] = Vector(scalar); - } - Matrix(const Vector& row0) - { - rows[0] = row0; - } - Matrix(const Vector& row0, const Vector& row1) - { - rows[0] = row0; - rows[1] = row1; - } - Matrix(const Vector& row0, const Vector& row1, const Vector& row2) - { - rows[0] = row0; - rows[1] = row1; - rows[2] = row2; - } - Matrix(const Vector& row0, const Vector& row1, const Vector& row2, const Vector& row3) - { - rows[0] = row0; - rows[1] = row1; - rows[2] = row2; - rows[3] = row3; - } - template - Matrix(const Matrix& other) - { - int minRow = ROWS; - int minCol = COLS; - if (minRow > otherRow) minRow = otherRow; - if (minCol > otherCol) minCol = otherCol; - for (int i = 0; i < minRow; i++) - for (int j = 0; j < minCol; j++) - rows[i][j] = (T)other.rows[i][j]; - } - Matrix(T v0, T v1, T v2, T v3) - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5) - { - if (COLS == 3) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - rows[2][0] = v4; rows[2][1] = v5; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) - { - if (COLS == 4) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - rows[2][0] = v4; rows[2][1] = v5; - rows[3][0] = v6; rows[3][1] = v7; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11) - { - if (COLS == 4) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; - rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; - rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15; - } -}; - -#define SLANG_MATRIX_BINARY_OP(T, op) \ - template \ - Matrix operator op(const Matrix& thisVal, const Matrix& other) \ - { \ - Matrix result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \ - return result;\ - } - -#define SLANG_MATRIX_UNARY_OP(T, op) \ - template \ - Matrix operator op(const Matrix& thisVal) \ - { \ - Matrix result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result[i].rows[i][j] = op thisVal.rows[i][j]; \ - return result;\ - } -#define SLANG_INT_MATRIX_OPS(T) \ - SLANG_MATRIX_BINARY_OP(T, +)\ - SLANG_MATRIX_BINARY_OP(T, -)\ - SLANG_MATRIX_BINARY_OP(T, *)\ - SLANG_MATRIX_BINARY_OP(T, / )\ - SLANG_MATRIX_BINARY_OP(T, &)\ - SLANG_MATRIX_BINARY_OP(T, |)\ - SLANG_MATRIX_BINARY_OP(T, &&)\ - SLANG_MATRIX_BINARY_OP(T, ||)\ - SLANG_MATRIX_BINARY_OP(T, ^)\ - SLANG_MATRIX_BINARY_OP(T, %)\ - SLANG_MATRIX_UNARY_OP(T, !)\ - SLANG_MATRIX_UNARY_OP(T, ~) -#define SLANG_FLOAT_MATRIX_OPS(T) \ - SLANG_MATRIX_BINARY_OP(T, +)\ - SLANG_MATRIX_BINARY_OP(T, -)\ - SLANG_MATRIX_BINARY_OP(T, *)\ - SLANG_MATRIX_BINARY_OP(T, /)\ - SLANG_MATRIX_UNARY_OP(T, -) -SLANG_INT_MATRIX_OPS(int) -SLANG_INT_MATRIX_OPS(int8_t) -SLANG_INT_MATRIX_OPS(int16_t) -SLANG_INT_MATRIX_OPS(int64_t) -SLANG_INT_MATRIX_OPS(uint) -SLANG_INT_MATRIX_OPS(uint8_t) -SLANG_INT_MATRIX_OPS(uint16_t) -SLANG_INT_MATRIX_OPS(uint64_t) - -SLANG_FLOAT_MATRIX_OPS(float) -SLANG_FLOAT_MATRIX_OPS(double) - -#define SLANG_MATRIX_INT_NEG_OP(T) \ - template\ - SLANG_FORCE_INLINE Matrix operator-(Matrix thisVal) \ - { \ - Matrix result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = 0 - thisVal.rows[i][j]; \ - return result;\ - } - SLANG_MATRIX_INT_NEG_OP(int) - SLANG_MATRIX_INT_NEG_OP(int8_t) - SLANG_MATRIX_INT_NEG_OP(int16_t) - SLANG_MATRIX_INT_NEG_OP(int64_t) - SLANG_MATRIX_INT_NEG_OP(uint) - SLANG_MATRIX_INT_NEG_OP(uint8_t) - SLANG_MATRIX_INT_NEG_OP(uint16_t) - SLANG_MATRIX_INT_NEG_OP(uint64_t) - -#define SLANG_FLOAT_MATRIX_MOD(T)\ - template \ - SLANG_FORCE_INLINE Matrix operator%(Matrix left, Matrix right) \ - {\ - Matrix result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \ - return result;\ - } - - SLANG_FLOAT_MATRIX_MOD(float) - SLANG_FLOAT_MATRIX_MOD(double) -#undef SLANG_FLOAT_MATRIX_MOD -#undef SLANG_MATRIX_BINARY_OP -#undef SLANG_MATRIX_UNARY_OP -#undef SLANG_INT_MATRIX_OPS -#undef SLANG_FLOAT_MATRIX_OPS -#undef SLANG_MATRIX_INT_NEG_OP -#undef SLANG_FLOAT_MATRIX_MOD - // We can just map `NonUniformResourceIndex` type directly to the index type on CPU, as CPU does not require // any special handling around such accesses. typedef size_t NonUniformResourceIndex; @@ -1484,12 +938,6 @@ struct ComputeVaryingInput typedef void(*ComputeThreadFunc)(ComputeThreadVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState); typedef void(*ComputeFunc)(ComputeVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState); -template -TResult slang_bit_cast(TInput val) -{ - return *(TResult*)(&val); -} - #ifdef SLANG_PRELUDE_NAMESPACE } #endif -- cgit v1.2.3