From d64ee86a3130f8eeb75d09193c38c621d7565eba Mon Sep 17 00:00:00 2001 From: Yong He Date: Sun, 26 Mar 2023 13:59:11 -0700 Subject: Add PyTorch C++ binding generation. (#2734) * Add PyTorch C++ binding generation. * fix --------- Co-authored-by: Yong He --- .../run-generators/run-generators.vcxproj | 18 + .../run-generators/run-generators.vcxproj.filters | 6 + build/visual-studio/slang/slang.vcxproj | 5 + build/visual-studio/slang/slang.vcxproj.filters | 15 + prelude/slang-cpp-types-core.h | 561 +++++++++++++++++++++ prelude/slang-cpp-types.h | 554 +------------------- prelude/slang-cuda-prelude.h | 71 +++ prelude/slang-torch-prelude.h | 126 +++++ premake5.lua | 9 +- slang.h | 1 + source/compiler-core/slang-artifact-desc-util.cpp | 1 + source/compiler-core/slang-artifact.h | 1 - source/core/slang-type-convert-util.cpp | 1 + source/core/slang-type-text-util.cpp | 1 + source/slang/core.meta.slang | 3 + source/slang/diff.meta.slang | 84 ++- source/slang/hlsl.meta.slang | 10 + source/slang/slang-ast-modifier.h | 5 + source/slang/slang-ast-type.cpp | 7 + source/slang/slang-ast-type.h | 9 + source/slang/slang-compiler.cpp | 2 + source/slang/slang-compiler.h | 1 + source/slang/slang-diagnostic-defs.h | 2 + source/slang/slang-emit-c-like.cpp | 1 + source/slang/slang-emit-cpp.cpp | 25 + source/slang/slang-emit-cuda.cpp | 5 + source/slang/slang-emit-torch.cpp | 181 +++++++ source/slang/slang-emit-torch.h | 28 + source/slang/slang-emit.cpp | 90 ++-- source/slang/slang-ir-inst-defs.h | 49 +- source/slang/slang-ir-insts.h | 44 ++ source/slang/slang-ir-pytorch-cpp-binding.cpp | 248 +++++++++ source/slang/slang-ir-pytorch-cpp-binding.h | 12 + source/slang/slang-ir.cpp | 57 +++ source/slang/slang-ir.h | 25 + source/slang/slang-lower-to-ir.cpp | 94 ++-- source/slang/slang-options.cpp | 3 +- source/slang/slang-parser.cpp | 4 +- source/slang/slang-syntax.cpp | 7 + source/slangc/main.cpp | 4 +- tests/autodiff/cuda-kernel-export.slang | 38 +- tools/gfx/slang.slang | 3 +- tools/slang-test/slang-test-main.cpp | 1 + 43 files changed, 1733 insertions(+), 679 deletions(-) create mode 100644 prelude/slang-cpp-types-core.h create mode 100644 prelude/slang-torch-prelude.h create mode 100644 source/slang/slang-emit-torch.cpp create mode 100644 source/slang/slang-emit-torch.h create mode 100644 source/slang/slang-ir-pytorch-cpp-binding.cpp create mode 100644 source/slang/slang-ir-pytorch-cpp-binding.h diff --git a/build/visual-studio/run-generators/run-generators.vcxproj b/build/visual-studio/run-generators/run-generators.vcxproj index 045601edf..31037fc27 100644 --- a/build/visual-studio/run-generators/run-generators.vcxproj +++ b/build/visual-studio/run-generators/run-generators.vcxproj @@ -141,6 +141,7 @@ + @@ -216,6 +217,23 @@ ../../../bin/windows-x64/release/slang-embed.exe ../../../bin/windows-aarch64/release/slang-embed.exe + + Document + "../../../bin/windows-x86/debug/slang-embed" %(Identity) + "../../../bin/windows-x64/debug/slang-embed" %(Identity) + "../../../bin/windows-aarch64/debug/slang-embed" %(Identity) + "../../../bin/windows-x86/release/slang-embed" %(Identity) + "../../../bin/windows-x64/release/slang-embed" %(Identity) + "../../../bin/windows-aarch64/release/slang-embed" %(Identity) + ../../../prelude/slang-torch-prelude.h.cpp + slang-embed %(Identity) + ../../../bin/windows-x86/debug/slang-embed.exe + ../../../bin/windows-x64/debug/slang-embed.exe + ../../../bin/windows-aarch64/debug/slang-embed.exe + ../../../bin/windows-x86/release/slang-embed.exe + ../../../bin/windows-x64/release/slang-embed.exe + ../../../bin/windows-aarch64/release/slang-embed.exe + Document "../../../bin/windows-x86/debug/slang-generate" %(Identity) diff --git a/build/visual-studio/run-generators/run-generators.vcxproj.filters b/build/visual-studio/run-generators/run-generators.vcxproj.filters index d91507ae6..233899d08 100644 --- a/build/visual-studio/run-generators/run-generators.vcxproj.filters +++ b/build/visual-studio/run-generators/run-generators.vcxproj.filters @@ -12,6 +12,9 @@ Header Files + + Header Files + Header Files @@ -37,6 +40,9 @@ Header Files + + Header Files + Source Files diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index fbc827f02..71545a5d1 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -336,6 +336,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla + @@ -411,6 +412,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla + @@ -489,6 +491,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla + @@ -527,6 +530,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla + @@ -599,6 +603,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla + diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 87655e974..f353468e4 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -114,6 +114,9 @@ Header Files + + Header Files + Header Files @@ -339,6 +342,9 @@ Header Files + + Header Files + Header Files @@ -569,6 +575,9 @@ Header Files + + Header Files + Source Files @@ -683,6 +692,9 @@ Source Files + + Source Files + Source Files @@ -899,6 +911,9 @@ Source Files + + Source Files + Source Files diff --git a/prelude/slang-cpp-types-core.h b/prelude/slang-cpp-types-core.h new file mode 100644 index 000000000..c49ee013c --- /dev/null +++ b/prelude/slang-cpp-types-core.h @@ -0,0 +1,561 @@ +#ifndef SLANG_PRELUDE_CPP_TYPES_CORE_H +#define SLANG_PRELUDE_CPP_TYPES_CORE_H + +#ifndef SLANG_PRELUDE_ASSERT +# ifdef SLANG_PRELUDE_ENABLE_ASSERT +# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE) +# else +# define SLANG_PRELUDE_ASSERT(VALUE) +# endif +#endif + +// Since we are using unsigned arithmatic care is need in this comparison. +// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0 +// Which means only a single test is needed + +// Asserts for bounds checking. +// It is assumed index/count are unsigned types. +#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count); +#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0); + +// Macros to zero index if an access is out of range +#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0; +#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0; + +// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX +// the fix macro will zero the index, if out of range +#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX +# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count) +# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) +# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count) +#else +# define SLANG_BOUND_FIX(index, count) +# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) +# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) +#endif + +#ifndef SLANG_BOUND_CHECK +# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count) +#endif + +#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS +# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) +#endif + +#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY +# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count) +#endif + +struct TypeInfo +{ + size_t typeSize; +}; + +template +struct FixedArray +{ + const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } + T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } + + T m_data[SIZE]; +}; + +// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially +// do bounds checking. +template +struct Array +{ + const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; } + T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; } + + T* data; + size_t count; +}; + +/* Constant buffers become a pointer to the contained type, so ConstantBuffer becomes T* in C++ code. +*/ + +template +struct Vector; + +template +struct Vector +{ + T x; + const T& operator[](size_t /*index*/) const { return x; } + T& operator[](size_t /*index*/) { return x; } + operator T() const { return x; } + Vector() = default; + Vector(T scalar) + { + x = scalar; + } + template + Vector(Vector other) + { + x = (T)other.x; + } + template + Vector(Vector other) + { + int minSize = 1; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } +}; + +template +struct Vector +{ + T x, y; + const T& operator[](size_t index) const { return index == 0 ? x : y; } + T& operator[](size_t index) { return index == 0 ? x : y; } + Vector() = default; + Vector(T scalar) + { + x = y = scalar; + } + Vector(T _x, T _y) + { + x = _x; + y = _y; + } + template + Vector(Vector other) + { + x = (T)other.x; + y = (T)other.y; + } + template + Vector(Vector other) + { + int minSize = 2; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } +}; + +template +struct Vector +{ + T x, y, z; + const T& operator[](size_t index) const { return *((T*)(this) + index); } + T& operator[](size_t index) { return *((T*)(this) + index); } + + Vector() = default; + Vector(T scalar) + { + x = y = z = scalar; + } + Vector(T _x, T _y, T _z) + { + x = _x; + y = _y; + z = _z; + } + template + Vector(Vector other) + { + x = (T)other.x; + y = (T)other.y; + z = (T)other.z; + } + template + Vector(Vector other) + { + int minSize = 3; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } +}; + +template +struct Vector +{ + T x, y, z, w; + + const T& operator[](size_t index) const { return *((T*)(this) + index); } + T& operator[](size_t index) { return *((T*)(this) + index); } + Vector() = default; + Vector(T scalar) + { + x = y = z = w = scalar; + } + Vector(T _x, T _y, T _z, T _w) + { + x = _x; + y = _y; + z = _z; + w = _w; + } + template + Vector(Vector other) + { + int minSize = 4; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } + +}; + +template +SLANG_FORCE_INLINE T _slang_vector_get_element(Vector x, int index) +{ + return x[index]; +} + +template +SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector* x, int index) +{ + return &((*const_cast*>(x))[index]); +} + +template +SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector* x, int index) +{ + return &((*x)[index]); +} + +template +SLANG_FORCE_INLINE Vector _slang_vector_reshape(const Vector other) +{ + Vector result; + for (int i = 0; i < n; i++) + { + OtherT otherElement = T(0); + if (i < m) + otherElement = _slang_vector_get_element(other, i); + *_slang_vector_get_element_ptr(&result, i) = (T)otherElement; + } + return result; +} + +typedef uint32_t uint; + +#define SLANG_VECTOR_BINARY_OP(T, op) \ + template \ + SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal, const Vector& other) \ + { \ + Vector result;\ + for (int i = 0; i < n; i++) \ + result[i] = thisVal[i] op other[i]; \ + return result;\ + } +#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \ + template \ + SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal, const Vector& other) \ + { \ + Vector result;\ + for (int i = 0; i < n; i++) \ + result[i] = thisVal[i] op other[i]; \ + return result;\ + } + +#define SLANG_VECTOR_UNARY_OP(T, op) \ + template \ + SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal) \ + { \ + Vector result;\ + for (int i = 0; i < n; i++) \ + result[i] = op thisVal[i]; \ + return result;\ + } +#define SLANG_INT_VECTOR_OPS(T) \ + SLANG_VECTOR_BINARY_OP(T, +)\ + SLANG_VECTOR_BINARY_OP(T, -)\ + SLANG_VECTOR_BINARY_OP(T, *)\ + SLANG_VECTOR_BINARY_OP(T, / )\ + SLANG_VECTOR_BINARY_OP(T, &)\ + SLANG_VECTOR_BINARY_OP(T, |)\ + SLANG_VECTOR_BINARY_OP(T, &&)\ + SLANG_VECTOR_BINARY_OP(T, ||)\ + SLANG_VECTOR_BINARY_OP(T, ^)\ + SLANG_VECTOR_BINARY_OP(T, %)\ + SLANG_VECTOR_BINARY_OP(T, >>)\ + SLANG_VECTOR_BINARY_OP(T, <<)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\ + SLANG_VECTOR_UNARY_OP(T, !)\ + SLANG_VECTOR_UNARY_OP(T, ~) +#define SLANG_FLOAT_VECTOR_OPS(T) \ + SLANG_VECTOR_BINARY_OP(T, +)\ + SLANG_VECTOR_BINARY_OP(T, -)\ + SLANG_VECTOR_BINARY_OP(T, *)\ + SLANG_VECTOR_BINARY_OP(T, /)\ + SLANG_VECTOR_UNARY_OP(T, -)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, !=) + +SLANG_INT_VECTOR_OPS(bool) +SLANG_INT_VECTOR_OPS(int) +SLANG_INT_VECTOR_OPS(int8_t) +SLANG_INT_VECTOR_OPS(int16_t) +SLANG_INT_VECTOR_OPS(int64_t) +SLANG_INT_VECTOR_OPS(uint) +SLANG_INT_VECTOR_OPS(uint8_t) +SLANG_INT_VECTOR_OPS(uint16_t) +SLANG_INT_VECTOR_OPS(uint64_t) + +SLANG_FLOAT_VECTOR_OPS(float) +SLANG_FLOAT_VECTOR_OPS(double) + +#define SLANG_VECTOR_INT_NEG_OP(T) \ + template\ + Vector operator-(const Vector& thisVal) \ + { \ + Vector result;\ + for (int i = 0; i < N; i++) \ + result[i] = 0 - thisVal[i]; \ + return result;\ + } +SLANG_VECTOR_INT_NEG_OP(int) +SLANG_VECTOR_INT_NEG_OP(int8_t) +SLANG_VECTOR_INT_NEG_OP(int16_t) +SLANG_VECTOR_INT_NEG_OP(int64_t) +SLANG_VECTOR_INT_NEG_OP(uint) +SLANG_VECTOR_INT_NEG_OP(uint8_t) +SLANG_VECTOR_INT_NEG_OP(uint16_t) +SLANG_VECTOR_INT_NEG_OP(uint64_t) + +#define SLANG_FLOAT_VECTOR_MOD(T)\ + template \ + Vector operator%(const Vector& left, const Vector& right) \ + {\ + Vector result;\ + for (int i = 0; i < N; i++) \ + result[i] = _slang_fmod(left[i], right[i]); \ + return result;\ + } + +SLANG_FLOAT_VECTOR_MOD(float) +SLANG_FLOAT_VECTOR_MOD(double) +#undef SLANG_FLOAT_VECTOR_MOD +#undef SLANG_VECTOR_BINARY_OP +#undef SLANG_VECTOR_UNARY_OP +#undef SLANG_INT_VECTOR_OPS +#undef SLANG_FLOAT_VECTOR_OPS +#undef SLANG_VECTOR_INT_NEG_OP +#undef SLANG_FLOAT_VECTOR_MOD + +template +struct Matrix +{ + Vector rows[ROWS]; + Vector& operator[](size_t index) { return rows[index]; } + Matrix() = default; + Matrix(T scalar) + { + for (int i = 0; i < ROWS; i++) + rows[i] = Vector(scalar); + } + Matrix(const Vector& row0) + { + rows[0] = row0; + } + Matrix(const Vector& row0, const Vector& row1) + { + rows[0] = row0; + rows[1] = row1; + } + Matrix(const Vector& row0, const Vector& row1, const Vector& row2) + { + rows[0] = row0; + rows[1] = row1; + rows[2] = row2; + } + Matrix(const Vector& row0, const Vector& row1, const Vector& row2, const Vector& row3) + { + rows[0] = row0; + rows[1] = row1; + rows[2] = row2; + rows[3] = row3; + } + template + Matrix(const Matrix& other) + { + int minRow = ROWS; + int minCol = COLS; + if (minRow > otherRow) minRow = otherRow; + if (minCol > otherCol) minCol = otherCol; + for (int i = 0; i < minRow; i++) + for (int j = 0; j < minCol; j++) + rows[i][j] = (T)other.rows[i][j]; + } + Matrix(T v0, T v1, T v2, T v3) + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5) + { + if (COLS == 3) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + rows[2][0] = v4; rows[2][1] = v5; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) + { + if (COLS == 4) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + rows[2][0] = v4; rows[2][1] = v5; + rows[3][0] = v6; rows[3][1] = v7; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11) + { + if (COLS == 4) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; + rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; + rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15; + } +}; + +#define SLANG_MATRIX_BINARY_OP(T, op) \ + template \ + Matrix operator op(const Matrix& thisVal, const Matrix& other) \ + { \ + Matrix result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \ + return result;\ + } + +#define SLANG_MATRIX_UNARY_OP(T, op) \ + template \ + Matrix operator op(const Matrix& thisVal) \ + { \ + Matrix result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result[i].rows[i][j] = op thisVal.rows[i][j]; \ + return result;\ + } +#define SLANG_INT_MATRIX_OPS(T) \ + SLANG_MATRIX_BINARY_OP(T, +)\ + SLANG_MATRIX_BINARY_OP(T, -)\ + SLANG_MATRIX_BINARY_OP(T, *)\ + SLANG_MATRIX_BINARY_OP(T, / )\ + SLANG_MATRIX_BINARY_OP(T, &)\ + SLANG_MATRIX_BINARY_OP(T, |)\ + SLANG_MATRIX_BINARY_OP(T, &&)\ + SLANG_MATRIX_BINARY_OP(T, ||)\ + SLANG_MATRIX_BINARY_OP(T, ^)\ + SLANG_MATRIX_BINARY_OP(T, %)\ + SLANG_MATRIX_UNARY_OP(T, !)\ + SLANG_MATRIX_UNARY_OP(T, ~) +#define SLANG_FLOAT_MATRIX_OPS(T) \ + SLANG_MATRIX_BINARY_OP(T, +)\ + SLANG_MATRIX_BINARY_OP(T, -)\ + SLANG_MATRIX_BINARY_OP(T, *)\ + SLANG_MATRIX_BINARY_OP(T, /)\ + SLANG_MATRIX_UNARY_OP(T, -) +SLANG_INT_MATRIX_OPS(int) +SLANG_INT_MATRIX_OPS(int8_t) +SLANG_INT_MATRIX_OPS(int16_t) +SLANG_INT_MATRIX_OPS(int64_t) +SLANG_INT_MATRIX_OPS(uint) +SLANG_INT_MATRIX_OPS(uint8_t) +SLANG_INT_MATRIX_OPS(uint16_t) +SLANG_INT_MATRIX_OPS(uint64_t) + +SLANG_FLOAT_MATRIX_OPS(float) +SLANG_FLOAT_MATRIX_OPS(double) + +#define SLANG_MATRIX_INT_NEG_OP(T) \ + template\ + SLANG_FORCE_INLINE Matrix operator-(Matrix thisVal) \ + { \ + Matrix result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = 0 - thisVal.rows[i][j]; \ + return result;\ + } + SLANG_MATRIX_INT_NEG_OP(int) + SLANG_MATRIX_INT_NEG_OP(int8_t) + SLANG_MATRIX_INT_NEG_OP(int16_t) + SLANG_MATRIX_INT_NEG_OP(int64_t) + SLANG_MATRIX_INT_NEG_OP(uint) + SLANG_MATRIX_INT_NEG_OP(uint8_t) + SLANG_MATRIX_INT_NEG_OP(uint16_t) + SLANG_MATRIX_INT_NEG_OP(uint64_t) + +#define SLANG_FLOAT_MATRIX_MOD(T)\ + template \ + SLANG_FORCE_INLINE Matrix operator%(Matrix left, Matrix right) \ + {\ + Matrix result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \ + return result;\ + } + + SLANG_FLOAT_MATRIX_MOD(float) + SLANG_FLOAT_MATRIX_MOD(double) +#undef SLANG_FLOAT_MATRIX_MOD +#undef SLANG_MATRIX_BINARY_OP +#undef SLANG_MATRIX_UNARY_OP +#undef SLANG_INT_MATRIX_OPS +#undef SLANG_FLOAT_MATRIX_OPS +#undef SLANG_MATRIX_INT_NEG_OP +#undef SLANG_FLOAT_MATRIX_MOD + +template +TResult slang_bit_cast(TInput val) +{ + return *(TResult*)(&val); +} + +#endif + + diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h index 28fe3dd8d..ac66ad9f3 100644 --- a/prelude/slang-cpp-types.h +++ b/prelude/slang-cpp-types.h @@ -1,244 +1,12 @@ #ifndef SLANG_PRELUDE_CPP_TYPES_H #define SLANG_PRELUDE_CPP_TYPES_H -#ifndef SLANG_PRELUDE_ASSERT -# ifdef SLANG_PRELUDE_ENABLE_ASSERT -# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE) -# else -# define SLANG_PRELUDE_ASSERT(VALUE) -# endif -#endif - -// Since we are using unsigned arithmatic care is need in this comparison. -// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0 -// Which means only a single test is needed - -// Asserts for bounds checking. -// It is assumed index/count are unsigned types. -#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count); -#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0); - -// Macros to zero index if an access is out of range -#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0; -#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0; - -// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX -// the fix macro will zero the index, if out of range -#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX -# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count) -# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count) -#else -# define SLANG_BOUND_FIX(index, count) -# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) -#endif - -#ifndef SLANG_BOUND_CHECK -# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count) -#endif - -#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS -# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -#endif - -#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY -# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count) -#endif - #ifdef SLANG_PRELUDE_NAMESPACE namespace SLANG_PRELUDE_NAMESPACE { #endif -struct TypeInfo -{ - size_t typeSize; -}; - -template -struct FixedArray -{ - const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } - T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } - - T m_data[SIZE]; -}; - -// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially -// do bounds checking. -template -struct Array -{ - const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; } - T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; } - - T* data; - size_t count; -}; - -/* Constant buffers become a pointer to the contained type, so ConstantBuffer becomes T* in C++ code. -*/ - -template -struct Vector; - -template -struct Vector -{ - T x; - const T& operator[](size_t /*index*/) const { return x; } - T& operator[](size_t /*index*/) { return x; } - operator T() const { return x; } - Vector() = default; - Vector(T scalar) - { - x = scalar; - } - template - Vector(Vector other) - { - x = (T)other.x; - } - template - Vector(Vector other) - { - int minSize = 1; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; - -template -struct Vector -{ - T x, y; - const T& operator[](size_t index) const { return index == 0 ? x : y; } - T& operator[](size_t index) { return index == 0 ? x : y; } - Vector() = default; - Vector(T scalar) - { - x = y = scalar; - } - Vector(T _x, T _y) - { - x = _x; - y = _y; - } - template - Vector(Vector other) - { - x = (T)other.x; - y = (T)other.y; - } - template - Vector(Vector other) - { - int minSize = 2; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; - -template -struct Vector -{ - T x, y, z; - const T& operator[](size_t index) const { return *((T*)(this) + index); } - T& operator[](size_t index) { return *((T*)(this) + index); } - - Vector() = default; - Vector(T scalar) - { - x = y = z = scalar; - } - Vector(T _x, T _y, T _z) - { - x = _x; - y = _y; - z = _z; - } - template - Vector(Vector other) - { - x = (T)other.x; - y = (T)other.y; - z = (T)other.z; - } - template - Vector(Vector other) - { - int minSize = 3; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; -template -struct Vector -{ - T x, y, z, w; - - const T& operator[](size_t index) const { return *((T*)(this) + index); } - T& operator[](size_t index) { return *((T*)(this) + index); } - Vector() = default; - Vector(T scalar) - { - x = y = z = w = scalar; - } - Vector(T _x, T _y, T _z, T _w) - { - x = _x; - y = _y; - z = _z; - w = _w; - } - template - Vector(Vector other) - { - int minSize = 4; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } - -}; - -template -SLANG_FORCE_INLINE T _slang_vector_get_element(Vector x, int index) -{ - return x[index]; -} - -template -SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector* x, int index) -{ - return &((*const_cast*>(x))[index]); -} - -template -SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector* x, int index) -{ - return &((*x)[index]); -} - -template -SLANG_FORCE_INLINE Vector _slang_vector_reshape(const Vector other) -{ - Vector result; - for (int i = 0; i < n; i++) - { - OtherT otherElement = T(0); - if (i < m) - otherElement = _slang_vector_get_element(other, i); - *_slang_vector_get_element_ptr(&result, i) = (T)otherElement; - } - return result; -} - -typedef uint32_t uint; +#include "slang-cpp-types-core.h" typedef Vector float2; typedef Vector float3; @@ -252,320 +20,6 @@ typedef Vector uint2; typedef Vector uint3; typedef Vector uint4; -#define SLANG_VECTOR_BINARY_OP(T, op) \ - template \ - SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal, const Vector& other) \ - { \ - Vector result;\ - for (int i = 0; i < n; i++) \ - result[i] = thisVal[i] op other[i]; \ - return result;\ - } -#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \ - template \ - SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal, const Vector& other) \ - { \ - Vector result;\ - for (int i = 0; i < n; i++) \ - result[i] = thisVal[i] op other[i]; \ - return result;\ - } - -#define SLANG_VECTOR_UNARY_OP(T, op) \ - template \ - SLANG_FORCE_INLINE Vector operator op(const Vector& thisVal) \ - { \ - Vector result;\ - for (int i = 0; i < n; i++) \ - result[i] = op thisVal[i]; \ - return result;\ - } -#define SLANG_INT_VECTOR_OPS(T) \ - SLANG_VECTOR_BINARY_OP(T, +)\ - SLANG_VECTOR_BINARY_OP(T, -)\ - SLANG_VECTOR_BINARY_OP(T, *)\ - SLANG_VECTOR_BINARY_OP(T, / )\ - SLANG_VECTOR_BINARY_OP(T, &)\ - SLANG_VECTOR_BINARY_OP(T, |)\ - SLANG_VECTOR_BINARY_OP(T, &&)\ - SLANG_VECTOR_BINARY_OP(T, ||)\ - SLANG_VECTOR_BINARY_OP(T, ^)\ - SLANG_VECTOR_BINARY_OP(T, %)\ - SLANG_VECTOR_BINARY_OP(T, >>)\ - SLANG_VECTOR_BINARY_OP(T, <<)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\ - SLANG_VECTOR_UNARY_OP(T, !)\ - SLANG_VECTOR_UNARY_OP(T, ~) -#define SLANG_FLOAT_VECTOR_OPS(T) \ - SLANG_VECTOR_BINARY_OP(T, +)\ - SLANG_VECTOR_BINARY_OP(T, -)\ - SLANG_VECTOR_BINARY_OP(T, *)\ - SLANG_VECTOR_BINARY_OP(T, /)\ - SLANG_VECTOR_UNARY_OP(T, -)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, !=) - -SLANG_INT_VECTOR_OPS(bool) -SLANG_INT_VECTOR_OPS(int) -SLANG_INT_VECTOR_OPS(int8_t) -SLANG_INT_VECTOR_OPS(int16_t) -SLANG_INT_VECTOR_OPS(int64_t) -SLANG_INT_VECTOR_OPS(uint) -SLANG_INT_VECTOR_OPS(uint8_t) -SLANG_INT_VECTOR_OPS(uint16_t) -SLANG_INT_VECTOR_OPS(uint64_t) - -SLANG_FLOAT_VECTOR_OPS(float) -SLANG_FLOAT_VECTOR_OPS(double) - -#define SLANG_VECTOR_INT_NEG_OP(T) \ - template\ - Vector operator-(const Vector& thisVal) \ - { \ - Vector result;\ - for (int i = 0; i < N; i++) \ - result[i] = 0 - thisVal[i]; \ - return result;\ - } -SLANG_VECTOR_INT_NEG_OP(int) -SLANG_VECTOR_INT_NEG_OP(int8_t) -SLANG_VECTOR_INT_NEG_OP(int16_t) -SLANG_VECTOR_INT_NEG_OP(int64_t) -SLANG_VECTOR_INT_NEG_OP(uint) -SLANG_VECTOR_INT_NEG_OP(uint8_t) -SLANG_VECTOR_INT_NEG_OP(uint16_t) -SLANG_VECTOR_INT_NEG_OP(uint64_t) - -#define SLANG_FLOAT_VECTOR_MOD(T)\ - template \ - Vector operator%(const Vector& left, const Vector& right) \ - {\ - Vector result;\ - for (int i = 0; i < N; i++) \ - result[i] = _slang_fmod(left[i], right[i]); \ - return result;\ - } - -SLANG_FLOAT_VECTOR_MOD(float) -SLANG_FLOAT_VECTOR_MOD(double) -#undef SLANG_FLOAT_VECTOR_MOD -#undef SLANG_VECTOR_BINARY_OP -#undef SLANG_VECTOR_UNARY_OP -#undef SLANG_INT_VECTOR_OPS -#undef SLANG_FLOAT_VECTOR_OPS -#undef SLANG_VECTOR_INT_NEG_OP -#undef SLANG_FLOAT_VECTOR_MOD - -template -struct Matrix -{ - Vector rows[ROWS]; - Vector& operator[](size_t index) { return rows[index]; } - Matrix() = default; - Matrix(T scalar) - { - for (int i = 0; i < ROWS; i++) - rows[i] = Vector(scalar); - } - Matrix(const Vector& row0) - { - rows[0] = row0; - } - Matrix(const Vector& row0, const Vector& row1) - { - rows[0] = row0; - rows[1] = row1; - } - Matrix(const Vector& row0, const Vector& row1, const Vector& row2) - { - rows[0] = row0; - rows[1] = row1; - rows[2] = row2; - } - Matrix(const Vector& row0, const Vector& row1, const Vector& row2, const Vector& row3) - { - rows[0] = row0; - rows[1] = row1; - rows[2] = row2; - rows[3] = row3; - } - template - Matrix(const Matrix& other) - { - int minRow = ROWS; - int minCol = COLS; - if (minRow > otherRow) minRow = otherRow; - if (minCol > otherCol) minCol = otherCol; - for (int i = 0; i < minRow; i++) - for (int j = 0; j < minCol; j++) - rows[i][j] = (T)other.rows[i][j]; - } - Matrix(T v0, T v1, T v2, T v3) - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5) - { - if (COLS == 3) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - rows[2][0] = v4; rows[2][1] = v5; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) - { - if (COLS == 4) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - rows[2][0] = v4; rows[2][1] = v5; - rows[3][0] = v6; rows[3][1] = v7; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11) - { - if (COLS == 4) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; - rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; - rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15; - } -}; - -#define SLANG_MATRIX_BINARY_OP(T, op) \ - template \ - Matrix operator op(const Matrix& thisVal, const Matrix& other) \ - { \ - Matrix result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \ - return result;\ - } - -#define SLANG_MATRIX_UNARY_OP(T, op) \ - template \ - Matrix operator op(const Matrix& thisVal) \ - { \ - Matrix result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result[i].rows[i][j] = op thisVal.rows[i][j]; \ - return result;\ - } -#define SLANG_INT_MATRIX_OPS(T) \ - SLANG_MATRIX_BINARY_OP(T, +)\ - SLANG_MATRIX_BINARY_OP(T, -)\ - SLANG_MATRIX_BINARY_OP(T, *)\ - SLANG_MATRIX_BINARY_OP(T, / )\ - SLANG_MATRIX_BINARY_OP(T, &)\ - SLANG_MATRIX_BINARY_OP(T, |)\ - SLANG_MATRIX_BINARY_OP(T, &&)\ - SLANG_MATRIX_BINARY_OP(T, ||)\ - SLANG_MATRIX_BINARY_OP(T, ^)\ - SLANG_MATRIX_BINARY_OP(T, %)\ - SLANG_MATRIX_UNARY_OP(T, !)\ - SLANG_MATRIX_UNARY_OP(T, ~) -#define SLANG_FLOAT_MATRIX_OPS(T) \ - SLANG_MATRIX_BINARY_OP(T, +)\ - SLANG_MATRIX_BINARY_OP(T, -)\ - SLANG_MATRIX_BINARY_OP(T, *)\ - SLANG_MATRIX_BINARY_OP(T, /)\ - SLANG_MATRIX_UNARY_OP(T, -) -SLANG_INT_MATRIX_OPS(int) -SLANG_INT_MATRIX_OPS(int8_t) -SLANG_INT_MATRIX_OPS(int16_t) -SLANG_INT_MATRIX_OPS(int64_t) -SLANG_INT_MATRIX_OPS(uint) -SLANG_INT_MATRIX_OPS(uint8_t) -SLANG_INT_MATRIX_OPS(uint16_t) -SLANG_INT_MATRIX_OPS(uint64_t) - -SLANG_FLOAT_MATRIX_OPS(float) -SLANG_FLOAT_MATRIX_OPS(double) - -#define SLANG_MATRIX_INT_NEG_OP(T) \ - template\ - SLANG_FORCE_INLINE Matrix operator-(Matrix thisVal) \ - { \ - Matrix result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = 0 - thisVal.rows[i][j]; \ - return result;\ - } - SLANG_MATRIX_INT_NEG_OP(int) - SLANG_MATRIX_INT_NEG_OP(int8_t) - SLANG_MATRIX_INT_NEG_OP(int16_t) - SLANG_MATRIX_INT_NEG_OP(int64_t) - SLANG_MATRIX_INT_NEG_OP(uint) - SLANG_MATRIX_INT_NEG_OP(uint8_t) - SLANG_MATRIX_INT_NEG_OP(uint16_t) - SLANG_MATRIX_INT_NEG_OP(uint64_t) - -#define SLANG_FLOAT_MATRIX_MOD(T)\ - template \ - SLANG_FORCE_INLINE Matrix operator%(Matrix left, Matrix right) \ - {\ - Matrix result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \ - return result;\ - } - - SLANG_FLOAT_MATRIX_MOD(float) - SLANG_FLOAT_MATRIX_MOD(double) -#undef SLANG_FLOAT_MATRIX_MOD -#undef SLANG_MATRIX_BINARY_OP -#undef SLANG_MATRIX_UNARY_OP -#undef SLANG_INT_MATRIX_OPS -#undef SLANG_FLOAT_MATRIX_OPS -#undef SLANG_MATRIX_INT_NEG_OP -#undef SLANG_FLOAT_MATRIX_MOD - // We can just map `NonUniformResourceIndex` type directly to the index type on CPU, as CPU does not require // any special handling around such accesses. typedef size_t NonUniformResourceIndex; @@ -1484,12 +938,6 @@ struct ComputeVaryingInput typedef void(*ComputeThreadFunc)(ComputeThreadVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState); typedef void(*ComputeFunc)(ComputeVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState); -template -TResult slang_bit_cast(TInput val) -{ - return *(TResult*)(&val); -} - #ifdef SLANG_PRELUDE_NAMESPACE } #endif diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 7a4c5a918..9a55aed57 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -6,6 +6,8 @@ #define SLANG_CUDA_RTC 0 #endif +#include + // Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support. // For this to work NVRTC needs to have the path to the CUDA SDK. // @@ -2080,4 +2082,73 @@ __forceinline__ __device__ void *traceOptiXRay( r0, r1 ); } + #endif + + +// TensorView +struct TensorView +{ + uint8_t* data; + uint32_t* strides; + uint32_t* sizes; + uint32_t dimensionCount; + + template + __device__ T* data_ptr() + { + return reinterpret_cast(data); + } + + template + __device__ T load(uint32_t x) + { + return *reinterpret_cast(data + strides[0] * x); + } + template + __device__ T load(uint32_t x, uint32_t y) + { + return *reinterpret_cast(data + strides[0] * x + strides[1] * y); + } + template + __device__ T load(uint32_t x, uint32_t y, uint32_t z) + { + return *reinterpret_cast(data + strides[0] * x + strides[1] * y + strides[2] * z); + } + template + __device__ T load(uint32_t x, uint32_t y, uint32_t z, uint32_t w) + { + return *reinterpret_cast(data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w); + } + template + __device__ T load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4) + { + return *reinterpret_cast(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4); + } + template + __device__ void store(uint32_t x, T val) + { + *reinterpret_cast(data + strides[0] * x) = val; + } + template + __device__ void store(uint32_t x, uint32_t y, T val) + { + *reinterpret_cast(data + strides[0] * x + strides[1] * y) = val; + } + template + __device__ void store(uint32_t x, uint32_t y, uint32_t z, T val) + { + *reinterpret_cast(data + strides[0] * x + strides[1] * y + strides[2] * z) = val; + } + template + __device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val) + { + *reinterpret_cast( + data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w) = val; + } + template + __device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val) + { + *reinterpret_cast(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4) = val; + } +}; diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h new file mode 100644 index 000000000..f2accc149 --- /dev/null +++ b/prelude/slang-torch-prelude.h @@ -0,0 +1,126 @@ +// Prelude for PyTorch cpp binding. + +#include +#include +#include +#include + +#ifndef SLANG_NO_THROW +# define SLANG_NO_THROW +#endif + +#ifndef SLANG_STDCALL +# define SLANG_STDCALL +#endif +#ifndef SLANG_MCALL +# define SLANG_MCALL SLANG_STDCALL +#endif +#ifndef SLANG_FORCE_INLINE +# define SLANG_FORCE_INLINE inline +#endif + +#ifdef SLANG_LLVM +#include "slang-llvm.h" +#else // SLANG_LLVM +# if SLANG_GCC_FAMILY && __GNUC__ < 6 +# include +# define SLANG_PRELUDE_STD std:: +# else +# include +# define SLANG_PRELUDE_STD +# endif + +# include +# include +# include +# include +#endif // SLANG_LLVM + +#include "slang-cpp-types-core.h" +#include "slang-cpp-scalar-intrinsics.h" + +struct TensorView +{ + uint8_t* data; + uint32_t* strides; + uint32_t* sizes; + uint32_t dimensionCount; +}; + +struct CudaTaskMemoryAllocator +{ + std::vector allocations; + + uint32_t* allocUIntArray(uint32_t size) + { + void* ptr = nullptr; + cudaMallocManaged(&ptr, size * sizeof(uint32_t)); + AT_CUDA_CHECK(cudaGetLastError()); + return (uint32_t*)ptr; + } + + ~CudaTaskMemoryAllocator() + { + for (auto ptr : allocations) + cudaFree(ptr); + } +}; + +TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val) +{ + val = val.to(torch::kCUDA); + TensorView res = {}; + res.dimensionCount = val.dim(); + res.strides = allocator->allocUIntArray(val.dim()); + res.sizes = allocator->allocUIntArray(val.dim()); + res.data = nullptr; + size_t elementSize = 4; + switch (val.scalar_type()) + { + case torch::kInt8: + case torch::kUInt8: + elementSize = 1; + res.data = (uint8_t*)val.data_ptr(); + break; + case torch::kBFloat16: + elementSize = 2; + res.data = (uint8_t*)val.data_ptr(); + break; + case torch::kInt16: + elementSize = 2; + res.data = (uint8_t*)val.data_ptr(); + break; + case torch::kFloat32: + elementSize = 4; + res.data = (uint8_t*)val.data_ptr(); + break; + case torch::kInt32: + elementSize = 4; + res.data = (uint8_t*)val.data_ptr(); + break; + case torch::kFloat64: + elementSize = 8; + res.data = (uint8_t*)val.data_ptr(); + break; + case torch::kInt64: + elementSize = 8; + res.data = (uint8_t*)val.data_ptr(); + break; + } + for (int i = 0; i < val.dim(); ++i) + { + res.strides[i] = val.stride(i) * elementSize; + res.sizes[i] = val.size(i); + } + return res; +} + +size_t slangGetCudaKernelSharedMemSize(const void* func) +{ + cudaFuncAttributes attr = {}; + cudaFuncGetAttributes(&attr, func); + AT_CUDA_CHECK(cudaGetLastError()); + return attr.sharedSizeBytes; +} + +#define SLANG_PRELUDE_EXPORT diff --git a/premake5.lua b/premake5.lua index add7544f2..093b4659d 100644 --- a/premake5.lua +++ b/premake5.lua @@ -1350,7 +1350,8 @@ if enableEmbedStdLib then "prelude/slang-cuda-prelude.h.cpp", "prelude/slang-hlsl-prelude.h.cpp", "prelude/slang-cpp-prelude.h.cpp", - "prelude/slang-cpp-host-prelude.h.cpp" + "prelude/slang-cpp-host-prelude.h.cpp", + "prelude/slang-torch-prelude.h.cpp" } end @@ -1460,7 +1461,8 @@ standardProject("slang", "source/slang") "prelude/slang-cuda-prelude.h.cpp", "prelude/slang-hlsl-prelude.h.cpp", "prelude/slang-cpp-prelude.h.cpp", - "prelude/slang-cpp-host-prelude.h.cpp" + "prelude/slang-cpp-host-prelude.h.cpp", + "prelude/slang-torch-prelude.h.cpp" } -- Similarly for any generated lookup tables @@ -1553,7 +1555,8 @@ if enableProfile then "prelude/slang-cuda-prelude.h.cpp", "prelude/slang-hlsl-prelude.h.cpp", "prelude/slang-cpp-prelude.h.cpp", - "prelude/slang-cpp-host-prelude.h.cpp" + "prelude/slang-cpp-host-prelude.h.cpp", + "prelude/slang-torch-prelude.h.cpp" } -- Add the slang source diff --git a/slang.h b/slang.h index 1a72c4348..ea165fbf7 100644 --- a/slang.h +++ b/slang.h @@ -566,6 +566,7 @@ extern "C" SLANG_DXIL_ASM, SLANG_C_SOURCE, ///< The C language SLANG_CPP_SOURCE, ///< C++ code for shader kernels. + SLANG_CPP_PYTORCH_BINDING, ///< C++ PyTorch binding code. SLANG_HOST_EXECUTABLE, ///< Standalone binary executable (for hosting CPU/OS) SLANG_SHADER_SHARED_LIBRARY, ///< A shared library/Dll for shader kernels (for hosting CPU/OS) SLANG_SHADER_HOST_CALLABLE, ///< A CPU target that makes the compiled shader code available to be run immediately diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp index ca7dcb70f..3be0448c4 100644 --- a/source/compiler-core/slang-artifact-desc-util.cpp +++ b/source/compiler-core/slang-artifact-desc-util.cpp @@ -273,6 +273,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case SLANG_C_SOURCE: return Desc::make(Kind::Source, Payload::C, Style::Kernel, 0); case SLANG_CPP_SOURCE: return Desc::make(Kind::Source, Payload::Cpp, Style::Kernel, 0); case SLANG_HOST_CPP_SOURCE: return Desc::make(Kind::Source, Payload::Cpp, Style::Host, 0); + case SLANG_CPP_PYTORCH_BINDING: return Desc::make(Kind::Source, Payload::Cpp, Style::Host, 0); case SLANG_HOST_EXECUTABLE: return Desc::make(Kind::Executable, Payload::HostCPU, Style::Host, 0); case SLANG_SHADER_SHARED_LIBRARY: return Desc::make(Kind::SharedLibrary, Payload::HostCPU, Style::Kernel, 0); case SLANG_SHADER_HOST_CALLABLE: return Desc::make(Kind::HostCallable, Payload::HostCPU, Style::Kernel, 0); diff --git a/source/compiler-core/slang-artifact.h b/source/compiler-core/slang-artifact.h index cc4d3e9fd..65d4d1bf9 100644 --- a/source/compiler-core/slang-artifact.h +++ b/source/compiler-core/slang-artifact.h @@ -208,7 +208,6 @@ enum class ArtifactStyle : uint8_t Kernel, ///< Compiled as `GPU kernel` style. Host, ///< Compiled in `host` style - Obfuscated, ///< Holds something specific to obfuscation, such as an obfuscated source map CountOf, diff --git a/source/core/slang-type-convert-util.cpp b/source/core/slang-type-convert-util.cpp index 6e6598357..c763b2835 100644 --- a/source/core/slang-type-convert-util.cpp +++ b/source/core/slang-type-convert-util.cpp @@ -17,6 +17,7 @@ namespace Slang case SLANG_HLSL: return SLANG_SOURCE_LANGUAGE_HLSL; case SLANG_C_SOURCE: return SLANG_SOURCE_LANGUAGE_C; case SLANG_CPP_SOURCE: return SLANG_SOURCE_LANGUAGE_CPP; + case SLANG_CPP_PYTORCH_BINDING:return SLANG_SOURCE_LANGUAGE_CPP; case SLANG_HOST_CPP_SOURCE: return SLANG_SOURCE_LANGUAGE_CPP; case SLANG_CUDA_SOURCE: return SLANG_SOURCE_LANGUAGE_CUDA; default: break; diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index d37051e47..9d2f93ba3 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -78,6 +78,7 @@ static const CompileTargetInfo s_compileTargetInfos[] = { SLANG_SPIRV_ASM, "spv-asm", "spirv-asm,spirv-assembly" }, { SLANG_C_SOURCE, "c", "c" }, { SLANG_CPP_SOURCE, "cpp,c++,cxx", "cpp,c++,cxx" }, + { SLANG_CPP_PYTORCH_BINDING, "cpp,c++,cxx", "torch,torch-binding,torch-cpp,torch-cpp-binding" }, { SLANG_HOST_CPP_SOURCE, "cpp,c++,cxx", "host-cpp,host-c++,host-cxx"}, { SLANG_HOST_EXECUTABLE,"exe", "exe,executable" }, { SLANG_SHADER_SHARED_LIBRARY, "dll,so", "sharedlib,sharedlibrary,dll" }, diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 82a60a612..c45ad5bd6 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -3080,6 +3080,9 @@ attribute_syntax [DllImport(modulePath: String)] : DllImportAttribute; __attributeTarget(FuncDecl) attribute_syntax [DllExport] : DllExportAttribute; +__attributeTarget(FuncDecl) +attribute_syntax [TorchEntryPoint] : TorchEntryPointAttribute; + __attributeTarget(FuncDecl) attribute_syntax [CudaDeviceExport] : CudaDeviceExportAttribute; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index bbe94dbc2..d5b70bbb3 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,7 +9,6 @@ attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; - __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; @@ -26,6 +25,89 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; +__generic +__magic_type(TensorViewType) +__intrinsic_type($(kIROp_TensorViewType)) +struct TensorView +{ + __target_intrinsic(cuda, "$0.data_ptr<$G0>()") + Ptr data_ptr(); + + __implicit_conversion($(kConversionCost_ImplicitDereference)) + __intrinsic_op($(kIROp_TorchTensorGetView)) + __init(TorchTensor t); + + __target_intrinsic(cuda, "$0.load<$G0>($1)") + T load(uint x); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2)") + T load(uint x, uint y); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3)") + T load(uint x, uint y, uint z); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4)") + T load(uint x, uint y, uint z, uint w); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4, $5)") + T load(uint i0, uint i1, uint i2, uint i3, uint i4); + + __target_intrinsic(cuda, "$0.store<$G0>($1, $2)") + void store(uint x, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3)") + void store(uint x, uint y, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4)") + void store(uint x, uint y, uint z, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5)") + void store(uint x, uint y, uint z, uint w, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5, $6)") + void store(uint i0, uint i1, uint i2, uint i3, uint i4, T val); + + __target_intrinsic(cuda, "$0.dimensionCount") + uint dims(); + + __target_intrinsic(cuda, "$0.sizes[$1]") + uint size(uint i); + + __target_intrinsic(cuda, "$0.strides[$1]") + uint stride(uint i); +} + +__generic +__intrinsic_type($(kIROp_TorchTensorType)) +struct TorchTensor +{ + __intrinsic_op($(kIROp_TorchTensorGetView)) + TensorView getView(); + + __target_intrinsic(cuda, "$0.dims()") + __target_intrinsic(cpp, "$0.dims()") + uint dims(); + + __target_intrinsic(cuda, "$0.size($1)") + __target_intrinsic(cpp, "$0.size($1)") + uint size(uint i); + + __target_intrinsic(cuda, "$0.stride($1)") + __target_intrinsic(cpp, "$0.stride($1)") + uint stride(uint i); + + __target_intrinsic(cuda, "$0.data_ptr<$G0>()") + __target_intrinsic(cpp, "$0.data_ptr<$G0>()") + Ptr data_ptr(); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor alloc(uint x); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor alloc(uint x, uint y); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor alloc(uint x, uint y, uint z); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor alloc(uint x, uint y, uint z, uint w); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor alloc(uint i0, uint i1, uint i2, uint i3, uint i4); +} + __generic __intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) DifferentialPair diffPair(T primal, T.Differential diff); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index d417e3b7c..4c4aaa4f0 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6834,3 +6834,13 @@ void debugBreak(); __specialized_for_target(glsl) [[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]] void debugBreak(); + + +__target_intrinsic(cuda, "(threadIdx)") +uint3 cudaThreadIdx(); + +__target_intrinsic(cuda, "(blockIdx)") +uint3 cudaBlockIdx(); + +__target_intrinsic(cuda, "(blockDim)") +uint3 cudaBlockDim(); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 0d2e27e5f..00a6570ef 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1055,6 +1055,11 @@ class DllExportAttribute : public Attribute SLANG_AST_CLASS(DllExportAttribute) }; +class TorchEntryPointAttribute : public Attribute +{ + SLANG_AST_CLASS(TorchEntryPointAttribute) +}; + class CudaDeviceExportAttribute : public Attribute { SLANG_AST_CLASS(CudaDeviceExportAttribute) diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index fdbd56377..1fed2d52a 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -275,6 +275,13 @@ BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Type* TensorViewType::getElementType() +{ + return as(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); +} + + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void VectorExpressionType::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 47608405a..cb3fde9f9 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -524,6 +524,15 @@ private: MatrixExpressionType(Type*, IntVal*, IntVal*) {} }; +class TensorViewType : public BuiltinType +{ + SLANG_AST_CLASS(TensorViewType) + + Type* getElementType(); +private: + TensorViewType(Type*) {} +}; + // Base class for built in string types class StringTypeBase : public BuiltinType { diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 8cbe12ef0..ff38bbbde 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -470,6 +470,7 @@ namespace Slang case CodeGenTarget::CUDASource: case CodeGenTarget::CPPSource: case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: { return PassThroughMode::None; @@ -1570,6 +1571,7 @@ namespace Slang case CodeGenTarget::CUDASource: case CodeGenTarget::CPPSource: case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: { RefPtr extensionTracker = _newExtensionTracker(target); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b287b21eb..b49ce4fc3 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -75,6 +75,7 @@ namespace Slang DXILAssembly = SLANG_DXIL_ASM, CSource = SLANG_C_SOURCE, CPPSource = SLANG_CPP_SOURCE, + PyTorchCppBinding = SLANG_CPP_PYTORCH_BINDING, HostCPPSource = SLANG_HOST_CPP_SOURCE, HostExecutable = SLANG_HOST_EXECUTABLE, ShaderSharedLibrary = SLANG_SHADER_SHARED_LIBRARY, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index c3e0adbca..39ceb6678 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -664,6 +664,8 @@ DIAGNOSTIC(54002, Error, meshOutputMustBeArray, "HLSL style mesh shader outputs DIAGNOSTIC(54003, Error, meshOutputArrayMustHaveSize, "HLSL style mesh shader output arrays must have a length specified") DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL style mesh shader output modifier") +DIAGNOSTIC(55101, Error, invalidTorchKernelReturnType, "'$0' is not a valid return type for a pytorch kernel function.") +DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid parameter type for a pytorch kernel function.") // // 8xxxx - Issues specific to a particular library/technology/platform/etc. diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 356f1c7ce..ebc312560 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -129,6 +129,7 @@ Index LocationTracker::getValue(Kind kind, IRInst* inst, IRDecoration* decoratio } case CodeGenTarget::CPPSource: case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: { return SourceLanguage::CPP; } diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 2a2ae06c6..346926712 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -304,6 +304,18 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S out << ">"; return SLANG_OK; } + case kIROp_TargetTupleType: + { + out << "std::tuple<"; + for (UInt i = 0; i < type->getOperandCount(); i++) + { + if (i > 0) out << ", "; + auto elementType = (IRType*)type->getOperand(i); + SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out)); + } + out << ">"; + return SLANG_OK; + } default: { if (isNominalOp(type->getOp())) @@ -1187,6 +1199,19 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut return true; } + case kIROp_MakeTargetTuple: + { + m_writer->emit("std::make_tuple("); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; + } case kIROp_CastFloatToInt: case kIROp_CastIntToFloat: case kIROp_FloatCast: diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 846b3b1f2..d2fa892ba 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -146,6 +146,11 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, out << prefix << vecCount; return SLANG_OK; } + case kIROp_TensorViewType: + { + out << "TensorView"; + return SLANG_OK; + } default: { if (isNominalOp(type->getOp())) diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp new file mode 100644 index 000000000..ef67c520a --- /dev/null +++ b/source/slang/slang-emit-torch.cpp @@ -0,0 +1,181 @@ +// slang-emit-torch.cpp +#include "slang-emit-torch.h" + +#include "../core/slang-writer.h" + +#include "slang-emit-source-writer.h" +#include "slang-mangled-lexer.h" + +#include + +namespace Slang +{ +bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) +{ + switch (inst->getOp()) + { + default: + { + return Super::tryEmitInstExprImpl(inst, inOuterPrec); + } + case kIROp_MakeTensorView: + { + m_writer->emit("make_tensor_view("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_CudaKernelLaunch: + { + m_writer->emit("cudaLaunchKernel("); + // func + m_writer->emit("(const void*)("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // gridDim + m_writer->emit("slang_bit_cast("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // blockDim + m_writer->emit("slang_bit_cast("); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // args + emitOperand(inst->getOperand(3), getInfo(EmitOp::General)); + m_writer->emit(", "); + + // shared mem + m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")), "); + + // stream + m_writer->emit("((cudaStream_t)"); + emitOperand(inst->getOperand(4), getInfo(EmitOp::General)); + m_writer->emit("))"); + return true; + } + case kIROp_TorchGetCudaStream: + { + m_writer->emit("at::cuda::getCurrentCUDAStream()"); + return true; + } + case kIROp_AllocateTorchTensor: + { + /* + Emit something like: + ``` + torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, + torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); + ``` + */ + m_writer->emit("torch::empty({ "); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); + switch (inst->getDataType()->getOperand(0)->getOp()) + { + case kIROp_FloatType: + m_writer->emit("kFloat32"); + break; + case kIROp_HalfType: + m_writer->emit("kFloat16"); + break; + case kIROp_DoubleType: + m_writer->emit("kFloat64"); + break; + case kIROp_UInt8Type: + m_writer->emit("kUInt8"); + break; + case kIROp_UInt16Type: + m_writer->emit("kUInt16"); + break; + case kIROp_UIntType: + m_writer->emit("kUInt32"); + break; + case kIROp_UInt64Type: + m_writer->emit("kUInt64"); + break; + case kIROp_Int8Type: + m_writer->emit("kInt8"); + break; + case kIROp_Int16Type: + m_writer->emit("kInt16"); + break; + case kIROp_IntType: + m_writer->emit("kInt32"); + break; + case kIROp_Int64Type: + m_writer->emit("kInt64"); + break; + default: + SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); + break; + } + m_writer->emit("))"); + return true; + } + } +} + +SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) +{ + switch (type->getOp()) + { + default: + return Super::calcTypeName(type, target, out); + case kIROp_TensorViewType: + { + out << "TensorView"; + return SLANG_OK; + } + case kIROp_TorchTensorType: + { + out << "torch::Tensor"; + return SLANG_OK; + } + case kIROp_TorchKernelMemoryAllocatorType: + { + out << "CudaTaskMemoryAllocator"; + return SLANG_OK; + } + } +} + +void TorchCppSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) +{ + Super::emitModuleImpl(module, sink); + + // Emit PyBind declarations. + m_writer->emit("PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n"); + m_writer->indent(); + for (auto inst : module->getGlobalInsts()) + { + auto func = as(inst); + if (!func) continue; + auto decor = func->findDecoration(); + if (!decor) continue; + m_writer->emit("m.def("); + emitStringLiteral(decor->getFunctionName()); + m_writer->emit(", &"); + m_writer->emit(decor->getFunctionName()); + m_writer->emit(", "); + emitStringLiteral(decor->getFunctionName()); + m_writer->emit(");\n"); + } + m_writer->dedent(); + m_writer->emit("}\n"); + +} + +} // namespace Slang diff --git a/source/slang/slang-emit-torch.h b/source/slang/slang-emit-torch.h new file mode 100644 index 000000000..84ce42331 --- /dev/null +++ b/source/slang/slang-emit-torch.h @@ -0,0 +1,28 @@ +// slang-emit-torch.h +#ifndef SLANG_EMIT_TORCH_H +#define SLANG_EMIT_TORCH_H + +#include "slang-emit-cpp.h" + +namespace Slang +{ + +class TorchCppSourceEmitter : public CPPSourceEmitter +{ +public: + typedef CPPSourceEmitter Super; + + TorchCppSourceEmitter(const Desc& desc) : + Super(desc) + { + } + +protected: + // CPPSourceEmitter overrides + virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) override; + virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) override; + virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink) override; +}; + +} +#endif diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 1b4eed8fd..fe72efcc7 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -56,6 +56,7 @@ #include "slang-ir-glsl-liveness.h" #include "slang-ir-string-hash.h" #include "slang-ir-simplify-for-emit.h" +#include "slang-ir-pytorch-cpp-binding.h" #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" @@ -74,6 +75,7 @@ #include "slang-emit-hlsl.h" #include "slang-emit-cpp.h" #include "slang-emit-cuda.h" +#include "slang-emit-torch.h" #include "../compiler-core/slang-artifact-desc-util.h" #include "../compiler-core/slang-artifact-util.h" @@ -83,6 +85,7 @@ #include Slang::String get_slang_cpp_host_prelude(); +Slang::String get_slang_torch_prelude(); namespace Slang { @@ -402,6 +405,18 @@ Result linkAndOptimizeIR( finalizeSpecialization(irModule); + switch (target) + { + case CodeGenTarget::PyTorchCppBinding: + generatePyTorchCppBinding(irModule, sink); + break; + case CodeGenTarget::CUDASource: + removeTorchKernels(irModule); + break; + default: + break; + } + // If we have a target that is GPU like we use the string hashing mechanism // but for that to work we need to inline such that calls (or returns) of strings // boil down into getStringHash(stringLiteral) @@ -969,31 +984,39 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr LinkedIR linkedIR; RefPtr sourceEmitter; - SourceLanguage sourceLanguage = CLikeSourceEmitter::getSourceLanguage(target); - switch (sourceLanguage) + + switch (target) { - case SourceLanguage::CPP: - { - sourceEmitter = new CPPSourceEmitter(desc); - break; - } - case SourceLanguage::GLSL: - { - sourceEmitter = new GLSLSourceEmitter(desc); - break; - } - case SourceLanguage::HLSL: - { - sourceEmitter = new HLSLSourceEmitter(desc); - break; - } - case SourceLanguage::CUDA: + default: + switch (sourceLanguage) { - sourceEmitter = new CUDASourceEmitter(desc); - break; + case SourceLanguage::CPP: + { + sourceEmitter = new CPPSourceEmitter(desc); + break; + } + case SourceLanguage::GLSL: + { + sourceEmitter = new GLSLSourceEmitter(desc); + break; + } + case SourceLanguage::HLSL: + { + sourceEmitter = new HLSLSourceEmitter(desc); + break; + } + case SourceLanguage::CUDA: + { + sourceEmitter = new CUDASourceEmitter(desc); + break; + } + default: break; } - default: break; + break; + case CodeGenTarget::PyTorchCppBinding: + sourceEmitter = new TorchCppSourceEmitter(desc); + break; } if (!sourceEmitter) @@ -1072,16 +1095,23 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr // Emit any front matter sourceEmitter->emitFrontMatter(targetRequest); - // If heterogeneous we output the prelude before everything else - if (isHeterogeneousTarget(target)) - { - sourceWriter.emit(get_slang_cpp_host_prelude()); - } - else + switch (target) { - // Get the prelude - String prelude = session->getPreludeForLanguage(sourceLanguage); - sourceWriter.emit(prelude); + case CodeGenTarget::PyTorchCppBinding: + sourceWriter.emit(get_slang_torch_prelude()); + break; + default: + if (isHeterogeneousTarget(target)) + { + sourceWriter.emit(get_slang_cpp_host_prelude()); + } + else + { + // Get the prelude + String prelude = session->getPreludeForLanguage(sourceLanguage); + sourceWriter.emit(prelude); + } + break; } // Emit anything that goes before the contents of the code generated for the module diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 04e08293f..68f1a28e6 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -67,6 +67,11 @@ INST(Nop, nop, 0, 0) INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE) + INST(TensorViewType, TensorView, 1, HOISTABLE) + INST(TorchTensorType, TorchTensor, 0, HOISTABLE) + INST(TorchKernelMemoryAllocatorType, TorchMemAllocatorType, 0, HOISTABLE) + INST(ArrayListType, ArrayListVector, 1, HOISTABLE) + /* BindExistentialsTypeBase */ // A `BindExistentials` represents @@ -220,6 +225,7 @@ INST(ThisType, this_type, 0, HOISTABLE) INST(RTTIType, rtti_type, 0, HOISTABLE) INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE) INST(TupleType, tuple_type, 0, HOISTABLE) +INST(TargetTupleType, TargetTuple, 0, HOISTABLE) // A type that identifies it's contained type as being emittable as `spirv_literal. INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE) @@ -308,6 +314,7 @@ INST(MakeArray, makeArray, 0, 0) INST(MakeArrayFromElement, makeArrayFromElement, 1, 0) INST(MakeStruct, makeStruct, 0, 0) INST(MakeTuple, makeTuple, 0, 0) +INST(MakeTargetTuple, makeTuple, 0, 0) INST(GetTupleElement, getTupleElement, 2, 0) INST(MakeResultValue, makeResultValue, 1, 0) INST(MakeResultError, makeResultError, 1, 0) @@ -509,24 +516,24 @@ INST(SwizzledStore, swizzledStore, 2, 0) /* IRConditionalbranch */ // conditionalBranch - INST(conditionalBranch, conditionalBranch, 3, 0) +INST(conditionalBranch, conditionalBranch, 3, 0) - // ifElse - INST(ifElse, ifElse, 4, 0) - INST_RANGE(ConditionalBranch, conditionalBranch, ifElse) +// ifElse +INST(ifElse, ifElse, 4, 0) +INST_RANGE(ConditionalBranch, conditionalBranch, ifElse) - INST(Throw, throw, 1, 0) - // tryCall ... - INST(TryCall, tryCall, 3, 0) - // switch ... - INST(Switch, switch, 3, 0) +INST(Throw, throw, 1, 0) +// tryCall ... +INST(TryCall, tryCall, 3, 0) +// switch ... +INST(Switch, switch, 3, 0) - INST(discard, discard, 0, 0) +INST(discard, discard, 0, 0) - /* IRUnreachable */ - INST(MissingReturn, missingReturn, 0, 0) - INST(Unreachable, unreachable, 0, 0) - INST_RANGE(Unreachable, MissingReturn, Unreachable) +/* IRUnreachable */ +INST(MissingReturn, missingReturn, 0, 0) +INST(Unreachable, unreachable, 0, 0) +INST_RANGE(Unreachable, MissingReturn, Unreachable) INST_RANGE(TerminatorInst, Return, Unreachable) @@ -575,10 +582,10 @@ INST(GetStringHash, getStringHash, 1, 0) INST(WaveGetActiveMask, waveGetActiveMask, 0, 0) - /// trueMask = waveMaskBallot(mask, condition) +/// trueMask = waveMaskBallot(mask, condition) INST(WaveMaskBallot, waveMaskBallot, 2, 0) - /// matchMask = waveMaskBallot(mask, value) +/// matchMask = waveMaskBallot(mask, value) INST(WaveMaskMatch, waveMaskMatch, 2, 0) // Texture sampling operation of the form `t.Sample(s,u)` @@ -604,6 +611,12 @@ INST(GetOptiXHitAttribute, getOptiXHitAttribute, 2, 0) // using a pointer. INST(GetOptiXSbtDataPtr, getOptiXSbtDataPointer, 0, 0) +INST(MakeArrayList, makeArrayList, 0, 0) +INST(MakeTensorView, makeTensorView, 0, 0) +INST(AllocateTorchTensor, allocTorchTensor , 0, 0) +INST(TorchGetCudaStream, TorchGetCudaStream, 0, 0) +INST(TorchTensorGetView, TorchTensorGetView, 0, 0) + /* Decoration */ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) @@ -669,7 +682,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(CudaKernelDecoration, CudaKernel, 0, 0) INST(CudaHostDecoration, CudaHost, 0, 0) - + INST(TorchEntryPointDecoration, TorchEntryPoint, 0, 0) + /// Used to mark parameters that are moved from entry point parameters to global params as coming from the entry point. INST(EntryPointParamDecoration, entryPointParam, 0, 0) @@ -908,6 +922,7 @@ INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) INST(PrimalSubstitute, PrimalSubstitute, 1, 0) INST(DispatchKernel, DispatchKernel, 3, 0) +INST(CudaKernelLaunch, CudaKernelLaunch, 6, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9bb66823b..4cdf6c749 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -518,6 +518,18 @@ struct IRDllExportDecoration : IRDecoration UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } }; +struct IRTorchEntryPointDecoration : IRDecoration +{ + enum + { + kOp = kIROp_TorchEntryPointDecoration + }; + IR_LEAF_ISA(TorchEntryPointDecoration) + + IRStringLit* getFunctionNameOperand() { return cast(getOperand(0)); } + UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } +}; + struct IRFormatDecoration : IRDecoration { enum { kOp = kIROp_FormatDecoration }; @@ -936,6 +948,15 @@ struct IRDispatchKernel : IRInst IR_LEAF_ISA(DispatchKernel) }; +struct IRTorchTensorGetView : IRInst +{ + enum + { + kOp = kIROp_TorchTensorGetView + }; + IR_LEAF_ISA(TorchTensorGetView) +}; + // Dictionary item mapping a type with a corresponding // IDifferentiable witness table // @@ -2720,6 +2741,8 @@ public: IRAnyValueType* getAnyValueType(IRInst* size); IRDynamicType* getDynamicType(); + IRTargetTupleType* getTargetTupleType(UInt count, IRType* const* types); + IRTupleType* getTupleType(UInt count, IRType* const* types); IRTupleType* getTupleType(List const& types) { @@ -2775,6 +2798,10 @@ public: IRInst* rowCount, IRInst* columnCount); + IRArrayListType* getArrayListType(IRType* elementType); + IRTensorViewType* getTensorViewType(IRType* elementType); + IRTorchTensorType* getTorchTensorType(); + IRDifferentialPairType* getDifferentialPairType( IRType* valueType, IRInst* witnessTable); @@ -2896,7 +2923,10 @@ public: IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn); IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn); + IRInst* emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs); + IRInst* emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream); + IRInst* emitGetTorchCudaStream(); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); @@ -2999,6 +3029,8 @@ public: // Creates an RTTI object. Result is of `IRRTTIType`. IRInst* emitMakeRTTIObject(IRInst* typeInst); + IRInst* emitMakeTargetTuple(IRType* type, UInt count, IRInst* const* args); + IRInst* emitMakeTuple(IRType* type, UInt count, IRInst* const* args); IRInst* emitMakeTuple(UInt count, IRInst* const* args); @@ -3067,6 +3099,11 @@ public: UInt argCount, IRInst* const* args); + IRInst* emitMakeArrayList( + IRType* type, + UInt argCount, + IRInst* const* args); + IRInst* emitMakeArrayFromElement( IRType* type, IRInst* element); @@ -3083,6 +3120,8 @@ public: return emitMakeStruct(type, args.getCount(), args.getBuffer()); } + IRInst* emitMakeTensorView(IRType* type, IRInst* allocator, IRInst* val); + IRInst* emitMakeExistential( IRType* type, IRInst* value, @@ -3785,6 +3824,11 @@ public: addDecoration(value, kIROp_DllExportDecoration, getStringValue(functionName)); } + void addTorchEntryPointDecoration(IRInst* value, UnownedStringSlice const& functionName) + { + addDecoration(value, kIROp_TorchEntryPointDecoration, getStringValue(functionName)); + } + void addCudaDeviceExportDecoration(IRInst* value, UnownedStringSlice const& functionName) { addDecoration(value, kIROp_CudaDeviceExportDecoration, getStringValue(functionName)); diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp new file mode 100644 index 000000000..e33adec1d --- /dev/null +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -0,0 +1,248 @@ +#include "slang-ir-pytorch-cpp-binding.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-diagnostics.h" + +namespace Slang +{ +static bool getHostReturnTypeImpl(List& elementTypes, IRBuilder& builder, IRType* type) +{ + bool isValid = true; + if (as(type)) + return true; + if (as(type)) + elementTypes.add(type); + else if (as(type)) + elementTypes.add(type); + else if (auto vectorType = as(type)) + { + auto count = as(vectorType->getElementCount()); + if (!count) + { + return false; + } + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + elementTypes.addRange(vectorType->getElementType()); + } + } + else if (auto arrayType = as(type)) + { + auto arraySize = as(arrayType->getElementCount()); + if (!arraySize) + { + return false; + } + List subElementTypes; + isValid &= getHostReturnTypeImpl(subElementTypes, builder, arrayType->getElementType()); + for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) + { + elementTypes.addRange(subElementTypes); + } + } + else if (auto structType = as(type)) + { + for (auto field : structType->getFields()) + { + isValid &= getHostReturnTypeImpl(elementTypes, builder, field->getFieldType()); + } + } + else + { + return false; + } + return isValid; +} + +static IRType* getHostReturnType(IRBuilder& builder, IRType* type) +{ + List types; + bool isValid = getHostReturnTypeImpl(types, builder, type); + if (isValid) + return builder.getTargetTupleType((UInt)types.getCount(), types.getBuffer()); + return nullptr; +} + +static void flattenToTupleImpl(List& result, IRBuilder& builder, IRInst* val) +{ + auto type = val->getDataType(); + if (as(type)) + return; + if (as(type)) + result.add(val); + else if (as(type)) + result.add(val); + else if (auto vectorType = as(type)) + { + auto count = as(vectorType->getElementCount()); + if (!count) + { + return; + } + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + result.add(builder.emitElementExtract(vectorType->getElementType(), builder.getIntValue(builder.getIntType(), i))); + } + } + else if (auto arrayType = as(type)) + { + auto arraySize = as(arrayType->getElementCount()); + if (!arraySize) + { + return; + } + for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) + { + auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + flattenToTupleImpl(result, builder, elementVal); + } + } + else if (auto structType = as(type)) + { + for (auto field : structType->getFields()) + { + auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey()); + flattenToTupleImpl(result, builder, elementVal); + } + } +} + +static IRInst* flattenToHostReturnTuple(IRBuilder& builder, IRType* type, IRInst* val) +{ + List vals; + flattenToTupleImpl(vals, builder, val); + return builder.emitMakeTargetTuple(type, (UInt)vals.getCount(), vals.getBuffer()); +} + +static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) +{ + IRBuilder builder(func); + + builder.setInsertBefore(func); + auto hostReturnType = getHostReturnType(builder, func->getResultType()); + if (!hostReturnType) + { + sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType()); + return; + } + List hostParamTypes; + auto funcType = as(func->getDataType()); + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + hostParamTypes.add(funcType->getParamType(i)); + } + auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType); + func->setFullType(bindingFuncType); + + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType)); + + List instsToRemove; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto kernelDispatch = as(inst)) + { + builder.setInsertBefore(kernelDispatch); + List kernelArgs; + auto kernelArgCount = kernelDispatch->getArgCount(); + auto argArrayType = builder.getArrayType(builder.getPtrType(builder.getVoidType()), + builder.getIntValue(builder.getIntType(), kernelArgCount)); + auto argArrayVar = builder.emitVar(argArrayType); + for (UInt i = 0; i < kernelArgCount; i++) + { + auto arg = kernelDispatch->getArg(i); + auto argVar = builder.emitVar(arg->getFullType()); + builder.emitStore(argVar, arg); + auto addr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), i)); + builder.emitStore(addr, argVar); + } + auto argArrayPtr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), 0)); + builder.emitCudaKernelLaunch( + kernelDispatch->getBaseFn(), + kernelDispatch->getDispatchSize(), + kernelDispatch->getThreadGroupSize(), + argArrayPtr, + builder.emitGetTorchCudaStream()); + instsToRemove.add(inst); + } + else if (auto getView = as(inst)) + { + builder.setInsertBefore(getView); + auto makeView = builder.emitMakeTensorView(getView->getFullType(), allocator, inst->getOperand(0)); + getView->replaceUsesWith(makeView); + instsToRemove.add(getView); + } + else if (auto ret = as(inst)) + { + builder.setInsertBefore(ret); + auto retVal = flattenToHostReturnTuple(builder, hostReturnType, ret->getVal()); + ret->setOperand(0, retVal); + } + } + } + + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); +} + +void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) +{ + List workList; + List cudaKernels; + for (auto globalInst : module->getGlobalInsts()) + { + auto func = as(globalInst); + if (!func) + continue; + if (func->findDecoration()) + { + workList.add(func); + } + else if (func->findDecoration()) + { + cudaKernels.add(func); + } + else + { + // Remove all other export decorations if this is not a cuda host func. + if (auto decor = func->findDecoration()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration()) + decor->removeAndDeallocate(); + } + } + + for (auto func : workList) + generateCppBindingForFunc(func, sink); + + for (auto func : cudaKernels) + { + for (auto block = func->getFirstBlock(); block;) + { + auto nextBlock = block->getNextBlock(); + block->removeAndDeallocate(); + block = nextBlock; + } + } +} + +// Remove all [TorchEntryPoint] functions when emitting CUDA source. +void removeTorchKernels(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (!as(globalInst)) + continue; + if (globalInst->findDecoration()) + globalInst->removeAndDeallocate(); + } + +} + +} diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h new file mode 100644 index 000000000..c35b6a8eb --- /dev/null +++ b/source/slang/slang-ir-pytorch-cpp-binding.h @@ -0,0 +1,12 @@ +#pragma once + +namespace Slang +{ +struct IRModule; +class DiagnosticSink; + +void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink); +void removeTorchKernels(IRModule* module); + +} + diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 69870c128..6ce54a948 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2594,6 +2594,11 @@ namespace Slang IRDynamicType* IRBuilder::getDynamicType() { return (IRDynamicType*)getType(kIROp_DynamicType); } + IRTargetTupleType* IRBuilder::getTargetTupleType(UInt count, IRType* const* types) + { + return (IRTargetTupleType*)getType(kIROp_TargetTupleType, count, (IRInst* const*)types); + } + IRAssociatedType* IRBuilder::getAssociatedType(ArrayView constraintTypes) { return (IRAssociatedType*)getType(kIROp_AssociatedType, @@ -2788,6 +2793,27 @@ namespace Slang operands); } + IRArrayListType* IRBuilder::getArrayListType(IRType* elementType) + { + return (IRArrayListType*)getType( + kIROp_ArrayListType, + 1, + (IRInst**)&elementType); + } + + IRTensorViewType* IRBuilder::getTensorViewType(IRType* elementType) + { + return (IRTensorViewType*)getType( + kIROp_TensorViewType, + 1, + (IRInst**)&elementType); + } + + IRTorchTensorType* IRBuilder::getTorchTensorType() + { + return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 0, nullptr); + } + IRDifferentialPairType* IRBuilder::getDifferentialPairType( IRType* valueType, IRInst* witnessTable) @@ -3173,6 +3199,21 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream) + { + IRInst* args[5] = {baseFn, gridDim, blockDim, argsArray, cudaStream}; + return emitIntrinsicInst( + getVoidType(), + kIROp_CudaKernelLaunch, + 5, + args); + } + + IRInst* IRBuilder::emitGetTorchCudaStream() + { + return emitIntrinsicInst(getPtrType(getVoidType()), kIROp_TorchGetCudaStream, 0, nullptr); + } + IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn) { auto inst = createInst( @@ -3659,6 +3700,11 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeTuple, count, args); } + IRInst* IRBuilder::emitMakeTargetTuple(IRType* type, UInt count, IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_MakeTargetTuple, count, args); + } + IRInst* IRBuilder::emitMakeTuple(UInt count, IRInst* const* args) { List types; @@ -3851,6 +3897,11 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeArray, argCount, args); } + IRInst* IRBuilder::emitMakeArrayList(IRType* type, UInt argCount, IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_MakeArrayList, argCount, args); + } + IRInst* IRBuilder::emitMakeArrayFromElement( IRType* type, IRInst* element) @@ -3866,6 +3917,12 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeStruct, argCount, args); } + IRInst* IRBuilder::emitMakeTensorView(IRType* type, IRInst* allocator, IRInst* val) + { + IRInst* args[2] = { allocator, val }; + return emitIntrinsicInst(type, kIROp_MakeTensorView, 2, args); + } + IRInst* IRBuilder::emitMakeExistential( IRType* type, IRInst* value, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 025812f83..d74a679d3 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1490,6 +1490,25 @@ struct IRMatrixType : IRType IR_LEAF_ISA(MatrixType) }; +struct IRArrayListType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + IR_LEAF_ISA(ArrayListType) +}; + +struct IRTensorViewType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + IR_LEAF_ISA(TensorViewType) +}; + +struct IRTorchTensorType : IRType +{ + IR_LEAF_ISA(TorchTensorType) +}; + struct IRSPIRVLiteralType : IRType { IR_LEAF_ISA(SPIRVLiteralType) @@ -1699,6 +1718,12 @@ struct IRTupleType : IRType IR_LEAF_ISA(TupleType) }; +/// Represents a tuple in target language. TargetTupleType will not be lowered to structs. +struct IRTargetTupleType : IRType +{ + IR_LEAF_ISA(TargetTupleType) +}; + /// Represents an `Result`, used by functions that throws error codes. struct IRResultType : IRType { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 7144b3450..9d424d1e8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1150,51 +1150,67 @@ static void addLinkageDecoration( { builder->addExportDecoration(inst, mangledName); } - if (decl->findModifier()) + for (auto modifier : decl->modifiers) { - builder->addPublicDecoration(inst); - builder->addKeepAliveDecoration(inst); - } - if (decl->findModifier()) - { - builder->addHLSLExportDecoration(inst); - builder->addKeepAliveDecoration(inst); - } - if (decl->findModifier()) - { - builder->addExternCppDecoration(inst, mangledName); + if (as(modifier)) + { + builder->addPublicDecoration(inst); + builder->addKeepAliveDecoration(inst); + } + else if (as(modifier)) + { + builder->addHLSLExportDecoration(inst); + builder->addKeepAliveDecoration(inst); + } + else if (as(modifier)) + { + builder->addExternCppDecoration(inst, mangledName); + } + else if (auto dllImportModifier = as(modifier)) + { + auto libraryName = dllImportModifier->modulePath; + auto functionName = dllImportModifier->functionName.getLength() + ? dllImportModifier->functionName.getUnownedSlice() + : decl->getName()->text.getUnownedSlice(); + builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), functionName); + } + else if (as(modifier)) + { + builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addPublicDecoration(inst); + } + else if (as(modifier)) + { + builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addPublicDecoration(inst); + builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + } + else if (as(modifier)) + { + builder->addCudaHostDecoration(inst); + builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + } + else if (as(modifier)) + { + builder->addCudaKernelDecoration(inst); + builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addPublicDecoration(inst); + builder->addKeepAliveDecoration(inst); + } + else if (as(modifier)) + { + builder->addTorchEntryPointDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addCudaHostDecoration(inst); + builder->addPublicDecoration(inst); + builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + } } if (as(decl->parentDecl) && - decl->parentDecl->hasModifier()) + decl->parentDecl->hasModifier() && + !inst->findDecoration()) { builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); } - if (auto dllImportModifier = decl->findModifier()) - { - auto libraryName = dllImportModifier->modulePath; - auto functionName = dllImportModifier->functionName.getLength() - ? dllImportModifier->functionName.getUnownedSlice() - : decl->getName()->text.getUnownedSlice(); - builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), functionName); - } - if (decl->findModifier()) - { - builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice()); - builder->addPublicDecoration(inst); - } - if (decl->findModifier()) - { - builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice()); - builder->addPublicDecoration(inst); - } - if (decl->findModifier()) - { - builder->addCudaHostDecoration(inst); - } - if (decl->findModifier()) - { - builder->addCudaKernelDecoration(inst); - } } static void addLinkageDecoration( diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 714e2c99d..d30c02484 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -2193,6 +2193,7 @@ struct OptionsParser if (rawOutputs.getCount() == 0 && rawTargets.getCount() == 1 && (rawTargets[0].format == CodeGenTarget::HostCPPSource || + rawTargets[0].format == CodeGenTarget::PyTorchCppBinding || rawTargets[0].format == CodeGenTarget::CUDASource || ArtifactDescUtil::makeDescForCompileTarget(asExternal(rawTargets[0].format)).kind == ArtifactKind::HostCallable)) { @@ -2258,7 +2259,7 @@ struct OptionsParser case CodeGenTarget::ShaderHostCallable: case CodeGenTarget::HostExecutable: case CodeGenTarget::ShaderSharedLibrary: - + case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::DXIL: rawOutput.isWholeProgram = true; diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index cdeb0b259..45f4be477 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2166,9 +2166,9 @@ namespace Slang dispatchExpr->baseFunction = parser->ParseArgExpr(); parser->ReadToken(TokenType::Comma); - dispatchExpr->threadGroupSize = parser->ParseArgExpr(); - parser->ReadToken(TokenType::Comma); dispatchExpr->dispatchSize = parser->ParseArgExpr(); + parser->ReadToken(TokenType::Comma); + dispatchExpr->threadGroupSize = parser->ParseArgExpr(); parser->ReadToken(TokenType::RParent); return dispatchExpr; diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 27aba435f..1c2726551 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -530,6 +530,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt matType->declRef = declRef; return matType; } + else if (magicMod->magicName == "TensorViewType") + { + SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); + auto vecType = astBuilder->getOrCreate(ExtractGenericArgType(subst->getArgs()[0])); + vecType->declRef = declRef; + return vecType; + } else if (magicMod->magicName == "Texture") { SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); diff --git a/source/slangc/main.cpp b/source/slangc/main.cpp index 2870a5a3c..2fe9d19a1 100644 --- a/source/slangc/main.cpp +++ b/source/slangc/main.cpp @@ -79,15 +79,15 @@ SLANG_TEST_TOOL_API SlangResult innerMain(StdWriters* stdWriters, slang::IGlobal if (TestToolUtil::hasDeferredStdLib(Index(argc - 1), argv + 1)) { SLANG_RETURN_ON_FAIL(slang_createGlobalSessionWithoutStdLib(SLANG_API_VERSION, session.writeRef())); - TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session); } else if (!session) { // Just create the global session in the regular way if there isn't one set SLANG_RETURN_ON_FAIL(slang_createGlobalSession(SLANG_API_VERSION, session.writeRef())); - TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session); } + TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session); + SlangCompileRequest* compileRequest = spCreateCompileRequest(session); compileRequest->addSearchPath(Path::getParentDirectory(Path::getExecutablePath()).getBuffer()); SlangResult res = _compile(compileRequest, argc, argv); diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang index 54442498b..2700fb054 100644 --- a/tests/autodiff/cuda-kernel-export.slang +++ b/tests/autodiff/cuda-kernel-export.slang @@ -3,39 +3,19 @@ // Verify that we can output a cuda device function with [CudaDeviceExport]. // Disabled until we have FileCheck. -struct MixedType : IDifferentiable -{ - no_diff float noDiffField; - float field; -} - -[BackwardDifferentiable] -float f1(MixedType m) -{ - return 2.0 * m.field; -} - -[BackwardDifferentiable] -float f(MixedType m) -{ - MixedType m1 = { m.noDiffField, m.field }; - return f1(m1); -} - -[CudaDeviceExport] -void diffF(inout DifferentialPair m, float dout) -{ - __bwd_diff(f)(m, dout); -} [CudaKernel] -void myKernel(float* inValues, float* outValues) +void myKernel(TensorView inValues, TensorView outValues) { - outValues[0] = sin(inValues[0]); + if (cudaThreadIdx().x > 0) + return; + outValues.store(cudaThreadIdx().x, sin(inValues.load(cudaThreadIdx().x))); } -[CudaHost] -public __extern_cpp void runCompute(float *inValues, float *outValues, uint3 dispathcSize) +[TorchEntryPoint] +public __extern_cpp TorchTensor runCompute(TorchTensor inValues) { - __dispatch_kernel(myKernel, uint3(128, 1, 1), dispathcSize)(inValues, outValues); + var outValues = TorchTensor.alloc(1); + __dispatch_kernel(myKernel, uint3(1, 1, 1), uint3(32, 1, 1))(inValues, outValues); + return outValues; } \ No newline at end of file diff --git a/tools/gfx/slang.slang b/tools/gfx/slang.slang index 4250cb62e..1fd06560f 100644 --- a/tools/gfx/slang.slang +++ b/tools/gfx/slang.slang @@ -53,6 +53,7 @@ enum SlangCompileTarget SLANG_DXIL_ASM, SLANG_C_SOURCE, ///< The C language SLANG_CPP_SOURCE, ///< C++ code for shader kernels. + SLANG_CPP_PYTORCH_BINDING, SLANG_HOST_EXECUTABLE, ///< Standalone binary executable (for hosting CPU/OS) SLANG_SHADER_SHARED_LIBRARY, ///< A shared library/Dll for shader kernels (for hosting CPU/OS) SLANG_SHADER_HOST_CALLABLE, ///< A CPU target that makes the compiled shader code available to be run immediately @@ -448,4 +449,4 @@ struct SpecializationArg TypeReflection *type; } -} \ No newline at end of file +} diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index d2023a0f0..71f5c04bd 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -765,6 +765,7 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target) case SLANG_GLSL: case SLANG_C_SOURCE: case SLANG_CPP_SOURCE: + case SLANG_CPP_PYTORCH_BINDING: case SLANG_HOST_CPP_SOURCE: case SLANG_CUDA_SOURCE: { -- cgit v1.2.3