diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-16 13:55:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-16 13:55:32 -0800 |
| commit | 4c4826d47eeef4675daae4ae53ff76f4d5ebd84a (patch) | |
| tree | ed4af0ded878e4f06e9641ce61d26ffd7c89ccbc | |
| parent | eda88e513e8b1e2abc05e9dc8555f237d96472df (diff) | |
Overhaul global inst deduplication and cpp/cuda backend. (#2654)
* Overhaul global inst deduplication and cpp/cuda backend.
* Update IR documentation.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
64 files changed, 2717 insertions, 4024 deletions
diff --git a/docs/design/ir-deduplication.md b/docs/design/ir-deduplication.md deleted file mode 100644 index 8d6ec3df0..000000000 --- a/docs/design/ir-deduplication.md +++ /dev/null @@ -1,80 +0,0 @@ -# IR Global Value Deduplication - -Types, constants and certain operations on constants are considered "global value" in the Slang IR. Some other insts like `Specialize()` and `Ptr(x)` are considered as "hoistable" insts, in that they will be defined at the outer most scope where their operands are available. For example, `Ptr(int)` will always be defined at global scope(as direct children of `IRModuleInst`) because its only operand, `int`, is defined at global scope. However if we have `Ptr(T)` where `T` is a generic parameter, then this `Ptr(T)` inst will be always be defined in the block of the generic. Global and hoistable values are always deduplicated and we can always assume two hoistable values with different pointer addresses are distinct values. - -The `IRBuilder` class is responsible for ensuring the uniqueness of global/hoistable values. If you call any `IRBuilder` methods that creates a new hoistable instruction, e.g. `IRBuilder::createIntrinsicInst`, `IRBuilder::emitXXX` or `IRBuilder::getType`, `IRBuilder` will check if an equivalent value already exists, and if so it returns the existing inst instead of creating a new one. - -The trickier part here is to always maintain the uniqueness when we modify the IR. When we update the operand of an inst from a non-hoistable-value to a hoistable-value, we may need to hoist `inst` itself as a result. For example, considered the following code: -``` -%1 = IntType -%p = Ptr(%1) -%2 = func { - %x = ...; - %3 = Ptr(%x); - %4 = ArrayType(%3); - %5 = Var (type: %4); - ... -} -``` - -Now consider the scenario where we need to replace the operand in `Ptr(x)` to `int` (where `x` is some non-constant value), we will get a `Ptr(int)` which is now a global value and should be deduplicated: -``` -%1 = IntType -%p = Ptr(%1) -%2 = func { - %x = ...; - //%3 now becomes %p. - %4 = ArrayType(%p); - %5 = Var (type: %4); - ... -} -``` -Note this code is now breaking the invariant that hoistable insts are always defined at the top-most scope, because `%4` becomes is no longer dependent on any local insts in the function, and should be hoisted to the global scope after replacing `%3` with `%p`. This means that we need to continue to perform hoisting of `%4`, to result this final code: -``` -%1 = IntType -%p = Ptr(%1) -%4 = ArrayType(%p); // hoisted to global scope -%2 = func { - %x = ...; - %5 = Var (type: %4); - ... -} -``` - -As illustrated above, because we need to maintain the invariants of global/hoistable values, replacing an operand of an inst can have wide-spread effect on the IR. - -To help ensure these invariants, we introduce the `IRBuilder.replaceOperand(inst, operandIndex, newOperand)` method to perform all the cascading modifications after replacing an operand. However the `IRInst.setOperand(idx, newOperand)` will not perform the cascading modifications, and using `setOperand` to modify the operand of a hoistable inst will trigger a runtime assertion error. - -Similarly, `inst->replaceUsesWith` will also perform any cascading modifications to ensure the uniqueness of hoistable values. Because of this, we need to be particularly careful when using a loop to iterate the IR linked list or def-use linked list and call `replaceUsesWith` or `replaceOperand` inside the loop. - -Consider the following code: - -``` -IRInst* nextInst = nullptr; -for (auto inst = func->getFirstChild(); inst; inst = nextInst) -{ - nextInst = inst->getNextInst(); // save a copy of nestInst - // ... - inst->replaceUsesWith(someNewInst); // Warning: this may be unsafe, because nextInst could been moved to parent->parent! -} -``` - -Now imagine this code is running on the `func` defined above, imagine we are now at `inst == %3` and we want to replace `inst` with `Ptr(int)`. Before calling `replaceUsesWith`, we have stored `inst->nextInst` to `nextInst`, so `nextInst` is now `%4`(the array type). Now after we call `replaceUsesWith`, `%4` is hoisted to global scope, so in the next iteration, we will start to process `%4` and follow its `next` pointer to `%2` and we will be processing `func` instead of continue walking the child list! - -Because of this, we should never be calling `replaceOperand` or `replaceUsesWith` when we are walking the IR linked list. If we want to do so, we must create a temporary workList and add all the insts to the work list before we make any modifications. The same can be said to the def-use linked list. There is `traverseUses` and `traverseUsers` utility functions defined in `slang-ir.h` to help with walking the def-use list safely. - -Another detail to keep in mind is that any local references to an inst may become out-of-date after a call to `replaceOperand` or `replaceUsesWith`. Consider the following code: -``` -IRBuilder builder; -auto x = builder.emitXXX(); // x is some non-hoistable value. -auto ptr = builder.getPtrType(x); // create ptr(x). -x->replaceUsesWith(intType); // this renders `ptr` obsolete!! -auto var = builder.emitVar(ptr); // use the obsolete inst to create another inst. -``` -In this example, calling `replaceUsesWith` will cause `ptr` to represent `Ptr(int)`, which may already exist in the global scope. After this call, all uses of `ptr` should be replaced with the global `Ptr(int)` inst instead. `IRBuilder` has provided the mechanism to track all the insts that are removed due to deduplication, and map those removed but not yet deleted inst to the existing inst. When using `ptr` to create a new inst, `IRBuilder` will first check if `ptr` should map to some existing hoistable inst in the global deduplication map and replace it if possible. This means that after the call to `builder.emitVar`, `var->type` is not equal to to `ptr`. - -## Best Practices - -In summary, the best practices when modifying the IR is: -- Never call `replaceUsesWith` or `replaceOperand` when walking raw linked lists in the IR. Always create a work list and iterate on the work list instead. -- Never assume any local references to an `inst` is up-to-date after a call to `replaceUsesWith` or `replaceOperand`. It is OK to continue using them as operands/types to create a new inst, but do not assume the created inst will reference the same inst passed in as argument. diff --git a/docs/design/ir.md b/docs/design/ir.md index 086d58b91..97c13c6ee 100644 --- a/docs/design/ir.md +++ b/docs/design/ir.md @@ -191,5 +191,85 @@ The current approach we use requires the structuring information to be maintaine In the future, it would be good to investigate adapting the "Relooper" algorithm used in Emscripten so that we can reover valid structured control flow from an arbitrary CFG; for now we put off that work. If we had a more powerful restructuring algorithm at hand, we could start to support things like multi-level `break`, and also ensure that `continue` clauses don't lead to code duplication any more. +## IR Global and Hoistable Value Deduplication + +Types, constants and certain operations on constants are considered "global value" in the Slang IR. Some other insts like `Specialize()` and `Ptr(x)` are considered as "hoistable" insts, in that they will be defined at the outer most scope where their operands are available. For example, `Ptr(int)` will always be defined at global scope(as direct children of `IRModuleInst`) because its only operand, `int`, is defined at global scope. However if we have `Ptr(T)` where `T` is a generic parameter, then this `Ptr(T)` inst will be always be defined in the block of the generic. Global and hoistable values are always deduplicated and we can always assume two hoistable values with different pointer addresses are distinct values. + +The `IRBuilder` class is responsible for ensuring the uniqueness of global/hoistable values. If you call any `IRBuilder` methods that creates a new hoistable instruction, e.g. `IRBuilder::createIntrinsicInst`, `IRBuilder::emitXXX` or `IRBuilder::getType`, `IRBuilder` will check if an equivalent value already exists, and if so it returns the existing inst instead of creating a new one. + +The trickier part here is to always maintain the uniqueness when we modify the IR. When we update the operand of an inst from a non-hoistable-value to a hoistable-value, we may need to hoist `inst` itself as a result. For example, considered the following code: +``` +%1 = IntType +%p = Ptr(%1) +%2 = func { + %x = ...; + %3 = Ptr(%x); + %4 = ArrayType(%3); + %5 = Var (type: %4); + ... +} +``` + +Now consider the scenario where we need to replace the operand in `Ptr(x)` to `int` (where `x` is some non-constant value), we will get a `Ptr(int)` which is now a global value and should be deduplicated: +``` +%1 = IntType +%p = Ptr(%1) +%2 = func { + %x = ...; + //%3 now becomes %p. + %4 = ArrayType(%p); + %5 = Var (type: %4); + ... +} +``` +Note this code is now breaking the invariant that hoistable insts are always defined at the top-most scope, because `%4` becomes is no longer dependent on any local insts in the function, and should be hoisted to the global scope after replacing `%3` with `%p`. This means that we need to continue to perform hoisting of `%4`, to result this final code: +``` +%1 = IntType +%p = Ptr(%1) +%4 = ArrayType(%p); // hoisted to global scope +%2 = func { + %x = ...; + %5 = Var (type: %4); + ... +} +``` + +As illustrated above, because we need to maintain the invariants of global/hoistable values, replacing an operand of an inst can have wide-spread effect on the IR. + +To help ensure these invariants, we introduce the `IRBuilder.replaceOperand(inst, operandIndex, newOperand)` method to perform all the cascading modifications after replacing an operand. However the `IRInst.setOperand(idx, newOperand)` will not perform the cascading modifications, and using `setOperand` to modify the operand of a hoistable inst will trigger a runtime assertion error. + +Similarly, `inst->replaceUsesWith` will also perform any cascading modifications to ensure the uniqueness of hoistable values. Because of this, we need to be particularly careful when using a loop to iterate the IR linked list or def-use linked list and call `replaceUsesWith` or `replaceOperand` inside the loop. + +Consider the following code: + +``` +IRInst* nextInst = nullptr; +for (auto inst = func->getFirstChild(); inst; inst = nextInst) +{ + nextInst = inst->getNextInst(); // save a copy of nestInst + // ... + inst->replaceUsesWith(someNewInst); // Warning: this may be unsafe, because nextInst could been moved to parent->parent! +} +``` + +Now imagine this code is running on the `func` defined above, imagine we are now at `inst == %3` and we want to replace `inst` with `Ptr(int)`. Before calling `replaceUsesWith`, we have stored `inst->nextInst` to `nextInst`, so `nextInst` is now `%4`(the array type). Now after we call `replaceUsesWith`, `%4` is hoisted to global scope, so in the next iteration, we will start to process `%4` and follow its `next` pointer to `%2` and we will be processing `func` instead of continue walking the child list! + +Because of this, we should never be calling `replaceOperand` or `replaceUsesWith` when we are walking the IR linked list. If we want to do so, we must create a temporary workList and add all the insts to the work list before we make any modifications. The `IRInst::getModifiableChildren` utility function will return a temporary work list for safe iteration on the children. The same can be said to the def-use linked list. There is `traverseUses` and `traverseUsers` utility functions defined in `slang-ir.h` to help with walking the def-use list safely. + +Another detail to keep in mind is that any local references to an inst may become out-of-date after a call to `replaceOperand` or `replaceUsesWith`. Consider the following code: +``` +IRBuilder builder; +auto x = builder.emitXXX(); // x is some non-hoistable value. +auto ptr = builder.getPtrType(x); // create ptr(x). +x->replaceUsesWith(intType); // this renders `ptr` obsolete!! +auto var = builder.emitVar(ptr); // use the obsolete inst to create another inst. +``` +In this example, calling `replaceUsesWith` will cause `ptr` to represent `Ptr(int)`, which may already exist in the global scope. After this call, all uses of `ptr` should be replaced with the global `Ptr(int)` inst instead. `IRBuilder` has provided the mechanism to track all the insts that are removed due to deduplication, and map those removed but not yet deleted inst to the existing inst. When using `ptr` to create a new inst, `IRBuilder` will first check if `ptr` should map to some existing hoistable inst in the global deduplication map and replace it if possible. This means that after the call to `builder.emitVar`, `var->type` is not equal to to `ptr`. + +### Best Practices + +In summary, the best practices when modifying the IR is: +- Never call `replaceUsesWith` or `replaceOperand` when walking raw linked lists in the IR. Always create a work list and iterate on the work list instead. Use `IRInst::getModifiableChildren` and `traverseUses` when you need to modify the IR while iterating. +- Never assume any local references to an `inst` is up-to-date after a call to `replaceUsesWith` or `replaceOperand`. It is OK to continue using them as operands/types to create a new inst, but do not assume the created inst will reference the same inst passed in as argument. diff --git a/prelude/slang-cpp-prelude.h b/prelude/slang-cpp-prelude.h index 84a61f929..d15abdb88 100644 --- a/prelude/slang-cpp-prelude.h +++ b/prelude/slang-cpp-prelude.h @@ -296,8 +296,8 @@ struct ISlangUnknown // Includes -#include "slang-cpp-types.h" #include "slang-cpp-scalar-intrinsics.h" +#include "slang-cpp-types.h" // TODO(JS): Hack! Output C++ code from slang can copy uninitialized variables. #if defined(_MSC_VER) diff --git a/prelude/slang-cpp-scalar-intrinsics.h b/prelude/slang-cpp-scalar-intrinsics.h index 66035260d..2b9e7f777 100644 --- a/prelude/slang-cpp-scalar-intrinsics.h +++ b/prelude/slang-cpp-scalar-intrinsics.h @@ -490,6 +490,17 @@ void InterlockedAdd(uint32_t* dest, uint32_t value, uint32_t* oldValue) #endif // SLANG_LLVM + +// ----------------------- fmod -------------------------- +SLANG_FORCE_INLINE float _slang_fmod(float x, float y) +{ + return F32_fmod(x, y); +} +SLANG_FORCE_INLINE double _slang_fmod(double x, double y) +{ + return F64_fmod(x, y); +} + #ifdef SLANG_PRELUDE_NAMESPACE } #endif diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h index c15c5ec40..28fe3dd8d 100644 --- a/prelude/slang-cpp-types.h +++ b/prelude/slang-cpp-types.h @@ -86,26 +86,159 @@ template <typename T> struct Vector<T, 1> { T x; + const T& operator[](size_t /*index*/) const { return x; } + T& operator[](size_t /*index*/) { return x; } + operator T() const { return x; } + Vector() = default; + Vector(T scalar) + { + x = scalar; + } + template <typename U> + Vector(Vector<U, 1> other) + { + x = (T)other.x; + } + template <typename U, int otherSize> + Vector(Vector<U, otherSize> other) + { + int minSize = 1; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } }; template <typename T> struct Vector<T, 2> { T x, y; + const T& operator[](size_t index) const { return index == 0 ? x : y; } + T& operator[](size_t index) { return index == 0 ? x : y; } + Vector() = default; + Vector(T scalar) + { + x = y = scalar; + } + Vector(T _x, T _y) + { + x = _x; + y = _y; + } + template <typename U> + Vector(Vector<U, 2> other) + { + x = (T)other.x; + y = (T)other.y; + } + template <typename U, int otherSize> + Vector(Vector<U, otherSize> other) + { + int minSize = 2; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } }; template <typename T> struct Vector<T, 3> { T x, y, z; + const T& operator[](size_t index) const { return *((T*)(this) + index); } + T& operator[](size_t index) { return *((T*)(this) + index); } + + Vector() = default; + Vector(T scalar) + { + x = y = z = scalar; + } + Vector(T _x, T _y, T _z) + { + x = _x; + y = _y; + z = _z; + } + template <typename U> + Vector(Vector<U, 3> other) + { + x = (T)other.x; + y = (T)other.y; + z = (T)other.z; + } + template <typename U, int otherSize> + Vector(Vector<U, otherSize> other) + { + int minSize = 3; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } }; template <typename T> struct Vector<T, 4> { T x, y, z, w; + + const T& operator[](size_t index) const { return *((T*)(this) + index); } + T& operator[](size_t index) { return *((T*)(this) + index); } + Vector() = default; + Vector(T scalar) + { + x = y = z = w = scalar; + } + Vector(T _x, T _y, T _z, T _w) + { + x = _x; + y = _y; + z = _z; + w = _w; + } + template <typename U, int otherSize> + Vector(Vector<U, otherSize> other) + { + int minSize = 4; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } + }; +template<typename T, int N> +SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index) +{ + return x[index]; +} + +template<typename T, int N> +SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector<T, N>* x, int index) +{ + return &((*const_cast<Vector<T,N>*>(x))[index]); +} + +template<typename T, int N> +SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector<T, N>* x, int index) +{ + return &((*x)[index]); +} + +template<typename T, int n, typename OtherT, int m> +SLANG_FORCE_INLINE Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other) +{ + Vector<T, n> result; + for (int i = 0; i < n; i++) + { + OtherT otherElement = T(0); + if (i < m) + otherElement = _slang_vector_get_element(other, i); + *_slang_vector_get_element_ptr(&result, i) = (T)otherElement; + } + return result; +} + +typedef uint32_t uint; typedef Vector<float, 2> float2; typedef Vector<float, 3> float3; @@ -119,12 +252,320 @@ typedef Vector<uint32_t, 2> uint2; typedef Vector<uint32_t, 3> uint3; typedef Vector<uint32_t, 4> uint4; +#define SLANG_VECTOR_BINARY_OP(T, op) \ + template<int n> \ + SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \ + { \ + Vector<T, n> result;\ + for (int i = 0; i < n; i++) \ + result[i] = thisVal[i] op other[i]; \ + return result;\ + } +#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \ + template<int n> \ + SLANG_FORCE_INLINE Vector<bool, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \ + { \ + Vector<bool, n> result;\ + for (int i = 0; i < n; i++) \ + result[i] = thisVal[i] op other[i]; \ + return result;\ + } + +#define SLANG_VECTOR_UNARY_OP(T, op) \ + template<int n> \ + SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal) \ + { \ + Vector<T, n> result;\ + for (int i = 0; i < n; i++) \ + result[i] = op thisVal[i]; \ + return result;\ + } +#define SLANG_INT_VECTOR_OPS(T) \ + SLANG_VECTOR_BINARY_OP(T, +)\ + SLANG_VECTOR_BINARY_OP(T, -)\ + SLANG_VECTOR_BINARY_OP(T, *)\ + SLANG_VECTOR_BINARY_OP(T, / )\ + SLANG_VECTOR_BINARY_OP(T, &)\ + SLANG_VECTOR_BINARY_OP(T, |)\ + SLANG_VECTOR_BINARY_OP(T, &&)\ + SLANG_VECTOR_BINARY_OP(T, ||)\ + SLANG_VECTOR_BINARY_OP(T, ^)\ + SLANG_VECTOR_BINARY_OP(T, %)\ + SLANG_VECTOR_BINARY_OP(T, >>)\ + SLANG_VECTOR_BINARY_OP(T, <<)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\ + SLANG_VECTOR_UNARY_OP(T, !)\ + SLANG_VECTOR_UNARY_OP(T, ~) +#define SLANG_FLOAT_VECTOR_OPS(T) \ + SLANG_VECTOR_BINARY_OP(T, +)\ + SLANG_VECTOR_BINARY_OP(T, -)\ + SLANG_VECTOR_BINARY_OP(T, *)\ + SLANG_VECTOR_BINARY_OP(T, /)\ + SLANG_VECTOR_UNARY_OP(T, -)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, !=) + +SLANG_INT_VECTOR_OPS(bool) +SLANG_INT_VECTOR_OPS(int) +SLANG_INT_VECTOR_OPS(int8_t) +SLANG_INT_VECTOR_OPS(int16_t) +SLANG_INT_VECTOR_OPS(int64_t) +SLANG_INT_VECTOR_OPS(uint) +SLANG_INT_VECTOR_OPS(uint8_t) +SLANG_INT_VECTOR_OPS(uint16_t) +SLANG_INT_VECTOR_OPS(uint64_t) + +SLANG_FLOAT_VECTOR_OPS(float) +SLANG_FLOAT_VECTOR_OPS(double) + +#define SLANG_VECTOR_INT_NEG_OP(T) \ + template<int N>\ + Vector<T, N> operator-(const Vector<T, N>& thisVal) \ + { \ + Vector<T, N> result;\ + for (int i = 0; i < N; i++) \ + result[i] = 0 - thisVal[i]; \ + return result;\ + } +SLANG_VECTOR_INT_NEG_OP(int) +SLANG_VECTOR_INT_NEG_OP(int8_t) +SLANG_VECTOR_INT_NEG_OP(int16_t) +SLANG_VECTOR_INT_NEG_OP(int64_t) +SLANG_VECTOR_INT_NEG_OP(uint) +SLANG_VECTOR_INT_NEG_OP(uint8_t) +SLANG_VECTOR_INT_NEG_OP(uint16_t) +SLANG_VECTOR_INT_NEG_OP(uint64_t) + +#define SLANG_FLOAT_VECTOR_MOD(T)\ + template<int N> \ + Vector<T, N> operator%(const Vector<T, N>& left, const Vector<T, N>& right) \ + {\ + Vector<T, N> result;\ + for (int i = 0; i < N; i++) \ + result[i] = _slang_fmod(left[i], right[i]); \ + return result;\ + } + +SLANG_FLOAT_VECTOR_MOD(float) +SLANG_FLOAT_VECTOR_MOD(double) +#undef SLANG_FLOAT_VECTOR_MOD +#undef SLANG_VECTOR_BINARY_OP +#undef SLANG_VECTOR_UNARY_OP +#undef SLANG_INT_VECTOR_OPS +#undef SLANG_FLOAT_VECTOR_OPS +#undef SLANG_VECTOR_INT_NEG_OP +#undef SLANG_FLOAT_VECTOR_MOD + template <typename T, int ROWS, int COLS> struct Matrix { Vector<T, COLS> rows[ROWS]; + Vector<T, COLS>& operator[](size_t index) { return rows[index]; } + Matrix() = default; + Matrix(T scalar) + { + for (int i = 0; i < ROWS; i++) + rows[i] = Vector<T, COLS>(scalar); + } + Matrix(const Vector<T, COLS>& row0) + { + rows[0] = row0; + } + Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1) + { + rows[0] = row0; + rows[1] = row1; + } + Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2) + { + rows[0] = row0; + rows[1] = row1; + rows[2] = row2; + } + Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3) + { + rows[0] = row0; + rows[1] = row1; + rows[2] = row2; + rows[3] = row3; + } + template<typename U, int otherRow, int otherCol> + Matrix(const Matrix<U, otherRow, otherCol>& other) + { + int minRow = ROWS; + int minCol = COLS; + if (minRow > otherRow) minRow = otherRow; + if (minCol > otherCol) minCol = otherCol; + for (int i = 0; i < minRow; i++) + for (int j = 0; j < minCol; j++) + rows[i][j] = (T)other.rows[i][j]; + } + Matrix(T v0, T v1, T v2, T v3) + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5) + { + if (COLS == 3) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + rows[2][0] = v4; rows[2][1] = v5; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) + { + if (COLS == 4) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + rows[2][0] = v4; rows[2][1] = v5; + rows[3][0] = v6; rows[3][1] = v7; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11) + { + if (COLS == 4) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; + rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; + rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15; + } }; +#define SLANG_MATRIX_BINARY_OP(T, op) \ + template<int R, int C> \ + Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \ + return result;\ + } + +#define SLANG_MATRIX_UNARY_OP(T, op) \ + template<int R, int C> \ + Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result[i].rows[i][j] = op thisVal.rows[i][j]; \ + return result;\ + } +#define SLANG_INT_MATRIX_OPS(T) \ + SLANG_MATRIX_BINARY_OP(T, +)\ + SLANG_MATRIX_BINARY_OP(T, -)\ + SLANG_MATRIX_BINARY_OP(T, *)\ + SLANG_MATRIX_BINARY_OP(T, / )\ + SLANG_MATRIX_BINARY_OP(T, &)\ + SLANG_MATRIX_BINARY_OP(T, |)\ + SLANG_MATRIX_BINARY_OP(T, &&)\ + SLANG_MATRIX_BINARY_OP(T, ||)\ + SLANG_MATRIX_BINARY_OP(T, ^)\ + SLANG_MATRIX_BINARY_OP(T, %)\ + SLANG_MATRIX_UNARY_OP(T, !)\ + SLANG_MATRIX_UNARY_OP(T, ~) +#define SLANG_FLOAT_MATRIX_OPS(T) \ + SLANG_MATRIX_BINARY_OP(T, +)\ + SLANG_MATRIX_BINARY_OP(T, -)\ + SLANG_MATRIX_BINARY_OP(T, *)\ + SLANG_MATRIX_BINARY_OP(T, /)\ + SLANG_MATRIX_UNARY_OP(T, -) +SLANG_INT_MATRIX_OPS(int) +SLANG_INT_MATRIX_OPS(int8_t) +SLANG_INT_MATRIX_OPS(int16_t) +SLANG_INT_MATRIX_OPS(int64_t) +SLANG_INT_MATRIX_OPS(uint) +SLANG_INT_MATRIX_OPS(uint8_t) +SLANG_INT_MATRIX_OPS(uint16_t) +SLANG_INT_MATRIX_OPS(uint64_t) + +SLANG_FLOAT_MATRIX_OPS(float) +SLANG_FLOAT_MATRIX_OPS(double) + +#define SLANG_MATRIX_INT_NEG_OP(T) \ + template<int R, int C>\ + SLANG_FORCE_INLINE Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = 0 - thisVal.rows[i][j]; \ + return result;\ + } + SLANG_MATRIX_INT_NEG_OP(int) + SLANG_MATRIX_INT_NEG_OP(int8_t) + SLANG_MATRIX_INT_NEG_OP(int16_t) + SLANG_MATRIX_INT_NEG_OP(int64_t) + SLANG_MATRIX_INT_NEG_OP(uint) + SLANG_MATRIX_INT_NEG_OP(uint8_t) + SLANG_MATRIX_INT_NEG_OP(uint16_t) + SLANG_MATRIX_INT_NEG_OP(uint64_t) + +#define SLANG_FLOAT_MATRIX_MOD(T)\ + template<int R, int C> \ + SLANG_FORCE_INLINE Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \ + {\ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \ + return result;\ + } + + SLANG_FLOAT_MATRIX_MOD(float) + SLANG_FLOAT_MATRIX_MOD(double) +#undef SLANG_FLOAT_MATRIX_MOD +#undef SLANG_MATRIX_BINARY_OP +#undef SLANG_MATRIX_UNARY_OP +#undef SLANG_INT_MATRIX_OPS +#undef SLANG_FLOAT_MATRIX_OPS +#undef SLANG_MATRIX_INT_NEG_OP +#undef SLANG_FLOAT_MATRIX_MOD + // We can just map `NonUniformResourceIndex` type directly to the index type on CPU, as CPU does not require // any special handling around such accesses. typedef size_t NonUniformResourceIndex; diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 448b69c63..cb1bb188b 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -149,12 +149,11 @@ typedef size_t NonUniformResourceIndex; template <typename T, int ROWS, int COLS> struct Matrix; -typedef bool bool1; +typedef int1 bool1; typedef int2 bool2; typedef int3 bool3; typedef int4 bool4; - typedef signed char int8_t; typedef short int16_t; typedef int int32_t; @@ -186,163 +185,522 @@ union Union64 double d; }; -// -// Half support -// +SLANG_FORCE_INLINE SLANG_CUDA_CALL float _slang_fmod(float x, float y) +{ + return ::fmodf(x, y); +} +SLANG_FORCE_INLINE SLANG_CUDA_CALL double _slang_fmod(double x, double y) +{ + return ::fmod(x, y); +} #if SLANG_CUDA_ENABLE_HALF // Add the other vector half types -struct __half3 { __half2 xy; __half z; }; -struct __half4 { __half2 xy; __half2 zw; }; - -// *** convert *** - -// half -> other - -// float -SLANG_FORCE_INLINE SLANG_CUDA_CALL float2 convert_float2(const __half2& v) { return __half22float2(v); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL float3 convert_float3(const __half3& v) { const float2 xy = __half22float2(v.xy); return float3{xy.x, xy.y, __half2float(v.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL float4 convert_float4(const __half4& v) { const float2 xy = __half22float2(v.xy); const float2 zw = __half22float2(v.zw); return float4{xy.x, xy.y, zw.x, zw.y}; } - -// double -SLANG_FORCE_INLINE SLANG_CUDA_CALL double2 convert_double2(const __half2& v) { const float2 xy = __half22float2(v); return double2{ xy.x, xy.y }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL double3 convert_double3(const __half3& v) { const float2 xy = __half22float2(v.xy); return double3{ xy.x, xy.y, __half2float(v.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL double4 convert_double4(const __half4& v) { const float2 xy = __half22float2(v.xy); const float2 zw = __half22float2(v.zw); return double4{xy.x, xy.y, zw.x, zw.y}; } - -// int -SLANG_FORCE_INLINE SLANG_CUDA_CALL int2 convert_int2(const __half2& v) { return int2 { __half2int_rz(v.x), __half2int_rz(v.y) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL int3 convert_int3(const __half3& v) { return int3 { __half2int_rz(v.xy.x), __half2int_rz(v.xy.y), __half2int_rz(v.z) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL int4 convert_int4(const __half4& v) { return int4 { __half2int_rz(v.xy.x), __half2int_rz(v.xy.y), __half2int_rz(v.zw.x), __half2int_rz(v.zw.y)}; } - -// uint -SLANG_FORCE_INLINE SLANG_CUDA_CALL uint2 convert_uint2(const __half2& v) { return uint2 { __half2uint_rz(v.x), __half2uint_rz(v.y) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL uint3 convert_uint3(const __half3& v) { return uint3 { __half2uint_rz(v.xy.x), __half2uint_rz(v.xy.y), __half2uint_rz(v.z) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL uint4 convert_uint4(const __half4& v) { return uint4 { __half2uint_rz(v.xy.x), __half2uint_rz(v.xy.y), __half2uint_rz(v.zw.x), __half2uint_rz(v.zw.y)}; } +struct __half1 { __half x; }; +struct __align__(4) __half3 { __half x, y, z; }; +struct __align__(4) __half4 { __half x, y, z, w; }; +#endif -// other -> half +#define SLANG_VECTOR_GET_ELEMENT(T) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##1 x, int index) { return ((T*)(&x))[index]; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##2 x, int index) { return ((T*)(&x))[index]; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##3 x, int index) { return ((T*)(&x))[index]; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##4 x, int index) { return ((T*)(&x))[index]; } +SLANG_VECTOR_GET_ELEMENT(int) +SLANG_VECTOR_GET_ELEMENT(uint) +SLANG_VECTOR_GET_ELEMENT(short) +SLANG_VECTOR_GET_ELEMENT(ushort) +SLANG_VECTOR_GET_ELEMENT(char) +SLANG_VECTOR_GET_ELEMENT(uchar) +SLANG_VECTOR_GET_ELEMENT(longlong) +SLANG_VECTOR_GET_ELEMENT(ulonglong) +SLANG_VECTOR_GET_ELEMENT(float) +SLANG_VECTOR_GET_ELEMENT(double) + +#define SLANG_VECTOR_GET_ELEMENT_PTR(T) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(T##1* x, int index) { return ((T*)(x)) + index; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(T##2* x, int index) { return ((T*)(x)) + index; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(T##3* x, int index) { return ((T*)(x)) + index; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(T##4* x, int index) { return ((T*)(x)) + index; } +SLANG_VECTOR_GET_ELEMENT_PTR(int) +SLANG_VECTOR_GET_ELEMENT_PTR(uint) +SLANG_VECTOR_GET_ELEMENT_PTR(short) +SLANG_VECTOR_GET_ELEMENT_PTR(ushort) +SLANG_VECTOR_GET_ELEMENT_PTR(char) +SLANG_VECTOR_GET_ELEMENT_PTR(uchar) +SLANG_VECTOR_GET_ELEMENT_PTR(longlong) +SLANG_VECTOR_GET_ELEMENT_PTR(ulonglong) +SLANG_VECTOR_GET_ELEMENT_PTR(float) +SLANG_VECTOR_GET_ELEMENT_PTR(double) -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 convert___half2(const float2& v) { return __float22half2_rn(v); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 convert___half3(const float3& v) { return __half3{ __float22half2_rn(float2{v.x, v.y}), __float2half_rn(v.z) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 convert___half4(const float4& v) { return __half4{ __float22half2_rn(float2{v.x, v.y}), __float22half2_rn(float2{v.z, v.w}) }; } +#if SLANG_CUDA_ENABLE_HALF +SLANG_VECTOR_GET_ELEMENT(__half) +SLANG_VECTOR_GET_ELEMENT_PTR(__half) +#endif -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 convert___half2(const int2& v) { return __half2{ __int2half_rz(v.x), __int2half_rz(v.y) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 convert___half3(const int3& v) { return __half3{ __half2{__int2half_rz(v.x), __int2half_rz(v.y)}, __int2half_rz(v.z) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 convert___half4(const int4& v) { return __half4{ __half2{__int2half_rz(v.x), __int2half_rz(v.y)}, __half2{__int2half_rz(v.z), __int2half_rz(v.w)} }; } +#define SLANG_CUDA_VECTOR_BINARY_OP(T, n, op) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator op(T##n thisVal, T##n other) \ + { \ + T##n result;\ + for (int i = 0; i < n; i++) \ + *_slang_vector_get_element_ptr(&result, i) = _slang_vector_get_element(thisVal,i) op _slang_vector_get_element(other,i); \ + return result;\ + } +#define SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, op) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL bool##n operator op(T##n thisVal, T##n other) \ + { \ + bool##n result;\ + for (int i = 0; i < n; i++) \ + *_slang_vector_get_element_ptr(&result, i) = (int)(_slang_vector_get_element(thisVal,i) op _slang_vector_get_element(other,i)); \ + return result;\ + } +#define SLANG_CUDA_VECTOR_UNARY_OP(T, n, op) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator op(T##n thisVal) \ + { \ + T##n result;\ + for (int i = 0; i < n; i++) \ + *_slang_vector_get_element_ptr(&result, i) = op _slang_vector_get_element(thisVal,i); \ + return result;\ + } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 convert___half2(const uint2& v) { return __half2{ __uint2half_rz(v.x), __uint2half_rz(v.y) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 convert___half3(const uint3& v) { return __half3{ __half2{__uint2half_rz(v.x), __uint2half_rz(v.y)}, __uint2half_rz(v.z) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 convert___half4(const uint4& v) { return __half4{ __half2{__uint2half_rz(v.x), __uint2half_rz(v.y)}, __half2{__uint2half_rz(v.z), __uint2half_rz(v.w)} }; } +#define SLANG_CUDA_VECTOR_INT_OP(T, n) \ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, +)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, -)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, *)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, /)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, %)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, ^)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, &)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, |)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, &&)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, ||)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, >>)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, <<)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >=)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <=)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, ==)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, !=)\ + SLANG_CUDA_VECTOR_UNARY_OP(T, n, !)\ + SLANG_CUDA_VECTOR_UNARY_OP(T, n, -)\ + SLANG_CUDA_VECTOR_UNARY_OP(T, n, ~) + +#define SLANG_CUDA_VECTOR_INT_OPS(T) \ + SLANG_CUDA_VECTOR_INT_OP(T, 2) \ + SLANG_CUDA_VECTOR_INT_OP(T, 3) \ + SLANG_CUDA_VECTOR_INT_OP(T, 4) + +SLANG_CUDA_VECTOR_INT_OPS(int) +SLANG_CUDA_VECTOR_INT_OPS(uint) +SLANG_CUDA_VECTOR_INT_OPS(ushort) +SLANG_CUDA_VECTOR_INT_OPS(short) +SLANG_CUDA_VECTOR_INT_OPS(char) +SLANG_CUDA_VECTOR_INT_OPS(uchar) +SLANG_CUDA_VECTOR_INT_OPS(longlong) +SLANG_CUDA_VECTOR_INT_OPS(ulonglong) + +#define SLANG_CUDA_VECTOR_FLOAT_OP(T, n) \ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, +)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, -)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, *)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, /)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, &&)\ + SLANG_CUDA_VECTOR_BINARY_OP(T, n, ||)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >=)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <=)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, ==)\ + SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, !=)\ + SLANG_CUDA_VECTOR_UNARY_OP(T, n, -) +#define SLANG_CUDA_VECTOR_FLOAT_OPS(T) \ + SLANG_CUDA_VECTOR_FLOAT_OP(T, 2) \ + SLANG_CUDA_VECTOR_FLOAT_OP(T, 3) \ + SLANG_CUDA_VECTOR_FLOAT_OP(T, 4) + +SLANG_CUDA_VECTOR_FLOAT_OPS(float) +SLANG_CUDA_VECTOR_FLOAT_OPS(double) +#if SLANG_CUDA_ENABLE_HALF +SLANG_CUDA_VECTOR_FLOAT_OPS(__half) +#endif +#define SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, n)\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator%(const T##n& left, const T##n& right) \ + {\ + T##n result;\ + for (int i = 0; i < n; i++) \ + *_slang_vector_get_element_ptr(&result, i) = _slang_fmod(_slang_vector_get_element(left,i), _slang_vector_get_element(right,i)); \ + return result;\ + } +#define SLANG_CUDA_FLOAT_VECTOR_MOD(T) \ + SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 2)\ + SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 3)\ + SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 4) + +SLANG_CUDA_FLOAT_VECTOR_MOD(float) +SLANG_CUDA_FLOAT_VECTOR_MOD(double) + +#define SLANG_MAKE_VECTOR(T) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x, T y) { return T##2{x, y}; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x, T y, T z) { return T##3{ x, y, z }; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x, T y, T z, T w) { return T##4{ x, y, z, w }; } +SLANG_MAKE_VECTOR(int) +SLANG_MAKE_VECTOR(uint) +SLANG_MAKE_VECTOR(short) +SLANG_MAKE_VECTOR(ushort) +SLANG_MAKE_VECTOR(char) +SLANG_MAKE_VECTOR(uchar) +SLANG_MAKE_VECTOR(float) +SLANG_MAKE_VECTOR(double) +SLANG_MAKE_VECTOR(longlong) +SLANG_MAKE_VECTOR(ulonglong) +#if SLANG_CUDA_ENABLE_HALF +SLANG_MAKE_VECTOR(__half) +#endif -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 convert___half2(const double2& v) { return __float22half2_rn(float2{v.x, v.y}); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 convert___half3(const double3& v) { return __half3{ __float22half2_rn(float2{v.x, v.y}), __float2half_rn(v.z) }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 convert___half4(const double4& v) { return __half4{ __float22half2_rn(float2{v.x, v.y}), __float22half2_rn(float2{v.z, v.w}) }; } +#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) { return T##1{x}; }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) { return make_##T##2(x, x); }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) { return make_##T##3(x, x, x); }\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) { return make_##T##4(x, x, x, x); } +SLANG_MAKE_VECTOR_FROM_SCALAR(int) +SLANG_MAKE_VECTOR_FROM_SCALAR(uint) +SLANG_MAKE_VECTOR_FROM_SCALAR(short) +SLANG_MAKE_VECTOR_FROM_SCALAR(ushort) +SLANG_MAKE_VECTOR_FROM_SCALAR(char) +SLANG_MAKE_VECTOR_FROM_SCALAR(uchar) +SLANG_MAKE_VECTOR_FROM_SCALAR(longlong) +SLANG_MAKE_VECTOR_FROM_SCALAR(ulonglong) +SLANG_MAKE_VECTOR_FROM_SCALAR(float) +SLANG_MAKE_VECTOR_FROM_SCALAR(double) +#if SLANG_CUDA_ENABLE_HALF +SLANG_MAKE_VECTOR_FROM_SCALAR(__half) +#endif -// *** make *** +template<typename T, int n> +struct GetVectorTypeImpl {}; -// Mechanism to make half vectors -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 make___half2(__half x, __half y) { return __halves2half2(x, y); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 make___half3(__half x, __half y, __half z) { return __half3{ __halves2half2(x, y), z }; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 make___half4(__half x, __half y, __half z, __half w) { return __half4{ __halves2half2(x, y), __halves2half2(z, w)}; } +#define GET_VECTOR_TYPE_IMPL(T, n)\ +struct GetVectorTypeImpl<T,n>\ +{\ + typedef T##n type;\ + static SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n fromScalar(T v) { return make_##T##n(v); } \ +}; +#define GET_VECTOR_TYPE_IMPL_N(T)\ + GET_VECTOR_TYPE_IMPL(T, 1)\ + GET_VECTOR_TYPE_IMPL(T, 2)\ + GET_VECTOR_TYPE_IMPL(T, 3)\ + GET_VECTOR_TYPE_IMPL(T, 4) + +GET_VECTOR_TYPE_IMPL_N(int) +GET_VECTOR_TYPE_IMPL_N(uint) +GET_VECTOR_TYPE_IMPL_N(short) +GET_VECTOR_TYPE_IMPL_N(ushort) +GET_VECTOR_TYPE_IMPL_N(char) +GET_VECTOR_TYPE_IMPL_N(uchar) +GET_VECTOR_TYPE_IMPL_N(longlong) +GET_VECTOR_TYPE_IMPL_N(ulonglong) +GET_VECTOR_TYPE_IMPL_N(float) +GET_VECTOR_TYPE_IMPL_N(double) +#if SLANG_CUDA_ENABLE_HALF +GET_VECTOR_TYPE_IMPL_N(__half) +#endif +template<typename T, int n> +using Vector = typename GetVectorTypeImpl<T, n>::type; -// *** constructFromScalar *** +template<typename T, int n, typename OtherT, int m> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other) +{ + Vector<T, n> result; + for (int i = 0; i < n; i++) + { + OtherT otherElement = T(0); + if (i < m) + otherElement = _slang_vector_get_element(other, i); + *_slang_vector_get_element_ptr(&result, i) = (T)otherElement; + } + return result; +} -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 constructFromScalar___half2(half x) { return __half2half2(x); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 constructFromScalar___half3(half x) { return __half3{__half2half2(x), x}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 constructFromScalar___half4(half x) { const __half2 v = __half2half2(x); return __half4{v, v}; } +template <typename T, int ROWS, int COLS> +struct Matrix +{ + Vector<T, COLS> rows[ROWS]; + SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector<T, COLS>& operator[](size_t index) { return rows[index]; } +}; -// *** half2 *** -// half2 maths ops +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(T scalar) +{ + Matrix<T, ROWS, COLS> result; + for (int i = 0; i < ROWS; i++) + result.rows[i] = GetVectorTypeImpl<T, COLS>::fromScalar(scalar); + return result; -// NOTE! That by default these are in cuda_fp16.hpp, but we disable them, because we need to define the comparison operators -// as we need versions that will return vector<bool> +} -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, const __half2& rh) { return __hadd2(lh, rh); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& lh, const __half2& rh) { return __hsub2(lh, rh); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(const __half2& lh, const __half2& rh) { return __hmul2(lh, rh); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(const __half2& lh, const __half2& rh) { return __h2div(lh, rh); } +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(const Vector<T, COLS>& row0) +{ + Matrix<T, ROWS, COLS> result; + result.rows[0] = row0; + return result; +} -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator+=(__half2& lh, const __half2& rh) { lh = __hadd2(lh, rh); return lh; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator-=(__half2& lh, const __half2& rh) { lh = __hsub2(lh, rh); return lh; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator*=(__half2& lh, const __half2& rh) { lh = __hmul2(lh, rh); return lh; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator/=(__half2& lh, const __half2& rh) { lh = __h2div(lh, rh); return lh; } +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1) +{ + Matrix<T, ROWS, COLS> result; + result.rows[0] = row0; + result.rows[1] = row1; + return result; +} -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 &operator++(__half2 &h) { __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hadd2(h, one); return h; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 &operator--(__half2 &h) { __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hsub2(h, one); return h; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator++(__half2 &h, int) { __half2 ret = h; __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hadd2(h, one); return ret; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator--(__half2 &h, int) { __half2 ret = h; __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hsub2(h, one); return ret; } +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2) +{ + Matrix<T, ROWS, COLS> result; + result.rows[0] = row0; + result.rows[1] = row1; + result.rows[2] = row2; + return result; +} -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2 &h) { return h; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2 &h) { return __hneg2(h); } +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3) +{ + Matrix<T, ROWS, COLS> result; + result.rows[0] = row0; + result.rows[1] = row1; + result.rows[2] = row2; + result.rows[3] = row3; + return result; +} -// vec op scalar -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, __half rh) { return __hadd2(lh, __half2half2(rh)); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& lh, __half rh) { return __hsub2(lh, __half2half2(rh)); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(const __half2& lh, __half rh) { return __hmul2(lh, __half2half2(rh)); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(const __half2& lh, __half rh) { return __h2div(lh, __half2half2(rh)); } +template<typename T, int ROWS, int COLS, typename U, int otherRow, int otherCol> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(const Matrix<U, otherRow, otherCol>& other) +{ + Matrix<T, ROWS, COLS> result; + int minRow = ROWS; + int minCol = COLS; + if (minRow > otherRow) minRow = otherRow; + if (minCol > otherCol) minCol = otherCol; + for (int i = 0; i < minRow; i++) + for (int j = 0; j < minCol; j++) + *_slang_vector_get_element_ptr(result.rows + i, j) = (T)_slang_vector_get_element(other.rows[i], j); + return result; +} -// scalar op vec -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(__half lh, const __half2& rh) { return __hadd2(__half2half2(lh), rh); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(__half lh, const __half2& rh) { return __hsub2(__half2half2(lh), rh); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(__half lh, const __half2& rh) { return __hmul2(__half2half2(lh), rh); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(__half lh, const __half2& rh) { return __h2div(__half2half2(lh), rh); } +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(T v0, T v1, T v2, T v3) +{ + Matrix<T, ROWS, COLS> rs; + rs.rows[0].x = v0; rs.rows[0].y = v1; + rs.rows[1].x = v2; rs.rows[1].y = v3; + return rs; +} -// *** half3 *** +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(T v0, T v1, T v2, T v3, T v4, T v5) +{ + Matrix<T, ROWS, COLS> rs; + if (COLS == 3) + { + rs.rows[0].x = v0; rs.rows[0].y = v1; rs.rows[0].z = v2; + rs.rows[1].x = v3; rs.rows[1].y = v4; rs.rows[1].z = v5; + } + else + { + rs.rows[0].x = v0; rs.rows[0].y = v1; + rs.rows[1].x = v2; rs.rows[1].y = v3; + rs.rows[2].x = v4; rs.rows[2].y = v5; + } + return rs; -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator+(const __half3& lh, const __half3& rh) { return __half3{__hadd2(lh.xy, rh.xy), __hadd(lh.z, rh.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(const __half3& lh, const __half3& rh) { return __half3{__hsub2(lh.xy, rh.xy), __hsub(lh.z, rh.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator*(const __half3& lh, const __half3& rh) { return __half3{__hmul2(lh.xy, rh.xy), __hmul(lh.z, rh.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator/(const __half3& lh, const __half3& rh) { return __half3{__h2div(lh.xy, rh.xy), __hdiv(lh.z, rh.z)}; } +} -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(const __half3& h) { return __half3{__hneg2(h.xy), __hneg(h.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator+(const __half3& h) { return h; } +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) +{ + Matrix<T, ROWS, COLS> rs; + if (COLS == 4) + { + rs.rows[0].x = v0; rs.rows[0].y = v1; rs.rows[0].z = v2; rs.rows[0].w = v3; + rs.rows[1].x = v4; rs.rows[1].y = v5; rs.rows[1].z = v6; rs.rows[1].w = v7; + } + else + { + rs.rows[0].x = v0; rs.rows[0].y = v1; + rs.rows[1].x = v2; rs.rows[1].y = v3; + rs.rows[2].x = v4; rs.rows[2].y = v5; + rs.rows[3].x = v6; rs.rows[3].y = v7; + } + return rs; +} -// vec op scalar -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator+(const __half3& lh, __half rh) { return __half3{__hadd2(lh.xy, __half2half2(rh)), __hadd(lh.z, rh)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(const __half3& lh, __half rh) { return __half3{__hsub2(lh.xy, __half2half2(rh)), __hsub(lh.z, rh)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator*(const __half3& lh, __half rh) { return __half3{__hmul2(lh.xy, __half2half2(rh)), __hmul(lh.z, rh)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator/(const __half3& lh, __half rh) { return __half3{__h2div(lh.xy, __half2half2(rh)), __hdiv(lh.z, rh)}; } +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8) +{ + Matrix<T, ROWS, COLS> rs; + rs.rows[0].x = v0; rs.rows[0].y = v1; rs.rows[0].z = v2; + rs.rows[1].x = v3; rs.rows[1].y = v4; rs.rows[1].z = v5; + rs.rows[2].x = v6; rs.rows[2].y = v7; rs.rows[2].z = v8; + return rs; +} -// scalar op vec -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator+(__half lh, const __half3& rh) { return __half3{__hadd2(__half2half2(lh), rh.xy), __hadd(lh, rh.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(__half lh, const __half3& rh) { return __half3{__hsub2(__half2half2(lh), rh.xy), __hsub(lh, rh.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator*(__half lh, const __half3& rh) { return __half3{__hmul2(__half2half2(lh), rh.xy), __hmul(lh, rh.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator/(__half lh, const __half3& rh) { return __half3{__h2div(__half2half2(lh), rh.xy), __hdiv(lh, rh.z)}; } +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11) +{ + Matrix<T, ROWS, COLS> rs; + if (COLS == 4) + { + rs.rows[0].x = v0; rs.rows[0].y = v1; rs.rows[0].z = v2; rs.rows[0].w = v3; + rs.rows[1].x = v4; rs.rows[1].y = v5; rs.rows[1].z = v6; rs.rows[1].w = v7; + rs.rows[2].x = v8; rs.rows[2].y = v9; rs.rows[2].z = v10; rs.rows[2].w = v11; + } + else + { + rs.rows[0].x = v0; rs.rows[0].y = v1; rs.rows[0].z = v2; + rs.rows[1].x = v3; rs.rows[1].y = v4; rs.rows[1].z = v5; + rs.rows[2].x = v6; rs.rows[2].y = v7; rs.rows[2].z = v8; + rs.rows[3].x = v9; rs.rows[3].y = v10; rs.rows[3].z = v11; + } + return rs; +} -// *** half4 *** +template<typename T, int ROWS, int COLS> +SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) +{ + Matrix<T, ROWS, COLS> rs; + rs.rows[0].x = v0; rs.rows[0].y = v1; rs.rows[0].z = v2; rs.rows[0].w = v3; + rs.rows[1].x = v4; rs.rows[1].y = v5; rs.rows[1].z = v6; rs.rows[1].w = v7; + rs.rows[2].x = v8; rs.rows[2].y = v9; rs.rows[2].z = v10; rs.rows[2].w = v11; + rs.rows[3].x = v12; rs.rows[3].y = v13; rs.rows[3].z = v14; rs.rows[3].w = v15; + return rs; +} -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& lh, const __half4& rh) { return __half4{__hadd2(lh.xy, rh.xy), __hadd2(lh.zw, rh.zw)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& lh, const __half4& rh) { return __half4{__hsub2(lh.xy, rh.xy), __hsub2(lh.zw, rh.zw)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator*(const __half4& lh, const __half4& rh) { return __half4{__hmul2(lh.xy, rh.xy), __hmul2(lh.zw, rh.zw)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator/(const __half4& lh, const __half4& rh) { return __half4{__h2div(lh.xy, rh.xy), __h2div(lh.zw, rh.zw)}; } +#define SLANG_MATRIX_BINARY_OP(T, op) \ + template<int R, int C> \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + *_slang_vector_get_element_ptr(result.rows+i,j) = _slang_vector_get_element(thisVal.rows[i], j) op _slang_vector_get_element(other.rows[i], j); \ + return result;\ + } -// vec op scalar -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& lh, __half rh) { const __half2 rhv = __half2half2(rh); return __half4{__hadd2(lh.xy, rhv), __hadd2(lh.zw, rhv)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& lh, __half rh) { const __half2 rhv = __half2half2(rh); return __half4{__hsub2(lh.xy, rhv), __hsub2(lh.zw, rhv)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator*(const __half4& lh, __half rh) { const __half2 rhv = __half2half2(rh); return __half4{__hmul2(lh.xy, rhv), __hmul2(lh.zw, rhv)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator/(const __half4& lh, __half rh) { const __half2 rhv = __half2half2(rh); return __half4{__h2div(lh.xy, rhv), __h2div(lh.zw, rhv)}; } +#define SLANG_MATRIX_UNARY_OP(T, op) \ + template<int R, int C> \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + *_slang_vector_get_element_ptr(result.rows+i,j) = op _slang_vector_get_element(thisVal.rows[i], j); \ + return result;\ + } +#define SLANG_INT_MATRIX_OPS(T) \ + SLANG_MATRIX_BINARY_OP(T, +)\ + SLANG_MATRIX_BINARY_OP(T, -)\ + SLANG_MATRIX_BINARY_OP(T, *)\ + SLANG_MATRIX_BINARY_OP(T, / )\ + SLANG_MATRIX_BINARY_OP(T, &)\ + SLANG_MATRIX_BINARY_OP(T, |)\ + SLANG_MATRIX_BINARY_OP(T, &&)\ + SLANG_MATRIX_BINARY_OP(T, ||)\ + SLANG_MATRIX_BINARY_OP(T, ^)\ + SLANG_MATRIX_BINARY_OP(T, %)\ + SLANG_MATRIX_UNARY_OP(T, !)\ + SLANG_MATRIX_UNARY_OP(T, ~) +#define SLANG_FLOAT_MATRIX_OPS(T) \ + SLANG_MATRIX_BINARY_OP(T, +)\ + SLANG_MATRIX_BINARY_OP(T, -)\ + SLANG_MATRIX_BINARY_OP(T, *)\ + SLANG_MATRIX_BINARY_OP(T, /)\ + SLANG_MATRIX_UNARY_OP(T, -) +SLANG_INT_MATRIX_OPS(int) +SLANG_INT_MATRIX_OPS(uint) +SLANG_INT_MATRIX_OPS(short) +SLANG_INT_MATRIX_OPS(ushort) +SLANG_INT_MATRIX_OPS(char) +SLANG_INT_MATRIX_OPS(uchar) +SLANG_INT_MATRIX_OPS(longlong) +SLANG_INT_MATRIX_OPS(ulonglong) +SLANG_FLOAT_MATRIX_OPS(float) +SLANG_FLOAT_MATRIX_OPS(double) +#if SLANG_CUDA_ENABLE_HALF +SLANG_FLOAT_MATRIX_OPS(__half) +#endif +#define SLANG_MATRIX_INT_NEG_OP(T) \ + template<int R, int C>\ + SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + *_slang_vector_get_element_ptr(result.rows+i,j) = 0 - _slang_vector_get_element(thisVal.rows[i], j); \ + return result;\ + } + SLANG_MATRIX_INT_NEG_OP(int) + SLANG_MATRIX_INT_NEG_OP(uint) + SLANG_MATRIX_INT_NEG_OP(short) + SLANG_MATRIX_INT_NEG_OP(ushort) + SLANG_MATRIX_INT_NEG_OP(char) + SLANG_MATRIX_INT_NEG_OP(uchar) + SLANG_MATRIX_INT_NEG_OP(longlong) + SLANG_MATRIX_INT_NEG_OP(ulonglong) + +#define SLANG_FLOAT_MATRIX_MOD(T)\ + template<int R, int C> \ + SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \ + {\ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + *_slang_vector_get_element_ptr(result.rows+i,j) = _slang_fmod(_slang_vector_get_element(left.rows[i], j), _slang_vector_get_element(right.rows[i], j)); \ + return result;\ + } -// scalar op vec -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(__half lh, const __half4& rh) { const __half2 lhv = __half2half2(lh); return __half4{__hadd2(lhv, rh.xy), __hadd2(lhv, rh.zw)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(__half lh, const __half4& rh) { const __half2 lhv = __half2half2(lh); return __half4{__hsub2(lhv, rh.xy), __hsub2(lhv, rh.zw)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator*(__half lh, const __half4& rh) { const __half2 lhv = __half2half2(lh); return __half4{__hmul2(lhv, rh.xy), __hmul2(lhv, rh.zw)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator/(__half lh, const __half4& rh) { const __half2 lhv = __half2half2(lh); return __half4{__h2div(lhv, rh.xy), __h2div(lhv, rh.zw)}; } + SLANG_FLOAT_MATRIX_MOD(float) + SLANG_FLOAT_MATRIX_MOD(double) +#if SLANG_CUDA_ENABLE_HALF + template<int R, int C> + SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<__half, R, C> operator%(Matrix<__half, R, C> left, Matrix<__half, R, C> right) + { + Matrix<__half, R, C> result; + for (int i = 0; i < R; i++) + for (int j = 0; j < C; j++) + * _slang_vector_get_element_ptr(result.rows + i, j) = __float2half(_slang_fmod(__half2float(_slang_vector_get_element(left.rows[i], j)), __half2float(_slang_vector_get_element(right.rows[i], j)))); + return result; + } +#endif +#undef SLANG_FLOAT_MATRIX_MOD +#undef SLANG_MATRIX_BINARY_OP +#undef SLANG_MATRIX_UNARY_OP +#undef SLANG_INT_MATRIX_OPS +#undef SLANG_FLOAT_MATRIX_OPS +#undef SLANG_MATRIX_INT_NEG_OP +#undef SLANG_FLOAT_MATRIX_MOD -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& h) { return __half4{__hneg2(h.xy), __hneg2(h.zw)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& h) { return h; } +// +// Half support +// +#if SLANG_CUDA_ENABLE_HALF // Convenience functions ushort -> half SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __ushort_as_half(const ushort2& i) { return __halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y)); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 __ushort_as_half(const ushort3& i) { return __half3{__halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y)), __ushort_as_half(i.z)}; } -SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 __ushort_as_half(const ushort4& i) { return __half4{ __halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y)), __halves2half2(__ushort_as_half(i.z), __ushort_as_half(i.w)) }; } +SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 __ushort_as_half(const ushort3& i) { return __half3{__ushort_as_half(i.x), __ushort_as_half(i.y), __ushort_as_half(i.z)}; } +SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 __ushort_as_half(const ushort4& i) { return __half4{ __ushort_as_half(i.x), __ushort_as_half(i.y), __ushort_as_half(i.z), __ushort_as_half(i.w) }; } // Convenience functions half -> ushort SLANG_FORCE_INLINE SLANG_CUDA_CALL ushort2 __half_as_ushort(const __half2& i) { return make_ushort2(__half_as_ushort(i.x), __half_as_ushort(i.y)); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL ushort3 __half_as_ushort(const __half3& i) { return make_ushort3(__half_as_ushort(i.xy.x), __half_as_ushort(i.xy.y), __half_as_ushort(i.z)); } -SLANG_FORCE_INLINE SLANG_CUDA_CALL ushort4 __half_as_ushort(const __half4& i) { return make_ushort4(__half_as_ushort(i.xy.x), __half_as_ushort(i.xy.y), __half_as_ushort(i.zw.x), __half_as_ushort(i.zw.y)); } +SLANG_FORCE_INLINE SLANG_CUDA_CALL ushort3 __half_as_ushort(const __half3& i) { return make_ushort3(__half_as_ushort(i.x), __half_as_ushort(i.y), __half_as_ushort(i.z)); } +SLANG_FORCE_INLINE SLANG_CUDA_CALL ushort4 __half_as_ushort(const __half4& i) { return make_ushort4(__half_as_ushort(i.x), __half_as_ushort(i.y), __half_as_ushort(i.z), __half_as_ushort(i.w)); } // This is a little bit of a hack. Fortunately CUDA has the definitions of the templated types in // include/surface_indirect_functions.h @@ -438,7 +796,7 @@ template <> \ SLANG_FORCE_INLINE SLANG_CUDA_CALL float4 FUNC_NAME##_convert<float4>(cudaSurfaceObject_t surfObj, SLANG_DROP_PARENS TYPE_ARGS, cudaSurfaceBoundaryMode boundaryMode) \ { \ const __half4 v = __ushort_as_half(FUNC_NAME<ushort4>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \ - return float4{v.xy.x, v.xy.y, v.zw.x, v.zw.y}; \ + return float4{v.x, v.y, v.z, v.w}; \ } SLANG_SURFACE_READ_HALF_CONVERT(surf1Dread, (int x), (x)) diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 2a8344e3a..6357d58bd 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -786,6 +786,8 @@ __generic<T = float, let R : int = 4, let C : int = 4> __magic_type(Matrix) struct matrix { + __intrinsic_op($(kIROp_MakeMatrixFromScalar)) + __init(T val); } ${{{{ @@ -1093,9 +1095,6 @@ extension matrix<T, R, C> : IDifferentiable { typedef matrix<T, R, C> Differential; - __intrinsic_op($(kIROp_MakeMatrixFromScalar)) - __init(T val); - [__unsafeForceInlineEarly] static Differential dzero() { diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 464811a96..1d2b327d2 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -872,36 +872,31 @@ matrix<T, N, M> acos(matrix<T, N, M> x) // Test if all components are non-zero (HLSL SM 1.0) __generic<T : __BuiltinType> +__target_intrinsic(cpp, "bool($0)") +__target_intrinsic(cuda, "bool($0)") __target_intrinsic(glsl, "bool($0)") bool all(T x); __generic<T : __BuiltinType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "all(bvec$N0($0))") -bool all(vector<T,N> x); -// TODO: implementation of `all()` in the stdlib is -// blocked on fixing implementation of `bool` vector -// `getAt` on the CUDA codegen path. -/* +bool all(vector<T,N> x) { bool result = true; for(int i = 0; i < N; ++i) result = result && all(x[i]); return result; } -*/ __generic<T : __BuiltinType, let N : int, let M : int> __target_intrinsic(hlsl) -bool all(matrix<T,N,M> x); -/* +bool all(matrix<T,N,M> x) { bool result = true; for(int i = 0; i < N; ++i) result = result && all(x[i]); return result; } -*/ // Barrier for writes to all memory spaces (HLSL SM 5.0) __target_intrinsic(glsl, "memoryBarrier(), groupMemoryBarrier(), memoryBarrierImage(), memoryBarrierBuffer()") @@ -916,42 +911,39 @@ void AllMemoryBarrierWithGroupSync(); // Test if any components is non-zero (HLSL SM 1.0) __generic<T : __BuiltinType> +__target_intrinsic(cpp, "bool($0)") +__target_intrinsic(cuda, "bool($0)") __target_intrinsic(glsl, "bool($0)") bool any(T x); __generic<T : __BuiltinType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "any(bvec$N0($0))") -bool any(vector<T, N> x); -// TODO: implementation of `any()` in the stdlib is -// blocked on fixing implementation of `bool` vector -// `getAt` on the CUDA codegen path. -/* +bool any(vector<T, N> x) { bool result = false; for(int i = 0; i < N; ++i) result = result || any(x[i]); return result; } -*/ __generic<T : __BuiltinType, let N : int, let M : int> __target_intrinsic(hlsl) -bool any(matrix<T, N, M> x); -/* +bool any(matrix<T, N, M> x) { bool result = false; for(int i = 0; i < N; ++i) result = result || any(x[i]); return result; } -*/ // Reinterpret bits as a double (HLSL SM 5.0) __target_intrinsic(hlsl) __target_intrinsic(glsl, "packDouble2x32(uvec2($0, $1))") +__target_intrinsic(cpp, "$P_asdouble($0, $1)") +__target_intrinsic(cuda, "$P_asdouble($0, $1)") __target_intrinsic(spirv_direct, "%v = OpCompositeConstruct _type(uint2) resultId _0 _1; OpExtInst resultType resultId glsl450 59 %v") __glsl_extension(GL_ARB_gpu_shader5) double asdouble(uint lowbits, uint highbits); @@ -960,11 +952,15 @@ double asdouble(uint lowbits, uint highbits); __target_intrinsic(hlsl) __target_intrinsic(glsl, "intBitsToFloat") +__target_intrinsic(cpp, "$P_asfloat($0)") +__target_intrinsic(cuda, "$P_asfloat($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") float asfloat(int x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "uintBitsToFloat") +__target_intrinsic(cpp, "$P_asfloat($0)") +__target_intrinsic(cuda, "$P_asfloat($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") float asfloat(uint x); @@ -1044,11 +1040,15 @@ matrix<T, N, M> asin(matrix<T, N, M> x) __target_intrinsic(hlsl) __target_intrinsic(glsl, "floatBitsToInt") +__target_intrinsic(cpp, "$P_asint($0)") +__target_intrinsic(cuda, "$P_asint($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") int asint(float x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "int($0)") +__target_intrinsic(cpp, "$P_asint($0)") +__target_intrinsic(cuda, "$P_asint($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") int asint(uint x); @@ -1104,6 +1104,8 @@ matrix<int,N,M> asint(matrix<int,N,M> x) __target_intrinsic(hlsl) __target_intrinsic(glsl, "{ uvec2 v = unpackDouble2x32($0); $1 = v.x; $2 = v.y; }") __glsl_extension(GL_ARB_gpu_shader5) +__target_intrinsic(cpp, "$P_asuint($0, $1, $2)") +__target_intrinsic(cuda, "$P_asuint($0, $1, $2)") void asuint(double value, out uint lowbits, out uint highbits); // Reinterpret bits as a uint (HLSL SM 4.0) @@ -1111,11 +1113,15 @@ void asuint(double value, out uint lowbits, out uint highbits); __target_intrinsic(hlsl) __target_intrinsic(glsl, "floatBitsToUint") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") +__target_intrinsic(cpp, "$P_asuint($0)") +__target_intrinsic(cuda, "$P_asuint($0)") uint asuint(float x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "uint($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") +__target_intrinsic(cpp, "$P_asuint($0)") +__target_intrinsic(cuda, "$P_asuint($0)") uint asuint(int x); __generic<let N : int> @@ -1812,7 +1818,7 @@ __target_intrinsic(glsl, "unpackHalf2x16($0).x") __glsl_version(420) __target_intrinsic(hlsl) __cuda_sm_version(6.0) -__target_intrinsic(cuda, "__half2float(__short_as_half($0))") +__target_intrinsic(cuda, "__half2float(__ushort_as_half($0))") float f16tof32(uint value); __generic<let N : int> diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 87b620ed2..ba6b26ec6 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -66,111 +66,6 @@ namespace Slang { static const char s_xyzwNames[] = "xyzw"; -static UnownedStringSlice _getTypePrefix(IROp op) -{ - switch (op) - { - case kIROp_BoolType: return UnownedStringSlice::fromLiteral("Bool"); - case kIROp_IntType: return UnownedStringSlice::fromLiteral("I32"); - case kIROp_UIntType: return UnownedStringSlice::fromLiteral("U32"); - case kIROp_FloatType: return UnownedStringSlice::fromLiteral("F32"); - case kIROp_Int64Type: return UnownedStringSlice::fromLiteral("I64"); - case kIROp_UInt64Type: return UnownedStringSlice::fromLiteral("U64"); - case kIROp_DoubleType: return UnownedStringSlice::fromLiteral("F64"); - default: return UnownedStringSlice::fromLiteral("?"); - } -} - - -static IROp _getCType(IROp op) -{ - switch (op) - { - case kIROp_VoidType: - case kIROp_BoolType: - { - return op; - } - case kIROp_Int8Type: - case kIROp_Int16Type: - case kIROp_IntType: - case kIROp_UInt8Type: - case kIROp_UInt16Type: - case kIROp_UIntType: - { - // Promote all these to Int - return kIROp_IntType; - } - case kIROp_IntPtrType: - case kIROp_UIntPtrType: - { - return kIROp_IntPtrType; - } - case kIROp_Int64Type: - case kIROp_UInt64Type: - { - // Promote all these to Int64, we can just vary the call to make these work - return kIROp_Int64Type; - } - case kIROp_DoubleType: - { - return kIROp_DoubleType; - } - case kIROp_HalfType: - case kIROp_FloatType: - { - // Promote both to float - return kIROp_FloatType; - } - default: - { - SLANG_ASSERT(!"Unhandled type"); - return kIROp_undefined; - } - } -} - -static UnownedStringSlice _getCTypeVecPostFix(IROp op) -{ - switch (op) - { - case kIROp_BoolType: return UnownedStringSlice::fromLiteral("B"); - case kIROp_IntType: return UnownedStringSlice::fromLiteral("I"); - case kIROp_UIntType: return UnownedStringSlice::fromLiteral("U"); - case kIROp_FloatType: return UnownedStringSlice::fromLiteral("F"); - case kIROp_Int64Type: return UnownedStringSlice::fromLiteral("I64"); - case kIROp_DoubleType: return UnownedStringSlice::fromLiteral("F64"); - case kIROp_IntPtrType: return UnownedStringSlice::fromLiteral(""); - case kIROp_UIntPtrType: return UnownedStringSlice::fromLiteral(""); - default: return UnownedStringSlice::fromLiteral("?"); - } -} - -static bool _isCppTarget(CodeGenTarget target) -{ - switch (target) - { - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - return true; - default: - return false; - } -} - -static bool _isCppOrCudaTarget(CodeGenTarget target) -{ - switch (target) - { - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - case CodeGenTarget::CUDASource: - return true; - default: - return false; - } -} - /* !!!!!!!!!!!!!!!!!!!!!!!! CPPEmitHandler !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ /* static */ UnownedStringSlice CPPSourceEmitter::getBuiltinTypeName(IROp op) @@ -204,118 +99,8 @@ static bool _isCppOrCudaTarget(CodeGenTarget target) } } -void CPPSourceEmitter::emitTypeDefinition(IRType* inType) +UnownedStringSlice CPPSourceEmitter::_getTypeName(IRType* type) { - if (_isCppTarget(m_target)) - { - // All types are templates in C++ - return; - } - - IRType* type = m_typeSet.getType(inType); - if (!m_typeSet.isOwned(type)) - { - // If defined in a different module, we assume they are emitted already. (Assumed to - // be a nominal type) - return; - } - - SourceWriter* writer = getSourceWriter(); - - switch (type->getOp()) - { - case kIROp_VectorType: - { - auto vecType = static_cast<IRVectorType*>(type); - - const UnownedStringSlice* elemNames = getVectorElementNames(vecType); - - int count = int(getIntVal(vecType->getElementCount())); - - SLANG_ASSERT(count > 0 && count < 4); - - UnownedStringSlice typeName = _getTypeName(type); - UnownedStringSlice elemName = _getTypeName(vecType->getElementType()); - - writer->emit("struct "); - writer->emit(typeName); - writer->emit("\n{\n"); - writer->indent(); - - writer->emit(elemName); - writer->emit(" "); - for (int i = 0; i < count; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - writer->emit(elemNames[i]); - } - writer->emit(";\n"); - - writer->dedent(); - writer->emit("};\n\n"); - break; - } - case kIROp_MatrixType: - { - auto matType = static_cast<IRMatrixType*>(type); - - const auto rowCount = int(getIntVal(matType->getRowCount())); - const auto colCount = int(getIntVal(matType->getColumnCount())); - - IRType* vecType = m_typeSet.addVectorType(matType->getElementType(), colCount); - - UnownedStringSlice typeName = _getTypeName(type); - UnownedStringSlice rowTypeName = _getTypeName(vecType); - - writer->emit("template<>\n"); - writer->emit("struct "); - writer->emit(typeName); - writer->emit("\n{\n"); - writer->indent(); - - writer->emit(rowTypeName); - writer->emit(" rows["); - writer->emit(rowCount); - writer->emit("];\n"); - - writer->dedent(); - writer->emit("};\n\n"); - break; - } - case kIROp_PtrType: - case kIROp_RefType: - { - // We don't need to output a definition for these types - break; - } - case kIROp_ArrayType: - case kIROp_UnsizedArrayType: - case kIROp_HLSLRWStructuredBufferType: - { - // We don't need to output a definition for these with C++ templates - // For C we may need to (or do casting at point of usage) - break; - } - default: - { - if (IRBasicType::isaImpl(type->getOp())) - { - // Don't emit anything for built in types - return; - } - SLANG_ASSERT(!"Unhandled type"); - break; - } - } -} - -UnownedStringSlice CPPSourceEmitter::_getTypeName(IRType* inType) -{ - IRType* type = m_typeSet.getType(inType); - StringSlicePool::Handle handle = StringSlicePool::kNullHandle; if (m_typeNameMap.TryGetValue(type, handle)) { @@ -424,22 +209,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S auto vecCount = int(getIntVal(vecType->getElementCount())); auto elemType = vecType->getElementType(); - if (_isCppOrCudaTarget(target)) - { - out << "Vector<" << _getTypeName(elemType) << ", " << vecCount << ">"; - } - else - { - out << "Vec"; - UnownedStringSlice postFix = _getCTypeVecPostFix(elemType->getOp()); - - out << postFix; - if (postFix.getLength() > 1) - { - out << "_"; - } - out << vecCount; - } + out << "Vector<" << _getTypeName(elemType) << ", " << vecCount << ">"; return SLANG_OK; } case kIROp_MatrixType: @@ -450,22 +220,8 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S const auto rowCount = int(getIntVal(matType->getRowCount())); const auto colCount = int(getIntVal(matType->getColumnCount())); - if (_isCppOrCudaTarget(target)) - { - out << "Matrix<" << _getTypeName(elementType) << ", " << rowCount << ", " << colCount << ">"; - } - else - { - out << "Mat"; - const UnownedStringSlice postFix = _getCTypeVecPostFix(_getCType(elementType->getOp())); - out << postFix; - if (postFix.getLength() > 1) - { - out << "_"; - } - out << rowCount; - out << colCount; - } + out << "Matrix<" << _getTypeName(elementType) << ", " << rowCount << ", " << colCount << ">"; + return SLANG_OK; } case kIROp_WitnessTableType: @@ -625,17 +381,6 @@ void CPPSourceEmitter::useType(IRType* type) _getTypeName(type); } -static IRBasicType* _getElementType(IRType* type) -{ - switch (type->getOp()) - { - case kIROp_VectorType: type = static_cast<IRVectorType*>(type)->getElementType(); break; - case kIROp_MatrixType: type = static_cast<IRMatrixType*>(type)->getElementType(); break; - default: break; - } - return dynamicCast<IRBasicType>(type); -} - /* static */CPPSourceEmitter::TypeDimension CPPSourceEmitter::_getTypeDimension(IRType* type, bool vecSwap) { switch (type->getOp()) @@ -735,943 +480,11 @@ void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDim } } -static bool _isOperator(const UnownedStringSlice& funcName) -{ - if (funcName.getLength() > 0) - { - const char c = funcName[0]; - return !((c >= 'a' && c <='z') || (c >= 'A' && c <= 'Z') || c == '_'); - } - return false; -} - -void CPPSourceEmitter::_emitAryDefinition(const HLSLIntrinsic* specOp) -{ - auto info = HLSLIntrinsic::getInfo(specOp->op); - auto funcName = info.funcName; - SLANG_ASSERT(funcName.getLength() > 0); - - const bool isOperator = _isOperator(funcName); - - SourceWriter* writer = getSourceWriter(); - - IRFuncType* funcType = specOp->signatureType; - const int numParams = int(funcType->getParamCount()); - SLANG_ASSERT(numParams <= 3); - - bool areAllScalar = true; - TypeDimension paramDims[3]; - for (int i = 0; i < numParams; ++i) - { - paramDims[i]= _getTypeDimension(funcType->getParamType(i), false); - areAllScalar = areAllScalar && paramDims[i].isScalar(); - } - - // If all are scalar, then we don't need to emit a definition - if (areAllScalar) - { - return; - } - - IRType* retType = specOp->returnType; - - UnownedStringSlice scalarFuncName(funcName); - if (isOperator) - { - StringBuilder builder; - builder << "operator"; - builder << funcName; - _emitSignature(builder.getUnownedSlice(), specOp); - } - else - { - scalarFuncName = _getScalarFuncName(specOp->op, _getElementType(funcType->getParamType(0))); - _emitSignature(funcName, specOp); - } - - writer->emit("\n{\n"); - writer->indent(); - - const bool hasReturnType = retType->getOp() != kIROp_VoidType; - - TypeDimension calcDim; - if (hasReturnType) - { - emitType(retType); - writer->emit(" r;\n"); - - calcDim = _getTypeDimension(retType, false); - } - else - { - calcDim = _getTypeDimension(funcType->getParamType(0), false); - } - - for (int i = 0; i < calcDim.rowCount; ++i) - { - for (int j = 0; j < calcDim.colCount; ++j) - { - if (hasReturnType) - { - _emitAccess(UnownedStringSlice::fromLiteral("r"), calcDim, i, j, writer); - writer->emit(" = "); - } - - if (isOperator) - { - switch (numParams) - { - case 1: - { - writer->emit(funcName); - _emitAccess(UnownedStringSlice::fromLiteral("a"), paramDims[0], i, j, writer); - break; - } - case 2: - { - _emitAccess(UnownedStringSlice::fromLiteral("a"), paramDims[0], i, j, writer); - writer->emit(" "); - writer->emit(funcName); - writer->emit(" "); - _emitAccess(UnownedStringSlice::fromLiteral("b"), paramDims[1], i, j, writer); - break; - } - default: SLANG_ASSERT(!"Unhandled"); - } - } - else - { - writer->emit(scalarFuncName); - writer->emit("("); - for (int k = 0; k < numParams; k++) - { - if (k > 0) - { - writer->emit(", "); - } - char c = char('a' + k); - _emitAccess(UnownedStringSlice(&c, 1), paramDims[k], i, j, writer); - } - writer->emit(")"); - } - writer->emit(";\n"); - } - } - - if (hasReturnType) - { - writer->emit("return r;\n"); - } - - writer->dedent(); - writer->emit("}\n\n"); -} - -void CPPSourceEmitter::_emitAnyAllDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - IRFuncType* funcType = specOp->signatureType; - SLANG_ASSERT(funcType->getParamCount() == 1); - IRType* paramType0 = funcType->getParamType(0); - - SourceWriter* writer = getSourceWriter(); - - IRType* elementType = _getElementType(paramType0); - SLANG_ASSERT(elementType); - IRType* retType = specOp->returnType; - auto retTypeName = _getTypeName(retType); - - IROp style = getTypeStyle(elementType->getOp()); - - const TypeDimension dim = _getTypeDimension(paramType0, false); - - _emitSignature(funcName, specOp); - writer->emit("\n{\n"); - writer->indent(); - - writer->emit("return "); - - for (int i = 0; i < dim.rowCount; ++i) - { - for (int j = 0; j < dim.colCount; ++j) - { - if (i > 0 || j > 0) - { - if (specOp->op == HLSLIntrinsic::Op::All) - { - writer->emit(" && "); - } - else - { - writer->emit(" || "); - } - } - - switch (style) - { - case kIROp_BoolType: - { - _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer); - break; - } - case kIROp_IntType: - { - writer->emit("("); - _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer); - writer->emit(" != 0)"); - break; - } - case kIROp_FloatType: - { - writer->emit("("); - _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer); - writer->emit(" != 0.0)"); - break; - } - } - } - } - - writer->emit(";\n"); - - writer->dedent(); - writer->emit("}\n\n"); -} - -void CPPSourceEmitter::_emitSignature(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - IRFuncType* funcType = specOp->signatureType; - const int paramsCount = int(funcType->getParamCount()); - IRType* retType = specOp->returnType; - - emitFunctionPreambleImpl(nullptr); - - SourceWriter* writer = getSourceWriter(); - - emitType(retType); - writer->emit(" "); - writer->emit(funcName); - writer->emit("("); - - for (int i = 0; i < paramsCount; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - - // We can't pass as const& for vector, scalar, array types, as they are pass by value - // For types passed by reference, we should do something different - IRType* paramType = funcType->getParamType(i); -#if 0 - writer->emit("const "); -#endif - emitType(paramType); -#if 0 - if (dynamicCast<IRBasicType>(paramType)) - { - writer->emit(" "); - } - else - { - writer->emit("& "); - } -#else - - writer->emit(" "); -#endif - - writer->emitChar(char('a' + i)); - } - writer->emit(")"); -} - -UnownedStringSlice CPPSourceEmitter::_getAndEmitSpecializedOperationDefinition(HLSLIntrinsic::Op op, IRType*const* argTypes, Int argCount, IRType* retType) -{ - HLSLIntrinsic intrinsic; - m_intrinsicSet.calcIntrinsic(op, retType, argTypes, argCount, intrinsic); - auto specOp = m_intrinsicSet.add(intrinsic); - _maybeEmitSpecializedOperationDefinition(specOp); - return _getFuncName(specOp); -} - -void CPPSourceEmitter::_emitGetAtDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - SourceWriter* writer = getSourceWriter(); - - IRFuncType* funcType = specOp->signatureType; - SLANG_ASSERT(funcType->getParamCount() == 2); - - IRType* srcType = funcType->getParamType(0); - - for (Index i = 0; i < 3; ++i) - { - UnownedStringSlice typePrefix = (i == 0) ? UnownedStringSlice::fromLiteral("const ") : UnownedStringSlice(); - bool lValue = (i != 2); - - emitFunctionPreambleImpl(nullptr); - - writer->emit(typePrefix); - emitType(specOp->returnType); - if (lValue) - m_writer->emit("*"); - writer->emit(" "); - writer->emit(funcName); - writer->emit("("); - - writer->emit(typePrefix); - emitType(funcType->getParamType(0)); - if (lValue) - writer->emit("*"); - writer->emit(" a, "); - emitType(funcType->getParamType(1)); - writer->emit(" b)\n{\n"); - - writer->indent(); - - if (auto vectorType = as<IRVectorType>(srcType)) - { - int vecSize = int(getIntVal(vectorType->getElementCount())); - - writer->emit("SLANG_PRELUDE_ASSERT(b >= 0 && b < "); - writer->emit(vecSize); - writer->emit(");\n"); - - writer->emit("return (("); - emitType(specOp->returnType); - writer->emit("*)"); - - if (lValue) - writer->emit("a) + b;\n"); - else - writer->emit("&a)[b];\n"); - } - else if (auto matrixType = as<IRMatrixType>(srcType)) - { - //int colCount = int(getIntVal(matrixType->getColumnCount())); - int rowCount = int(getIntVal(matrixType->getRowCount())); - - writer->emit("SLANG_PRELUDE_ASSERT(b >= 0 && b < "); - writer->emit(rowCount); - writer->emit(");\n"); - - if (lValue) - writer->emit("return &(a->rows[b]);\n"); - else - writer->emit("return a.rows[b];\n"); - } - - writer->dedent(); - writer->emit("}\n\n"); - } -} - -void CPPSourceEmitter::_emitConstructConvertDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - SourceWriter* writer = getSourceWriter(); - IRFuncType* funcType = specOp->signatureType; - - SLANG_ASSERT(funcType->getParamCount() == 2); - - IRType* srcType = funcType->getParamType(1); - IRType* retType = specOp->returnType; - - emitFunctionPreambleImpl(nullptr); - - emitType(retType); - writer->emit(" "); - writer->emit(funcName); - writer->emit("("); - emitType(srcType); - writer->emitChar(' '); - writer->emitChar(char('a' + 0)); - writer->emit(")"); - - writer->emit("\n{\n"); - writer->indent(); - - writer->emit("return "); - emitType(retType); - writer->emit("{ "); - - - IRType* dstElemType = _getElementType(retType); - //IRType* srcElemType = _getElementType(srcType); - - TypeDimension dim = _getTypeDimension(retType, false); - - UnownedStringSlice rowTypeName; - if (dim.rowCount > 1) - { - IRType* rowType = m_typeSet.addVectorType(dstElemType, int(dim.colCount)); - rowTypeName = _getTypeName(rowType); - } - - for (int i = 0; i < dim.rowCount; ++i) - { - if (dim.rowCount > 1) - { - if (i > 0) - { - writer->emit(", \n"); - } - - if (m_target == CodeGenTarget::CUDASource) - { - m_writer->emit("make_"); - writer->emit(rowTypeName); - m_writer->emit("("); - } - else - { - writer->emit(rowTypeName); - writer->emit("{ "); - } - } - - for (int j = 0; j < dim.colCount; ++j) - { - if (j > 0) - { - writer->emit(", "); - } - - emitType(dstElemType); - writer->emit("("); - _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer); - writer->emit(")"); - } - if (dim.rowCount > 1) - { - if (m_target == CodeGenTarget::CUDASource) - { - writer->emit(")"); - } - else - { - writer->emit("}"); - } - } - } - - writer->emit("};\n"); - - writer->dedent(); - writer->emit("}\n\n"); -} - -void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - SourceWriter* writer = getSourceWriter(); - IRFuncType* funcType = specOp->signatureType; - - emitFunctionPreambleImpl(nullptr); - - IRType* retType = specOp->returnType; - - _emitSignature(funcName, specOp); - writer->emit("\n{\n"); - writer->indent(); - - // Use C++ construction - writer->emit("return "); - emitType(retType); - writer->emit("{ "); - - const Index paramCount = Index(funcType->getParamCount()); - bool handled = false; - - if (IRVectorType* vecType = as<IRVectorType>(retType)) - { - Index elementCount = Index(getIntVal(vecType->getElementCount())); - - Index paramIndex = 0; - Index paramSubIndex = 0; - - for (Index i = 0; i < elementCount; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - - if (paramIndex >= paramCount) - { - writer->emit("0"); - } - else - { - IRType* paramType = funcType->getParamType(paramIndex); - - if (IRVectorType* paramVecType = as<IRVectorType>(paramType)) - { - Index paramElementCount = Index(getIntVal(paramVecType->getElementCount())); - - const UnownedStringSlice* elemNames = getVectorElementNames(paramVecType); - - writer->emitChar('a' + char(paramIndex)); - writer->emit("."); - writer->emit(elemNames[paramSubIndex]); - - paramSubIndex++; - - if (paramSubIndex >= paramElementCount) - { - paramIndex++; - paramSubIndex = 0; - } - } - else - { - writer->emitChar('a' + char(paramIndex)); - paramIndex++; - } - } - } - handled = true; - } - else if (IRMatrixType* matType = as<IRMatrixType>(retType)) - { - if (paramCount != 1) - goto fallback; - - auto paramMat = as<IRMatrixType>(funcType->getParamType(0)); - if (!paramMat) - goto fallback; - - // We are constructing a matrix from a differently sized matrix. - - Index rows = Index(getIntVal(matType->getRowCount())); - Index cols = Index(getIntVal(matType->getColumnCount())); - Index paramRows = Index(getIntVal(paramMat->getRowCount())); - Index paramCols = Index(getIntVal(paramMat->getColumnCount())); - char elementNames[] = { 'x', 'y', 'z', 'w' }; - - for (Index r = 0; r < rows; r++) - { - for (Index c = 0; c < cols; c++) - { - if (r != 0 || c != 0) - writer->emit(", "); - - if (r < paramRows && c < paramCols && c < 4) - { - writer->emitRawText("a.rows["); - writer->emit(r); - writer->emitRawText("]."); - writer->emitChar(elementNames[c]); - } - else - { - writer->emit("0"); - } - } - } - handled = true; - } -fallback: - if (!handled) - { - // Fallback default: just use all params to construct. - for (Index i = 0; i < paramCount; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - writer->emitChar('a' + char(i)); - } - } - - writer->emit("};\n"); - - writer->dedent(); - writer->emit("}\n\n"); -} - - -void CPPSourceEmitter::_emitConstructFromScalarDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - SourceWriter* writer = getSourceWriter(); - IRFuncType* funcType = specOp->signatureType; - - SLANG_ASSERT(funcType->getParamCount() == 2); - - IRType* srcType = funcType->getParamType(1); - IRType* retType = specOp->returnType; - - emitFunctionPreambleImpl(nullptr); - - emitType(retType); - writer->emit(" "); - writer->emit(funcName); - writer->emit("("); - emitType(srcType); - writer->emitChar(' '); - writer->emitChar(char('a' + 0)); - writer->emit(")"); - - writer->emit("\n{\n"); - writer->indent(); - - writer->emit("return "); - emitType(retType); - writer->emit("{ "); - - const TypeDimension dim = _getTypeDimension(retType, false); - - for (int i = 0; i < dim.rowCount; ++i) - { - if (dim.rowCount > 1) - { - if (i > 0) - { - writer->emit(", \n"); - } - writer->emit("{ "); - } - for (int j = 0; j < dim.colCount; ++j) - { - if (j > 0) - { - writer->emit(", "); - } - writer->emit("a"); - } - if (dim.rowCount > 1) - { - writer->emit("}"); - } - } - - writer->emit("};\n"); - - writer->dedent(); - writer->emit("}\n\n"); -} - -void CPPSourceEmitter::_maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) -{ - // Check if it's been emitted already, if not add it. - if (!m_intrinsicEmitted.Add(specOp)) - { - return; - } - emitSpecializedOperationDefinition(specOp); -} - -void CPPSourceEmitter::emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) -{ - typedef HLSLIntrinsic::Op Op; - - switch (specOp->op) - { - case Op::Init: - { - return _emitInitDefinition(_getFuncName(specOp), specOp); - } - case Op::Any: - case Op::All: - { - return _emitAnyAllDefinition(_getFuncName(specOp), specOp); - } - case Op::ConstructConvert: - { - return _emitConstructConvertDefinition(_getFuncName(specOp), specOp); - } - case Op::ConstructFromScalar: - { - return _emitConstructFromScalarDefinition(_getFuncName(specOp), specOp); - } - case Op::GetAt: - { - return _emitGetAtDefinition(_getFuncName(specOp), specOp); - } - case Op::Swizzle: - { - // Don't have to output anything for swizzle for now - return; - } - default: - { - const auto& info = HLSLIntrinsic::getInfo(specOp->op); - const int paramCount = (info.numOperands < 0) ? int(specOp->signatureType->getParamCount()) : info.numOperands; - - if (paramCount >= 1 && paramCount <= 3) - { - return _emitAryDefinition(specOp); - } - break; - } - } - - SLANG_ASSERT(!"Unhandled"); -} - -void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec) -{ - typedef HLSLIntrinsic::Op Op; - - SLANG_UNUSED(inOuterPrec); - SourceWriter* writer = getSourceWriter(); - - switch (specOp->op) - { - case Op::Init: - { - IRType* retType = specOp->returnType; - if (IRBasicType::isaImpl(retType->getOp())) - { - SLANG_ASSERT(numOperands == 1); - - writer->emit(_getTypeName(retType)); - writer->emitChar('('); - - emitOperand(operands[0].get(), getInfo(EmitOp::General)); - - writer->emitChar(')'); - return; - } - break; - } - case Op::Swizzle: - { - // Currently only works for C++ (we use {} constuction) - which means we don't need to generate a function. - // For C we need to generate suitable construction function - auto swizzleInst = static_cast<IRSwizzle*>(inst); - const Index elementCount = Index(swizzleInst->getElementCount()); - - IRType* srcType = swizzleInst->getBase()->getDataType(); - IRVectorType* srcVecType = as<IRVectorType>(srcType); - - const UnownedStringSlice* elemNames = getVectorElementNames(srcVecType); - - // TODO(JS): Not 100% sure this is correct on the parens handling front - IRType* retType = specOp->returnType; - emitType(retType); - writer->emit("{"); - - for (Index i = 0; i < elementCount; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - - auto outerPrec = getInfo(EmitOp::General); - - auto prec = getInfo(EmitOp::Postfix); - emitOperand(swizzleInst->getBase(), leftSide(outerPrec, prec)); - - writer->emit("."); - - IRInst* irElementIndex = swizzleInst->getElementIndex(i); - SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit); - IRConstant* irConst = (IRConstant*)irElementIndex; - UInt elementIndex = (UInt)irConst->value.intVal; - SLANG_RELEASE_ASSERT(elementIndex < 4); - - writer->emit(elemNames[elementIndex]); - } - - writer->emit("}"); - return; - } - default: break; - } - - { - const auto& info = HLSLIntrinsic::getInfo(specOp->op); - // Make sure that the return type is available - const bool isOperator = _isOperator(info.funcName); - const UnownedStringSlice funcName = _getFuncName(specOp); - - switch (specOp->op) - { - case Op::ConstructFromScalar: - { - // We need to special case, because this may have come from a swizzle from a built in - // type, in that case the only parameter we want is the first one - numOperands = 1; - break; - } - - default: break; - } - - // add that we want a function - SLANG_ASSERT(info.numOperands < 0 || numOperands == info.numOperands); - - useType(specOp->returnType); - - if (isOperator) - { - // Just do the default output - defaultEmitInstExpr(inst, inOuterPrec); - } - else - { - writer->emit(funcName); - writer->emitChar('('); - - for (int i = 0; i < numOperands; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - emitOperand(operands[i].get(), getInfo(EmitOp::General)); - } - - writer->emitChar(')'); - } - } -} - -HLSLIntrinsic* CPPSourceEmitter::_addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount) -{ - HLSLIntrinsic intrinsic; - m_intrinsicSet.calcIntrinsic(op, returnType, argTypes, argTypeCount, intrinsic); - HLSLIntrinsic* addedIntrinsic = m_intrinsicSet.add(intrinsic); - _getFuncName(addedIntrinsic); - return addedIntrinsic; -} - -SlangResult CPPSourceEmitter::calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder) -{ - outBuilder << _getTypePrefix(type->getOp()) << "_" << HLSLIntrinsic::getInfo(op).funcName; - return SLANG_OK; -} - -UnownedStringSlice CPPSourceEmitter::_getScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type) -{ - /* TODO(JS): This is kind of fast and loose. That we don't know all the parameters that are taken or - what the return type is, so we can't add to the HLSLIntrinsic map - we just generate the scalar - function name and use it (whilst also adding to the slice pool, so that we can return an - unowned slice). */ - - StringBuilder builder; - if (SLANG_FAILED(calcScalarFuncName(op, type, builder))) - { - SLANG_ASSERT(!"Unable to create scalar function name"); - return UnownedStringSlice(); - } - - // Add to the pool. - auto handle = m_slicePool.add(builder); - return m_slicePool.getSlice(handle); -} - -UnownedStringSlice CPPSourceEmitter::_getFuncName(const HLSLIntrinsic* specOp) -{ - StringSlicePool::Handle handle = StringSlicePool::kNullHandle; - if (m_intrinsicNameMap.TryGetValue(specOp, handle)) - { - return m_slicePool.getSlice(handle); - } - - StringBuilder builder; - if (SLANG_FAILED(calcFuncName(specOp, builder))) - { - SLANG_ASSERT(!"Unable to create function name"); - // Return an empty slice, as an error... - return UnownedStringSlice(); - } - - handle = m_slicePool.add(builder); - m_intrinsicNameMap.Add(specOp, handle); - - SLANG_ASSERT(handle != StringSlicePool::kNullHandle); - return m_slicePool.getSlice(handle); -} - -SlangResult CPPSourceEmitter::calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& outBuilder) -{ - typedef HLSLIntrinsic::Op Op; - - if (specOp->isScalar()) - { - IRType* paramType = specOp->signatureType->getParamType(0); - IRBasicType* basicType = as<IRBasicType>(paramType); - if (basicType) - { - return calcScalarFuncName(specOp->op, basicType, outBuilder); - } - else - { - outBuilder << getName(paramType) << HLSLIntrinsic::getInfo(specOp->op).name; - return SLANG_OK; - } - } - else - { - switch (specOp->op) - { - case Op::ConstructConvert: - { - // Work out the function name - IRFuncType* signatureType = specOp->signatureType; - SLANG_ASSERT(signatureType->getParamCount() == 2); - - IRType* dstType = signatureType->getParamType(0); - //IRType* srcType = signatureType->getParamType(1); - - outBuilder << "convert_"; - // I need a function that is called that will construct this - SLANG_RETURN_ON_FAIL(calcTypeName(dstType, CodeGenTarget::CSource, outBuilder)); - return SLANG_OK; - } - case Op::ConstructFromScalar: - { - // Work out the function name - IRFuncType* signatureType = specOp->signatureType; - SLANG_ASSERT(signatureType->getParamCount() == 2); - - IRType* dstType = signatureType->getParamType(0); - - outBuilder << "constructFromScalar_"; - // I need a function that is called that will construct this - SLANG_RETURN_ON_FAIL(calcTypeName(dstType, CodeGenTarget::CSource, outBuilder)); - return SLANG_OK; - } - case Op::GetAt: - { - outBuilder << "getAt"; - return SLANG_OK; - } - case Op::Init: - { - outBuilder << "make_"; - SLANG_RETURN_ON_FAIL(calcTypeName(specOp->returnType, CodeGenTarget::CSource, outBuilder)); - return SLANG_OK; - } - default: break; - } - - const auto& info = HLSLIntrinsic::getInfo(specOp->op); - if (info.funcName.getLength()) - { - if (!_isOperator(info.funcName)) - { - // If there is a standard default name, just use that - outBuilder << info.funcName; - return SLANG_OK; - } - } - - // Just use the name of the Op. This is probably wrong, but gives a pretty good idea of what the desired (presumably missing) op is. - outBuilder << info.name; - return SLANG_OK; - } -} - /* !!!!!!!!!!!!!!!!!!!!!! CPPSourceEmitter !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ CPPSourceEmitter::CPPSourceEmitter(const Desc& desc): Super(desc), - m_slicePool(StringSlicePool::Style::Default), - m_typeSet(desc.codeGenContext->getSession()), - m_opLookup(new HLSLIntrinsicOpLookup), - m_intrinsicSet(&m_typeSet, m_opLookup) + m_slicePool(StringSlicePool::Style::Default) { m_semanticUsedFlags = 0; //m_semanticUsedFlags = SemanticUsedFlag::GroupID | SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::DispatchThreadID; @@ -2145,12 +958,16 @@ void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) void CPPSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) { - emitSimpleType(m_typeSet.addVectorType(elementType, int(elementCount))); + m_writer->emit("Vector<"); + m_writer->emit(_getTypeName(elementType)); + m_writer->emit(", "); + m_writer->emit(elementCount); + m_writer->emit(">"); } void CPPSourceEmitter::emitSimpleTypeImpl(IRType* inType) { - UnownedStringSlice slice = _getTypeName(m_typeSet.getType(inType)); + UnownedStringSlice slice = _getTypeName(inType); m_writer->emit(slice); } @@ -2225,8 +1042,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl( IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) { - typedef HLSLIntrinsic::Op Op; - // TODO: Much of this logic duplicates code that is already // in `CLikeSourceEmitter::emitIntrinsicCallExpr`. The only // real difference is that when things bottom out on an ordinary @@ -2248,36 +1063,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl( if (name == ".operator[]") { SLANG_ASSERT(argCount == 2 || argCount == 3); - - // If the first item is either a matrix or a vector, we use 'getAt' logic - IRType* targetType = args[0].get()->getDataType(); - if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType) - { - // Work out the intrinsic used - HLSLIntrinsic intrinsic; - m_intrinsicSet.calcIntrinsic(HLSLIntrinsic::Op::GetAt, inst->getDataType(), args, 2, intrinsic); - HLSLIntrinsic* specOp = m_intrinsicSet.add(intrinsic); - - if (argCount == 2) - { - // Load - emitCall(specOp, inst, args, 2, inOuterPrec); - } - else - { - // Store - auto prec = getInfo(EmitOp::Postfix); - needClose = maybeEmitParens(outerPrec, prec); - - emitCall(specOp, inst, inst->getOperands(), 2, inOuterPrec); - - m_writer->emit(" = "); - emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); - - maybeCloseParens(needClose); - } - } - else { // The user is invoking a built-in subscript operator @@ -2318,21 +1103,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl( return; } - { - Op op = m_opLookup->getOpByName(name); - if (op != Op::Invalid) - { - - // Work out the intrinsic used - HLSLIntrinsic intrinsic; - m_intrinsicSet.calcIntrinsic(op, inst->getDataType(), args, argCount, intrinsic); - HLSLIntrinsic* specOp = m_intrinsicSet.add(intrinsic); - - emitCall(specOp, inst, args, int(argCount), inOuterPrec); - return; - } - } - // Use default impl (which will do intrinsic special macro expansion as necessary) return Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec); } @@ -2372,32 +1142,147 @@ const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(IRVectorType* return getVectorElementNames(basicType->getBaseType(), elemCount); } -bool CPPSourceEmitter::_tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec) +bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { - HLSLIntrinsic* specOp = m_intrinsicSet.add(inst); - if (specOp) + switch (inst->getOp()) { - if (inst->getOp() == kIROp_Call) + default: { - IRCall* call = static_cast<IRCall*>(inst); - emitCall(specOp, inst, call->getArgs(), int(call->getArgCount()), inOuterPrec); + return false; } - else + case kIROp_MakeVector: { - emitCall(specOp, inst, inst->getOperands(), int(inst->getOperandCount()), inOuterPrec); + IRType* retType = inst->getFullType(); + emitType(retType); + m_writer->emit("("); + bool isFirst = true; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto arg = inst->getOperand(i); + if (auto vectorType = as<IRVectorType>(arg->getDataType())) + { + for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++) + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(arg, leftSide(outerPrec, prec)); + m_writer->emit("."); + m_writer->emitChar(s_xyzwNames[j]); + } + } + else + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + emitOperand(arg, getInfo(EmitOp::General)); + } + } + m_writer->emit(")"); + + return true; } - return true; - } - return false; -} + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_FloatCast: + case kIROp_IntCast: + { + if (auto vectorType = as<IRVectorType>(inst->getDataType())) + { + emitType(vectorType); + m_writer->emit("{"); + for (Index i = 0; i < cast<IRIntLit>(vectorType->getElementCount())->getValue(); i++) + { + if (i > 0) + m_writer->emit(", "); + m_writer->emit("("); + emitType(vectorType->getElementType()); + m_writer->emit(")_slang_vector_get_element("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(i); + m_writer->emit(")"); + } + m_writer->emit("}"); + return true; + } + return false; + } + case kIROp_VectorReshape: + { + if (auto vectorType = as<IRVectorType>(inst->getDataType())) + { + m_writer->emit("_slang_vector_reshape<"); + emitType(vectorType->getElementType()); + m_writer->emit(", "); + emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + m_writer->emit(">("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + return false; + } + case kIROp_GetElement: + { + auto getElementInst = static_cast<IRGetElement*>(inst); -bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) -{ - switch (inst->getOp()) - { - default: + IRInst* baseInst = getElementInst->getBase(); + IRType* baseType = baseInst->getDataType(); + if (as<IRVectorType>(baseType)) + { + m_writer->emit("_slang_vector_get_element("); + emitOperand(baseInst, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + else if (as<IRMatrixType>(baseType)) + { + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(baseInst, leftSide(outerPrec, prec)); + m_writer->emit(".rows["); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit("]"); + return true; + } + return false; + } + case kIROp_GetElementPtr: { - return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec); + auto getElementInst = static_cast<IRGetElement*>(inst); + + IRInst* baseInst = getElementInst->getBase(); + IRType* baseType = as<IRPtrTypeBase>(baseInst->getDataType())->getValueType(); + if (as<IRVectorType>(baseType)) + { + m_writer->emit("_slang_vector_get_element_ptr("); + emitOperand(baseInst, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + else if (as<IRMatrixType>(baseType)) + { + m_writer->emit("&("); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(baseInst, leftSide(outerPrec, prec)); + m_writer->emit("->rows["); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit("]"); + m_writer->emit(")"); + return true; + } + return false; } case kIROp_swizzle: { @@ -2430,8 +1315,79 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut return true; } } - // try doing automatically - return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec); + + { + // Currently only works for C++ (we use {} constuction) - which means we don't need to generate a function. + // For C we need to generate suitable construction function + + const Index elementCount = Index(swizzleInst->getElementCount()); + + IRType* srcType = swizzleInst->getBase()->getDataType(); + IRVectorType* srcVecType = as<IRVectorType>(srcType); + + const UnownedStringSlice* elemNames = nullptr; + if (srcVecType) + elemNames = getVectorElementNames(srcVecType); + + IRType* retType = swizzleInst->getFullType(); + emitType(retType); + m_writer->emit("{"); + + for (Index i = 0; i < elementCount; ++i) + { + if (i > 0) + { + m_writer->emit(", "); + } + + auto outerPrec = getInfo(EmitOp::General); + + auto prec = getInfo(EmitOp::Postfix); + emitOperand(swizzleInst->getBase(), leftSide(outerPrec, prec)); + + if (srcVecType) + { + m_writer->emit("."); + + IRInst* irElementIndex = swizzleInst->getElementIndex(i); + SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit); + IRConstant* irConst = (IRConstant*)irElementIndex; + UInt elementIndex = (UInt)irConst->value.intVal; + SLANG_RELEASE_ASSERT(elementIndex < 4); + + m_writer->emit(elemNames[elementIndex]); + } + } + + m_writer->emit("}"); + } + return true; + } + case kIROp_FRem: + { + if (auto basicType = as<IRBasicType>(inst->getDataType())) + { + switch (basicType->getOp()) + { + case kIROp_HalfType: + m_writer->emit("F16_fmod("); + break; + case kIROp_FloatType: + m_writer->emit("F32_fmod("); + break; + case kIROp_DoubleType: + m_writer->emit("F64_fmod("); + break; + default: + return false; + } + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + return false; } case kIROp_Call: { @@ -2441,7 +1397,7 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut handleRequiredCapabilities(funcValue); // try doing automatically - return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec); + return false; } case kIROp_LookupWitness: { @@ -2562,29 +1518,6 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut } } -// We want order of built in types (typically output nothing), vector, matrix, other types -// Types that aren't output have negative indices -static Index _calcTypeOrder(IRType* a) -{ - switch (a->getOp()) - { - case kIROp_FuncType: - { - return -2; - } - case kIROp_VectorType: return 1; - case kIROp_MatrixType: return 2; - default: - { - if (as<IRBasicType>(a)) - { - return -1; - } - return 3; - } - } -} - void CPPSourceEmitter::emitPreModuleImpl() { if (m_target == CodeGenTarget::CPPSource) @@ -2604,45 +1537,6 @@ void CPPSourceEmitter::emitPreModuleImpl() m_writer->emit("using namespace SLANG_PRELUDE_NAMESPACE;\n"); m_writer->emit("#endif\n\n"); } - - // Emit generated functions and types - - if (m_target == CodeGenTarget::CSource) - { - // For C output we need to emit type definitions. - List<IRType*> types; - m_typeSet.getTypes(types); - - // Remove ones we don't need to emit - for (Index i = 0; i < types.getCount(); ++i) - { - if (_calcTypeOrder(types[i]) < 0) - { - types.fastRemoveAt(i); - --i; - } - } - - // Sort them so that vectors come before matrices and everything else after that - types.sort([&](IRType* a, IRType* b) { return _calcTypeOrder(a) < _calcTypeOrder(b); }); - - // Emit the type definitions - for (auto type : types) - { - emitTypeDefinition(type); - } - } - - { - List<const HLSLIntrinsic*> intrinsics; - m_intrinsicSet.getIntrinsics(intrinsics); - - // Emit all the intrinsics that were used - for (auto intrinsic : intrinsics) - { - _maybeEmitSpecializedOperationDefinition(intrinsic); - } - } } @@ -2980,11 +1874,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); - // Setup all built in types used in the module - m_typeSet.addAllBuiltinTypes(module); - // If any matrix types are used, then we need appropriate vector types too. - m_typeSet.addVectorForMatrixTypes(); - List<EmitAction> actions; computeEmitActions(module, actions); diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index c5b9f3d9c..ec70b02b8 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -39,9 +39,6 @@ public: }; virtual void useType(IRType* type); - virtual void emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec); - virtual void emitTypeDefinition(IRType* type); - virtual void emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp); static UnownedStringSlice getBuiltinTypeName(IROp op); @@ -78,43 +75,21 @@ protected: virtual void emitVarDecorationsImpl(IRInst* var) SLANG_OVERRIDE; virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE; - virtual const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount); + const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount); // Replaceable for classes derived from CPPSourceEmitter virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out); - virtual SlangResult calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& out); - virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder); const UnownedStringSlice* getVectorElementNames(IRVectorType* vectorType); - void _maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp); - void _emitForwardDeclarations(const List<EmitAction>& actions); - void _emitAryDefinition(const HLSLIntrinsic* specOp); - - // Really we don't want any of these defined like they are here, they should be defined in slang stdlib - void _emitAnyAllDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitConstructConvertDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitConstructFromScalarDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitGetAtDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitInitDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - - void _emitSignature(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitInOutParamType(IRType* type, String const& name, IRType* valueType); - - UnownedStringSlice _getAndEmitSpecializedOperationDefinition(HLSLIntrinsic::Op op, IRType*const* argTypes, Int argCount, IRType* retType); - static TypeDimension _getTypeDimension(IRType* type, bool vecSwap); void _emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer); - UnownedStringSlice _getScalarFuncName(HLSLIntrinsic::Op operation, IRBasicType* scalarType); - - UnownedStringSlice _getFuncName(const HLSLIntrinsic* specOp); - UnownedStringSlice _getTypeName(IRType* type); SlangResult _calcCPPTextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName); @@ -126,8 +101,6 @@ protected: void _emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName); - bool _tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec); - // Emit the actual definition (including intializer list) // of all the witness table objects in `pendingWitnessTableDefinitions`. void _emitWitnessTableDefinitions(); @@ -136,18 +109,9 @@ protected: void _getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC); void _maybeEmitExportLike(IRInst* inst); - HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount); - static bool _isVariable(IROp op); Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap; - Dictionary<const HLSLIntrinsic*, StringSlicePool::Handle> m_intrinsicNameMap; - - IRTypeSet m_typeSet; - RefPtr<HLSLIntrinsicOpLookup> m_opLookup; - HLSLIntrinsicSet m_intrinsicSet; - - HashSet<const HLSLIntrinsic*> m_intrinsicEmitted; HashSet<IRInterfaceType*> m_interfaceTypesEmitted; diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 284652682..a151ab0e2 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -123,131 +123,6 @@ SlangResult CUDASourceEmitter::_calcCUDATextureTypeName(IRTextureTypeBase* texTy return SLANG_FAIL; } -SlangResult CUDASourceEmitter::calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder) -{ - typedef HLSLIntrinsic::Op Op; - - UnownedStringSlice funcName; - - switch (op) - { - case Op::FRem: - { - if (type->getOp() == kIROp_FloatType || type->getOp() == kIROp_DoubleType) - { - funcName = HLSLIntrinsic::getInfo(op).funcName; - } - break; - } - default: break; - } - - if (funcName.getLength()) - { - outBuilder << funcName; - if (type->getOp() == kIROp_FloatType) - { - outBuilder << "f"; - } - return SLANG_OK; - } - - // Defer to the supers impl - return Super::calcScalarFuncName(op, type, outBuilder); -} - -void CUDASourceEmitter::emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) -{ - typedef HLSLIntrinsic::Op Op; - - if (auto vecType = as <IRVectorType>(specOp->returnType)) - { - // Converting to or from half vector types is implemented prelude as convert___half functions - // Get the from type -> if it's half we ignore - - if (specOp->op == Op::ConstructConvert) - { - auto signatureType = specOp->signatureType; - - // Need to have impl of convert_float, double, int, uint, in prelude - - const auto paramCount = signatureType->getParamCount(); - SLANG_UNUSED(paramCount); - - // We have 2 'params' and param 1 is the source type - SLANG_ASSERT(paramCount == 2); - IRType* paramType = signatureType->getParamType(1); - - auto vecParamType = as<IRVectorType>(paramType); - - if (auto baseType = as<IRBasicType>(vecParamType->getElementType())) - { - if (baseType->getBaseType() == BaseType::Half) - { - return; - } - } - } - - if (auto baseType = as<IRBasicType>(vecType->getElementType())) - { - if (baseType->getBaseType() == BaseType::Half) - { - switch (specOp->op) - { - case Op::Init: - - case Op::Add: - case Op::Mul: - case Op::Div: - case Op::Sub: - - case Op::Neg: - - case Op::ConstructFromScalar: - case Op::ConstructConvert: - - case Op::Leq: - case Op::Less: - case Op::Greater: - case Op::Geq: - case Op::Neq: - case Op::Eql: - { - return; - } - } - } - } - } - - switch (specOp->op) - { - case Op::Init: - { - // Special case handling - auto returnType = specOp->returnType; - - if (auto vecType = as <IRVectorType>(returnType)) - { - if (auto baseType = as<IRBasicType>(vecType->getElementType())) - { - if (baseType->getBaseType() == BaseType::Half) - { - // Defined already in cuda-prelude.h - return; - } - } - } - - break; - } - default: break; - } - - Super::emitSpecializedOperationDefinition(specOp); -} - SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) { SLANG_UNUSED(target); @@ -322,25 +197,6 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, return Super::calcTypeName(type, target, out); } -const UnownedStringSlice* CUDASourceEmitter::getVectorElementNames(BaseType baseType, Index elemCount) -{ - static const UnownedStringSlice normal[] = { UnownedStringSlice::fromLiteral("x"), UnownedStringSlice::fromLiteral("y"), UnownedStringSlice::fromLiteral("z"), UnownedStringSlice::fromLiteral("w") }; - static const UnownedStringSlice half3[] = { UnownedStringSlice::fromLiteral("xy.x"), UnownedStringSlice::fromLiteral("xy.y"), UnownedStringSlice::fromLiteral("z") }; - static const UnownedStringSlice half4[] = { UnownedStringSlice::fromLiteral("xy.x"), UnownedStringSlice::fromLiteral("xy.y"), UnownedStringSlice::fromLiteral("zw.x"), UnownedStringSlice::fromLiteral("zw.y")}; - - if (baseType == BaseType::Half) - { - switch (elemCount) - { - default: break; - case 3: return half3; - case 4: return half4; - } - } - - return normal; -} - void CUDASourceEmitter::emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling) { Super::emitLayoutSemanticsImpl(inst, uniformSemanticSpelling); @@ -436,49 +292,6 @@ void CUDASourceEmitter::emitGlobalRTTISymbolPrefix() m_writer->emit("__constant__ "); } -void CUDASourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec) -{ - switch (specOp->op) - { - case HLSLIntrinsic::Op::Init: - { - // For CUDA vector types we construct with make_ - - auto writer = m_writer; - - IRType* retType = specOp->returnType; - - if (IRVectorType* vecType = as<IRVectorType>(retType)) - { - if (numOperands == getIntVal(vecType->getElementCount())) - { - // Get the type name - writer->emit("make_"); - emitType(retType); - writer->emitChar('('); - - for (int i = 0; i < numOperands; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - emitOperand(operands[i].get(), getInfo(EmitOp::General)); - } - - writer->emitChar(')'); - return; - } - } - // Just use the default - break; - } - default: break; - } - - return Super::emitCall(specOp, inst, operands, numOperands, inOuterPrec); -} - void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) { if (decl->getMode() == kIRLoopControl_Unroll) @@ -487,59 +300,25 @@ void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* d } } -static bool _areEquivalent(IRType* a, IRType* b) -{ - if (a == b) - { - return true; - } - if (a->getOp() != b->getOp()) - { - return false; - } - - switch (a->getOp()) - { - case kIROp_VectorType: - { - IRVectorType* vecA = static_cast<IRVectorType*>(a); - IRVectorType* vecB = static_cast<IRVectorType*>(b); - return getIntVal(vecA->getElementCount()) == getIntVal(vecB->getElementCount()) && - _areEquivalent(vecA->getElementType(), vecB->getElementType()); - } - case kIROp_MatrixType: - { - IRMatrixType* matA = static_cast<IRMatrixType*>(a); - IRMatrixType* matB = static_cast<IRMatrixType*>(b); - return getIntVal(matA->getColumnCount()) == getIntVal(matB->getColumnCount()) && - getIntVal(matA->getRowCount()) == getIntVal(matB->getRowCount()) && - _areEquivalent(matA->getElementType(), matB->getElementType()); - } - default: - { - return as<IRBasicType>(a) != nullptr; - } - } -} - void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value) { // When constructing a matrix or vector from a single value this is handled by the default path switch (value->getOp()) { - case kIROp_MakeMatrix: case kIROp_MakeVector: + case kIROp_MakeMatrix: { IRType* type = value->getDataType(); // If the types are the same, we can can just break down and use - if (_areEquivalent(dstType, type)) + if (dstType == type) { if (auto vecType = as<IRVectorType>(type)) { if (UInt(getIntVal(vecType->getElementCount())) == value->getOperandCount()) { + emitType(type); _emitInitializerList(vecType->getElementType(), value->getOperands(), value->getOperandCount()); return; } @@ -551,20 +330,25 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value // TODO(JS): If num cols = 1, then it *doesn't* actually return a vector. // That could be argued is an error because we want swizzling or [] to work. - IRType* rowType = m_typeSet.addVectorType(matType->getElementType(), int(colCount)); - IRVectorType* rowVectorType = as<IRVectorType>(rowType); + IRBuilder builder(matType->getModule()); + builder.setInsertBefore(matType); const Index operandCount = Index(value->getOperandCount()); // Can init, with vectors. // For now special case if the rowVectorType is not actually a vector (when elementSize == 1) - if (operandCount == rowCount || rowVectorType == nullptr) + if (operandCount == rowCount) { - // We have to output vectors - - // Emit the braces for the Matrix struct, contains an row array. + // Emit the braces for the Matrix struct, and then each row vector in its own line. + emitType(matType); m_writer->emit("{\n"); m_writer->indent(); - _emitInitializerList(rowType, value->getOperands(), rowCount); + for (Index i = 0; i < rowCount; ++i) + { + if (i != 0) m_writer->emit(",\n"); + emitType(matType->getElementType()); + m_writer->emit(colCount); + _emitInitializerList(matType->getElementType(), value->getOperand(i)->getOperands(), colCount); + } m_writer->dedent(); m_writer->emit("\n}"); return; @@ -575,21 +359,18 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value IRType* elementType = matType->getElementType(); IRUse* operands = value->getOperands(); - // Emit the braces for the Matrix struct, and the array of rows - m_writer->emit("{\n"); - m_writer->indent(); + // Emit the braces for the Matrix struct, and the elements of each row in its own line. + emitType(matType); m_writer->emit("{\n"); m_writer->indent(); for (Index i = 0; i < rowCount; ++i) { - if (i != 0) m_writer->emit(", "); - _emitInitializerList(elementType, operands, colCount); + if (i != 0) m_writer->emit(",\n"); + _emitInitializerListContent(elementType, operands, colCount); operands += colCount; } m_writer->dedent(); m_writer->emit("\n}"); - m_writer->dedent(); - m_writer->emit("\n}"); return; } } @@ -603,116 +384,157 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value emitOperand(value, getInfo(EmitOp::General)); } -void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount) +void CUDASourceEmitter::_emitInitializerListContent(IRType* elementType, IRUse* operands, Index operandCount) { - m_writer->emit("{\n"); - m_writer->indent(); - for (Index i = 0; i < operandCount; ++i) { if (i != 0) m_writer->emit(", "); _emitInitializerListValue(elementType, operands[i].get()); } - - m_writer->dedent(); - m_writer->emit("\n}"); } -void CUDASourceEmitter::_emitGetHalfVectorElement(IRInst* base, Index index, Index vecSize, const EmitOpInfo& inOuterPrec) -{ - SLANG_ASSERT(index < vecSize); - - EmitOpInfo outerPrec = inOuterPrec; - - auto prec = getInfo(EmitOp::Postfix); - const bool needClose = maybeEmitParens(outerPrec, prec); - emitOperand(base, leftSide(outerPrec, prec)); +void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount) +{ + m_writer->emit("{\n"); + m_writer->indent(); - m_writer->emit("."); + _emitInitializerListContent(elementType, operands, operandCount); - switch (vecSize) - { - default: - { - char const* kComponents[] = { "x", "y", "z", "w" }; - m_writer->emit(kComponents[index]); - break; - } - case 3: - { - char const* kComponents[] = { "xy.x", "xy.y", "z"}; - m_writer->emit(kComponents[index]); - break; - } - case 4: - { - char const* kComponents[] = { "xy.x", "xy.y", "zw.x", "zw.y" }; - m_writer->emit(kComponents[index]); - break; - } - } + m_writer->dedent(); + m_writer->emit("\n}"); +} - maybeCloseParens(needClose); +void CUDASourceEmitter::emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) +{ + if (targetIntrinsic->getDefinition().startsWith("__half")) + m_extensionTracker->requireBaseType(BaseType::Half); + Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec); } bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { switch(inst->getOp()) { - case kIROp_swizzle: + case kIROp_MakeVector: + case kIROp_MakeVectorFromScalar: { - // We need to special case for half types. - auto swizzleInst = static_cast<IRSwizzle*>(inst); - - IRInst* baseInst = swizzleInst->getBase(); - IRType* baseType = baseInst->getDataType(); - - // If we are swizzling from a built in type, - if (as<IRBasicType>(baseType)) + m_writer->emit("make_"); + emitType(inst->getDataType()); + m_writer->emit("("); + bool isFirst = true; + char xyzwNames[] = "xyzw"; + for (UInt i = 0; i < inst->getOperandCount(); i++) { - // Just use the default behavior + auto arg = inst->getOperand(i); + if (auto vectorType = as<IRVectorType>(arg->getDataType())) + { + for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++) + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(arg, leftSide(outerPrec, prec)); + m_writer->emit("."); + m_writer->emitChar(xyzwNames[j]); + } + } + else + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + emitOperand(arg, getInfo(EmitOp::General)); + } } - else if (auto vecType = as<IRVectorType>(baseType)) + m_writer->emit(")"); + return true; + } + case kIROp_FloatCast: + case kIROp_CastIntToFloat: + case kIROp_IntCast: + case kIROp_CastFloatToInt: + { + if (auto dstVectorType = as<IRVectorType>(inst->getDataType())) { - if (auto basicType = as<IRBasicType>(vecType->getElementType())) + m_writer->emit("make_"); + emitType(inst->getDataType()); + m_writer->emit("("); + bool isFirst = true; + char xyzwNames[] = "xyzw"; + for (UInt i = 0; i < inst->getOperandCount(); i++) { - if (basicType->getBaseType() == BaseType::Half) + auto arg = inst->getOperand(i); + if (auto vectorType = as<IRVectorType>(arg->getDataType())) { - const Index vecElementCount = Index(getIntVal(vecType->getElementCount())); - - const Index elementCount = Index(swizzleInst->getElementCount()); - if (elementCount == 1) - { - const Index index = Index(getIntVal(swizzleInst->getElementIndex(0))); - _emitGetHalfVectorElement(baseInst, index, vecElementCount, inOuterPrec); - } - else + for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++) { - auto outerPrec = getInfo(EmitOp::General); - - m_writer->emit("make___half"); - m_writer->emitInt64(elementCount); + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); m_writer->emit("("); - - for (Index i = 0; i < elementCount; ++i) - { - if (i) - { - m_writer->emit(", "); - } - - const Index index = Index(getIntVal(swizzleInst->getElementIndex(i))); - _emitGetHalfVectorElement(baseInst, index, vecElementCount, outerPrec); - } - + emitType(dstVectorType->getElementType()); m_writer->emit(")"); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(arg, leftSide(outerPrec, prec)); + m_writer->emit("."); + m_writer->emitChar(xyzwNames[j]); } - return true; + } + else + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + m_writer->emit("("); + emitType(dstVectorType->getElementType()); + m_writer->emit(")"); + emitOperand(arg, getInfo(EmitOp::General)); } } + m_writer->emit(")"); + return true; } - break; + else if (auto matrixType = as<IRMatrixType>(inst->getDataType())) + { + m_writer->emit("make"); + emitType(inst->getDataType()); + m_writer->emit("("); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto arg = inst->getOperand(i); + if (i > 0) + m_writer->emit(", "); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; + } + return false; + } + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MatrixReshape: + { + m_writer->emit("make"); + emitType(inst->getDataType()); + m_writer->emit("("); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto arg = inst->getOperand(i); + if (i > 0) + m_writer->emit(", "); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; } case kIROp_MakeArray: { @@ -722,13 +544,9 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu IRType* elementType = arrayType->getElementType(); // Emit braces for the FixedArray struct. - m_writer->emit("{\n"); - m_writer->indent(); _emitInitializerList(elementType, inst->getOperands(), Index(inst->getOperandCount())); - m_writer->dedent(); - m_writer->emit("\n}"); return true; } case kIROp_WaveMaskBallot: @@ -820,7 +638,19 @@ void CUDASourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerVal void CUDASourceEmitter::emitSimpleTypeImpl(IRType* type) { - m_writer->emit(_getTypeName(type)); + switch (type->getOp()) + { + case kIROp_VectorType: + { + auto vectorType = as<IRVectorType>(type); + m_writer->emit(getVectorPrefix(vectorType->getElementType()->getOp())); + m_writer->emit(as<IRIntLit>(vectorType->getElementCount())->getValue()); + break; + } + default: + m_writer->emit(_getTypeName(type)); + break; + } } void CUDASourceEmitter::emitRateQualifiersImpl(IRRate* rate) @@ -907,27 +737,6 @@ void CUDASourceEmitter::emitPreModuleImpl() // Emit generated types/functions writer->emit("\n"); - - { - List<IRType*> types; - m_typeSet.getTypes(IRTypeSet::Kind::Matrix, types); - - // Emit the type definitions - for (auto type : types) - { - emitTypeDefinition(type); - } - } - - { - List<const HLSLIntrinsic*> intrinsics; - m_intrinsicSet.getIntrinsics(intrinsics); - // Emit all the intrinsics that were used - for (auto intrinsic : intrinsics) - { - _maybeEmitSpecializedOperationDefinition(intrinsic); - } - } } @@ -951,22 +760,6 @@ bool CUDASourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* v void CUDASourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) { - // Setup all built in types used in the module - m_typeSet.addAllBuiltinTypes(module); - // If any matrix types are used, then we need appropriate vector types too. - m_typeSet.addVectorForMatrixTypes(); - - // We need to add some vector intrinsics - used for calculating thread ids - { - IRType* type = m_typeSet.addVectorType(m_typeSet.getBuilder().getBasicType(BaseType::UInt), 3); - IRType* args[] = { type, type }; - - _addIntrinsic(HLSLIntrinsic::Op::Add, type, args, SLANG_COUNT_OF(args)); - _addIntrinsic(HLSLIntrinsic::Op::Mul, type, args, SLANG_COUNT_OF(args)); - } - - // TODO(JS): We may need to generate types (for example for matrices) - CLikeSourceEmitter::emitModuleImpl(module, sink); // Emit all witness table definitions. diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h index ff947fe58..8a907dc7c 100644 --- a/source/slang/slang-emit-cuda.h +++ b/source/slang/slang-emit-cuda.h @@ -78,12 +78,9 @@ protected: virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; - virtual void emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; virtual void emitFunctionPreambleImpl(IRInst* inst) SLANG_OVERRIDE; virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE; - virtual const UnownedStringSlice* getVectorElementNames(BaseType baseType, Index elemCount) SLANG_OVERRIDE; - virtual void emitGlobalRTTISymbolPrefix() SLANG_OVERRIDE; virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE; @@ -92,23 +89,19 @@ protected: virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE; virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; - + virtual void emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink) SLANG_OVERRIDE; // CPPSourceEmitter overrides virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) SLANG_OVERRIDE; - virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder) SLANG_OVERRIDE; - - virtual void emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) SLANG_OVERRIDE; SlangResult _calcCUDATextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName); void _emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount); + void _emitInitializerListContent(IRType* elementType, IRUse* operands, Index operandCount); void _emitInitializerListValue(IRType* elementType, IRInst* value); - void _emitGetHalfVectorElement(IRInst* baseInst, Index index, Index vecSize, const EmitOpInfo& inOuterPrec); - RefPtr<CUDAExtensionTracker> m_extensionTracker; }; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c49265fe7..ef0d062bb 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1022,7 +1022,7 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr auto irModule = linkedIR.module; // Perform final simplifications to help emit logic to generate more compact code. - simplifyForEmit(irModule); + simplifyForEmit(irModule, targetRequest); metadata = linkedIR.metadata; diff --git a/source/slang/slang-hlsl-intrinsic-set.cpp b/source/slang/slang-hlsl-intrinsic-set.cpp index ea3476473..e69de29bb 100644 --- a/source/slang/slang-hlsl-intrinsic-set.cpp +++ b/source/slang/slang-hlsl-intrinsic-set.cpp @@ -1,590 +0,0 @@ -// slang-hlsl-intrinsic-set.cpp -#include "slang-hlsl-intrinsic-set.h" - -#include "slang-ir.h" -#include "slang-ir-insts.h" - -namespace Slang -{ - -/* static */const HLSLIntrinsic::Info HLSLIntrinsic::s_operationInfos[] = -{ -#define SLANG_HLSL_INTRINSIC_OP_INFO(x, funcName, numOperands) { UnownedStringSlice::fromLiteral(#x), UnownedStringSlice::fromLiteral(funcName), int8_t(numOperands) }, - SLANG_HLSL_INTRINSIC_OP(SLANG_HLSL_INTRINSIC_OP_INFO) -}; - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! HLSLIntrinsicSet !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -HLSLIntrinsicSet::HLSLIntrinsicSet(IRTypeSet* typeSet, HLSLIntrinsicOpLookup* lookup): - m_intrinsicFreeList(sizeof(HLSLIntrinsic), SLANG_ALIGN_OF(HLSLIntrinsic), 1024), - m_typeSet(typeSet), - m_opLookup(lookup) -{ -} - -static IRBasicType* _getElementType(IRType* type) -{ - switch (type->getOp()) - { - case kIROp_VectorType: type = static_cast<IRVectorType*>(type)->getElementType(); break; - case kIROp_MatrixType: type = static_cast<IRMatrixType*>(type)->getElementType(); break; - default: break; - } - return dynamicCast<IRBasicType>(type); -} - -void HLSLIntrinsicSet::_calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgs, Index argsCount, HLSLIntrinsic& out) -{ - IRBuilder& builder = m_typeSet->getBuilder(); - - // Check all types belong to the module - - IRModule* module = builder.getModule(); - - SLANG_UNUSED(module); - SLANG_ASSERT(returnType->getModule() == module); - - for (Index i = 0; i < argsCount; ++i) - { - SLANG_ASSERT(inArgs[i]->getModule() == module); - } - - // Set up the out - out.op = op; - out.returnType = returnType; - - switch (op) - { - case Op::GetAt: - { - IRType* argTypes[3]; - - SLANG_ASSERT(argsCount == 2 || argsCount == 3); - // TODO(JS): - // HACK! GetAt can be from getElementPtr or from getElement. Get element ptr means the return type will be - // a pointer. We don't want to deal with that, so strip it - if (returnType->getOp() == kIROp_PtrType) - { - returnType = as<IRType>(returnType->getOperand(0)); - } - - // TODO(JS): Similarly for the input parameters - for (Index i = 0; i < argsCount; ++i) - { - IRType* argType = inArgs[i]; - - if (argType->getOp() == kIROp_PtrType) - { - argType = as<IRType>(argType->getOperand(0)); - } - argTypes[i] = argType; - } - - out.returnType = returnType; - out.signatureType = builder.getFuncType(argsCount, argTypes, builder.getVoidType()); - break; - } - case Op::ConstructFromScalar: - { - //SLANG_ASSERT(argsCount == 1); - SLANG_ASSERT(argsCount == 1); - IRType* srcType = _getElementType(returnType); - IRType* argTypes[2] = { returnType, srcType }; - - out.signatureType = builder.getFuncType(2, argTypes, builder.getVoidType()); - break; - } - case Op::ConstructConvert: - { - // Make the return type a parameter, to make the signature take into account - SLANG_ASSERT(argsCount == 1); - IRType* argTypes[2] = { returnType, inArgs[0] }; - - out.signatureType = builder.getFuncType(2, argTypes, builder.getVoidType()); - break; - } - default: - { - out.signatureType = builder.getFuncType(argsCount, inArgs, builder.getVoidType()); - break; - } - } -} - -void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgTypes, Index argCount, HLSLIntrinsic& out) -{ - returnType = m_typeSet->getType(returnType); - - if (argCount <= 8) - { - IRType* args[8]; - for (Index i = 0; i < argCount; ++i) - { - args[i] = m_typeSet->getType(inArgTypes[i]); - } - _calcIntrinsic(op, returnType, args, argCount, out); - } - else - { - List<IRType*> args; - args.setCount(argCount); - - for (Index i = 0; i < argCount; ++i) - { - args[i] = m_typeSet->getType(inArgTypes[i]); - } - _calcIntrinsic(op, returnType, args.getBuffer(), argCount, out); - } -} - -void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRInst* inst, Index operandCount, HLSLIntrinsic& out) -{ - IRType* returnType = m_typeSet->getType(inst->getDataType()); - if (operandCount <= 8) - { - IRType* argTypes[8]; - for (Index i = 0; i < operandCount; ++i) - { - auto operand = inst->getOperand(i); - argTypes[i] = m_typeSet->getType(operand->getDataType()); - } - _calcIntrinsic(op, returnType, argTypes, operandCount, out); - } - else - { - List<IRType*> argTypes; - argTypes.setCount(operandCount); - - for (Index i = 0; i < operandCount; ++i) - { - auto operand = inst->getOperand(i); - argTypes[i] = m_typeSet->getType(operand->getDataType()); - } - _calcIntrinsic(op, returnType, argTypes.getBuffer(), operandCount, out); - } -} - -void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRUse* inArgs, Index argCount, HLSLIntrinsic& out) -{ - returnType = m_typeSet->getType(returnType); - - if (argCount <= 8) - { - IRType* argTypes[8]; - - for (Index i = 0; i < argCount; ++i) - { - auto operand = inArgs[i].get(); - argTypes[i] = m_typeSet->getType(operand->getDataType()); - } - _calcIntrinsic(op, returnType, argTypes, argCount, out); - } - else - { - List<IRType*> argTypes; - argTypes.setCount(argCount); - - for (Index i = 0; i < argCount; ++i) - { - auto operand = inArgs[i].get(); - argTypes[i] = m_typeSet->getType(operand->getDataType()); - } - _calcIntrinsic(op, returnType, argTypes.getBuffer(), argCount, out); - } -} - -HLSLIntrinsic* HLSLIntrinsicSet::add(IRInst* inst) -{ - HLSLIntrinsic intrinsic; - if (SLANG_SUCCEEDED(makeIntrinsic(inst, intrinsic))) - { - return add(intrinsic); - } - return nullptr; -} - -SlangResult HLSLIntrinsicSet::makeIntrinsic(IRInst* inst, HLSLIntrinsic& out) -{ - // Mark as invalid... - out.op = Op::Invalid; - - { - // See if we can just directly convert - Op op = HLSLIntrinsicOpLookup::getOpForIROp(inst->getOp()); - - - // HACK: some cases we want to stop handling via the synthesis - // path, but only for vector and matrix types (not scalars). - // - switch( op ) - { - default: break; - - case Op::AsFloat: - case Op::AsInt: - case Op::AsUInt: - // Note: the `any()`/`all()` case can't be handled via a stdlib definition - // right now because `bool` vectors map to `int` vectors on the CUDA - // path, so that the generated `geAt` operation is incorrect. - // -// case Op::Any: -// case Op::All: - { - IRType* srcType = inst->getOperand(0)->getDataType(); - switch( srcType->getOp() ) - { - default: - break; - - case kIROp_VectorType: - case kIROp_MatrixType: - return SLANG_FAIL; - } - } - break; - } - - - if (op != Op::Invalid) - { - calcIntrinsic(op, inst, inst->getOperandCount(), out); - return SLANG_OK; - } - } - - // All the special cases - switch (inst->getOp()) - { - case kIROp_MakeVectorFromScalar: - case kIROp_MakeMatrixFromScalar: - { - SLANG_ASSERT(inst->getOperandCount() == 1); - calcIntrinsic(Op::ConstructFromScalar, inst, 1, out); - return SLANG_OK; - } - case kIROp_IntCast: - case kIROp_FloatCast: - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: - { - IRType* dstType = inst->getDataType(); - IRType* srcType = inst->getOperand(0)->getDataType(); - - if ((dstType->getOp() == kIROp_VectorType || dstType->getOp() == kIROp_MatrixType) && - inst->getOperandCount() == 1) - { - if (as<IRBasicType>(srcType)) - { - calcIntrinsic(Op::ConstructFromScalar, inst, out); - } - else - { - SLANG_ASSERT(m_typeSet->getType(dstType) != m_typeSet->getType(srcType)); - // If it's constructed from a type conversion - calcIntrinsic(Op::ConstructConvert, inst, out); - } - return SLANG_OK; - } - else - { - // If we are constructing a basic type, we don't need an Op::Init - if (!IRBasicType::isaImpl(dstType->getOp())) - { - // Emit the 'init' intrinsic - calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out); - return SLANG_OK; - } - } - return SLANG_FAIL; - } - case kIROp_MakeVector: - case kIROp_VectorReshape: - { - if (inst->getOperandCount() == 1 && as<IRBasicType>(inst->getOperand(0)->getDataType())) - { - // This is make from scalar - calcIntrinsic(Op::ConstructFromScalar, inst, out); - } - else - { - calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out); - } - return SLANG_OK; - } - case kIROp_MakeMatrix: - case kIROp_MatrixReshape: - { - // We only emit as if it has one operand, but we can tell how many it actually has from the return type - calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out); - return SLANG_OK; - } - case kIROp_swizzle: - { - // We don't need to add swizzle function, but we do output the need for some other functions - - // For C++ we don't need to emit a swizzle function - // For C we need a construction function - auto swizzleInst = static_cast<IRSwizzle*>(inst); - - IRInst* baseInst = swizzleInst->getBase(); - IRType* baseType = baseInst->getDataType(); - - // If we are swizzling from a built in type, - if (as<IRBasicType>(baseType)) - { - // We can swizzle a scalar type to be a vector, or just a scalar - IRType* dstType = swizzleInst->getDataType(); - if (!as<IRBasicType>(dstType)) - { - // If it's a scalar make sure we have construct from scalar, because we will want to use that - SLANG_ASSERT(dstType->getOp() == kIROp_VectorType); - IRType* argTypes[] = { baseType }; - calcIntrinsic(Op::ConstructFromScalar, inst->getDataType(), argTypes, 1, out); - return SLANG_OK; - } - } - else - { - const Index elementCount = Index(swizzleInst->getElementCount()); - if (elementCount >= 1) - { - // Will need to generate a swizzle method - calcIntrinsic(Op::Swizzle, inst, out); - return SLANG_OK; - } - } - break; - } - case kIROp_GetElement: - { - IRInst* target = inst->getOperand(0); - IRType* targetType = target->getDataType(); - if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType) - { - // Specially handle this - calcIntrinsic(Op::GetAt, inst, out); - return SLANG_OK; - } - break; - } - case kIROp_GetElementPtr: - { - IRInst* target = inst->getOperand(0); - IRType* targetType = target->getDataType(); - - if (auto ptrType = as<IRPtrType>(targetType)) - { - targetType = as<IRType>(ptrType->getOperand(0)); - if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType) - { - // Specially handle this - calcIntrinsic(Op::GetAt, inst, out); - return SLANG_OK; - } - } - break; - } - case kIROp_Call: - { - IRCall* callInst = (IRCall*)inst; - auto funcValue = callInst->getCallee(); - - const Op op = m_opLookup->getOpFromTargetDecoration(funcValue); - if (op != Op::Invalid) - { - calcIntrinsic(op, inst->getDataType(), callInst->getArgs(), callInst->getArgCount(), out); - return SLANG_OK; - } - break; - } - - default: break; - } - - return SLANG_FAIL; -} - -void HLSLIntrinsicSet::getIntrinsics(List<const HLSLIntrinsic*>& out) const -{ - for (auto& intrinsic : m_intrinsicsList) - { - out.add(intrinsic); - } -} - -HLSLIntrinsic* HLSLIntrinsicSet::add(const HLSLIntrinsic& intrinsic) -{ - // Make sure it's valid(!) - SLANG_ASSERT(intrinsic.op != Op::Invalid); - - HLSLIntrinsic* copy = (HLSLIntrinsic*)m_intrinsicFreeList.allocate(); - *copy = intrinsic; - HLSLIntrinsicRef ref(copy); - HLSLIntrinsic** found = m_intrinsicsDict.TryGetValueOrAdd(ref, copy); - if (found) - { - // If we have found an intrinsic, we can free the copy - m_intrinsicFreeList.deallocate(copy); - return *found; - } - - // If we are adding an intrinsic for the first time, - // it should be added to the deduplicated list - m_intrinsicsList.add(copy); - - return copy; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! HLSLIntrinsicOpLookup !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -HLSLIntrinsicOpLookup::HLSLIntrinsicOpLookup(): - m_slicePool(StringSlicePool::Style::Default) -{ - // Add all the operations with names (not ops like -, / etc) to the lookup map - for (int i = 0; i < SLANG_COUNT_OF(HLSLIntrinsic::s_operationInfos); ++i) - { - const auto& info = HLSLIntrinsic::getInfo(Op(i)); - UnownedStringSlice slice = info.funcName; - - if (slice.getLength() > 0 && slice[0] >= 'a' && slice[0] <= 'z') - { - auto handle = m_slicePool.add(slice); - Index index = Index(handle); - // Make sure there is space - if (index >= m_sliceToOpMap.getCount()) - { - Index oldSize = m_sliceToOpMap.getCount(); - m_sliceToOpMap.setCount(index + 1); - for (Index j = oldSize; j < index; j++) - { - m_sliceToOpMap[j] = Op::Invalid; - } - } - m_sliceToOpMap[index] = Op(i); - } - } -} - -HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpByName(const UnownedStringSlice& slice) -{ - const Index index = m_slicePool.findIndex(slice); - return (index >= 0 && index < m_sliceToOpMap.getCount()) ? m_sliceToOpMap[index] : Op::Invalid; -} - -static IRInst* _getSpecializedValue(IRSpecialize* specInst) -{ - auto base = specInst->getBase(); - auto baseGeneric = as<IRGeneric>(base); - if (!baseGeneric) - return base; - - auto lastBlock = baseGeneric->getLastBlock(); - if (!lastBlock) - return base; - - auto returnInst = as<IRReturn>(lastBlock->getTerminator()); - if (!returnInst) - return base; - - return returnInst->getVal(); -} - -HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpFromTargetDecoration(IRInst* inInst) -{ - // An intrinsic generic function will be invoked through a `specialize` instruction, - // so the callee won't directly be the thing that is decorated. We will look up - // through specializations until we can see the actual thing being called. - // - IRInst* inst = inInst; - while (auto specInst = as<IRSpecialize>(inst)) - { - inst = _getSpecializedValue(specInst); - - // If `getSpecializedValue` can't find the result value - // of the generic being specialized, then it returns - // the original instruction. This would be a disaster - // for use because this loop would go on forever. - // - // This case should never happen if the stdlib is well-formed - // and the compiler is doing its job right. - // - SLANG_ASSERT(inst != specInst); - } - - // We are just looking for the original name so we can match against it - for (auto dd : inst->getDecorations()) - { - if (auto decor = as<IRTargetIntrinsicDecoration>(dd)) - { - // TODO(JS): Should confirm that we'll always have this entry - which we need for lookups to work (we need the name - // not a targets transformation) - // - // It turns out that addCatchAllIntrinsicDecorationIfNeeded will add a target intrinsic with the - // original HLSL name, which has an empty `CapabilitySet`. - // - // It's not 100% clear this covers all the cases, but for now lets go with that - if (decor->getTargetCaps().isEmpty()) - { - Op op = getOpByName(decor->getDefinition()); - if (op != Op::Invalid) - { - return op; - } - } - } - } - - return Op::Invalid; -} - -HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpForIROp(IRInst* inst) -{ - switch (inst->getOp()) - { - case kIROp_Call: - { - return getOpFromTargetDecoration(inst); - } - default: break; - } - return getOpForIROp(inst->getOp()); -} - -/* static */HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpForIROp(IROp op) -{ - switch (op) - { - case kIROp_Add: return Op::Add; - case kIROp_Mul: return Op::Mul; - case kIROp_Sub: return Op::Sub; - case kIROp_Div: return Op::Div; - case kIROp_Lsh: return Op::Lsh; - case kIROp_Rsh: return Op::Rsh; - case kIROp_IRem: return Op::IRem; - case kIROp_FRem: return Op::FRem; - - case kIROp_Eql: return Op::Eql; - case kIROp_Neq: return Op::Neq; - case kIROp_Greater: return Op::Greater; - case kIROp_Less: return Op::Less; - case kIROp_Geq: return Op::Geq; - case kIROp_Leq: return Op::Leq; - - case kIROp_BitAnd: return Op::BitAnd; - case kIROp_BitXor: return Op::BitXor; - case kIROp_BitOr: return Op::BitOr; - - case kIROp_And: return Op::And; - case kIROp_Or: return Op::Or; - - case kIROp_Neg: return Op::Neg; - case kIROp_Not: return Op::Not; - case kIROp_BitNot: return Op::BitNot; - - case kIROp_MakeVectorFromScalar: return Op::ConstructFromScalar; - - default: return Op::Invalid; - } -} - -} diff --git a/source/slang/slang-hlsl-intrinsic-set.h b/source/slang/slang-hlsl-intrinsic-set.h index 3dc2996c1..8368491db 100644 --- a/source/slang/slang-hlsl-intrinsic-set.h +++ b/source/slang/slang-hlsl-intrinsic-set.h @@ -11,217 +11,5 @@ namespace Slang { -/* TODO(JS): Note that there are multiple methods to handle 'construction' operations. That is because 'construct' is used as a kind of -generic 'construction' for built in types including vectors and matrices. - -For the moment the cpp emit code, determines what kind of construct is needed, and has special handling for ConstructConvert and -ConstructFromScalar. - -That currently we do not see MakeVectorFromScalar - for example when we do... - -int2 fromScalar = 1; - -This appears as a construction from an int. - -That the better thing to do would be that there were IR instructions for the specific types of construction. I suppose there is a question -about whether there should be separate instructions for vector/matrix, or emit code should just use the destination type. In practice I think -it's fine that there isn't an instruction separating vector/matrix. That being the case I guess we arguably don't need MakeVectorFromScalar, -just constructXXXFromScalar. Would be good if there was a suitable name to encompass vector/matrix. -*/ -#define SLANG_HLSL_INTRINSIC_OP(x) \ - x(Invalid, "", -1) \ - x(Init, "", -1) \ - \ - x(Mul, "*", 2) \ - x(Div, "/", 2) \ - x(Add, "+", 2) \ - x(Sub, "-", 2) \ - x(Lsh, "<<", 2) \ - x(Rsh, ">>", 2) \ - x(IRem, "%", 2) \ - x(FRem, "fmod", 2) \ - \ - x(Eql, "==", 2) \ - x(Neq, "!=", 2) \ - x(Greater, ">", 2) \ - x(Less, "<", 2) \ - x(Geq, ">=", 2) \ - x(Leq, "<=", 2) \ - \ - x(BitAnd, "&", 2) \ - x(BitXor, "^", 2) \ - x(BitOr, "|" , 2) \ - \ - x(And, "&&", 2) \ - x(Or, "||", 2) \ - \ - x(Neg, "-", 1) \ - x(Not, "!", 1) \ - x(BitNot, "~", 1) \ - \ - x(Any, "any", 1) \ - x(All, "all", 1) \ - \ - x(Swizzle, "", -1) \ - \ - x(AsFloat, "asfloat", 1) \ - x(AsInt, "asint", -1) \ - x(AsUInt, "asuint", -1) \ - x(AsDouble, "asdouble", 2) \ - \ - x(ConstructConvert, "", 1) \ - x(ConstructFromScalar, "", 1) \ - \ - x(GetAt, "", 2) \ - /* end */ - -struct HLSLIntrinsic -{ - typedef HLSLIntrinsic ThisType; - - enum class Op : uint8_t - { -#define SLANG_HLSL_INTRINSIC_OP_ENUM(name, hlslName, numOperands) name, - SLANG_HLSL_INTRINSIC_OP(SLANG_HLSL_INTRINSIC_OP_ENUM) - }; - - struct Info - { - UnownedStringSlice name; ///< The enum name - UnownedStringSlice funcName; ///< The HLSL function name (if there is one) - int8_t numOperands; ///< -1 if can't be handled automatically via amount of params - }; - - bool operator==(const ThisType& rhs) const { return op == rhs.op && returnType == rhs.returnType && signatureType == rhs.signatureType; } - bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - static bool isTypeScalar(IRType* type) - { - // Strip off ptr if it's an operand type - if (type->getOp() == kIROp_PtrType) - { - type = as<IRType>(type->getOperand(0)); - } - // If any are vec or matrix, then we - return !(type->getOp() == kIROp_MatrixType || type->getOp() == kIROp_VectorType); - } - - bool isScalar() const - { - Index paramCount = Index(signatureType->getParamCount()); - for (Index i = 0; i < paramCount; ++i) - { - if (!isTypeScalar(signatureType->getParamType(i))) - { - return false; - } - } - return isTypeScalar(returnType); - } - - HashCode getHashCode() const { return combineHash(int(op), combineHash(Slang::getHashCode(returnType), Slang::getHashCode(signatureType))); } - - static const Info& getInfo(Op op) { return s_operationInfos[Index(op)]; } - static const Info s_operationInfos[]; - - Op op; - IRType* returnType; - IRFuncType* signatureType; // Same as funcType, but has return type of void -}; - -/* A helper type that allows comparing pointers to HLSLIntrinsic types as if they are the values */ -struct HLSLIntrinsicRef -{ - typedef HLSLIntrinsicRef ThisType; - - HashCode getHashCode() const { return m_intrinsic->getHashCode(); } - bool operator==(const ThisType& rhs) const { return m_intrinsic == rhs.m_intrinsic || (*m_intrinsic == *rhs.m_intrinsic); } - bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - HLSLIntrinsicRef():m_intrinsic(nullptr) {} - HLSLIntrinsicRef(const ThisType& rhs):m_intrinsic(rhs.m_intrinsic) {} - HLSLIntrinsicRef(const HLSLIntrinsic* intrinsic): m_intrinsic(intrinsic) {} - void operator=(const ThisType& rhs) { m_intrinsic = rhs.m_intrinsic; } - - const HLSLIntrinsic* m_intrinsic; -}; - -class HLSLIntrinsicOpLookup : public RefObject -{ -public: - typedef HLSLIntrinsic::Op Op; - - Op getOpFromTargetDecoration(IRInst* inInst); - Op getOpByName(const UnownedStringSlice& slice); - - Op getOpForIROp(IRInst* inst); - - HLSLIntrinsicOpLookup(); - - /// Given an IROp returns the Op equivalent or Op::Invalid if not found - static Op getOpForIROp(IROp op); - -protected: - - StringSlicePool m_slicePool; - List<Op> m_sliceToOpMap; -}; - - -/* This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsic. -That we want to have a pointer to a type be unique, and slang supports this through the m_sharedIRBuilder. BUT for this to -work all work on the module must use the same sharedIRBuilder, and that appears to not be the case in terms -of other passes. -Even if it was the case when we may want to add types as part of emitting, we can't use the previously used -shared builder, so again we end up with pointers to the same things not being the same thing. - -To work around this we clone types we want to use as keys into the 'unique module'. -This is not necessary for all types though - as we assume nominal types *must* have unique pointers (that is the -definition of nominal). - -This could be handled in other ways (for example not testing equality on pointer equality). Anyway for now this -works, but probably needs to be handled in a better way. The better way may involve having guarantees about equality -enabled in other code generation and making de-duping possible in emit code. - -Note that one pro for this approach is that it does not alter the source module. That as it stands it's not necessary -for the source module to be immutable, because it is created for emitting and then discarded. - */ -class HLSLIntrinsicSet -{ -public: - typedef HLSLIntrinsic::Op Op; - - /* Note that calculating an intrinsic, the types will be added to the type set. That might mean subsequent code will - emit those types being required, which may not be the case */ - - void calcIntrinsic(Op op, IRType* returnType, IRType*const* args, Index argsCount, HLSLIntrinsic& out); - void calcIntrinsic(Op op, IRInst* inst, Index argsCount, HLSLIntrinsic& out); - void calcIntrinsic(Op op, IRType* returnType, IRUse* args, Index argCount, HLSLIntrinsic& out); - void calcIntrinsic(Op op, IRInst* inst, HLSLIntrinsic& out) { calcIntrinsic(op, inst, Index(inst->getOperandCount()), out); } - - SlangResult makeIntrinsic(IRInst* inst, HLSLIntrinsic& out); - - HLSLIntrinsic* add(const HLSLIntrinsic& intrinsic); - - /// Returns the intrinsic constructed if there is one from the inst. If not possible to construct returns nullptr. - HLSLIntrinsic* add(IRInst* inst); - - void getIntrinsics(List<const HLSLIntrinsic*>& out) const; - - HLSLIntrinsicSet(IRTypeSet* typeSet, HLSLIntrinsicOpLookup* lookup); - -protected: - // All calcs must go through this choke point for some special case handling. - // NOTE that this function must only be called with unique types (ie from the m_typeSet) - void _calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgs, Index argsCount, HLSLIntrinsic& out); - - List<HLSLIntrinsic*> m_intrinsicsList; - Dictionary<HLSLIntrinsicRef, HLSLIntrinsic*> m_intrinsicsDict; - - FreeList m_intrinsicFreeList; ///< the storage for the intrinsics when they are in the map - - HLSLIntrinsicOpLookup* m_opLookup; - IRTypeSet* m_typeSet; -}; } // namespace Slang diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp index aba59e1de..1473bc466 100644 --- a/source/slang/slang-ir-address-analysis.cpp +++ b/source/slang/slang-ir-address-analysis.cpp @@ -79,9 +79,8 @@ namespace Slang // Deduplicate and move known address insts. for (auto block : func->getBlocks()) { - for (auto inst = block->getFirstChild(); inst;) + for (auto inst : block->getModifiableChildren()) { - auto next = inst->getNextInst(); switch (inst->getOp()) { case kIROp_Var: @@ -151,7 +150,6 @@ namespace Slang } break; } - inst = next; } } diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index b5d3dba10..1f599a344 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -170,40 +170,36 @@ InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRIns { SLANG_ASSERT(origLogic->getOperandCount() == 2); - // TODO: Check other boolean cases. - if (as<IRBoolType>(origLogic->getDataType())) - { - // Boolean operations are not differentiable. For the linearization - // pass, we do not need to do anything but copy them over to the ne - // function. - auto primalLogic = maybeCloneForPrimalInst(builder, origLogic); - return InstPair(primalLogic, nullptr); - } - - SLANG_UNEXPECTED("Logical operation with non-boolean result"); + // Boolean operations are not differentiable. For the linearization + // pass, we do not need to do anything but copy them over to the ne + // function. + auto primalLogic = maybeCloneForPrimalInst(builder, origLogic); + return InstPair(primalLogic, nullptr); } InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { auto origPtr = origLoad->getPtr(); auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr); - auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType(); - - if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType)) + auto primalPtrType = as<IRPtrTypeBase>(primalPtr->getFullType()); + if (primalPtrType) { - // Special case load from an `out` param, which will not have corresponding `diff` and - // `primal` insts yet. - - // TODO: Could we move this load to _after_ DifferentialPairGetPrimal, - // and DifferentialPairGetDifferential? - // - auto load = builder->emitLoad(primalPtr); - builder->markInstAsMixedDifferential(load, diffPairType); + if (auto diffPairType = as<IRDifferentialPairType>(primalPtrType->getValueType())) + { + // Special case load from an `out` param, which will not have corresponding `diff` and + // `primal` insts yet. - auto primalElement = builder->emitDifferentialPairGetPrimal(load); - auto diffElement = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); - return InstPair(primalElement, diffElement); + // TODO: Could we move this load to _after_ DifferentialPairGetPrimal, + // and DifferentialPairGetDifferential? + // + auto load = builder->emitLoad(primalPtr); + builder->markInstAsMixedDifferential(load, diffPairType); + + auto primalElement = builder->emitDifferentialPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); + return InstPair(primalElement, diffElement); + } } auto primalLoad = maybeCloneForPrimalInst(builder, origLoad); @@ -492,7 +488,6 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (!diffReturnType) { - SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); diffReturnType = argBuilder.getVoidType(); } @@ -1364,6 +1359,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_Or: case kIROp_Geq: case kIROp_Leq: + case kIROp_Eql: + case kIROp_Neq: return transcribeBinaryLogic(builder, origInst); case kIROp_CastIntToFloat: @@ -1452,7 +1449,27 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_undefined: return transcribeUndefined(builder, origInst); + case kIROp_Not: + case kIROp_BitAnd: + case kIROp_BitNot: + case kIROp_BitXor: + case kIROp_BitCast: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_IRem: + case kIROp_ByteAddressBufferLoad: + case kIROp_ByteAddressBufferStore: + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferStore: + case kIROp_Reinterpret: + case kIROp_IsType: + case kIROp_ImageSubscript: + case kIROp_ImageLoad: + case kIROp_ImageStore: case kIROp_CreateExistentialObject: + case kIROp_PackAnyValue: + case kIROp_UnpackAnyValue: + case kIROp_GetNativePtr: // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, // so we treat this inst as non differentiable. // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index d83ff57e4..d10a9349d 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -1256,10 +1256,8 @@ struct DiffUnzipPass diffBuilder.setInsertInto(diffBlock); List<IRInst*> splitInsts; - for (auto child = block->getFirstChild(); child;) + for (auto child : block->getModifiableChildren()) { - IRInst* nextChild = child->getNextInst(); - if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(child)) { // Replace GetDiff(A) with A.d @@ -1267,7 +1265,6 @@ struct DiffUnzipPass { getDiffInst->replaceUsesWith(lookupDiffInst(getDiffInst->getBase())); getDiffInst->removeAndDeallocate(); - child = nextChild; continue; } } @@ -1278,7 +1275,6 @@ struct DiffUnzipPass { getPrimalInst->replaceUsesWith(lookupPrimalInst(getPrimalInst->getBase())); getPrimalInst->removeAndDeallocate(); - child = nextChild; continue; } } @@ -1296,8 +1292,6 @@ struct DiffUnzipPass { child->insertAtEnd(primalBlock); } - - child = nextChild; } // Remove insts that were split. diff --git a/source/slang/slang-ir-byte-address-legalize.cpp b/source/slang/slang-ir-byte-address-legalize.cpp index 3a8d1852a..721efadaf 100644 --- a/source/slang/slang-ir-byte-address-legalize.cpp +++ b/source/slang/slang-ir-byte-address-legalize.cpp @@ -66,11 +66,8 @@ struct ByteAddressBufferLegalizationContext break; } - - IRInst* nextChild = nullptr; - for( IRInst* child = inst->getFirstChild(); child; child = nextChild ) + for( IRInst* child : inst->getModifiableChildren()) { - nextChild = child->getNextInst(); processInstRec(child); } } diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index dbeb1e934..8b8b28f09 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -72,29 +72,29 @@ IRInst* cloneInstAndOperands( auto oldType = oldInst->getFullType(); auto newType = (IRType*) findCloneForOperand(env, oldType); - // Next we will create an empty shell of the instruction, - // with space for the operands, but no actual operand - // values attached. - // - UInt operandCount = oldInst->getOperandCount(); - auto newInst = builder->emitIntrinsicInst( - newType, - oldInst->getOp(), - operandCount, - nullptr); - - // Finally we will iterate over the operands of `oldInst` + // Next we will iterate over the operands of `oldInst` // to find their replacements and install them as // the operands of `newInst`. // - for(UInt ii = 0; ii < operandCount; ++ii) + UInt operandCount = oldInst->getOperandCount(); + + ShortList<IRInst*> newOperands; + newOperands.setCount(operandCount); + for (UInt ii = 0; ii < operandCount; ++ii) { auto oldOperand = oldInst->getOperand(ii); auto newOperand = findCloneForOperand(env, oldOperand); - newInst->getOperands()[ii].init(newInst, newOperand); + newOperands[ii] = newOperand; } + // Finally we create the inst with the updated operands. + auto newInst = builder->emitIntrinsicInst( + newType, + oldInst->getOp(), + operandCount, + newOperands.getArrayView().getBuffer()); + newInst->sourceLoc = oldInst->sourceLoc; return newInst; diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index ca5e56b53..ad0dfda91 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -192,7 +192,8 @@ struct CollectGlobalUniformParametersContext // per-field layout information to reference the key we created // instead of the existing parameter (which we will be removing). // - fieldLayoutAttr->setOperand(0, fieldKey); + fieldLayoutAttr = as<IRStructFieldLayoutAttr>( + builder->replaceOperand(fieldLayoutAttr->getOperands(), fieldKey)); // If the given parameter doesn't contribute to uniform/ordinary usage, then // we can safely leave it at the global scope and potentially avoid a lot @@ -266,7 +267,7 @@ struct CollectGlobalUniformParametersContext // if(auto layoutAttr = as<IRStructFieldLayoutAttr>(user)) { - layoutAttr->setOperand(0, fieldKey); + builder->replaceOperand(layoutAttr->getOperands(), fieldKey); continue; } diff --git a/source/slang/slang-ir-com-interface.cpp b/source/slang/slang-ir-com-interface.cpp index 3e52054cd..0684cc8e6 100644 --- a/source/slang/slang-ir-com-interface.cpp +++ b/source/slang/slang-ir-com-interface.cpp @@ -105,7 +105,7 @@ void lowerComInterfaces(IRModule* module, ArtifactStyle artifactStyle, Diagnosti for (auto use : uses) { // Do the replacement - use->set(result); + builder.replaceOperand(use, result); } } } diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 05c10b317..251b473e0 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -237,14 +237,16 @@ struct DeadCodeEliminationContext // might still be dead. // // The biggest wrinkle is that we walk the linked list of - // children/decorations a bit carefully, using a temporary - // to hold the next node, in case we eliminate one of - // the children as we go. + // children/decorations a bit carefully, because eliminating one inst + // may cause the other nodes to be hoisted out of the current scope. + // We need to cache all children in a work list to ensure they are + // properly traversed. // - IRInst* next = nullptr; - for( IRInst* child = inst->getFirstDecorationOrChild(); child; child = next ) + List<IRInst*> children; + for (auto child : inst->getDecorationsAndChildren()) + children.add(child); + for(IRInst* child : children) { - next = child->getNextInst(); changed |= eliminateDeadInstsRec(child); } } diff --git a/source/slang/slang-ir-deduplicate.cpp b/source/slang/slang-ir-deduplicate.cpp index 51a677627..74efc3cb3 100644 --- a/source/slang/slang-ir-deduplicate.cpp +++ b/source/slang/slang-ir-deduplicate.cpp @@ -2,116 +2,84 @@ namespace Slang { - struct DeduplicateContext + void SharedIRBuilder::deduplicateAndRebuildGlobalNumberingMap() { - SharedIRBuilder* builder; - IRInst* addValue(IRInst* value) - { - if (!value) return nullptr; - if (as<IRType>(value)) - return addTypeValue(value); - if (auto constValue = as<IRConstant>(value)) - return addConstantValue(constValue); - return value; - } - IRInst* addConstantValue(IRConstant* value) - { - IRConstantKey key = { value }; - value->setFullType((IRType*)addValue(value->getFullType())); - if (auto newValue = builder->getConstantMap().TryGetValue(key)) - return *newValue; - builder->getConstantMap()[key] = value; - return value; - } - IRInst* addTypeValue(IRInst* value) - { - // Do not deduplicate struct or interface types. - switch (value->getOp()) - { - case kIROp_StructType: - case kIROp_InterfaceType: - return value; - default: - break; - } + } - for (UInt i = 0; i < value->getOperandCount(); i++) - { - value->setOperand(i, addValue(value->getOperand(i))); - } - value->setFullType((IRType*)addValue(value->getFullType())); - IRInstKey key = { value }; - if (auto newValue = builder->getGlobalValueNumberingMap().TryGetValue(key)) - return *newValue; - builder->getGlobalValueNumberingMap()[key] = value; - return value; - } - }; - void SharedIRBuilder::deduplicateAndRebuildGlobalNumberingMap() + void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst) + { + oldInst->replaceUsesWith(newInst); + } + + void SharedIRBuilder::removeHoistableInstFromGlobalNumberingMap(IRInst* instToRemove) { - DeduplicateContext context; - context.builder = this; - m_constantMap.Clear(); - m_globalValueNumberingMap.Clear(); - List<IRInst*> instToRemove; - for (auto inst : m_module->getGlobalInsts()) + HashSet<IRInst*> userWorkListSet; + List<IRInst*> userWorkList; + auto addToWorkList = [&](IRInst* i) { - if (auto constVal = as<IRConstant>(inst)) - { - auto newConst = context.addConstantValue(constVal); - if (newConst != constVal) - { - constVal->replaceUsesWith(newConst); - instToRemove.add(constVal); - } - } - } - for (auto inst : m_module->getGlobalInsts()) + if (userWorkListSet.Add(i)) + userWorkList.add(i); + }; + addToWorkList(instToRemove); + for (Index i = 0; i < userWorkList.getCount(); i++) { - if (as<IRType>(inst) || as<IRSpecialize>(inst)) + auto inst = userWorkList[i]; + if (getIROpInfo(inst->getOp()).isHoistable()) { - auto newInst = context.addTypeValue(inst); - if (newInst != inst) + _removeGlobalNumberingEntry(inst); + for (auto use = inst->firstUse; use; use = use->nextUse) { - inst->replaceUsesWith(newInst); - instToRemove.add(inst); + addToWorkList(use->getUser()); } } } - for (auto inst : instToRemove) - inst->removeAndDeallocate(); } - void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst) + void addHoistableInst( + IRBuilder* builder, + IRInst* inst); + + void SharedIRBuilder::tryHoistInst(IRInst* inst) { - List<IRUse*> uses; - for (auto use = oldInst->firstUse; use; use = use->nextUse) - { - uses.add(use); - } + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + workList.add(inst); + workListSet.Add(inst); + IRBuilder builder(inst->getModule()); - bool shouldUpdateGlobalNumberedCache = false; - for (auto use : uses) + for (Index i = 0; i < workList.getCount(); i++) { - use->set(newInst); - // depending on the type of the user inst, we may need to rebuild and update the global - // numbering cache. - if (isGloballyNumberedInst(use->getUser())) + auto item = workList[i]; + + // Does inst no longer depend on anything defined locally? + // If so we should hoist it. + bool shouldHoist = false; + for (UInt a = 0; a < item->getOperandCount(); a++) { - shouldUpdateGlobalNumberedCache = true; + auto opParent = item->getOperand(a)->getParent(); + if (opParent != item->getParent()) + { + shouldHoist = true; + break; + } } - } - oldInst->removeAndDeallocate(); - if (shouldUpdateGlobalNumberedCache) - { - deduplicateAndRebuildGlobalNumberingMap(); - } - } - bool SharedIRBuilder::isGloballyNumberedInst(IRInst* inst) - { - if (!inst->getParent() || inst->getParent()->getOp() != kIROp_Module) - return false; - return m_globalValueNumberingMap.ContainsKey(IRInstKey{inst}); + // Hoisting this inst + if (shouldHoist) + { + item->removeFromParent(); + addHoistableInst(&builder, item); + + // Continue to consider all users for hoisting. + for (auto use = item->firstUse; use; use = use->nextUse) + { + if (getIROpInfo(use->getUser()->getOp()).isHoistable()) + { + if (workListSet.Add(use->getUser())) + workList.add(use->getUser()); + } + } + } + } } } diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 0dcd437fe..55d120228 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -1791,7 +1791,7 @@ void legalizeMeshOutputParam( // the writes may only be writing to parts of the output struct, or may not // be writes at all (i.e. being passed as an out paramter). // - traverseUses(g, [&](IRInst* u) + traverseUsers(g, [&](IRInst* u) { auto l = as<IRLoad>(u); SLANG_EXPECT(l, "Mesh Output sentinel parameter wasn't used in a load"); @@ -1811,7 +1811,7 @@ void legalizeMeshOutputParam( return; } // Otherwise, go through the uses one by one and see what we can do - traverseUses(a, [&](IRInst* s) + traverseUsers(a, [&](IRInst* s) { IRBuilderInsertLocScope locScope{builder}; builder->setInsertBefore(s); @@ -2022,7 +2022,7 @@ void legalizeMeshOutputParam( for(auto builtin : builtins) { - traverseUses(builtin.param, [&](IRInst* u) + traverseUsers(builtin.param, [&](IRInst* u) { auto p = as<IRGetElementPtr>(u); SLANG_EXPECT(p, "Mesh Output sentinel parameter wasn't used as an array"); diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 7fc977170..643acdbb8 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -53,10 +53,8 @@ struct InliningPassBase // so that even if `child` gets removed (because of inlining) // we automatically start at the next instruction after it. // - IRInst* next = nullptr; - for( auto child = inst->getFirstChild(); child; child = next ) + for (auto child : inst->getModifiableChildren()) { - next = child->getNextInst(); changed |= considerAllCallSitesRec(child); } return changed; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 788e02c90..35877d680 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -10,6 +10,8 @@ #define PARENT kIROpFlag_Parent #define USE_OTHER kIROpFlag_UseOther +#define HOISTABLE kIROpFlag_Hoistable +#define GLOBAL kIROpFlag_Global INST(Nop, nop, 0, 0) @@ -17,7 +19,7 @@ INST(Nop, nop, 0, 0) /* Basic Types */ - #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, 0) + #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, HOISTABLE) FOREACH_BASE_TYPE(DEFINE_BASE_TYPE_INST) #undef DEFINE_BASE_TYPE_INST INST(AfterBaseType, afterBaseType, 0, 0) @@ -25,42 +27,42 @@ INST(Nop, nop, 0, 0) INST_RANGE(BasicType, VoidType, AfterBaseType) /* StringTypeBase */ - INST(StringType, String, 0, 0) - INST(NativeStringType, NativeString, 0, 0) + INST(StringType, String, 0, HOISTABLE) + INST(NativeStringType, NativeString, 0, HOISTABLE) INST_RANGE(StringTypeBase, StringType, NativeStringType) - INST(CapabilitySetType, CapabilitySet, 0, 0) + INST(CapabilitySetType, CapabilitySet, 0, HOISTABLE) - INST(DynamicType, DynamicType, 0, 0) + INST(DynamicType, DynamicType, 0, HOISTABLE) - INST(AnyValueType, AnyValueType, 1, 0) + INST(AnyValueType, AnyValueType, 1, HOISTABLE) - INST(RawPointerType, RawPointerType, 0, 0) - INST(RTTIPointerType, RTTIPointerType, 1, 0) + INST(RawPointerType, RawPointerType, 0, HOISTABLE) + INST(RTTIPointerType, RTTIPointerType, 1, HOISTABLE) INST(AfterRawPointerTypeBase, AfterRawPointerTypeBase, 0, 0) INST_RANGE(RawPointerTypeBase, RawPointerType, AfterRawPointerTypeBase) /* ArrayTypeBase */ - INST(ArrayType, Array, 2, 0) - INST(UnsizedArrayType, UnsizedArray, 1, 0) + INST(ArrayType, Array, 2, HOISTABLE) + INST(UnsizedArrayType, UnsizedArray, 1, HOISTABLE) INST_RANGE(ArrayTypeBase, ArrayType, UnsizedArrayType) - INST(FuncType, Func, 0, 0) - INST(BasicBlockType, BasicBlock, 0, 0) + INST(FuncType, Func, 0, HOISTABLE) + INST(BasicBlockType, BasicBlock, 0, HOISTABLE) - INST(VectorType, Vec, 2, 0) - INST(MatrixType, Mat, 3, 0) + INST(VectorType, Vec, 2, HOISTABLE) + INST(MatrixType, Mat, 3, HOISTABLE) - INST(TaggedUnionType, TaggedUnion, 0, 0) + INST(TaggedUnionType, TaggedUnion, 0, HOISTABLE) - INST(ConjunctionType, Conjunction, 0, 0) - INST(AttributedType, Attributed, 0, 0) - INST(ResultType, Result, 2, 0) - INST(OptionalType, Optional, 1, 0) + INST(ConjunctionType, Conjunction, 0, HOISTABLE) + INST(AttributedType, Attributed, 0, HOISTABLE) + INST(ResultType, Result, 2, HOISTABLE) + INST(OptionalType, Optional, 1, HOISTABLE) - INST(DifferentialPairType, DiffPair, 1, 0) - INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, 0) + INST(DifferentialPairType, DiffPair, 1, HOISTABLE) + INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE) /* BindExistentialsTypeBase */ @@ -70,58 +72,58 @@ INST(Nop, nop, 0, 0) // where each `Ti, wi` pair represents the concrete type // and witness table to plug in for parameter `i`. // - INST(BindExistentialsType, BindExistentials, 1, 0) + INST(BindExistentialsType, BindExistentials, 1, HOISTABLE) // An `BindInterface<B, T0, w0>` represents the special case // of a `BindExistentials` where the type `B` is known to be // an interface type. // - INST(BoundInterfaceType, BoundInterface, 3, 0) + INST(BoundInterfaceType, BoundInterface, 3, HOISTABLE) INST_RANGE(BindExistentialsTypeBase, BindExistentialsType, BoundInterfaceType) /* Rate */ - INST(ConstExprRate, ConstExpr, 0, 0) - INST(GroupSharedRate, GroupShared, 0, 0) - INST(ActualGlobalRate, ActualGlobalRate, 0, 0) + INST(ConstExprRate, ConstExpr, 0, HOISTABLE) + INST(GroupSharedRate, GroupShared, 0, HOISTABLE) + INST(ActualGlobalRate, ActualGlobalRate, 0, HOISTABLE) INST_RANGE(Rate, ConstExprRate, GroupSharedRate) - INST(RateQualifiedType, RateQualified, 2, 0) + INST(RateQualifiedType, RateQualified, 2, HOISTABLE) // Kinds represent the "types of types." // They should not really be nested under `IRType` // in the overall hierarchy, but we can fix that later. // /* Kind */ - INST(TypeKind, Type, 0, 0) - INST(RateKind, Rate, 0, 0) - INST(GenericKind, Generic, 0, 0) + INST(TypeKind, Type, 0, HOISTABLE) + INST(RateKind, Rate, 0, HOISTABLE) + INST(GenericKind, Generic, 0, HOISTABLE) INST_RANGE(Kind, TypeKind, GenericKind) /* PtrTypeBase */ - INST(PtrType, Ptr, 1, 0) - INST(RefType, Ref, 1, 0) + INST(PtrType, Ptr, 1, HOISTABLE) + INST(RefType, Ref, 1, HOISTABLE) // A `PsuedoPtr<T>` logically represents a pointer to a value of type // `T` on a platform that cannot support pointers. The expectation // is that the "pointer" will be legalized away by storing a value // of type `T` somewhere out-of-line. - INST(PseudoPtrType, PseudoPtr, 1, 0) + INST(PseudoPtrType, PseudoPtr, 1, HOISTABLE) /* OutTypeBase */ - INST(OutType, Out, 1, 0) - INST(InOutType, InOut, 1, 0) + INST(OutType, Out, 1, HOISTABLE) + INST(InOutType, InOut, 1, HOISTABLE) INST_RANGE(OutTypeBase, OutType, InOutType) INST_RANGE(PtrTypeBase, PtrType, InOutType) // A ComPtr<T> type is treated as a opaque type that represents a reference-counted handle to a COM object. - INST(ComPtrType, ComPtr, 1, 0) + INST(ComPtrType, ComPtr, 1, HOISTABLE) // A NativePtr<T> type represents a native pointer to a managed resource. - INST(NativePtrType, NativePtr, 1, 0) + INST(NativePtrType, NativePtr, 1, HOISTABLE) /* SamplerStateTypeBase */ - INST(SamplerStateType, SamplerState, 0, 0) - INST(SamplerComparisonStateType, SamplerComparisonState, 0, 0) + INST(SamplerStateType, SamplerState, 0, HOISTABLE) + INST(SamplerComparisonStateType, SamplerComparisonState, 0, HOISTABLE) INST_RANGE(SamplerStateTypeBase, SamplerStateType, SamplerComparisonStateType) // TODO: Why do we have all this hierarchy here, when everything @@ -131,11 +133,11 @@ INST(Nop, nop, 0, 0) /* TextureTypeBase */ // NOTE! TextureFlavor::Flavor is stored in 'other' bits for these types. /* TextureType */ - INST(TextureType, TextureType, 0, USE_OTHER) + INST(TextureType, TextureType, 0, USE_OTHER | HOISTABLE) /* TextureSamplerType */ - INST(TextureSamplerType, TextureSamplerType, 0, USE_OTHER) + INST(TextureSamplerType, TextureSamplerType, 0, USE_OTHER | HOISTABLE) /* GLSLImageType */ - INST(GLSLImageType, GLSLImageType, 0, USE_OTHER) + INST(GLSLImageType, GLSLImageType, 0, USE_OTHER | HOISTABLE) INST_RANGE(TextureTypeBase, TextureType, GLSLImageType) INST_RANGE(ResourceType, TextureType, GLSLImageType) INST_RANGE(ResourceTypeBase, TextureType, GLSLImageType) @@ -143,53 +145,53 @@ INST(Nop, nop, 0, 0) /* UntypedBufferResourceType */ /* ByteAddressBufferTypeBase */ - INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, 0) - INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, 0) - INST(HLSLRasterizerOrderedByteAddressBufferType, RasterizerOrderedByteAddressBuffer, 0, 0) + INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, HOISTABLE) + INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, HOISTABLE) + INST(HLSLRasterizerOrderedByteAddressBufferType, RasterizerOrderedByteAddressBuffer, 0, HOISTABLE) INST_RANGE(ByteAddressBufferTypeBase, HLSLByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType) - INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, 0) + INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, HOISTABLE) INST_RANGE(UntypedBufferResourceType, HLSLByteAddressBufferType, RaytracingAccelerationStructureType) /* HLSLPatchType */ - INST(HLSLInputPatchType, InputPatch, 2, 0) - INST(HLSLOutputPatchType, OutputPatch, 2, 0) + INST(HLSLInputPatchType, InputPatch, 2, HOISTABLE) + INST(HLSLOutputPatchType, OutputPatch, 2, HOISTABLE) INST_RANGE(HLSLPatchType, HLSLInputPatchType, HLSLOutputPatchType) - INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, 0) + INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, HOISTABLE) /* BuiltinGenericType */ /* HLSLStreamOutputType */ - INST(HLSLPointStreamType, PointStream, 1, 0) - INST(HLSLLineStreamType, LineStream, 1, 0) - INST(HLSLTriangleStreamType, TriangleStream, 1, 0) + INST(HLSLPointStreamType, PointStream, 1, HOISTABLE) + INST(HLSLLineStreamType, LineStream, 1, HOISTABLE) + INST(HLSLTriangleStreamType, TriangleStream, 1, HOISTABLE) INST_RANGE(HLSLStreamOutputType, HLSLPointStreamType, HLSLTriangleStreamType) /* MeshOutputType */ - INST(VerticesType, Vertices, 2, 0) - INST(IndicesType, Indices, 2, 0) - INST(PrimitivesType, Primitives, 2, 0) + INST(VerticesType, Vertices, 2, HOISTABLE) + INST(IndicesType, Indices, 2, HOISTABLE) + INST(PrimitivesType, Primitives, 2, HOISTABLE) INST_RANGE(MeshOutputType, VerticesType, PrimitivesType) /* HLSLStructuredBufferTypeBase */ - INST(HLSLStructuredBufferType, StructuredBuffer, 0, 0) - INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, 0) - INST(HLSLRasterizerOrderedStructuredBufferType, RasterizerOrderedStructuredBuffer, 0, 0) - INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, 0) - INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, 0) + INST(HLSLStructuredBufferType, StructuredBuffer, 0, HOISTABLE) + INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, HOISTABLE) + INST(HLSLRasterizerOrderedStructuredBufferType, RasterizerOrderedStructuredBuffer, 0, HOISTABLE) + INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, HOISTABLE) + INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, HOISTABLE) INST_RANGE(HLSLStructuredBufferTypeBase, HLSLStructuredBufferType, HLSLConsumeStructuredBufferType) /* PointerLikeType */ /* ParameterGroupType */ /* UniformParameterGroupType */ - INST(ConstantBufferType, ConstantBuffer, 1, 0) - INST(TextureBufferType, TextureBuffer, 1, 0) - INST(ParameterBlockType, ParameterBlock, 1, 0) - INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, 0) + INST(ConstantBufferType, ConstantBuffer, 1, HOISTABLE) + INST(TextureBufferType, TextureBuffer, 1, HOISTABLE) + INST(ParameterBlockType, ParameterBlock, 1, HOISTABLE) + INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, HOISTABLE) INST_RANGE(UniformParameterGroupType, ConstantBufferType, GLSLShaderStorageBufferType) /* VaryingParameterGroupType */ - INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, 0) - INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, 0) + INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, HOISTABLE) + INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, HOISTABLE) INST_RANGE(VaryingParameterGroupType, GLSLInputParameterGroupType, GLSLOutputParameterGroupType) INST_RANGE(ParameterGroupType, ConstantBufferType, GLSLOutputParameterGroupType) INST_RANGE(PointerLikeType, ConstantBufferType, GLSLOutputParameterGroupType) @@ -209,28 +211,28 @@ INST(Nop, nop, 0, 0) // INST(StructType, struct, 0, PARENT) INST(ClassType, class, 0, PARENT) -INST(InterfaceType, interface, 0, 0) -INST(AssociatedType, associated_type, 0, 0) -INST(ThisType, this_type, 0, 0) -INST(RTTIType, rtti_type, 0, 0) -INST(RTTIHandleType, rtti_handle_type, 0, 0) -INST(TupleType, tuple_type, 0, 0) +INST(InterfaceType, interface, 0, GLOBAL) +INST(AssociatedType, associated_type, 0, HOISTABLE) +INST(ThisType, this_type, 0, HOISTABLE) +INST(RTTIType, rtti_type, 0, HOISTABLE) +INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE) +INST(TupleType, tuple_type, 0, HOISTABLE) // A type that identifies it's contained type as being emittable as `spirv_literal. -INST(SPIRVLiteralType, spirvLiteralType, 1, 0) +INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE) // A TypeType-typed IRValue represents a IRType. // It is used to represent a type parameter/argument in a generics. -INST(TypeType, type_t, 0, 0) +INST(TypeType, type_t, 0, HOISTABLE) /*IRWitnessTableTypeBase*/ // An `IRWitnessTable` has type `WitnessTableType`. - INST(WitnessTableType, witness_table_t, 1, 0) + INST(WitnessTableType, witness_table_t, 1, HOISTABLE) // An integer type representing a witness table for targets where // witness tables are represented as integer IDs. This type is used // during the lower-generics pass while generating dynamic dispatch // code and will eventually lower into an uint type. - INST(WitnessTableIDType, witness_table_id_t, 1, 0) + INST(WitnessTableIDType, witness_table_id_t, 1, HOISTABLE) INST_RANGE(WitnessTableTypeBase, WitnessTableType, WitnessTableIDType) INST_RANGE(Type, VoidType, WitnessTableIDType) @@ -240,14 +242,14 @@ INST_RANGE(Type, VoidType, WitnessTableIDType) INST(Generic, generic, 0, PARENT) INST_RANGE(GlobalValueWithParams, Func, Generic) - INST(GlobalVar, global_var, 0, 0) + INST(GlobalVar, global_var, 0, GLOBAL) INST_RANGE(GlobalValueWithCode, Func, GlobalVar) -INST(GlobalParam, global_param, 0, 0) -INST(GlobalConstant, globalConstant, 0, 0) +INST(GlobalParam, global_param, 0, GLOBAL) +INST(GlobalConstant, globalConstant, 0, GLOBAL) -INST(StructKey, key, 0, 0) -INST(GlobalGenericParam, global_generic_param, 0, 0) +INST(StructKey, key, 0, GLOBAL) +INST(GlobalGenericParam, global_generic_param, 0, GLOBAL) INST(WitnessTable, witness_table, 0, 0) INST(GlobalHashedStringLiterals, global_hashed_string_literals, 0, 0) @@ -265,7 +267,7 @@ INST(Block, block, 0, PARENT) INST(VoidLit, void_constant, 0, 0) INST_RANGE(Constant, BoolLit, VoidLit) -INST(CapabilitySet, capabilitySet, 0, 0) +INST(CapabilitySet, capabilitySet, 0, HOISTABLE) INST(undefined, undefined, 0, 0) @@ -279,10 +281,9 @@ INST(MakeDifferentialPair, MakeDiffPair, 2, 0) INST(DifferentialPairGetDifferential, GetDifferential, 1, 0) INST(DifferentialPairGetPrimal, GetPrimal, 1, 0) -INST(Specialize, specialize, 2, 0) -INST(LookupWitness, lookupWitness, 2, 0) +INST(Specialize, specialize, 2, HOISTABLE) +INST(LookupWitness, lookupWitness, 2, HOISTABLE) INST(GetSequentialID, GetSequentialID, 1, 0) -INST(lookup_witness_table, lookup_witness_table, 2, 0) INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) INST(AllocObj, allocObj, 0, 0) @@ -317,7 +318,7 @@ INST(PackAnyValue, packAnyValue, 1, 0) INST(UnpackAnyValue, unpackAnyValue, 1, 0) INST(WitnessTableEntry, witness_table_entry, 2, 0) -INST(InterfaceRequirementEntry, interface_req_entry, 2, 0) +INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL) INST(Param, param, 0, 0) INST(StructField, field, 2, 0) @@ -558,8 +559,6 @@ INST(BitNot, bitnot, 1, 0) INST(Select, select, 3, 0) -INST(Dot, dot, 2, 0) - INST(GetStringHash, getStringHash, 1, 0) INST(WaveGetActiveMask, waveGetActiveMask, 0, 0) @@ -880,40 +879,40 @@ INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) /* Layout */ - INST(VarLayout, varLayout, 1, 0) + INST(VarLayout, varLayout, 1, HOISTABLE) /* TypeLayout */ - INST(TypeLayoutBase, typeLayout, 0, 0) - INST(ParameterGroupTypeLayout, parameterGroupTypeLayout, 2, 0) - INST(ArrayTypeLayout, arrayTypeLayout, 1, 0) - INST(StreamOutputTypeLayout, streamOutputTypeLayout, 1, 0) - INST(MatrixTypeLayout, matrixTypeLayout, 1, 0) - INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, 0) - INST(ExistentialTypeLayout, existentialTypeLayout, 0, 0) - INST(StructTypeLayout, structTypeLayout, 0, 0) + INST(TypeLayoutBase, typeLayout, 0, HOISTABLE) + INST(ParameterGroupTypeLayout, parameterGroupTypeLayout, 2, HOISTABLE) + INST(ArrayTypeLayout, arrayTypeLayout, 1, HOISTABLE) + INST(StreamOutputTypeLayout, streamOutputTypeLayout, 1, HOISTABLE) + INST(MatrixTypeLayout, matrixTypeLayout, 1, HOISTABLE) + INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, HOISTABLE) + INST(ExistentialTypeLayout, existentialTypeLayout, 0, HOISTABLE) + INST(StructTypeLayout, structTypeLayout, 0, HOISTABLE) INST_RANGE(TypeLayout, TypeLayoutBase, StructTypeLayout) - INST(EntryPointLayout, EntryPointLayout, 1, 0) + INST(EntryPointLayout, EntryPointLayout, 1, HOISTABLE) INST_RANGE(Layout, VarLayout, EntryPointLayout) /* Attr */ - INST(PendingLayoutAttr, pendingLayout, 1, 0) - INST(StageAttr, stage, 1, 0) - INST(StructFieldLayoutAttr, fieldLayout, 2, 0) - INST(CaseTypeLayoutAttr, caseLayout, 1, 0) - INST(UNormAttr, unorm, 0, 0) - INST(SNormAttr, snorm, 0, 0) - INST(NoDiffAttr, no_diff, 0, 0) + INST(PendingLayoutAttr, pendingLayout, 1, HOISTABLE) + INST(StageAttr, stage, 1, HOISTABLE) + INST(StructFieldLayoutAttr, fieldLayout, 2, HOISTABLE) + INST(CaseTypeLayoutAttr, caseLayout, 1, HOISTABLE) + INST(UNormAttr, unorm, 0, HOISTABLE) + INST(SNormAttr, snorm, 0, HOISTABLE) + INST(NoDiffAttr, no_diff, 0, HOISTABLE) /* SemanticAttr */ - INST(UserSemanticAttr, userSemantic, 2, 0) - INST(SystemValueSemanticAttr, systemValueSemantic, 2, 0) + INST(UserSemanticAttr, userSemantic, 2, HOISTABLE) + INST(SystemValueSemanticAttr, systemValueSemantic, 2, HOISTABLE) INST_RANGE(SemanticAttr, UserSemanticAttr, SystemValueSemanticAttr) /* LayoutResourceInfoAttr */ - INST(TypeSizeAttr, size, 2, 0) - INST(VarOffsetAttr, offset, 2, 0) + INST(TypeSizeAttr, size, 2, HOISTABLE) + INST(VarOffsetAttr, offset, 2, HOISTABLE) INST_RANGE(LayoutResourceInfoAttr, TypeSizeAttr, VarOffsetAttr) - INST(FuncThrowTypeAttr, FuncThrowType, 1, 0) + INST(FuncThrowTypeAttr, FuncThrowType, 1, HOISTABLE) INST_RANGE(Attr, PendingLayoutAttr, FuncThrowTypeAttr) /* Liveness */ diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 7bc711f97..7a2e1f0e2 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2436,106 +2436,37 @@ struct IRLiveRangeEnd : IRLiveRangeMarker IR_LEAF_ISA(LiveRangeEnd); }; -// Description of an instruction to be used for global value numbering -struct IRInstKey -{ - IRInst* inst; - - HashCode getHashCode(); -}; - -bool operator==(IRInstKey const& left, IRInstKey const& right); - -struct IRConstantKey -{ - IRConstant* inst; - - bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); } - HashCode getHashCode() const { return inst->getHashCode(); } -}; - -struct SharedIRBuilder -{ -public: - SharedIRBuilder() - {} - - explicit SharedIRBuilder(IRModule* module) - { - init(module); - } - - void init(IRModule* module) - { - m_module = module; - m_session = module->getSession(); - - m_globalValueNumberingMap.Clear(); - m_constantMap.Clear(); - } - - IRModule* getModule() - { - return m_module; - } - - Session* getSession() - { - return m_session; - } - - void insertBlockAlongEdge(IREdge const& edge); - - // Rebuilds `globalValueNumberingMap`. This is necessary if any existing - // keys are modified (thus its hash code is changed). - void deduplicateAndRebuildGlobalNumberingMap(); - - // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement. - void replaceGlobalInst(IRInst* oldInst, IRInst* newInst); - - typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap; - typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap; - - GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; } - ConstantMap& getConstantMap() { return m_constantMap; } - - bool isGloballyNumberedInst(IRInst* inst); - -private: - // The module that will own all of the IR - IRModule* m_module; - - // The parent compilation session - Session* m_session; - - GlobalValueNumberingMap m_globalValueNumberingMap; - ConstantMap m_constantMap; -}; - struct IRBuilderSourceLocRAII; struct IRBuilder { private: - /// Shared state for all IR builders working on the same module - SharedIRBuilder* m_sharedBuilder = nullptr; + /// Shared state for all IR builders working on the same module + SharedIRBuilder* m_sharedBuilder = nullptr; - /// Default location for inserting new instructions as they are emitted + IRModule* m_module = nullptr; + + /// Default location for inserting new instructions as they are emitted IRInsertLoc m_insertLoc; - /// Information that controls how source locations are associatd with instructions that get emitted + /// Information that controls how source locations are associatd with instructions that get emitted IRBuilderSourceLocRAII* m_sourceLocInfo = nullptr; public: IRBuilder() {} + explicit IRBuilder(IRModule* module) + : m_module(module) + , m_sharedBuilder(module->getSharedBuilder()) + {} + explicit IRBuilder(SharedIRBuilder* sharedBuilder) - : m_sharedBuilder(sharedBuilder) + : IRBuilder(sharedBuilder->getModule()) {} explicit IRBuilder(SharedIRBuilder& sharedBuilder) - : m_sharedBuilder(&sharedBuilder) + : IRBuilder(sharedBuilder.getModule()) {} void init(SharedIRBuilder* sharedBuilder) @@ -2550,17 +2481,17 @@ public: SharedIRBuilder* getSharedBuilder() const { - return m_sharedBuilder; + return m_module->getSharedBuilder(); } Session* getSession() const { - return m_sharedBuilder->getSession(); + return m_module->getSession(); } IRModule* getModule() const { - return m_sharedBuilder->getModule(); + return m_module; } IRInsertLoc const& getInsertLoc() const { return m_insertLoc; } @@ -2597,6 +2528,18 @@ public: IRConstant* _findOrEmitConstant( IRConstant& keyInst); + /// Implements a special case of inst creation (intended only for calling from `_createInst`) + /// that returns an matching existing hoistable inst if it exists, otherwise it creates the inst and + /// add it to the global numbering map. + IRInst* _findOrEmitHoistableInst( + IRType* type, + IROp op, + Int fixedArgCount, + IRInst* const* fixedArgs, + Int varArgListCount, + Int const* listArgCounts, + IRInst* const* const* listArgs); + /// Create a new instruction with the given `type` and `op`, with an allocated /// size of at least `minSizeInBytes`, and with its operand list initialized /// from the provided lists of "fixed" and "variable" operands. @@ -2615,7 +2558,8 @@ public: /// size. /// /// Note: This is an extremely low-level operation and clients of an `IRBuilder` - /// should not be using it when other options are available. + /// should not be using it when other options are available. This is also where + /// all insts creation are bottlenecked through. /// IRInst* _createInst( size_t minSizeInBytes, @@ -2654,6 +2598,12 @@ public: void addInst(IRInst* inst); + // Replace the operand of a potentially hoistable inst. + // If the hoistable inst become duplicate of an existing inst, + // all uses of the original user will be replaced with the existing inst. + // The function returns the new user after any potential updates. + IRInst* replaceOperand(IRUse* use, IRInst* newValue); + IRInst* getBoolValue(bool value); IRInst* getIntValue(IRType* type, IRIntegerValue value); IRInst* getFloatValue(IRType* type, IRFloatingPointValue value); @@ -2918,6 +2868,20 @@ public: UInt argCount, IRInst* const* args); + IRInst* createIntrinsicInst( + IRType* type, + IROp op, + IRInst* operand, + UInt operandCount, + IRInst* const* operands); + + IRInst* createIntrinsicInst( + IRType* type, + IROp op, + UInt operandListCount, + UInt const* listOperandCounts, + IRInst* const* const* listOperands); + IRInst* emitIntrinsicInst( IRType* type, IROp op, @@ -3001,6 +2965,10 @@ public: UInt argCount, IRInst* const* args); + IRInst* emitMakeMatrixFromScalar( + IRType* type, + IRInst* scalarValue); + IRInst* emitMakeArray( IRType* type, UInt argCount, @@ -3066,31 +3034,6 @@ public: IRInst* emitReinterpret(IRInst* type, IRInst* value); - IRInst* findOrAddInst( - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands); - - IRInst* findOrEmitHoistableInst( - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands); - IRInst* findOrEmitHoistableInst( - IRType* type, - IROp op, - UInt operandCount, - IRInst* const* operands); - IRInst* findOrEmitHoistableInst( - IRType* type, - IROp op, - IRInst* operand, - UInt operandCount, - IRInst* const* operands); - IRFunc* createFunc(); IRGlobalVar* createGlobalVar( IRType* valueType); @@ -3841,10 +3784,6 @@ public: } }; -void addHoistableInst( - IRBuilder* builder, - IRInst* inst); - // Helper to establish the source location that will be used // by an IRBuilder. struct IRBuilderSourceLocRAII diff --git a/source/slang/slang-ir-legalize-mesh-outputs.cpp b/source/slang/slang-ir-legalize-mesh-outputs.cpp index 7c6d256ab..db4d74ddb 100644 --- a/source/slang/slang-ir-legalize-mesh-outputs.cpp +++ b/source/slang/slang-ir-legalize-mesh-outputs.cpp @@ -25,7 +25,7 @@ void legalizeMeshOutputTypes(IRModule* module) : as<IRPrimitivesType>(meshOutput) ? kIROp_PrimitivesDecoration : (SLANG_UNREACHABLE("Missing case for IRMeshOutputType"), IROp(0)); // Ensure that all params are marked up as vertices/indices/primitives - traverseUses<IRParam>(meshOutput, [&](IRParam* i) + traverseUsers<IRParam>(meshOutput, [&](IRParam* i) { builder.addMeshOutputDecoration(decorationOp, i, maxCount); }); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 38503155d..d916fa691 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1861,14 +1861,27 @@ static LegalVal legalizeInst( // While the operands are all "simple," they might not necessarily // be equal to the operands we started with. // + ShortList<IRInst*> newArgs; + newArgs.setCount(argCount); + bool recreate = false; for (UInt aa = 0; aa < argCount; ++aa) { auto legalArg = legalArgs[aa]; - inst->setOperand(aa, legalArg.getSimple()); + newArgs[aa] = legalArg.getSimple(); + if (newArgs[aa] != inst->getOperand(aa)) + recreate = true; + } + if (recreate) + { + IRBuilder builder(inst->getModule()); + builder.setInsertBefore(inst); + auto newInst = builder.emitIntrinsicInst(legalType.getSimple(), inst->getOp(), argCount, newArgs.getArrayView().getBuffer()); + inst->replaceUsesWith(newInst); + inst->removeFromParent(); + context->replacedInstructions.add(inst); + return LegalVal::simple(newInst); } - inst->setFullType(legalType.getSimple()); - return LegalVal::simple(inst); } @@ -1888,6 +1901,10 @@ static LegalVal legalizeInst( legalType, legalArgs.getBuffer()); + if (legalVal.flavor == LegalVal::Flavor::simple) + { + inst->replaceUsesWith(legalVal.getSimple()); + } // After we are done, we will eliminate the // original instruction by removing it from // the IR. diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 80f974536..55048484f 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -229,11 +229,14 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) switch (originalValue->getOp()) { case kIROp_StructType: + case kIROp_ClassType: case kIROp_Func: case kIROp_Generic: case kIROp_GlobalVar: case kIROp_GlobalParam: + case kIROp_GlobalConstant: case kIROp_StructKey: + case kIROp_InterfaceRequirementEntry: case kIROp_GlobalGenericParam: case kIROp_WitnessTable: case kIROp_InterfaceType: @@ -277,26 +280,34 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) } break; + case kIROp_VoidLit: + { + return builder->getVoidValue(); + } + break; + default: { // In the default case, assume that we have some sort of "hoistable" // instruction that requires us to create a clone of it. UInt argCount = originalValue->getOperandCount(); - IRInst* clonedValue = builder->createIntrinsicInst( - cloneType(this, originalValue->getFullType()), - originalValue->getOp(), - argCount, nullptr); - registerClonedValue(this, clonedValue, originalValue); + ShortList<IRInst*> newArgs; + newArgs.setCount(argCount); for (UInt aa = 0; aa < argCount; ++aa) { IRInst* originalArg = originalValue->getOperand(aa); IRInst* clonedArg = cloneValue(this, originalArg); - clonedValue->getOperands()[aa].init(clonedValue, clonedArg); + newArgs[aa] = clonedArg; } + IRInst* clonedValue = builder->createIntrinsicInst( + cloneType(this, originalValue->getFullType()), + originalValue->getOp(), + argCount, newArgs.getArrayView().getBuffer()); + registerClonedValue(this, clonedValue, originalValue); + cloneDecorationsAndChildren(this, clonedValue, originalValue); - - addHoistableInst(builder, clonedValue); + builder->addInst(clonedValue); return clonedValue; } @@ -524,6 +535,8 @@ IRGlobalConstant* cloneGlobalConstantImpl( IRGlobalConstant* originalVal, IROriginalValuesForClone const& originalValues) { + auto oldBuilder = context->builder; + context->builder = builder; auto clonedType = cloneType(context, originalVal->getFullType()); IRGlobalConstant* clonedVal = nullptr; if(auto originalInitVal = originalVal->getValue()) @@ -537,7 +550,7 @@ IRGlobalConstant* cloneGlobalConstantImpl( } cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); - + context->builder = oldBuilder; return clonedVal; } @@ -1174,21 +1187,24 @@ IRInst* cloneInst( // instruction with the right number of operands, intialize // it, and then add it to the sequence. UInt argCount = originalInst->getOperandCount(); - IRInst* clonedInst = builder->createIntrinsicInst( - cloneType(context, originalInst->getFullType()), - originalInst->getOp(), - argCount, nullptr); - registerClonedValue(context, clonedInst, originalValues); + ShortList<IRInst*> newArgs; + newArgs.setCount(argCount); auto oldBuilder = context->builder; context->builder = builder; for (UInt aa = 0; aa < argCount; ++aa) { IRInst* originalArg = originalInst->getOperand(aa); IRInst* clonedArg = cloneValue(context, originalArg); - clonedInst->getOperands()[aa].init(clonedInst, clonedArg); + newArgs[aa] = clonedArg; } - builder->addInst(clonedInst); context->builder = oldBuilder; + + IRInst* clonedInst = builder->createIntrinsicInst( + cloneType(context, originalInst->getFullType()), + originalInst->getOp(), + argCount, newArgs.getArrayView().getBuffer()); + builder->addInst(clonedInst); + registerClonedValue(context, clonedInst, originalValues); cloneDecorationsAndChildren(context, clonedInst, originalInst); cloneExtraDecorations(context, clonedInst, originalValues); return clonedInst; diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index 6f412d579..f2d7159d4 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -56,25 +56,51 @@ namespace Slang lowerGenericFuncType(&builder, genericParent, cast<IRFuncType>(func->getFullType())); SLANG_ASSERT(loweredGenericType); loweredFunc->setFullType(loweredGenericType); - List<IRInst*> clonedParams; + List<IRInst*> childrenToDemote; + List<IRInst*> clonedParams; for (auto genericChild : genericParent->getFirstBlock()->getChildren()) { - if (genericChild == func) + switch (genericChild->getOp()) + { + case kIROp_Func: continue; - if (genericChild->getOp() == kIROp_Return) + case kIROp_Return: continue; + } // Process all generic parameters and local type definitions. auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); - if (clonedChild->getOp() == kIROp_Param) + switch (clonedChild->getOp()) { - auto paramType = clonedChild->getFullType(); - auto loweredParamType = sharedContext->lowerType(&builder, paramType); - if (loweredParamType != paramType) + case kIROp_Param: { - clonedChild->setFullType((IRType*)loweredParamType); + auto paramType = clonedChild->getFullType(); + auto loweredParamType = sharedContext->lowerType(&builder, paramType); + if (loweredParamType != paramType) + { + clonedChild->setFullType((IRType*)loweredParamType); + } + clonedParams.add(clonedChild); + } + break; + + case kIROp_LookupWitness: + case kIROp_Specialize: + { + childrenToDemote.add(clonedChild); + // Make sure all uses are from the function body. + for (auto use = genericChild->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getParent() == genericChild->getParent()) + { + // This specialize/lookup is used as operand to some other + // global inst in the generic. This is not supported now. + SLANG_UNIMPLEMENTED_X( + "Unsupported use of specialize/lookupWitness in generic body."); + } + } + continue; } - clonedParams.add(clonedChild); } } cloneInstDecorationsAndChildren(&cloneEnv, &sharedContext->sharedBuilderStorage, func, loweredFunc); @@ -85,6 +111,15 @@ namespace Slang param->removeFromParent(); block->addParam(as<IRParam>(param)); } + + // Demote specialize and lookupWitness insts and their dependents down to function body. + auto insertPoint = block->getFirstOrdinaryInst(); + for (Index i = childrenToDemote.getCount() - 1; i >= 0; i--) + { + auto child = childrenToDemote[i]; + child->insertBefore(insertPoint); + } + // Lower generic typed parameters into AnyValueType. auto firstInst = loweredFunc->getFirstOrdinaryInst(); builder.setInsertBefore(firstInst); @@ -292,7 +327,8 @@ namespace Slang loweredFunc = lowerGenericFunction(funcToSpecialize); if (loweredFunc != funcToSpecialize) { - specializeInst->setOperand(0, loweredFunc); + IRBuilder builder; + builder.replaceOperand(specializeInst->getOperands(), loweredFunc); } } } diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index 176142601..f3996fc01 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -38,8 +38,6 @@ struct RedundancyRemovalContext case kIROp_GetElement: case kIROp_GetElementPtr: case kIROp_UpdateElement: - case kIROp_LookupWitness: - case kIROp_Specialize: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: case kIROp_MakeOptionalValue: diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp index 5e5f61a4a..67d95c59f 100644 --- a/source/slang/slang-ir-simplify-for-emit.cpp +++ b/source/slang/slang-ir-simplify-for-emit.cpp @@ -5,12 +5,16 @@ namespace Slang { +bool isCPUTarget(TargetRequest* targetReq); +bool isCUDATarget(TargetRequest* targetReq); + struct SimplifyForEmitContext : public InstPassBase { - SimplifyForEmitContext(IRModule* inModule) - : InstPassBase(inModule) + SimplifyForEmitContext(IRModule* inModule, TargetRequest* inTargetReq) + : InstPassBase(inModule), targetReq(inTargetReq) {} + TargetRequest* targetReq; List<IRInst*> followUpWorkList; HashSet<IRInst*> followUpWorkListSet; @@ -134,7 +138,7 @@ struct SimplifyForEmitContext : public InstPassBase IRBuilder builder(sharedBuilderStorage); builder.setInsertBefore(user); auto newLoad = builder.emitLoad(load->getPtr()); - use->set(newLoad); + builder.replaceOperand(use, newLoad); } void processLoad(IRLoad* inst) @@ -330,8 +334,115 @@ struct SimplifyForEmitContext : public InstPassBase processInst(followUpWorkList[i]); } + void unifyBinaryExprOperands(IRGlobalValueWithCode* func) + { + IRBuilder builder(func->getModule()); + + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Leq: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Greater: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Lsh: + case kIROp_Rsh: + builder.setInsertBefore(inst); + SLANG_ASSERT(inst->getOperandCount() == 2); + if (as<IRVectorType>(inst->getDataType())) + { + for (UInt a = 0; a < 2; a++) + { + if (as<IRBasicType>(inst->getOperand(a)->getDataType())) + { + auto v = builder.emitMakeVectorFromScalar( + inst->getOperand(1 - a)->getDataType(), inst->getOperand(a)); + inst->setOperand(a, v); + } + } + } + else if (as<IRMatrixType>(inst->getDataType())) + { + for (UInt a = 0; a < 2; a++) + { + if (as<IRBasicType>(inst->getOperand(a)->getDataType())) + { + auto v = builder.emitMakeMatrixFromScalar( + inst->getOperand(1 - a)->getDataType(), inst->getOperand(a)); + inst->setOperand(a, v); + } + } + } + + break; + } + } + } + } + + // Turn single element vector values into scalars before using it to call an intrinsic func. + void lowerTrivialVector(IRGlobalValueWithCode* func) + { + IRBuilder builder(func->getModule()); + List<IRInst*> instsToProcess; + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_Call: + { + // If we are calling an intrinsic with any vector<T,1> argument, replace it with T. + auto callInst = as<IRCall>(inst); + if (getResolvedInstForDecorations(callInst->getCallee())->findDecoration<IRTargetIntrinsicDecoration>()) + { + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + if (auto argVectorType = as<IRVectorType>(arg->getDataType())) + { + if (cast<IRIntLit>(argVectorType->getElementCount())->getValue() == 1) + { + builder.setInsertBefore(callInst); + UInt idx = 0; + auto newArg = builder.emitSwizzle(argVectorType->getElementType(), arg, 1, &idx); + callInst->setOperand(a + 1, newArg); + } + } + } + } + } + break; + } + } + } + } + + void processFunc(IRGlobalValueWithCode* func) { + if (isCPUTarget(targetReq) || isCUDATarget(targetReq)) + { + unifyBinaryExprOperands(func); + lowerTrivialVector(func); + } eliminateCompositeConstruct(func); deferAndDuplicateElementExtract(func); deferAndDuplicateLoad(func); @@ -345,9 +456,9 @@ struct SimplifyForEmitContext : public InstPassBase } }; -void simplifyForEmit(IRModule* module) +void simplifyForEmit(IRModule* module, TargetRequest* targetRequest) { - SimplifyForEmitContext context(module); + SimplifyForEmitContext context(module, targetRequest); context.processModule(); } diff --git a/source/slang/slang-ir-simplify-for-emit.h b/source/slang/slang-ir-simplify-for-emit.h index a6cf3bad8..e35c74841 100644 --- a/source/slang/slang-ir-simplify-for-emit.h +++ b/source/slang/slang-ir-simplify-for-emit.h @@ -4,6 +4,7 @@ namespace Slang { struct IRModule; + class TargetRequest; - void simplifyForEmit(IRModule* inModule); + void simplifyForEmit(IRModule* inModule, TargetRequest* req); } diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index 39edaeb16..cfc9d9c76 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -200,23 +200,22 @@ struct AssociatedTypeLookupSpecializationContext if (!seqId) return; // Insert code to pack sequential ID into an uint2 at all use sites. - IRUse* nextUse = nullptr; - for (auto use = inst->firstUse; use; use = nextUse) + traverseUses(inst, [&](IRUse* use) { - nextUse = use->nextUse; if (as<IRCOMWitnessDecoration>(use->getUser())) - continue; + { + return; + } IRBuilder builder(sharedContext->sharedBuilderStorage); builder.setInsertBefore(use->getUser()); auto uint2Type = builder.getVectorType( builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); IRInst* uint2Args[] = { seqId->getSequentialIDOperand(), - builder.getIntValue(builder.getUIntType(), 0)}; + builder.getIntValue(builder.getUIntType(), 0) }; auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args); - use->set(uint2seqID); - use = nextUse; - } + builder.replaceOperand(use, uint2seqID); + }); } }); @@ -229,14 +228,12 @@ struct AssociatedTypeLookupSpecializationContext builder.setInsertBefore(globalInst); auto witnessTableIDType = builder.getWitnessTableIDType( (IRType*)cast<IRWitnessTableType>(globalInst)->getConformanceType()); - IRUse* nextUse = nullptr; - for (auto use = globalInst->firstUse; use; use = nextUse) + traverseUses(globalInst, [&](IRUse* use) { - nextUse = use->nextUse; if (use->getUser()->getOp() == kIROp_WitnessTable) - continue; - use->set(witnessTableIDType); - } + return; + builder.replaceOperand(use, witnessTableIDType); + }); sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); } } diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index e4ccf40d5..03eda0d99 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -256,16 +256,16 @@ struct ResourceOutputSpecializationPass // the aid of this pass. // List<IRCall*> calls; - for( auto use = oldFunc->firstUse; use; use = use->nextUse ) - { - auto user = use->getUser(); - auto call = as<IRCall>(user); - if(!call) - continue; - if(call->getCallee() != oldFunc) - continue; - calls.add(call); - } + traverseUses(oldFunc, [&](IRUse* use) + { + auto user = use->getUser(); + auto call = as<IRCall>(user); + if (!call) + return; + if (call->getCallee() != oldFunc) + return; + calls.add(call); + }); // Once we have identified the calls to `oldFunc`, we will set about replacing // them with calls to `newFunc`. @@ -833,16 +833,16 @@ struct ResourceOutputSpecializationPass // `out`/`inout` parameters that doesn't have as many "gotcha" cases. // List<IRStore*> stores; - for( auto use = param->firstUse; use; use = use->nextUse ) - { - auto user = use->getUser(); - auto store = as<IRStore>(user); - if(!store) - continue; - if(store->ptr.get() != param) - continue; - stores.add(store); - } + traverseUses(param, [&](IRUse* use) + { + auto user = use->getUser(); + auto store = as<IRStore>(user); + if (!store) + return; + if (store->ptr.get() != param) + return; + stores.add(store); + }); // Having identified the places where a value is stored to // the output parameter, we iterate over those values to @@ -1194,16 +1194,16 @@ bool specializeResourceUsage( // Inline unspecializable resource output functions and then continue trying. for (auto func : unspecializableFuncs) { - for (auto use = func->firstUse; use; use = use->nextUse) + traverseUses(func, [&](IRUse* use) { auto user = use->getUser(); auto call = as<IRCall>(user); if (!call) - continue; + return; if (call->getCallee() != func) - continue; + return; inlineCall(call); - } + }); } simplifyIR(irModule); } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index cf7acd46c..0044e5745 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -897,7 +897,8 @@ struct SpecializationContext // specialization opportunities (generic specialization, // existential specialization, simplifications, etc.) // - iterChanged |= maybeSpecializeInst(inst); + if (inst->hasUses() || inst->mightHaveSideEffects()) + iterChanged |= maybeSpecializeInst(inst); // Finally, we need to make our logic recurse through // the whole IR module, so we want to add the children @@ -1041,7 +1042,6 @@ struct SpecializationContext // The old callee should be in the form of `specialize(.operator[], IInterfaceType)`, // we should update it to be `specialize(.operator[], elementType)`, so the return type // of the load call is `elementType`. - auto oldCallee = inst->getCallee(); // A subscript operation on mutable buffers returns a ptr type instead of a value type. // We need to make sure the pointer-ness is preserved correctly. @@ -1057,9 +1057,6 @@ struct SpecializationContext inst->replaceUsesWith(newWrapExistential); workList.Remove(inst); inst->removeAndDeallocate(); - SLANG_ASSERT(!oldCallee->hasUses()); - workList.Remove(oldCallee); - oldCallee->removeAndDeallocate(); addUsersToWorkList(newWrapExistential); workList.Remove(wrapExistential); diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 3f250e31e..b195af2cc 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -923,6 +923,15 @@ IRBlock* IREdge::getSuccessor() const return cast<IRBlock>(getUse()->get()); } +void SharedIRBuilder::init(IRModule* module) +{ + m_module = module; + m_session = module->getSession(); + + m_globalValueNumberingMap.Clear(); + m_constantMap.Clear(); +} + void SharedIRBuilder::insertBlockAlongEdge( IREdge const& edge) { diff --git a/source/slang/slang-ir-type-set.cpp b/source/slang/slang-ir-type-set.cpp index 0cfe69e42..7ac617bda 100644 --- a/source/slang/slang-ir-type-set.cpp +++ b/source/slang/slang-ir-type-set.cpp @@ -7,313 +7,4 @@ namespace Slang { -IRTypeSet::IRTypeSet(Session* session) -{ - m_module = IRModule::create(session); - - m_sharedBuilder.init(m_module); - m_builder.init(m_sharedBuilder); - - m_builder.setInsertInto(m_module->getModuleInst()); -} - -IRTypeSet::~IRTypeSet() -{ - _clearTypes(); -} - -void IRTypeSet::clear() -{ - _clearTypes(); - - m_cloneMap.Clear(); - - m_module = IRModule::create(m_sharedBuilder.getSession()); - - m_sharedBuilder.init(m_module); - m_builder.init(m_sharedBuilder); - - m_builder.setInsertInto(m_module->getModuleInst()); -} - -void IRTypeSet::_clearTypes() -{ - List<IRType*> types; - getTypes(types); - - for (auto type : types) - { - // We need to destroy references to instructions in other modules - if (type->getModule() == m_module) - { - // We want to remove arguments because an argument *could* be an instruction in another module, - // and we don't want to those modules insts to have uses, in this module which is being destroyed - type->removeArguments(); - } - } -} - -IRInst* IRTypeSet::cloneInst(IRInst* inst) -{ - if (inst == nullptr) - { - return nullptr; - } - - // See if it's already cloned - if (IRInst*const* newInstPtr = m_cloneMap.TryGetValue(inst)) - { - return *newInstPtr; - } - - IRModule* module = inst->getModule(); - // All inst's must belong to a module - SLANG_ASSERT(module); - - // If it's in this module then we don't need to clone - if (module == m_module) - { - return inst; - } - - if (isNominalOp(inst->getOp())) - { - // We can clone without any definition, and add the linkage - - // TODO(JS) - // This is arguably problematic - I'm adding an instruction from another module to the map, to be it's self. - // I did have code which created a copy of the nominal instruction and name hint, but because nominality means - // 'same address' other code would generate a different name for that instruction (say as compared to being a member in - // the original instruction) - // - // Because I use findOrAddInst which doesn't hoist instructions, the hoisting doesn't rely on parenting, that would - // break. - - // If nominal, we just use the original inst - m_cloneMap.Add(inst, inst); - return inst; - } - - // It would be nice if I could use ir-clone.cpp to do this -> but it doesn't clone - // operands. We wouldn't want to clone decorations, and it can't clone IRConstant(!) so - // it's no use - - IRInst* clone = nullptr; - switch (inst->getOp()) - { - case kIROp_IntLit: - { - auto intLit = static_cast<IRConstant*>(inst); - IRType* clonedType = cloneType(intLit->getDataType()); - clone = m_builder.getIntValue(clonedType, intLit->value.intVal); - break; - } - case kIROp_StringLit: - { - auto stringLit = static_cast<IRStringLit*>(inst); - clone = m_builder.getStringValue(stringLit->getStringSlice()); - break; - } - case kIROp_VectorType: - { - auto vecType = static_cast<IRVectorType*>(inst); - const Index elementCount = Index(getIntVal(vecType->getElementCount())); - - if (elementCount <= 1) - { - clone = cloneType(vecType->getElementType()); - } - break; - } - case kIROp_MatrixType: - { - auto matType = static_cast<IRMatrixType*>(inst); - const Index columnCount = Index(getIntVal(matType->getColumnCount())); - const Index rowCount = Index(getIntVal(matType->getRowCount())); - - if (columnCount <= 1 && rowCount <= 1) - { - clone = cloneType(matType->getElementType()); - } - break; - } - default: break; - } - - if (!clone) - { - if (IRBasicType::isaImpl(inst->getOp())) - { - clone = m_builder.getType(inst->getOp()); - } - else - { - IRType* irType = dynamicCast<IRType>(inst); - if (irType) - { - auto clonedType = cloneType(inst->getFullType()); - Index operandCount = Index(inst->getOperandCount()); - - List<IRInst*> cloneOperands; - cloneOperands.setCount(operandCount); - - for (Index i = 0; i < operandCount; ++i) - { - cloneOperands[i] = cloneInst(inst->getOperand(i)); - } - - //clone = m_irBuilder.findOrEmitHoistableInst(cloneType, inst->op, operandCount, cloneOperands.getBuffer()); - - UInt operandCounts[1] = { UInt(operandCount) }; - IRInst*const* listOperands[1] = { cloneOperands.getBuffer() }; - - clone = m_builder.findOrAddInst(clonedType, inst->getOp(), 1, operandCounts, listOperands); - } - else - { - // This cloning style only works on insts that are not unique - auto clonedType = cloneType(inst->getFullType()); - - Index operandCount = Index(inst->getOperandCount()); - clone = m_builder.emitIntrinsicInst(clonedType, inst->getOp(), operandCount, nullptr); - for (Index i = 0; i < operandCount; ++i) - { - auto cloneOperand = cloneInst(inst->getOperand(i)); - clone->getOperands()[i].init(clone, cloneOperand); - } - } - } - } - - m_cloneMap.Add(inst, clone); - return clone; -} - -IRType* IRTypeSet::add(IRType* irType) -{ - if (irType->getModule() == m_module) - { - return irType; - } - // We need to clone the type - return cloneType(irType); -} - -void IRTypeSet::getTypes(List<IRType*>& outTypes) const -{ - outTypes.clear(); - for (auto inst : m_module->getModuleInst()->getChildren()) - { - if (IRType* type = as<IRType>(inst)) - { - outTypes.add(type); - } - } -} - -void IRTypeSet::getTypes(Kind kind, List<IRType*>& outTypes) const -{ - outTypes.clear(); - - for (auto inst : m_module->getModuleInst()->getChildren()) - { - IRType* type = nullptr; - - switch (kind) - { - case Kind::Scalar: - { - type = as<IRBasicType>(inst); - break; - } - case Kind::Vector: - { - type = as<IRVectorType>(inst); - break; - } - case Kind::Matrix: - { - type = as<IRMatrixType>(inst); - break; - } - default: break; - } - - if (type) - { - outTypes.add(type); - } - } -} - -IRType* IRTypeSet::addVectorType(IRType* inElementType, int colsCount) -{ - IRType* elementType = cloneType(inElementType); - if (colsCount == 1) - { - return elementType; - } - return m_builder.getVectorType(elementType, m_builder.getIntValue(m_builder.getIntType(), colsCount)); -} - -void IRTypeSet::addVectorForMatrixTypes() -{ - // Make a copy so we can alter m_types dictionary - List<IRType*> types; - getTypes(Kind::Matrix, types); - for (IRType* type : types) - { - SLANG_ASSERT(as<IRMatrixType>(type)); - IRMatrixType* matType = static_cast<IRMatrixType*>(type); - m_builder.getVectorType(matType->getElementType(), matType->getColumnCount()); - } -} - -static bool _hasNominalOperand(IRInst* inst) -{ - const Index operandCount = Index(inst->getOperandCount()); - auto operands = inst->getOperands(); - - for (Index i = 0; i < operandCount; ++i) - { - IRInst* operand = operands[i].get(); - if (isNominalOp(operand->getOp())) - { - return true; - } - } - - return false; -} - -void IRTypeSet::_addAllBuiltinTypesRec(IRInst* inst) -{ - for (IRInst* child = inst->getFirstDecorationOrChild(); child; child = child->getNextInst()) - { - IRType* type = nullptr; - - if (auto vectorType = as<IRVectorType>(child)) - { - type = vectorType; - } - else if (auto matrixType = as<IRMatrixType>(child)) - { - type = matrixType; - } - if (type && !_hasNominalOperand(type)) - { - add(type); - } - else - { - _addAllBuiltinTypesRec(child); - } - } -} - -void IRTypeSet::addAllBuiltinTypes(IRModule* module) -{ - _addAllBuiltinTypesRec(module->getModuleInst()); -} - } diff --git a/source/slang/slang-ir-type-set.h b/source/slang/slang-ir-type-set.h index 958d71cf1..f60088fcd 100644 --- a/source/slang/slang-ir-type-set.h +++ b/source/slang/slang-ir-type-set.h @@ -9,85 +9,4 @@ namespace Slang { -/* -NOTE! This type set is only designed to work for emitting code to determine unique types. It is envisaged in the -future that it will not be needed because types will be made unique within a module, and thus the pointer to a type -will uniquely identify the type. - -The other reason this type exists, is to allow an IRModule for emit to be immutable. That is not currently possible -within emit code because it may be necessary in order to emit to be able to create other types that needed (for example -vector types required for a matrix type implementation). - -This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsic. -That we want to have a pointer to a type be unique, and slang supports this through the m_sharedIRBuilder. BUT for this to -work all work on the module must use the same sharedIRBuilder, and that appears to not be the case in terms -of other passes. -Even if it was the case when we may want to add types as part of emitting, we can't use the previously used -shared builder, so again we end up with pointers to the same things not being the same thing. - -To work around this we clone types we want to use as keys into the 'unique module'. -This is not necessary for all types though - as we assume nominal types *must* have unique pointers (that is the -definition of nominal). - -This could be handled in other ways (for example not testing equality on pointer equality). Anyway for now this -works, but probably needs to be handled in a better way. The better way may involve having guarantees about equality -enabled in other code generation and making de-duping possible in emit code. - -Note that one pro for this approach is that it does not alter the source module. That as it stands it's not necessary -for the source module to be immutable, because it is created for emitting and then discarded. - -NOTE! That Vector<X, 1> or Matrix<X, 1, 1> will be turned into the type X. - - */ -class IRTypeSet -{ -public: - enum class Kind - { - Scalar, - Vector, - Matrix, - CountOf, - }; - - IRType* add(IRType* type); - IRType* addVectorType(IRType* elementType, int colsCount); - - void addAllBuiltinTypes(IRModule* module); - - void addVectorForMatrixTypes(); - - void getTypes(List<IRType*>& outTypes) const; - void getTypes(Kind kind, List<IRType*>& outTypes) const; - - IRType* getType(IRType* type) { return cloneType(type); } - - IRType* cloneType(IRType* type) { return (IRType*)cloneInst((IRInst*)type); } - IRInst* cloneInst(IRInst* inst); - - /// Returns true if the type belongs and is created on the module owned by the set - bool isOwned(IRType* type) { return type->getModule() == m_module; } - - IRBuilder& getBuilder() { return m_builder; } - IRModule* getModule() const { return m_module; } - - void clear(); - - IRTypeSet(Session* session); - ~IRTypeSet(); - -protected: - void _addAllBuiltinTypesRec(IRInst* inst); - void _clearTypes(); - - // Maps insts from source modules into m_module. - // NOTE! That nominal types are not cloned, as they are identified by pointer. They are just - Dictionary<IRInst*, IRInst*> m_cloneMap; - - // Can find all types by traversing the types in the m_module - SharedIRBuilder m_sharedBuilder; - IRBuilder m_builder; - RefPtr<IRModule> m_module; -}; - } // namespace Slang diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 1ea426715..253686aa5 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -549,6 +549,8 @@ struct GenericChildrenMigrationContextImpl } if (as<IRConstant>(inst)) return false; + if (getIROpInfo(inst->getOp()).isHoistable()) + return false; return true; }); } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index efd38f7b7..2f1ac2d1a 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -49,7 +49,7 @@ struct DeduplicateContext return *newValue; for (UInt i = 0; i < value->getOperandCount(); i++) { - value->setOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate)); + value->unsafeSetOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate)); } value->setFullType((IRType*)deduplicate(value->getFullType(), shouldDeduplicate)); if (auto newValue = deduplicateMap.TryGetValue(key)) diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 03db96ac5..d5c0aa432 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -186,6 +186,28 @@ namespace Slang if (pp == operandParent) return; } + + // We allow out-of-order def-use in global scope. + bool allInGlobalScope = inst->getParent() && inst->getParent()->getOp() == kIROp_Module; + if (allInGlobalScope) + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto op = inst->getOperand(i); + if (!op) + continue; + if (!op->getParent()) + continue; + if (op->getParent()->getOp() != kIROp_Module) + { + allInGlobalScope = false; + break; + } + } + } + if (allInGlobalScope) + return; + // // We failed to find `operandParent` while walking the ancestors of `inst`, // so something had gone wrong. diff --git a/source/slang/slang-ir-wrap-structured-buffers.cpp b/source/slang/slang-ir-wrap-structured-buffers.cpp index 2ad09aa90..53671fa7f 100644 --- a/source/slang/slang-ir-wrap-structured-buffers.cpp +++ b/source/slang/slang-ir-wrap-structured-buffers.cpp @@ -134,7 +134,7 @@ struct WrapStructuredBuffersContext // scanning through its IR uses, since values of that // type are using it as a (type) operand. // - for( auto typeUse = newStructuredBufferType->firstUse; typeUse; typeUse = typeUse->nextUse ) + traverseUses(newStructuredBufferType, [&](IRUse* typeUse) { // There might be uses of `newStructuredBufferType` where // it isn't being used as the type of a value, so we @@ -142,7 +142,7 @@ struct WrapStructuredBuffersContext // auto valueOfStructuredBufferType = typeUse->getUser(); if(valueOfStructuredBufferType->getFullType() != newStructuredBufferType) - continue; + return; // Now we have some `valueOfStructuredBufferType`. In our running // example, this might be `gBuffer`, which is an `IRGlobalParam`. @@ -155,7 +155,7 @@ struct WrapStructuredBuffersContext // because these could be calls to intrinsic functions like // `RWStructuredBuffer.Load` // - for( auto valueUse = valueOfStructuredBufferType->firstUse; valueUse; valueUse = valueUse->nextUse ) + traverseUses(valueOfStructuredBufferType, [&](IRUse* valueUse) { // we are only interested in instructions that are calls, // with at least one argument, where the first argument @@ -165,11 +165,11 @@ struct WrapStructuredBuffersContext // auto call = as<IRCall>(valueUse->getUser()); if(!call) - continue; + return; if(call->getArgCount() == 0) - continue; + return; if(call->getArg(0) != valueOfStructuredBufferType) - continue; + return; // At this point we have a candidate `call` instruction, // but we need to determine whether it is a call to @@ -196,7 +196,7 @@ struct WrapStructuredBuffersContext // auto callee = call->getCallee(); if(!as<IRSpecialize>(callee)) - continue; + return; // At this point it seems likely we have one of the calls // we want to rewrite, but there are still intrinsics @@ -285,8 +285,8 @@ struct WrapStructuredBuffersContext newVal->setOperand(0, call); } } - } - } + }); + }); } /// Get the struture field "key" to use for generated wrappers diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1b16bfe1f..6cf0f09a5 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -148,7 +148,6 @@ namespace Slang void IRUse::init(IRInst* u, IRInst* v) { clear(); - user = u; usedValue = v; if(v) @@ -170,6 +169,9 @@ namespace Slang void IRUse::set(IRInst* uv) { + // Normally we should never be modifying the operand of an hoistable inst. + // They can be modified by `replaceUsesWith`, or to be replaced by a new inst. + SLANG_ASSERT(!getIROpInfo(user->getOp()).isHoistable() || uv == usedValue); init(user, uv); } @@ -1196,11 +1198,57 @@ namespace Slang return as<IRGlobalValueWithCode>(pp); } + void addHoistableInst( + IRBuilder* builder, + IRInst* inst); + // Add an instruction into the current scope void IRBuilder::addInst( IRInst* inst) { - inst->insertAt(m_insertLoc); + if (getIROpInfo(inst->getOp()).isGlobal()) + { + addHoistableInst(this, inst); + return; + } + + if (!inst->parent) + inst->insertAt(m_insertLoc); + } + + IRInst* IRBuilder::replaceOperand(IRUse* use, IRInst* newValue) + { + auto user = use->getUser(); + if (user->getModule()) + { + user->getModule()->getSharedBuilder()->getInstReplacementMap().TryGetValue(newValue, newValue); + } + + if (!getIROpInfo(user->getOp()).isHoistable()) + { + use->set(newValue); + return user; + } + + // If user is hoistable, we need to remove it from the global number map first, + // perform the update, then try to reinsert it back to the global number map. + // If we find an equivalent entry already exists in the global number map, + // we return the existing entry. + auto builder = user->getModule()->getSharedBuilder(); + builder->_removeGlobalNumberingEntry(user); + use->init(user, newValue); + + IRInst* existingVal = nullptr; + if (builder->getGlobalValueNumberingMap().TryGetValue(IRInstKey{ user }, existingVal)) + { + user->replaceUsesWith(existingVal); + return existingVal; + } + else + { + builder->_addGlobalNumberingEntry(user); + return user; + } } // Given two parent instructions, pick the better one to use as as @@ -1645,6 +1693,13 @@ namespace Slang Int const* listArgCounts, IRInst* const* const* listArgs) { + m_sharedBuilder->getInstReplacementMap().TryGetValue((IRInst*)(type), *(IRInst**)&type); + + if (getIROpInfo(op).flags & kIROpFlag_Hoistable) + { + return _findOrEmitHoistableInst(type, op, fixedArgCount, fixedArgs, varArgListCount, listArgCounts, listArgs); + } + Int varArgCount = 0; for (Int ii = 0; ii < varArgListCount; ++ii) { @@ -1671,7 +1726,9 @@ namespace Slang { if (fixedArgs) { - operand->init(inst, fixedArgs[aa]); + auto arg = fixedArgs[aa]; + m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg); + operand->init(inst, arg); } else { @@ -1687,7 +1744,9 @@ namespace Slang { if (listArgs[ii]) { - operand->init(inst, listArgs[ii][jj]); + auto arg = listArgs[ii][jj]; + m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg); + operand->init(inst, arg); } else { @@ -2309,21 +2368,23 @@ namespace Slang args.add(getIntValue(capabilityAtomType, Int(atom))); } - return findOrEmitHoistableInst( + return createIntrinsicInst( capabilitySetType, kIROp_CapabilitySet, args.getCount(), args.getBuffer()); } - IRInst* IRBuilder::findOrEmitHoistableInst( - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands) - { - UInt operandCount = 0; - for (UInt ii = 0; ii < operandListCount; ++ii) + IRInst* IRBuilder::_findOrEmitHoistableInst( + IRType* type, + IROp op, + Int fixedArgCount, + IRInst* const* fixedArgs, + Int varArgListCount, + Int const* listArgCounts, + IRInst* const* const* listArgs) + { + UInt operandCount = fixedArgCount; + for (Int ii = 0; ii < varArgListCount; ++ii) { - operandCount += listOperandCounts[ii]; + operandCount += listArgCounts[ii]; } auto& memoryArena = getModule()->getMemoryArena(); @@ -2350,102 +2411,21 @@ namespace Slang // Don't link up as we may free (if we already have this key) { IRUse* operand = inst->getOperands(); - for (UInt ii = 0; ii < operandListCount; ++ii) + for (Int ii = 0; ii < fixedArgCount; ++ii) { - UInt listOperandCount = listOperandCounts[ii]; - for (UInt jj = 0; jj < listOperandCount; ++jj) - { - operand->usedValue = listOperands[ii][jj]; - operand++; - } - } - } - - // Find or add the key/inst - { - IRInstKey key = { inst }; - - // Ideally we would add if not found, else return if was found instead of testing & then adding. - IRInst** found = getSharedBuilder()->getGlobalValueNumberingMap().TryGetValueOrAdd(key, inst); - SLANG_ASSERT(endCursor == memoryArena.getCursor()); - // If it's found, just return, and throw away the instruction - if (found) - { - memoryArena.rewindToCursor(cursor); - return *found; - } - } - - // Make the lookup 'inst' instruction into 'proper' instruction. Equivalent to - // IRInst* inst = createInstImpl<IRInst>(builder, op, type, 0, nullptr, operandListCount, listOperandCounts, listOperands); - { - if (type) - { - inst->typeUse.usedValue = nullptr; - inst->typeUse.init(inst, type); - } - - _maybeSetSourceLoc(inst); - - IRUse*const operands = inst->getOperands(); - for (UInt i = 0; i < operandCount; ++i) - { - IRUse& operand = operands[i]; - auto value = operand.usedValue; - - operand.usedValue = nullptr; - operand.init(inst, value); + auto arg = fixedArgs[ii]; + m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg); + operand->usedValue = arg; + operand++; } - } - - addHoistableInst(this, inst); - - return inst; - } - - IRInst* IRBuilder::findOrAddInst( - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands) - { - UInt operandCount = 0; - for (UInt ii = 0; ii < operandListCount; ++ii) - { - operandCount += listOperandCounts[ii]; - } - - auto& memoryArena = getModule()->getMemoryArena(); - void* cursor = memoryArena.getCursor(); - - // We are going to create a 'dummy' instruction on the memoryArena - // which can be used as a key for lookup, so see if we - // already have an equivalent instruction available to use. - size_t keySize = sizeof(IRInst) + operandCount * sizeof(IRUse); - IRInst* inst = (IRInst*)memoryArena.allocateAndZero(keySize); - - void* endCursor = memoryArena.getCursor(); - // Mark as 'unused' cos it is unused on release builds. - SLANG_UNUSED(endCursor); - - new(inst) IRInst(); -#if SLANG_ENABLE_IR_BREAK_ALLOC - inst->_debugUID = _debugGetAndIncreaseInstCounter(); -#endif - inst->m_op = op; - inst->typeUse.usedValue = type; - inst->operandCount = (uint32_t)operandCount; - - // Don't link up as we may free (if we already have this key) - { - IRUse* operand = inst->getOperands(); - for (UInt ii = 0; ii < operandListCount; ++ii) + for (Int ii = 0; ii < varArgListCount; ++ii) { - UInt listOperandCount = listOperandCounts[ii]; + UInt listOperandCount = listArgCounts[ii]; for (UInt jj = 0; jj < listOperandCount; ++jj) { - operand->usedValue = listOperands[ii][jj]; + auto arg = listArgs[ii][jj]; + m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg); + operand->usedValue = arg; operand++; } } @@ -2488,50 +2468,17 @@ namespace Slang } } - addInst(inst); - return inst; - } - - - IRInst* IRBuilder::findOrEmitHoistableInst( - IRType* type, - IROp op, - UInt operandCount, - IRInst* const* operands) - { - return findOrEmitHoistableInst( - type, - op, - 1, - &operandCount, - &operands); - } - - IRInst* IRBuilder::findOrEmitHoistableInst( - IRType* type, - IROp op, - IRInst* operand, - UInt operandCount, - IRInst* const* operands) - { - UInt counts[] = { 1, operandCount }; - IRInst* const* lists[] = { &operand, operands }; + addHoistableInst(this, inst); - return findOrEmitHoistableInst( - type, - op, - 2, - counts, - lists); + return inst; } - IRType* IRBuilder::getType( IROp op, UInt operandCount, IRInst* const* operands) { - return (IRType*) findOrEmitHoistableInst( + return (IRType*)createIntrinsicInst( nullptr, op, operandCount, @@ -2831,7 +2778,7 @@ namespace Slang IRType* const* paramTypes, IRType* resultType) { - return (IRFuncType*) findOrEmitHoistableInst( + return (IRFuncType*)createIntrinsicInst( nullptr, kIROp_FuncType, resultType, @@ -2844,13 +2791,13 @@ namespace Slang { UInt counts[3] = {1, paramCount, 1}; IRInst** lists[3] = {(IRInst**)&resultType, (IRInst**)paramTypes, (IRInst**)&attribute}; - return (IRFuncType*)findOrEmitHoistableInst(nullptr, kIROp_FuncType, 3, counts, lists); + return (IRFuncType*)createIntrinsicInst(nullptr, kIROp_FuncType, 3, counts, lists); } IRWitnessTableType* IRBuilder::getWitnessTableType( IRType* baseType) { - return (IRWitnessTableType*)findOrEmitHoistableInst( + return (IRWitnessTableType*)createIntrinsicInst( nullptr, kIROp_WitnessTableType, 1, @@ -2860,7 +2807,7 @@ namespace Slang IRWitnessTableIDType* IRBuilder::getWitnessTableIDType( IRType* baseType) { - return (IRWitnessTableIDType*)findOrEmitHoistableInst( + return (IRWitnessTableIDType*)createIntrinsicInst( nullptr, kIROp_WitnessTableIDType, 1, @@ -2914,7 +2861,7 @@ namespace Slang UInt caseCount, IRType* const* caseTypes) { - return (IRType*) findOrEmitHoistableInst( + return (IRType*)createIntrinsicInst( getTypeKind(), kIROp_TaggedUnionType, caseCount, @@ -2947,7 +2894,7 @@ namespace Slang } } - return (IRType*) findOrEmitHoistableInst( + return (IRType*)createIntrinsicInst( getTypeKind(), kIROp_BindExistentialsType, baseType, @@ -3197,7 +3144,7 @@ namespace Slang if (as<IRWitnessTable>(innerReturnVal)) { - return findOrEmitHoistableInst( + return createIntrinsicInst( type, kIROp_Specialize, genericVal, @@ -3214,7 +3161,8 @@ namespace Slang argCount, args); - addInst(inst); + if (!inst->parent) + addInst(inst); return inst; } @@ -3233,7 +3181,7 @@ namespace Slang IRInst* args[] = {witnessTableVal, interfaceMethodVal}; - return findOrEmitHoistableInst( + return createIntrinsicInst( type, kIROp_LookupWitness, 2, @@ -3331,6 +3279,17 @@ namespace Slang args); } + IRInst* IRBuilder::createIntrinsicInst( + IRType* type, IROp op, IRInst* operand, UInt operandCount, IRInst* const* operands) + { + return createInstWithTrailingArgs<IRInst>(this, op, type, operand, operandCount, operands); + } + + IRInst* IRBuilder::createIntrinsicInst(IRType* type, IROp op, UInt operandListCount, UInt const* listOperandCounts, IRInst* const* const* listOperands) + { + return createInstImpl<IRInst>(this, op, type, 0, nullptr, (Int)operandListCount, (Int const* )listOperandCounts, listOperands); + } + IRInst* IRBuilder::emitIntrinsicInst( IRType* type, @@ -3343,7 +3302,8 @@ namespace Slang op, argCount, args); - addInst(inst); + if (!inst->parent) + addInst(inst); return inst; } @@ -3772,6 +3732,13 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeMatrix, argCount, args); } + IRInst* IRBuilder::emitMakeMatrixFromScalar( + IRType* type, + IRInst* scalarValue) + { + return emitIntrinsicInst(type, kIROp_MakeMatrixFromScalar, 1, &scalarValue); + } + IRInst* IRBuilder::emitMakeArray( IRType* type, UInt argCount, @@ -3938,7 +3905,7 @@ namespace Slang value->insertAtEnd(parent); } } - + IRInst* IRBuilder::addDifferentiableTypeDictionaryDecoration(IRInst* target) { return addDecoration(target, kIROp_DifferentiableTypeDictionaryDecoration); @@ -5056,7 +5023,7 @@ namespace Slang this, kIROp_GlobalConstant, type); - addInst(inst); + addGlobalValue(this, inst); return inst; } @@ -5069,7 +5036,7 @@ namespace Slang kIROp_GlobalConstant, type, val); - addInst(inst); + addGlobalValue(this, inst); return inst; } @@ -5349,7 +5316,7 @@ namespace Slang IRInst* operands[] = { kindInst, sizeInst }; - return cast<IRTypeSizeAttr>(findOrEmitHoistableInst( + return cast<IRTypeSizeAttr>(createIntrinsicInst( getVoidType(), kIROp_TypeSizeAttr, SLANG_COUNT_OF(operands), @@ -5376,7 +5343,7 @@ namespace Slang operands[operandCount++] = spaceInst; } - return cast<IRVarOffsetAttr>(findOrEmitHoistableInst( + return cast<IRVarOffsetAttr>(createIntrinsicInst( getVoidType(), kIROp_VarOffsetAttr, operandCount, @@ -5388,7 +5355,7 @@ namespace Slang { IRInst* operands[] = { pendingLayout }; - return cast<IRPendingLayoutAttr>(findOrEmitHoistableInst( + return cast<IRPendingLayoutAttr>(createIntrinsicInst( getVoidType(), kIROp_PendingLayoutAttr, SLANG_COUNT_OF(operands), @@ -5401,7 +5368,7 @@ namespace Slang { IRInst* operands[] = { key, layout }; - return cast<IRStructFieldLayoutAttr>(findOrEmitHoistableInst( + return cast<IRStructFieldLayoutAttr>(createIntrinsicInst( getVoidType(), kIROp_StructFieldLayoutAttr, SLANG_COUNT_OF(operands), @@ -5413,7 +5380,7 @@ namespace Slang { IRInst* operands[] = { layout }; - return cast<IRCaseTypeLayoutAttr>(findOrEmitHoistableInst( + return cast<IRCaseTypeLayoutAttr>(createIntrinsicInst( getVoidType(), kIROp_CaseTypeLayoutAttr, SLANG_COUNT_OF(operands), @@ -5430,7 +5397,7 @@ namespace Slang IRInst* operands[] = { nameInst, indexInst }; - return cast<IRSemanticAttr>(findOrEmitHoistableInst( + return cast<IRSemanticAttr>(createIntrinsicInst( getVoidType(), op, SLANG_COUNT_OF(operands), @@ -5441,7 +5408,7 @@ namespace Slang { auto stageInst = getIntValue(getIntType(), IRIntegerValue(stage)); IRInst* operands[] = { stageInst }; - return cast<IRStageAttr>(findOrEmitHoistableInst( + return cast<IRStageAttr>(createIntrinsicInst( getVoidType(), kIROp_StageAttr, SLANG_COUNT_OF(operands), @@ -5450,7 +5417,7 @@ namespace Slang IRAttr* IRBuilder::getAttr(IROp op, UInt operandCount, IRInst* const* operands) { - return cast<IRAttr>(findOrEmitHoistableInst( + return cast<IRAttr>(createIntrinsicInst( getVoidType(), op, operandCount, @@ -5461,7 +5428,7 @@ namespace Slang IRTypeLayout* IRBuilder::getTypeLayout(IROp op, List<IRInst*> const& operands) { - return cast<IRTypeLayout>(findOrEmitHoistableInst( + return cast<IRTypeLayout>(createIntrinsicInst( getVoidType(), op, operands.getCount(), @@ -5470,7 +5437,7 @@ namespace Slang IRVarLayout* IRBuilder::getVarLayout(List<IRInst*> const& operands) { - return cast<IRVarLayout>(findOrEmitHoistableInst( + return cast<IRVarLayout>(createIntrinsicInst( getVoidType(), kIROp_VarLayout, operands.getCount(), @@ -5483,7 +5450,7 @@ namespace Slang { IRInst* operands[] = { paramsLayout, resultLayout }; - return cast<IREntryPointLayout>(findOrEmitHoistableInst( + return cast<IREntryPointLayout>(createIntrinsicInst( getVoidType(), kIROp_EntryPointLayout, SLANG_COUNT_OF(operands), @@ -6528,70 +6495,146 @@ namespace Slang void validateIRInstOperands(IRInst*); - void IRInst::replaceUsesWith(IRInst* other) + static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) { - // Safety check: don't try to replace something with itself. - if(other == this) - return; + SharedIRBuilder* sharedBuilder = nullptr; - // We will walk through the list of uses for the current - // instruction, and make them point to the other inst. - IRUse* ff = firstUse; + struct WorkItem + { + IRInst* thisInst; + IRInst* otherInst; + }; - // No uses? Nothing to do. - if(!ff) - return; + // A work list of hoistable users for which we need + // to deduplicate/update their entry in the global numbering map. + List<WorkItem> workList; + HashSet<IRInst*> workListSet; - ff->debugValidate(); + auto addToWorkList = [&](IRInst* src, IRInst* target) + { + if (workListSet.Add(src)) + { + WorkItem item; + item.thisInst = src; + item.otherInst = target; + workList.add(item); + } + }; - IRUse* uu = ff; - for(;;) + addToWorkList(thisInst, other); + + for (Index i = 0; i < workList.getCount(); i++) { - // The uses had better all be uses of this - // instruction, or invariants are broken. - SLANG_ASSERT(uu->get() == this); + auto workItem = workList[i]; + thisInst = workItem.thisInst; + other = workItem.otherInst; - // Swap this use over to use the other value. - uu->usedValue = other; + // Safety check: don't try to replace something with itself. + if (other == thisInst) + continue; - // Try to move to the next use, but bail - // out if we are at the last one. - IRUse* nn = uu->nextUse; - if( !nn ) - break; + if (getIROpInfo(thisInst->getOp()).isHoistable()) + { + if (!sharedBuilder) + { + SLANG_ASSERT(thisInst->getModule()); + sharedBuilder = thisInst->getModule()->getSharedBuilder(); + } + sharedBuilder->getInstReplacementMap()[thisInst] = other; + } - uu = nn; - } + // We will walk through the list of uses for the current + // instruction, and make them point to the other inst. + IRUse* ff = thisInst->firstUse; - // We are at the last use (and there must - // be at least one, because we handled - // the case of an empty list earlier). - SLANG_ASSERT(uu); + // No uses? Nothing to do. + if (!ff) + continue; - // Our job at this point is to splice - // our list of uses onto the other - // value's uses. - // - // If the value already had uses, then - // we need to patch our new list onto - // the front. - if( auto nn = other->firstUse ) - { - uu->nextUse = nn; - nn->prevLink = &uu->nextUse; - } + //ff->debugValidate(); + + IRUse* uu = ff; + for (;;) + { + // The uses had better all be uses of this + // instruction, or invariants are broken. + SLANG_ASSERT(uu->get() == thisInst); + + auto user = uu->getUser(); + bool userIsHoistable = getIROpInfo(user->getOp()).isHoistable(); + if (userIsHoistable) + { + if (!sharedBuilder) + { + SLANG_ASSERT(user->getModule()); + sharedBuilder = user->getModule()->getSharedBuilder(); + } + sharedBuilder->_removeGlobalNumberingEntry(user); + } + + // Swap this use over to use the other value. + uu->usedValue = other; + + if (userIsHoistable) + { + // Is the updated inst already exists in the global numbering map? + // If so, we need to continue work on replacing the updated inst with the existing value. + IRInst* existingVal = nullptr; + if (sharedBuilder->getGlobalValueNumberingMap().TryGetValue(IRInstKey{ user }, existingVal)) + { + addToWorkList(user, existingVal); + } + else + { + sharedBuilder->_addGlobalNumberingEntry(user); + } + } + + // Try to move to the next use, but bail + // out if we are at the last one. + IRUse* nn = uu->nextUse; + if (!nn) + break; + + uu = nn; + } - // No matter what, our list of - // uses will become the start - // of the list of uses for - // `other` - other->firstUse = ff; - ff->prevLink = &other->firstUse; + // We are at the last use (and there must + // be at least one, because we handled + // the case of an empty list earlier). + SLANG_ASSERT(uu); - // And `this` will have no uses any more. - this->firstUse = nullptr; + // Our job at this point is to splice + // our list of uses onto the other + // value's uses. + // + // If the value already had uses, then + // we need to patch our new list onto + // the front. + if (auto nn = other->firstUse) + { + uu->nextUse = nn; + nn->prevLink = &uu->nextUse; + } + + // No matter what, our list of + // uses will become the start + // of the list of uses for + // `other` + other->firstUse = ff; + ff->prevLink = &other->firstUse; + + // And `this` will have no uses any more. + thisInst->firstUse = nullptr; + + ff->debugValidate(); + } - ff->debugValidate(); + } + + void IRInst::replaceUsesWith(IRInst* other) + { + _replaceInstUsesWith(this, other); } // Insert this instruction into the same basic block @@ -6750,9 +6793,21 @@ namespace Slang // and then destroy it (it had better have no uses!) void IRInst::removeAndDeallocate() { - removeFromParent(); + if (auto module = getModule()) + { + if (getIROpInfo(getOp()).isHoistable()) + { + module->getSharedBuilder()->removeHoistableInstFromGlobalNumberingMap(this); + } + else if (auto constInst = as<IRConstant>(this)) + { + module->getSharedBuilder()->getConstantMap().Remove(IRConstantKey{ constInst }); + } + module->getSharedBuilder()->getInstReplacementMap().Remove(this); + } removeArguments(); removeAndDeallocateAllDecorationsAndChildren(); + removeFromParent(); // Run destructor to be sure... this->~IRInst(); @@ -6919,7 +6974,6 @@ namespace Slang case kIROp_Not: case kIROp_BitNot: case kIROp_Select: - case kIROp_Dot: case kIROp_MakeExistential: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialValue: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 41b140972..9b8aa5cb7 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -37,12 +37,14 @@ enum : IROpFlags kIROpFlags_None = 0, kIROpFlag_Parent = 1 << 0, ///< This op is a parent op kIROpFlag_UseOther = 1 << 1, ///< If set this op can use 'other bits' to store information + kIROpFlag_Hoistable = 1 << 2, ///< If set this op is a hoistable inst that needs to be deduplicated. + kIROpFlag_Global = 1 << 3, ///< If set this op should always be hoisted but should never be deduplicated. }; /* Bit usage of IROp is a follows MainOp | Other -Bit range: 0-7 | Remaining bits +Bit range: 0-10 | Remaining bits For doing range checks (for example for doing isa tests), the value is masked by kIROpMeta_OpMask, such that the Other bits don't interfere. The other bits can be used for storage for anything that needs to identify as a different 'op' or 'type'. It is currently @@ -92,6 +94,9 @@ struct IROpInfo // Flags to control how we emit additional info IROpFlags flags; + + bool isHoistable() const { return (flags & kIROpFlag_Hoistable) != 0; } + bool isGlobal() const { return (flags & kIROpFlag_Global) != 0; } }; // Look up the info for an op @@ -206,6 +211,43 @@ struct IRInstList : IRInstListBase }; template<typename T> +struct IRModifiableInstList +{ + IRInst* parent; + List<IRInst*> workList; + + IRModifiableInstList() {} + + IRModifiableInstList(T* parent, T* first, T* last); + + T* getFirst() { return workList.getCount() ? (T*)workList.getFirst() : nullptr; } + T* getLast() { return workList.getCount() ? (T*)workList.getLast() : nullptr; } + + struct Iterator + { + IRModifiableInstList<T>* list; + Index position = 0; + + Iterator() {} + Iterator(IRModifiableInstList<T>* inList, Index inPos) : list(inList), position(inPos) {} + + T* operator*() + { + return (T*)(list->workList[position]); + } + void operator++(); + + bool operator!=(Iterator const& i) + { + return i.list != list || i.position != position; + } + }; + + Iterator begin() { return Iterator(this, 0); } + Iterator end() { return Iterator(this, workList.getCount()); } +}; + +template<typename T> struct IRFilteredInstList : IRInstListBase { IRFilteredInstList() {} @@ -591,6 +633,14 @@ struct IRInst getLastChild()); } + IRModifiableInstList<IRInst> getModifiableChildren() + { + return IRModifiableInstList<IRInst>( + this, + getFirstChild(), + getLastChild()); + } + /// A doubly-linked list containing any decorations and then any children of this instruction. /// /// We store both the decorations and children of an instruction @@ -607,7 +657,13 @@ struct IRInst IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; } IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; } IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; } - + IRModifiableInstList<IRInst> getModifiableDecorationsAndChildren() + { + return IRModifiableInstList<IRInst>( + this, + m_decorationsAndChildren.first, + m_decorationsAndChildren.last); + } void removeAndDeallocateAllDecorationsAndChildren(); #ifdef SLANG_ENABLE_IR_BREAK_ALLOC @@ -647,6 +703,12 @@ struct IRInst getOperands()[index].set(value); } + void unsafeSetOperand(UInt index, IRInst* value) + { + SLANG_ASSERT(getOperands()[index].user != nullptr); + getOperands()[index].init(this, value); + } + // @@ -773,6 +835,39 @@ typename IRInstList<T>::Iterator IRInstList<T>::end() } template<typename T> +IRModifiableInstList<T>::IRModifiableInstList(T* inParent, T* first, T* last) +{ + parent = inParent; + for (auto item = first; item; item = item->next) + { + workList.add(item); + if (item == last) + break; + } +} + +template<typename T> +void IRModifiableInstList<T>::Iterator::operator++() +{ + position++; + while (position < list->workList.getCount()) + { + auto inst = list->workList[position]; + if (!as<T>(inst)) + { + // Skip insts that are not of type T. + } + else if (list->parent != inst->parent) + { + // Skip insts that are no longer in its original parent. + } + else + break; + position++; + } +} + +template<typename T> IRFilteredInstList<T>::IRFilteredInstList(IRInst* fst, IRInst* lst) { first = fst; @@ -1796,6 +1891,104 @@ struct IRModuleInst : IRInst IR_LEAF_ISA(Module) }; +struct IRModule; + +// Description of an instruction to be used for global value numbering +struct IRInstKey +{ + IRInst* inst; + + HashCode getHashCode(); +}; + +bool operator==(IRInstKey const& left, IRInstKey const& right); + +struct IRConstantKey +{ + IRConstant* inst; + + bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); } + HashCode getHashCode() const { return inst->getHashCode(); } +}; + +struct SharedIRBuilder +{ +public: + SharedIRBuilder() + {} + + explicit SharedIRBuilder(IRModule* module) + { + init(module); + } + + void init(IRModule* module); + + IRModule* getModule() + { + return m_module; + } + + Session* getSession() + { + return m_session; + } + + void insertBlockAlongEdge(IREdge const& edge); + + // Rebuilds `globalValueNumberingMap`. This is necessary if any existing + // keys are modified (thus its hash code is changed). + void deduplicateAndRebuildGlobalNumberingMap(); + + // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement. + void replaceGlobalInst(IRInst* oldInst, IRInst* newInst); + + void removeHoistableInstFromGlobalNumberingMap(IRInst* inst); + + void tryHoistInst(IRInst* inst); + + typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap; + typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap; + + GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; } + Dictionary<IRInst*, IRInst*>& getInstReplacementMap() { return m_instReplacementMap; } + + void _addGlobalNumberingEntry(IRInst* inst) + { + m_globalValueNumberingMap.Add(IRInstKey{ inst }, inst); + m_instReplacementMap.Remove(inst); + tryHoistInst(inst); + } + void _removeGlobalNumberingEntry(IRInst* inst) + { + IRInst* value = nullptr; + if (m_globalValueNumberingMap.TryGetValue(IRInstKey{ inst }, value)) + { + if (value == inst) + { + m_globalValueNumberingMap.Remove(IRInstKey{ inst }); + } + } + } + + ConstantMap& getConstantMap() { return m_constantMap; } + +private: + // The module that will own all of the IR + IRModule* m_module; + + // The parent compilation session + Session* m_session; + + GlobalValueNumberingMap m_globalValueNumberingMap; + + // Duplicate insts that are still alive and needs to be replaced in m_globalValueNumberMap + // when used as an operand to create another inst. + Dictionary<IRInst*, IRInst*> m_instReplacementMap; + + ConstantMap m_constantMap; +}; + struct IRModule : RefObject { public: @@ -1810,6 +2003,8 @@ public: SLANG_FORCE_INLINE IRModuleInst* getModuleInst() const { return m_moduleInst; } SLANG_FORCE_INLINE MemoryArena& getMemoryArena() { return m_memoryArena; } + SharedIRBuilder* getSharedBuilder() const { return &m_sharedBuilder; } + IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); } /// Create an empty instruction with the `op` opcode and space for @@ -1853,6 +2048,7 @@ private: IRModule(Session* session) : m_session(session) , m_memoryArena(kMemoryArenaBlockSize) + , m_sharedBuilder(this) { } @@ -1870,6 +2066,9 @@ private: /// The memory arena from which all IR instructions (and any associated state) in this module are allocated. MemoryArena m_memoryArena; + + /// Shared contexts for constructing and maintaining the IR. + mutable SharedIRBuilder m_sharedBuilder; }; struct IRSpecializationDictionaryItem : public IRInst @@ -1943,13 +2142,17 @@ uint32_t& _debugGetIRAllocCounter(); // TODO: Ellie, comment and move somewhere more appropriate? template<typename I = IRInst, typename F> -static void traverseUses(IRInst* inst, F f) +static void traverseUsers(IRInst* inst, F f) { - auto n = inst->firstUse; - IRUse* u; - while((u = n) != nullptr) + List<IRUse*> uses; + for (auto use = inst->firstUse; use; use = use->nextUse) { - n = u->nextUse; + uses.add(use); + } + for (auto u : uses) + { + if (u->usedValue != inst) + continue; if(auto s = as<I>(u->getUser())) { f(s); @@ -1957,6 +2160,22 @@ static void traverseUses(IRInst* inst, F f) } } +template<typename F> +static void traverseUses(IRInst* inst, F f) +{ + List<IRUse*> uses; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + uses.add(use); + } + for (auto u : uses) + { + if (u->usedValue != inst) + continue; + f(u); + } +} + namespace detail { // A helper to get the singular pointer argument of something callable diff --git a/tests/bugs/vk-structured-buffer-load.hlsl.glsl b/tests/bugs/vk-structured-buffer-load.hlsl.glsl index 05c8de193..7f3ec40a2 100644 --- a/tests/bugs/vk-structured-buffer-load.hlsl.glsl +++ b/tests/bugs/vk-structured-buffer-load.hlsl.glsl @@ -2,17 +2,9 @@ //TEST_IGNORE_FILE: #version 460 +#extension GL_NV_ray_tracing : require layout(row_major) uniform; layout(row_major) buffer; -#extension GL_NV_ray_tracing : require - -#define rcp_tmp _S2 -#define RayData _S3 -#define Attributes _S4 - -#define tmpA _S5 -#define tmpB _S6 -#define tmpC _S7 layout(std430, binding = 1) readonly buffer _S1 { float _data[]; @@ -20,8 +12,8 @@ layout(std430, binding = 1) readonly buffer _S1 { float rcp_0(float x_0) { - float rcp_tmp = float(1.00000000000000000000) / x_0; - return rcp_tmp; + float _S2 = 1.0 / x_0; + return _S2; } struct RayHitInfoPacked_0 @@ -29,49 +21,48 @@ struct RayHitInfoPacked_0 vec4 PackedHitInfoA_0; }; -rayPayloadInNV RayHitInfoPacked_0 RayData; +rayPayloadInNV RayHitInfoPacked_0 _S3; struct BuiltInTriangleIntersectionAttributes_0 { vec2 barycentrics_0; }; -hitAttributeNV BuiltInTriangleIntersectionAttributes_0 Attributes; +hitAttributeNV BuiltInTriangleIntersectionAttributes_0 _S4; void main() { - float HitT_0 = (gl_RayTmaxNV); - RayData.PackedHitInfoA_0.x = HitT_0; - + float HitT_0 = ((gl_RayTmaxNV)); + _S3.PackedHitInfoA_0.x = HitT_0; float offsfloat_0 = ((gParamBlock_sbuf_0)._data[(0)]); - uint use_rcp_0 = 0U | uint(HitT_0 > 0.00000000000000000000); + uint use_rcp_0 = 0U | uint(HitT_0 > 0.0); - if(bool(use_rcp_0)) + if(use_rcp_0 != 0U) { - float tmpA = rcp_0(offsfloat_0); + float _S5 = rcp_0(offsfloat_0); - RayData.PackedHitInfoA_0.y = tmpA; + _S3.PackedHitInfoA_0.y = _S5; } else { - if(use_rcp_0 > 0U&&offsfloat_0 == 0.00000000000000000000) + if(use_rcp_0 > 0U&&offsfloat_0 == 0.0) { - float tmpB = (inversesqrt((offsfloat_0 + 1.00000000000000000000))); + float _S6 = (inversesqrt((offsfloat_0 + 1.0))); - RayData.PackedHitInfoA_0.y = tmpB; + _S3.PackedHitInfoA_0.y = _S6; } else { - float tmpC = (inversesqrt((offsfloat_0))); + float _S7 = (inversesqrt((offsfloat_0))); - RayData.PackedHitInfoA_0.y = tmpC; + _S3.PackedHitInfoA_0.y = _S7; } diff --git a/tests/compute/dynamic-dispatch-bindless-texture.slang b/tests/compute/dynamic-dispatch-bindless-texture.slang index 4611fbd48..04c1f1766 100644 --- a/tests/compute/dynamic-dispatch-bindless-texture.slang +++ b/tests/compute/dynamic-dispatch-bindless-texture.slang @@ -10,7 +10,7 @@ interface IInterface //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=gOutputBuffer RWStructuredBuffer<uint> gOutputBuffer; -//TEST_INPUT: set gCb = new StructuredBuffer<IInterface>{new MyImpl{Texture2D(size=8, content = one)}} +//TEST_INPUT: set gCb = new StructuredBuffer<IInterface>{new MyImpl{Texture2D(size=8, content = one), Sampler}} StructuredBuffer<IInterface> gCb; [numthreads(4, 1, 1)] diff --git a/tests/cross-compile/function-static-const.slang.hlsl b/tests/cross-compile/function-static-const.slang.hlsl index a4f1118eb..95d1e3070 100644 --- a/tests/cross-compile/function-static-const.slang.hlsl +++ b/tests/cross-compile/function-static-const.slang.hlsl @@ -13,15 +13,15 @@ cbuffer C_0 : register(b0) SLANG_ParameterGroup_C_0 C_0; } -static const int kArray_0[16] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }; +static const int kArray_0[int(16)] = { int(1), int(2), int(3), int(4), int(5), int(6), int(7), int(8), int(9), int(10), int(11), int(12), int(13), int(14), int(15), int(16) }; int test_0(int val_0) { return kArray_0[val_0]; } -vector<float,4> main() : SV_TARGET +float4 main() : SV_TARGET { int _S1 = test_0(C_0.index_0); - return (vector<float,4>) _S1; + return (float4) float(_S1); } diff --git a/tests/hlsl-intrinsic/f16tof32.slang b/tests/hlsl-intrinsic/f16tof32.slang index 78c5fdae6..d45eab00b 100644 --- a/tests/hlsl-intrinsic/f16tof32.slang +++ b/tests/hlsl-intrinsic/f16tof32.slang @@ -2,7 +2,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -render-features half +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -render-features half //TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer RWStructuredBuffer<float> outputBuffer; diff --git a/tests/hlsl-intrinsic/f32tof16.slang b/tests/hlsl-intrinsic/f32tof16.slang index ebcb6b40a..ad8e8e5df 100644 --- a/tests/hlsl-intrinsic/f32tof16.slang +++ b/tests/hlsl-intrinsic/f32tof16.slang @@ -2,7 +2,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -render-features half +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -render-features half //TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer RWStructuredBuffer<uint> outputBuffer; diff --git a/tests/hlsl-intrinsic/matrix-double.slang b/tests/hlsl-intrinsic/matrix-double.slang index b9e36bef8..08bd78cee 100644 --- a/tests/hlsl-intrinsic/matrix-double.slang +++ b/tests/hlsl-intrinsic/matrix-double.slang @@ -35,8 +35,7 @@ Float calcTotal(FloatMatrix v) FloatMatrix makeFloatMatrix(Float f) { - FloatMatrix m = { { f, f }, { f, f } }; - return m; + return FloatMatrix(f); } IntMatrix makeIntMatrix(int v) @@ -45,68 +44,58 @@ IntMatrix makeIntMatrix(int v) return m; } -[numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void test1(inout FloatMatrix ft, inout FloatMatrix f, int idx) { - int idx = int(dispatchThreadID.x); - - Float scalarF = idx * (1.0f / (4.0f)); - - FloatMatrix ft = {}; - - FloatMatrix f = { { scalarF + 0.01, scalarF + 0.02}, { scalarF + 0.011, scalarF + 0.022}}; - // fmod ft += FloatMatrix(IntMatrix(((f % makeFloatMatrix(0.11f)) * makeFloatMatrix(100)) + makeFloatMatrix(0.5))); - + ft += sin(f); - + // Lets try some matrix/matrix ft = f * ft; - + // Lets try some vector matrix - + { - FloatMatrix r = {mul(f[0], ft), mul(ft, f[1])}; + FloatMatrix r = { mul(f[0], ft), mul(ft, f[1]) }; ft += r; } - + // Back to the transcendentals - + ft += cos(f); ft += tan(f); - + ft += asin(f); ft += acos(f); ft += atan(f); - - ft += atan2(f, makeFloatMatrix(2)); + ft += atan2(f, makeFloatMatrix(2)); { FloatMatrix sf, cf; sincos(f, sf, cf); - + ft += sf; ft += cf; } - + ft += rcp(makeFloatMatrix(1.0) + f); ft += FloatMatrix(sign(f - makeFloatMatrix(0.5))); - + ft += saturate(f * makeFloatMatrix(4) - makeFloatMatrix(2.0)); - + ft += sqrt(f); ft += rsqrt(makeFloatMatrix(1.0f) + f); - + ft += exp2(f); ft += exp(f); - + ft += frac(f * makeFloatMatrix(3)); ft += ceil(f * makeFloatMatrix(5) - makeFloatMatrix(3)); - + ft += floor(f * makeFloatMatrix(10) - makeFloatMatrix(7)); ft += trunc(f * makeFloatMatrix(7)); - + ft += log(f + makeFloatMatrix(10.0)); ft += log2(f * makeFloatMatrix(3) + makeFloatMatrix(2)); @@ -114,12 +103,15 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float scalarVs[] = { 1, 10, 100, 1000 }; ft += FloatMatrix(IntMatrix(log10(makeFloatMatrix(scalarVs[idx])) + makeFloatMatrix(0.5f))); } - + ft += abs(f * makeFloatMatrix(4) - makeFloatMatrix(2.0f)); - + ft += min(makeFloatMatrix(0.5), f); ft += max(f, makeFloatMatrix(0.75)); +} +void test2(inout FloatMatrix ft, inout FloatMatrix f) +{ ft += pow(makeFloatMatrix(0.5), f); ft += smoothstep(makeFloatMatrix(0.2), makeFloatMatrix(0.7), f); @@ -135,7 +127,22 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) UIntMatrix vu = asuint(f); ft += asfloat(vu); -#endif - +#endif +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + int idx = int(dispatchThreadID.x); + + Float scalarF = idx * (1.0f / (4.0f)); + + FloatMatrix ft = {}; + + FloatMatrix f = { { scalarF + 0.01, scalarF + 0.02}, { scalarF + 0.011, scalarF + 0.022}}; + + test1(ft, f, idx); + test2(ft, f); + outputBuffer[idx] = calcTotal(ft); }
\ No newline at end of file diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected index 58f06cfec..6e3e5d5d8 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang.1.expected @@ -17,13 +17,17 @@ layout(location = 0) rayPayloadEXT SomeValues_0 p_0; -layout(binding = 0) -uniform accelerationStructureEXT scene_0; - layout(location = 0) hitObjectAttributeNV SomeValues_0 t_0; +layout(location = 1) +rayPayloadEXT +SomeValues_0 p_1; + +layout(binding = 0) +uniform accelerationStructureEXT scene_0; + SomeValues_0 HitObject_GetAttributes_0(hitObjectNV this_0) { hitObjectGetAttributesNV((this_0), ((0))); @@ -50,15 +54,11 @@ uint calcValue_0(hitObjectNV hit_0) return r_0; } -layout(location = 1) -rayPayloadEXT -SomeValues_0 p_1; - void HitObject_Invoke_0(accelerationStructureEXT AccelerationStructure_0, hitObjectNV HitOrMiss_0, inout SomeValues_0 Payload_0) { - p_1 = Payload_0; - hitObjectExecuteShaderNV(HitOrMiss_0, (1)); - Payload_0 = p_1; + p_0 = Payload_0; + hitObjectExecuteShaderNV(HitOrMiss_0, (0)); + Payload_0 = p_0; return; } @@ -87,10 +87,10 @@ void main() ray_0.Direction_0 = vec3(0.0, 1.0, 0.0); ray_0.TMax_0 = 10000.0; RayDesc_0 _S7 = ray_0; - p_0.a_0 = idx_0; - p_0.b_0 = _S6; + p_1.a_0 = idx_0; + p_1.b_0 = _S6; hitObjectNV hitObj_0; - hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S7.Origin_0, _S7.TMin_0, _S7.Direction_0, _S7.TMax_0, (0)); + hitObjectTraceRayNV(hitObj_0, scene_0, 20U, 255U, 0U, 4U, 0U, _S7.Origin_0, _S7.TMin_0, _S7.Direction_0, _S7.TMax_0, (1)); uint r_1 = calcValue_0(hitObj_0); reorderThreadNV(hitObj_0); float _S8 = _S5 * 4.0; diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected index e6d31a3c8..12025efe0 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang.1.expected @@ -15,16 +15,16 @@ struct SomeValues_0 }; layout(location = 0) +hitObjectAttributeNV +SomeValues_0 t_0; + +layout(location = 0) rayPayloadEXT SomeValues_0 p_0; layout(binding = 0) uniform accelerationStructureEXT scene_0; -layout(location = 0) -hitObjectAttributeNV -SomeValues_0 t_0; - SomeValues_0 HitObject_GetAttributes_0(hitObjectNV this_0) { hitObjectGetAttributesNV((this_0), ((0))); diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected index 9b5a4d193..ac4d78234 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang.1.expected @@ -14,16 +14,16 @@ struct SomeValues_0 }; layout(location = 0) +hitObjectAttributeNV +SomeValues_0 t_0; + +layout(location = 0) rayPayloadEXT SomeValues_0 p_0; layout(binding = 0) uniform accelerationStructureEXT scene_0; -layout(location = 0) -hitObjectAttributeNV -SomeValues_0 t_0; - SomeValues_0 HitObject_GetAttributes_0(hitObjectNV this_0) { hitObjectGetAttributesNV((this_0), ((0))); diff --git a/tests/nv-extensions/nv-ray-tracing-motion-blur.slang.glsl b/tests/nv-extensions/nv-ray-tracing-motion-blur.slang.glsl index 79c6232f3..724a0a241 100644 --- a/tests/nv-extensions/nv-ray-tracing-motion-blur.slang.glsl +++ b/tests/nv-extensions/nv-ray-tracing-motion-blur.slang.glsl @@ -1,26 +1,49 @@ -//TEST_IGNORE_FILE: - #version 460 #extension GL_EXT_ray_tracing : require #extension GL_NV_ray_tracing_motion_blur : require layout(row_major) uniform; layout(row_major) buffer; +struct ReflectionRay_0 +{ + float color_0; +}; + + +layout(location = 0) +rayPayloadEXT +ReflectionRay_0 p_0; + + +struct ShadowRay_0 +{ + float hitDistance_0; +}; + + +layout(location = 1) +rayPayloadEXT +ShadowRay_0 p_1; + + layout(binding = 0) uniform texture2D samplerPosition_0; + layout(binding = 2) uniform sampler sampler_0; + layout(binding = 1) uniform texture2D samplerNormal_0; struct Light_0 { vec4 position_0; - vec4 color_0; + vec4 color_1; }; + struct Uniforms_0 { Light_0 light_0; @@ -29,21 +52,13 @@ struct Uniforms_0 mat4x4 model_0; }; + layout(binding = 3) layout(std140) uniform _S1 { Uniforms_0 _data; } ubo_0; -struct ShadowRay_0 -{ - float hitDistance_0; -}; - -layout(location = 0) -rayPayloadEXT -ShadowRay_0 p_0; - struct RayDesc_0 { vec3 Origin_0; @@ -52,18 +67,22 @@ struct RayDesc_0 float TMax_0; }; + void TraceMotionRay_0(accelerationStructureEXT AccelerationStructure_0, uint RayFlags_0, uint InstanceInclusionMask_0, uint RayContributionToHitGroupIndex_0, uint MultiplierForGeometryContributionToHitGroupIndex_0, uint MissShaderIndex_0, RayDesc_0 Ray_0, float CurrentTime_0, inout ShadowRay_0 Payload_0) { - p_0 = Payload_0; - traceRayMotionNV(AccelerationStructure_0, RayFlags_0, InstanceInclusionMask_0, RayContributionToHitGroupIndex_0, MultiplierForGeometryContributionToHitGroupIndex_0, MissShaderIndex_0, Ray_0.Origin_0, Ray_0.TMin_0, Ray_0.Direction_0, Ray_0.TMax_0, CurrentTime_0, (0)); - Payload_0 = p_0; + p_1 = Payload_0; + traceRayMotionNV(AccelerationStructure_0, RayFlags_0, InstanceInclusionMask_0, RayContributionToHitGroupIndex_0, MultiplierForGeometryContributionToHitGroupIndex_0, MissShaderIndex_0, Ray_0.Origin_0, Ray_0.TMin_0, Ray_0.Direction_0, Ray_0.TMax_0, CurrentTime_0, (1)); + + Payload_0 = p_1; return; } + layout(binding = 5) uniform accelerationStructureEXT as_0; + float saturate_0(float x_0) { float _S2 = clamp(x_0, 0.0, 1.0); @@ -71,28 +90,23 @@ float saturate_0(float x_0) return _S2; } -struct ReflectionRay_0 -{ - float color_1; -}; - -layout(location = 1) -rayPayloadEXT -ReflectionRay_0 p_1; void TraceRay_0(accelerationStructureEXT AccelerationStructure_1, uint RayFlags_1, uint InstanceInclusionMask_1, uint RayContributionToHitGroupIndex_1, uint MultiplierForGeometryContributionToHitGroupIndex_1, uint MissShaderIndex_1, RayDesc_0 Ray_1, inout ReflectionRay_0 Payload_1) { - p_1 = Payload_1; - traceRayEXT(AccelerationStructure_1, RayFlags_1, InstanceInclusionMask_1, RayContributionToHitGroupIndex_1, MultiplierForGeometryContributionToHitGroupIndex_1, MissShaderIndex_1, Ray_1.Origin_0, Ray_1.TMin_0, Ray_1.Direction_0, Ray_1.TMax_0, (1)); - Payload_1 = p_1; + p_0 = Payload_1; + traceRayEXT(AccelerationStructure_1, RayFlags_1, InstanceInclusionMask_1, RayContributionToHitGroupIndex_1, MultiplierForGeometryContributionToHitGroupIndex_1, MissShaderIndex_1, Ray_1.Origin_0, Ray_1.TMin_0, Ray_1.Direction_0, Ray_1.TMax_0, (0)); + + Payload_1 = p_0; return; } + layout(rgba32f) layout(binding = 4) uniform image2D outputImage_0; + void main() { uvec3 _S3 = ((gl_LaunchIDEXT)); @@ -102,6 +116,7 @@ void main() ivec2 launchSize_0 = ivec2(_S4.xy); + float _S5 = (float(launchID_0.x) + 0.5) / float(launchSize_0.x); float _S6 = (float(launchID_0.y) + 0.5) / float(launchSize_0.y); @@ -114,6 +129,7 @@ void main() vec3 N_0 = _S8.xyz * 2.0 - 1.0; + vec3 lightDelta_0 = ubo_0._data.light_0.position_0.xyz - P_0; float lightDist_0 = length(lightDelta_0); vec3 L_0 = normalize(lightDelta_0); @@ -125,22 +141,30 @@ void main() ray_0.Direction_0 = lightDelta_0; ray_0.TMax_0 = lightDist_0; + ShadowRay_0 shadowRay_0; shadowRay_0.hitDistance_0 = 0.0; + + + TraceMotionRay_0(as_0, 1U, 255U, 0U, 0U, 2U, ray_0, 1.0, shadowRay_0); float atten_0; if(shadowRay_0.hitDistance_0 < lightDist_0) { + atten_0 = 0.0; + } else { + atten_0 = _S9; + } - vec3 _S10 = ubo_0._data.light_0.color_0.xyz; + vec3 _S10 = ubo_0._data.light_0.color_1.xyz; float _S11 = dot(N_0, L_0); @@ -148,9 +172,10 @@ void main() vec3 color_2 = _S10 * _S12 * atten_0; + ReflectionRay_0 reflectionRay_0; TraceRay_0(as_0, 1U, 255U, 0U, 0U, 2U, ray_0, reflectionRay_0); - imageStore((outputImage_0), ivec2((uvec2(launchID_0))), vec4(color_2 + reflectionRay_0.color_1, 1.0)); + imageStore((outputImage_0), ivec2((uvec2(launchID_0))), vec4(color_2 + reflectionRay_0.color_0, 1.0)); return; }
\ No newline at end of file diff --git a/tests/pipeline/rasterization/mesh/component-write.slang.glsl b/tests/pipeline/rasterization/mesh/component-write.slang.glsl index 5234fb62f..143947504 100644 --- a/tests/pipeline/rasterization/mesh/component-write.slang.glsl +++ b/tests/pipeline/rasterization/mesh/component-write.slang.glsl @@ -2,26 +2,30 @@ #extension GL_EXT_mesh_shader : require layout(row_major) uniform; layout(row_major) buffer; -const vec3 colors_0[3] = { vec3(1.00000000000000000000, 1.00000000000000000000, 0.00000000000000000000), vec3(0.00000000000000000000, 1.00000000000000000000, 1.00000000000000000000), vec3(1.00000000000000000000, 0.00000000000000000000, 1.00000000000000000000) }; -const vec2 positions_0[3] = { vec2(0.00000000000000000000, -0.50000000000000000000), vec2(0.50000000000000000000, 0.50000000000000000000), vec2(-0.50000000000000000000, 0.50000000000000000000) }; +const vec2 positions_0[3] = { vec2(0.0, -0.5), vec2(0.5, 0.5), vec2(-0.5, 0.5) }; +const vec3 colors_0[3] = { vec3(1.0, 1.0, 0.0), vec3(0.0, 1.0, 1.0), vec3(1.0, 0.0, 1.0) }; layout(location = 0) out vec3 _S1[3]; + out gl_MeshPerVertexEXT { vec4 gl_Position; } gl_MeshVerticesEXT[3]; +out uvec3 gl_PrimitiveTriangleIndicesEXT[1]; + layout(local_size_x = 3, local_size_y = 1, local_size_z = 1) in; layout(max_vertices = 3) out; layout(max_primitives = 1) out; layout(triangles) out; void main() { + SetMeshOutputsEXT(3U, 1U); if(gl_LocalInvocationIndex < 3U) { - gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = vec4(positions_0[gl_LocalInvocationIndex], 0.00000000000000000000, 1.00000000000000000000); + gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = vec4(positions_0[gl_LocalInvocationIndex], 0.0, 1.0); _S1[gl_LocalInvocationIndex] = colors_0[gl_LocalInvocationIndex]; } else @@ -34,6 +38,6 @@ void main() else { } + return; } - diff --git a/tests/pipeline/rasterization/mesh/hello.slang.glsl b/tests/pipeline/rasterization/mesh/hello.slang.glsl index 3a0848dcb..0b3d4acb3 100644 --- a/tests/pipeline/rasterization/mesh/hello.slang.glsl +++ b/tests/pipeline/rasterization/mesh/hello.slang.glsl @@ -2,12 +2,12 @@ #extension GL_EXT_mesh_shader : require layout(row_major) uniform; layout(row_major) buffer; -const vec3 colors_0[3] = { vec3(1.0, 1.0, 0.0), vec3(0.0, 1.0, 1.0), vec3(1.0, 0.0, 1.0) }; const vec2 positions_0[3] = { vec2(0.0, -0.5), vec2(0.5, 0.5), vec2(-0.5, 0.5) }; - +const vec3 colors_0[3] = { vec3(1.0, 1.0, 0.0), vec3(0.0, 1.0, 1.0), vec3(1.0, 0.0, 1.0) }; layout(location = 0) out vec3 _S1[3]; + out gl_MeshPerVertexEXT { vec4 gl_Position; @@ -21,6 +21,7 @@ layout(max_primitives = 1) out; layout(triangles) out; void main() { + SetMeshOutputsEXT(3U, 1U); if(gl_LocalInvocationIndex < 3U) { @@ -30,6 +31,7 @@ void main() else { } + if(gl_LocalInvocationIndex < 1U) { gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0U, 1U, 2U); @@ -37,6 +39,6 @@ void main() else { } + return; } - diff --git a/tests/pipeline/rasterization/mesh/hello.slang.hlsl b/tests/pipeline/rasterization/mesh/hello.slang.hlsl index a10dc6884..06871fedd 100644 --- a/tests/pipeline/rasterization/mesh/hello.slang.hlsl +++ b/tests/pipeline/rasterization/mesh/hello.slang.hlsl @@ -4,16 +4,15 @@ #endif #pragma warning(disable: 3557) -static const float3 colors_0[int(3)] = { float3(1.0, 1.0, 0.0), float3(0.0, 1.0, 1.0), float3(1.0, 0.0, 1.0) }; static const float2 positions_0[int(3)] = { float2(0.0, -0.5), float2(0.5, 0.5), float2(-0.5, 0.5) }; +static const float3 colors_0[int(3)] = { float3(1.0, 1.0, 0.0), float3(0.0, 1.0, 1.0), float3(1.0, 0.0, 1.0) }; struct Vertex_0 { float4 pos_0 : SV_Position; float3 color_0 : Color; }; -[shader("mesh")] -[numthreads(3, 1, 1)] +[shader("mesh")][numthreads(3, 1, 1)] [outputtopology("triangle")] void main(uint tig_0 : SV_GROUPINDEX, vertices vertices out Vertex_0 verts_0[int(3)], indices indices out uint3 triangles_0[int(1)]) { @@ -26,6 +25,7 @@ void main(uint tig_0 : SV_GROUPINDEX, vertices vertices out Vertex_0 verts_0[in else { } + if(tig_0 < 1U) { triangles_0[tig_0] = uint3(0U, 1U, 2U); @@ -33,6 +33,6 @@ void main(uint tig_0 : SV_GROUPINDEX, vertices vertices out Vertex_0 verts_0[in else { } + return; } - diff --git a/tests/pipeline/rasterization/mesh/primitive-output.slang.glsl b/tests/pipeline/rasterization/mesh/primitive-output.slang.glsl index 35efcf4af..cfe266bec 100644 --- a/tests/pipeline/rasterization/mesh/primitive-output.slang.glsl +++ b/tests/pipeline/rasterization/mesh/primitive-output.slang.glsl @@ -2,9 +2,10 @@ #extension GL_EXT_mesh_shader : require layout(row_major) uniform; layout(row_major) buffer; -const vec3 colors_0[3] = { vec3(1.0, 1.0, 0.0), vec3(0.0, 1.0, 1.0), vec3(1.0, 0.0, 1.0) }; const vec2 positions_0[3] = { vec2(0.0, -0.5), vec2(0.5, 0.5), vec2(-0.5, 0.5) }; +const vec3 colors_0[3] = { vec3(1.0, 1.0, 0.0), vec3(0.0, 1.0, 1.0), vec3(1.0, 0.0, 1.0) }; out uvec3 gl_PrimitiveTriangleIndicesEXT[1]; + layout(location = 0) out vec3 _S1[3]; @@ -28,6 +29,7 @@ layout(max_primitives = 1) out; layout(triangles) out; void main() { + SetMeshOutputsEXT(3U, 1U); if(gl_LocalInvocationIndex < 3U) { @@ -37,6 +39,7 @@ void main() else { } + if(gl_LocalInvocationIndex < 1U) { gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uvec3(0U, 1U, 2U); @@ -47,6 +50,6 @@ void main() else { } + return; } - diff --git a/tests/vkray/raygen.slang.glsl b/tests/vkray/raygen.slang.glsl index f8b97973b..e34f1f6e0 100644 --- a/tests/vkray/raygen.slang.glsl +++ b/tests/vkray/raygen.slang.glsl @@ -1,8 +1,25 @@ -//TEST_IGNORE_FILE: #version 460 #extension GL_EXT_ray_tracing : require layout(row_major) uniform; layout(row_major) buffer; +struct ReflectionRay_0 +{ + float color_0; +}; + +layout(location = 0) +rayPayloadEXT +ReflectionRay_0 p_0; + +struct ShadowRay_0 +{ + float hitDistance_0; +}; + +layout(location = 1) +rayPayloadEXT +ShadowRay_0 p_1; + layout(binding = 0) uniform texture2D samplerPosition_0; @@ -11,11 +28,10 @@ uniform sampler sampler_0; layout(binding = 1) uniform texture2D samplerNormal_0; - struct Light_0 { vec4 position_0; - vec4 color_0; + vec4 color_1; }; struct Uniforms_0 @@ -31,15 +47,6 @@ layout(std140) uniform _S1 { Uniforms_0 _data; } ubo_0; -struct ShadowRay_0 -{ - float hitDistance_0; -}; - -layout(location = 0) -rayPayloadEXT -ShadowRay_0 p_0; - struct RayDesc_0 { vec3 Origin_0; @@ -50,26 +57,17 @@ struct RayDesc_0 void TraceRay_0(accelerationStructureEXT AccelerationStructure_0, uint RayFlags_0, uint InstanceInclusionMask_0, uint RayContributionToHitGroupIndex_0, uint MultiplierForGeometryContributionToHitGroupIndex_0, uint MissShaderIndex_0, RayDesc_0 Ray_0, inout ShadowRay_0 Payload_0) { - p_0 = Payload_0; - traceRayEXT(AccelerationStructure_0, RayFlags_0, InstanceInclusionMask_0, RayContributionToHitGroupIndex_0, MultiplierForGeometryContributionToHitGroupIndex_0, MissShaderIndex_0, Ray_0.Origin_0, Ray_0.TMin_0, Ray_0.Direction_0, Ray_0.TMax_0, (0)); - Payload_0 = p_0; + p_1 = Payload_0; + traceRayEXT(AccelerationStructure_0, RayFlags_0, InstanceInclusionMask_0, RayContributionToHitGroupIndex_0, MultiplierForGeometryContributionToHitGroupIndex_0, MissShaderIndex_0, Ray_0.Origin_0, Ray_0.TMin_0, Ray_0.Direction_0, Ray_0.TMax_0, (1)); + Payload_0 = p_1; return; } -struct ReflectionRay_0 -{ - float color_1; -}; - -layout(location = 1) -rayPayloadEXT -ReflectionRay_0 p_1; - void TraceRay_1(accelerationStructureEXT AccelerationStructure_1, uint RayFlags_1, uint InstanceInclusionMask_1, uint RayContributionToHitGroupIndex_1, uint MultiplierForGeometryContributionToHitGroupIndex_1, uint MissShaderIndex_1, RayDesc_0 Ray_1, inout ReflectionRay_0 Payload_1) { - p_1 = Payload_1; - traceRayEXT(AccelerationStructure_1, RayFlags_1, InstanceInclusionMask_1, RayContributionToHitGroupIndex_1, MultiplierForGeometryContributionToHitGroupIndex_1, MissShaderIndex_1, Ray_1.Origin_0, Ray_1.TMin_0, Ray_1.Direction_0, Ray_1.TMax_0, (1)); - Payload_1 = p_1; + p_0 = Payload_1; + traceRayEXT(AccelerationStructure_1, RayFlags_1, InstanceInclusionMask_1, RayContributionToHitGroupIndex_1, MultiplierForGeometryContributionToHitGroupIndex_1, MissShaderIndex_1, Ray_1.Origin_0, Ray_1.TMin_0, Ray_1.Direction_0, Ray_1.TMax_0, (0)); + Payload_1 = p_0; return; } @@ -78,7 +76,7 @@ uniform accelerationStructureEXT as_0; float saturate_0(float x_0) { - float _S2 = clamp(x_0, float(0), float(1)); + float _S2 = clamp(x_0, 0.0, 1.0); return _S2; } @@ -88,48 +86,51 @@ uniform image2D outputImage_0; void main() { - float atten_0; uvec3 _S3 = ((gl_LaunchIDEXT)); - float _S4 = float(_S3.x) + 0.50000000000000000000; + float _S4 = float(_S3.x) + 0.5; uvec3 _S5 = ((gl_LaunchSizeEXT)); float _S6 = _S4 / float(_S5.x); uvec3 _S7 = ((gl_LaunchIDEXT)); - float _S8 = float(_S7.y) + 0.50000000000000000000; + float _S8 = float(_S7.y) + 0.5; uvec3 _S9 = ((gl_LaunchSizeEXT)); float _S10 = _S8 / float(_S9.y); vec2 inUV_0 = vec2(_S6, _S10); vec4 _S11 = (texture(sampler2D(samplerPosition_0,sampler_0), (inUV_0))); vec3 P_0 = _S11.xyz; vec4 _S12 = (texture(sampler2D(samplerNormal_0,sampler_0), (inUV_0))); - vec3 N_0 = _S12.xyz * 2.00000000000000000000 - 1.00000000000000000000; + vec3 N_0 = _S12.xyz * 2.0 - 1.0; + vec3 lightDelta_0 = ubo_0._data.light_0.position_0.xyz - P_0; float lightDist_0 = length(lightDelta_0); vec3 L_0 = normalize(lightDelta_0); - float _S13 = 1.00000000000000000000 / (lightDist_0 * lightDist_0); + float _S13 = 1.0 / (lightDist_0 * lightDist_0); RayDesc_0 ray_0; ray_0.Origin_0 = P_0; - ray_0.TMin_0 = 0.00000100000000000000; + ray_0.TMin_0 = 0.00000099999999747524; ray_0.Direction_0 = lightDelta_0; ray_0.TMax_0 = lightDist_0; + ShadowRay_0 shadowRay_0; - shadowRay_0.hitDistance_0 = float(0); - TraceRay_0(as_0, uint(1), uint(255), uint(0), uint(0), uint(2), ray_0, shadowRay_0); + shadowRay_0.hitDistance_0 = 0.0; + TraceRay_0(as_0, 1U, 255U, 0U, 0U, 2U, ray_0, shadowRay_0); + float atten_0; if(shadowRay_0.hitDistance_0 < lightDist_0) { - atten_0 = 0.00000000000000000000; + atten_0 = 0.0; } else { atten_0 = _S13; } - vec3 _S14 = ubo_0._data.light_0.color_0.xyz; + vec3 _S14 = ubo_0._data.light_0.color_1.xyz; float _S15 = dot(N_0, L_0); float _S16 = saturate_0(_S15); vec3 color_2 = _S14 * _S16 * atten_0; + ReflectionRay_0 reflectionRay_0; - TraceRay_1(as_0, uint(1), uint(255), uint(0), uint(0), uint(2), ray_0, reflectionRay_0); - vec3 color_3 = color_2 + reflectionRay_0.color_1; + TraceRay_1(as_0, 1U, 255U, 0U, 0U, 2U, ray_0, reflectionRay_0); + vec3 color_3 = color_2 + reflectionRay_0.color_0; uvec3 _S17 = ((gl_LaunchIDEXT)); - imageStore((outputImage_0), ivec2((uvec2(ivec2(_S17.xy)))), vec4(color_3, 1.00000000000000000000)); + imageStore((outputImage_0), ivec2((uvec2(ivec2(_S17.xy)))), vec4(color_3, 1.0)); return; } diff --git a/tools/gfx/d3d11/d3d11-device.cpp b/tools/gfx/d3d11/d3d11-device.cpp index f15c94da6..e32bdf7ed 100644 --- a/tools/gfx/d3d11/d3d11-device.cpp +++ b/tools/gfx/d3d11/d3d11-device.cpp @@ -1322,8 +1322,11 @@ Result DeviceImpl::createProgram( if (diagnostics) { + DebugMessageType msgType = DebugMessageType::Warning; + if (compileResult != SLANG_OK) + msgType = DebugMessageType::Error; getDebugCallback()->handleMessage( - compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error, + msgType, DebugMessageSource::Slang, (char*)diagnostics->getBufferPointer()); if (outDiagnosticBlob) diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 4a8fd04b6..445f22e5a 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -1008,8 +1008,11 @@ Result ShaderProgramBase::compileShaders(RendererBase* device) entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef()); if (diagnostics) { + DebugMessageType msgType = DebugMessageType::Warning; + if (compileResult != SLANG_OK) + msgType = DebugMessageType::Error; getDebugCallback()->handleMessage( - compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error, + msgType, DebugMessageSource::Slang, (char*)diagnostics->getBufferPointer()); } |
