diff options
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index af06e6bac..a60a77cc3 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -270,6 +270,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll]\ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(VALUE.p[i], __slang_noop_cast<TYPE.Differential>(VALUE.d[i]))); \ @@ -281,6 +282,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(LEFT.p[i], __slang_noop_cast<TYPE.Differential>(LEFT.d[i])), \ @@ -292,6 +294,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \ vector<TYPE, COUNT>.Differential d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \ @@ -302,6 +305,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma #define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \ vector<TYPE, COUNT>.Differential left_d_result, right_d_result; \ + [ForceUnroll] \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair<TYPE> left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \ @@ -705,6 +709,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair { T result = T(0); T.Differential d_result = T.dzero(); + [ForceUnroll] for (int i = 0; i < N; ++i) { result = result + dpx.p[i] * dpy.p[i]; @@ -719,6 +724,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int> void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut) { vector<T, N>.Differential x_d_result, y_d_result; + [ForceUnroll] for (int i = 0; i < N; ++i) { x_d_result[i] = dpy.p[i] * __slang_noop_cast<T>(dOut); |
