diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/core.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 77 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 167 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-collect-global-uniforms.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-entry-point-uniforms.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-init-local-var.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 29 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-legalize-types.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-legalize-types.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 8 |
20 files changed, 354 insertions, 85 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 5265d6cb6..c272da75d 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -3141,3 +3141,9 @@ attribute_syntax [payload] : PayloadAttribute; __attributeTarget(DeclBase) attribute_syntax [deprecated(message: String)] : DeprecatedAttribute; + +__attributeTarget(FunctionDeclBase) +attribute_syntax [PreferRecompute] : PreferRecomputeAttribute; + +__attributeTarget(FunctionDeclBase) +attribute_syntax [PreferCheckpoint] : PreferCheckpointAttribute; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index cb87156f5..f8b36a3ac 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -362,6 +362,7 @@ extension Array<T, N> : IDifferentiable __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] [ForwardDerivativeOf(transpose)] +[PreferRecompute] DifferentialPair<matrix<T, M, N>> __d_transpose(DifferentialPair<matrix<T, N, M>> m) { return DifferentialPair<matrix<T, M, N>>(transpose(m.p), transpose(m.d)); @@ -370,6 +371,7 @@ DifferentialPair<matrix<T, M, N>> __d_transpose(DifferentialPair<matrix<T, N, M> __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] [BackwardDerivativeOf(transpose)] +[PreferRecompute] void __d_transpose(inout DifferentialPair<matrix<T, N, M>> m, matrix<T, M, N>.Differential dOut) { m = diffPair(m.p, transpose(dOut)); @@ -379,6 +381,7 @@ void __d_transpose(inout DifferentialPair<matrix<T, N, M>> m, matrix<T, M, N>.Di __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] [ForwardDerivativeOf(mul)] +[PreferRecompute] DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right) { let primal = mul(left.p, right.p); @@ -388,6 +391,7 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [BackwardDerivativeOf(mul)] +[PreferRecompute] void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut) { vector<T, N>.Differential left_d_result; @@ -410,6 +414,7 @@ void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<m __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] [ForwardDerivativeOf(mul)] +[PreferRecompute] DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, DifferentialPair<vector<T,M>> right) { let primal = mul(left.p, right.p); @@ -419,6 +424,7 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [BackwardDerivativeOf(mul)] +[PreferRecompute] void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut) { matrix<T, N, M>.Differential left_d_result; @@ -441,6 +447,7 @@ void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPai __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> [ForceInline] [ForwardDerivativeOf(mul)] +[PreferRecompute] DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right) { let primal = mul(left.p, right.p); @@ -450,6 +457,7 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, Differ __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> [BackwardDerivativeOf(mul)] +[PreferRecompute] void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut) { matrix<T, R, N>.Differential left_d_result; @@ -480,6 +488,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma // Vector dot product __generic<T : __BuiltinFloatingPointType, let N : int> [ForwardDerivativeOf(dot)] +[PreferRecompute] DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) { T result = T(0); @@ -496,6 +505,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDerivativeOf(dot)] +[PreferRecompute] void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut) { vector<T, N>.Differential x_d_result, y_d_result; @@ -512,6 +522,7 @@ void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<ve // Cross product __generic<T : __BuiltinFloatingPointType> [ForwardDerivativeOf(cross)] +[PreferRecompute] DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, DifferentialPair<vector<T, 3>> b) { /* @@ -539,6 +550,7 @@ DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, Diffe __generic<T : __BuiltinFloatingPointType> [BackwardDerivativeOf(cross)] +[PreferRecompute] void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<vector<T, 3>> b, vector<T, 3>.Differential dOut) { /* @@ -560,7 +572,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve #define VECTOR_MATRIX_BINARY_DIFF_IMPL(NAME) \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector( \ DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) \ @@ -578,7 +590,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<vector<T, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ DifferentialPair<matrix<T, M, N>> dpx, DifferentialPair<matrix<T, M, N>> dpy) \ @@ -597,7 +609,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<matrix<T, M, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ inout DifferentialPair<vector<T, N>> dpx, \ @@ -617,7 +629,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpy = diffPair(dpy.p, right_d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ inout DifferentialPair<matrix<T, M, N>> dpx, \ @@ -640,7 +652,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve #define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector( \ DifferentialPair<vector<T, N>> dpx, \ @@ -661,7 +673,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<vector<T, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ DifferentialPair<matrix<T, M, N>> dpx, \ @@ -683,7 +695,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<matrix<T, M, N>>(result, d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ inout DifferentialPair<vector<T, N>> dpx, \ @@ -708,7 +720,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpz = diffPair(dpz.p, right_d_result); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ inout DifferentialPair<matrix<T, M, N>> dpx, \ @@ -736,7 +748,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve #define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \ __generic<T : __BuiltinFloatingPointType> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \ { \ @@ -744,7 +756,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \ { \ @@ -752,7 +764,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \ { \ @@ -763,10 +775,10 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve var dpx = diffPair(dpm.p[i], dpm.d[i]); \ diff[i] = FWD_DIFF_FUNC; \ } \ - return diffPair(NAME(dpm.p), diff); \ + return diffPair(NAME(dpm.p), diff); \ } \ __generic<T : __BuiltinFloatingPointType> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \ { \ @@ -774,7 +786,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \ @@ -783,7 +795,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable] \ + [BackwardDifferentiable][PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \ @@ -848,6 +860,7 @@ SIMPLE_UNARY_DERIVATIVE_IMPL(atan, T(1.0) / (T(1.0) + dpx.p * dpx.p)) // Atan2 __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [ForwardDerivativeOf(atan2)] DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx) { @@ -860,6 +873,7 @@ DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx) __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [BackwardDerivativeOf(atan2)] void __d_atan2(inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpx, T.Differential dOut) { @@ -872,6 +886,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2) // fmod __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [ForwardDerivativeOf(fmod)] DifferentialPair<T> __d_fmod(DifferentialPair<T> x, DifferentialPair<T> y) { @@ -879,6 +894,7 @@ DifferentialPair<T> __d_fmod(DifferentialPair<T> x, DifferentialPair<T> y) } __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [BackwardDerivativeOf(fmod)] void __d_fmod(inout DifferentialPair<T> x, inout DifferentialPair<T> y, T.Differential dOut) { @@ -890,6 +906,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(fmod) // Raise to a power __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [ForwardDerivativeOf(pow)] DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy) { @@ -910,6 +927,7 @@ DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy) __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [BackwardDerivativeOf(pow)] void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) { @@ -936,6 +954,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(pow) // Maximum __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [ForwardDerivativeOf(max)] DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy) { @@ -947,6 +966,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy) __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [BackwardDerivativeOf(max)] void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) { @@ -959,6 +979,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(max) // Minimum __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [ForwardDerivativeOf(min)] DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy) { @@ -970,6 +991,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy) __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [BackwardDerivativeOf(min)] void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) { @@ -982,6 +1004,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(min) // Lerp __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [ForwardDerivativeOf(lerp)] DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dps) { @@ -992,6 +1015,7 @@ DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, D } __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [BackwardDerivativeOf(lerp)] void __d_lerp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dps, T.Differential dOut) { @@ -1004,6 +1028,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp) // Clamp __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [ForwardDerivativeOf(clamp)] DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin, DifferentialPair<T> dpMax) { @@ -1013,6 +1038,7 @@ DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin } __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] [BackwardDerivativeOf(clamp)] void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, inout DifferentialPair<T> dpMax, T.Differential dOut) { @@ -1025,6 +1051,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp) // fma [BackwardDifferentiable] [ForwardDerivativeOf(fma)] +[PreferRecompute] DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair<double> dpy, DifferentialPair<double> dpz) { return DifferentialPair<double>( @@ -1033,6 +1060,7 @@ DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair< } [BackwardDifferentiable] [BackwardDerivativeOf(fma)] +[PreferRecompute] void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double> dpy, inout DifferentialPair<double> dpz, double dOut) { dpx = diffPair(dpx.p, dpy.p * dOut); @@ -1042,6 +1070,7 @@ void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double> __generic<let N : int> [BackwardDifferentiable] [ForwardDerivativeOf(fma)] +[PreferRecompute] DifferentialPair<vector<double, N>> __d_fma_vector( DifferentialPair<vector<double, N>> dpx, DifferentialPair<vector<double, N>> dpy, @@ -1063,6 +1092,7 @@ DifferentialPair<vector<double, N>> __d_fma_vector( __generic<let N : int> [BackwardDifferentiable] [BackwardDerivativeOf(fma)] +[PreferRecompute] void __d_fma_vector( inout DifferentialPair<vector<double, N>> dpx, inout DifferentialPair<vector<double, N>> dpy, @@ -1089,6 +1119,7 @@ void __d_fma_vector( __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] [ForwardDerivativeOf(mad)] +[PreferRecompute] DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dpz) { return DifferentialPair<T>( @@ -1098,6 +1129,7 @@ DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, Di __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] [BackwardDerivativeOf(mad)] +[PreferRecompute] void __d_mad(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut) { dpx = diffPair(dpx.p, T.dmul(dpy.p, dOut)); @@ -1109,6 +1141,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(mad) // Smoothstep __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PreferRecompute] T __smoothstep_impl(T minVal, T maxVal, T x) { let t = saturate((x - minVal) / (maxVal - minVal)); @@ -1117,6 +1150,7 @@ T __smoothstep_impl(T minVal, T maxVal, T x) __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] [ForwardDerivativeOf(smoothstep)] +[PreferRecompute] DifferentialPair<T> __d_smoothstep(DifferentialPair<T> minVal, DifferentialPair<T> maxVal, DifferentialPair<T> x) { return __fwd_diff(__smoothstep_impl)(minVal, maxVal, x); @@ -1124,6 +1158,7 @@ DifferentialPair<T> __d_smoothstep(DifferentialPair<T> minVal, DifferentialPair< __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] [BackwardDerivativeOf(smoothstep)] +[PreferRecompute] void __d_smoothstep(inout DifferentialPair<T> minVal, inout DifferentialPair<T> maxVal, inout DifferentialPair<T> x, T.Differential dOut) { __bwd_diff(__smoothstep_impl)(minVal, maxVal, x, dOut); @@ -1133,6 +1168,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(smoothstep) // Vector length __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] +[PreferRecompute] T __length_impl(vector<T, N> x) { T len = T(0.0); @@ -1147,6 +1183,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] [ForwardDerivativeOf(length)] [ForceInline] +[PreferRecompute] DifferentialPair<T> __d_length(DifferentialPair<vector<T, N>> x) { return __fwd_diff(__length_impl)(x); @@ -1156,6 +1193,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] [BackwardDerivativeOf(length)] [ForceInline] +[PreferRecompute] void __d_length(inout DifferentialPair<vector<T, N>> x, T.Differential dOut) { return __bwd_diff(__length_impl)(x, dOut); @@ -1164,6 +1202,7 @@ void __d_length(inout DifferentialPair<vector<T, N>> x, T.Differential dOut) // Vector distance __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] +[PreferRecompute] T __distance_impl(vector<T, N> x, vector<T, N> y) { return length(y - x); @@ -1172,6 +1211,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] [ForwardDerivativeOf(distance)] [ForceInline] +[PreferRecompute] DifferentialPair<T> __d_distance(DifferentialPair<vector<T, N>> x, DifferentialPair<vector<T, N>> y) { return __fwd_diff(__distance_impl)(x, y); @@ -1181,6 +1221,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] [BackwardDerivativeOf(distance)] [ForceInline] +[PreferRecompute] void __d_distance(inout DifferentialPair<vector<T, N>> x, inout DifferentialPair<vector<T, N>> y, T.Differential dOut) { return __bwd_diff(__distance_impl)(x, y, dOut); @@ -1189,6 +1230,7 @@ void __d_distance(inout DifferentialPair<vector<T, N>> x, inout DifferentialPair // Vector normalize __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] +[PreferRecompute] vector<T, N> __normalize_impl(vector<T, N> x) { let r = T(1.0) / length(x); @@ -1198,6 +1240,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] [ForwardDerivativeOf(normalize)] [ForceInline] +[PreferRecompute] DifferentialPair<vector<T, N>> __d_normalize(DifferentialPair<vector<T, N>> x) { return __fwd_diff(__normalize_impl)(x); @@ -1206,6 +1249,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] [BackwardDerivativeOf(normalize)] [ForceInline] +[PreferRecompute] void __d_distance(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dOut) { return __bwd_diff(__normalize_impl)(x, dOut); @@ -1264,6 +1308,7 @@ void __d_refract(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair< __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] [PrimalSubstituteOf(sincos)] +[PreferRecompute] void __sincos_impl(T x, out T s, out T c) { s = sin(x); @@ -1272,6 +1317,7 @@ void __sincos_impl(T x, out T s, out T c) __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] +[PreferRecompute] [PrimalSubstituteOf(sincos)] void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c) { @@ -1282,6 +1328,7 @@ void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c) __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [BackwardDifferentiable] [PrimalSubstituteOf(sincos)] +[PreferRecompute] void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M> s, out matrix<T, N, M> c) { s = sin(x); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index a47c0cb25..1580a7a23 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -1744,6 +1744,7 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(spirv_direct, "OpExtInst resultType resultId glsl450 Determinant _0") [__readNone] +[PreferCheckpoint] T determinant(matrix<T,N,N> m); // Barrier for device memory @@ -3810,6 +3811,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) [__readNone] +[PreferRecompute] matrix<T, M, N> transpose(matrix<T, N, M> x) { matrix<T,M,N> result; @@ -3822,6 +3824,7 @@ __generic<T : __BuiltinIntegerType, let N : int, let M : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) [__readNone] +[PreferRecompute] matrix<T, M, N> transpose(matrix<T, N, M> x) { matrix<T, M, N> result; @@ -3834,6 +3837,7 @@ __generic<T : __BuiltinLogicalType, let N : int, let M : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) [__readNone] +[PreferRecompute] matrix<T, M, N> transpose(matrix<T, N, M> x) { matrix<T, M, N> result; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index caa439704..6f79e900f 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1078,6 +1078,16 @@ class CudaHostAttribute : public Attribute SLANG_AST_CLASS(CudaHostAttribute) }; +class PreferRecomputeAttribute : public Attribute +{ + SLANG_AST_CLASS(PreferRecomputeAttribute) +}; + +class PreferCheckpointAttribute : public Attribute +{ + SLANG_AST_CLASS(PreferCheckpointAttribute) +}; + class DerivativeMemberAttribute : public Attribute { SLANG_AST_CLASS(DerivativeMemberAttribute) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 0b88bd057..4412eccb8 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -537,17 +537,12 @@ Result linkAndOptimizeIR( } else { -#if 0 - // On CPU/CUDA targets, we simply elminate any empty types. - // TODO: disable for now, since the CPU compute shader - // trampoline is still hard coded to assume there are - // entrypoint and global parameters. renable when - // we fix that logic. + // On CPU/CUDA targets, we simply elminate any empty types if + // they are not part of public interface. legalizeEmptyTypes( irModule, sink); eliminateDeadCode(irModule); -#endif } // Once specialization and type legalization have been performed, diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 598a1dde9..6a87550d2 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -12,6 +12,7 @@ #include "slang-ir-ssa-simplification.h" #include "slang-ir-validate.h" #include "slang-ir-inline.h" +#include "slang-ir-init-local-var.h" namespace Slang { @@ -1593,6 +1594,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) removeLinkageDecorations(func); performForceInlining(func); + + initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func); AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &differentiableTypeConformanceContext; diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 5d11b7fb3..af47ffbca 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -35,20 +35,6 @@ static bool isDifferentialBlock(IRBlock* block) return block->findDecoration<IRDifferentialInstDecoration>(); } -static Dictionary<IRBlock*, IRBlock*> reconstructDiffBlockMap(IRGlobalValueWithCode* func) -{ - Dictionary<IRBlock*, IRBlock*> diffBlockMap; - for (auto block : func->getBlocks()) - { - if (auto diffDecor = block->findDecoration<IRDifferentialInstDecoration>()) - { - if (diffDecor->getPrimalType()) - diffBlockMap[as<IRBlock>(diffDecor->getPrimalInst())] = block; - } - } - return diffBlockMap; -} - static IRBlock* getLoopRegionBodyBlock(IRLoop* loop) { auto condBlock = as<IRBlock>(loop->getTargetBlock()); @@ -397,7 +383,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( auto branchInst = as<IRUnconditionalBranch>(predecessor->getTerminator()); SLANG_ASSERT(branchInst->getOperandCount() > paramIndex); - workList.add(&branchInst->getOperands()[paramIndex]); + workList.add(&branchInst->getArgs()[paramIndex]); } } else @@ -511,17 +497,10 @@ void applyToInst( if (as<IRParam>(inst)) { // Can completely ignore first block parameters - if (getBlock(inst) != getBlock(inst)->getParent()->getFirstBlock()) + if (getBlock(inst) == getBlock(inst)->getParent()->getFirstBlock()) { - // TODO: We would need to clone in the control-flow for each region (without nested loops) - // prior to this, and then hoist this parameter into the within-region block, otherwise - // this parameter will not be visible to transposed insts. - // This will also include adding an extra case to 'ensurePrimalAvailability': if both insts - // are withing the _same_ indexed region, skip the indexed store/load and use a simple var. - // - SLANG_UNIMPLEMENTED_X("Parameter recompute is not currently supported"); + return; } - return; } auto recomputeInst = cloneCtx->cloneInstOutOfOrder(builder, inst); @@ -565,10 +544,6 @@ void applyCheckpointSet( Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock, IROutOfOrderCloneContext* cloneCtx) { - // Reconstruct diff block map. - Dictionary<IRBlock*, IRBlock*> diffBlockMap = reconstructDiffBlockMap(func); - - for (auto use : pendingUses) cloneCtx->pendingUses.add(use); @@ -600,13 +575,24 @@ void applyCheckpointSet( if (block->findDecoration<IRRecomputeBlockDecoration>()) continue; - auto diffBlock = as<IRBlock>(diffBlockMap[block]); + IRBlock* recomputeBlock = block; + mapPrimalBlockToRecomputeBlock.tryGetValue(block, recomputeBlock); + auto recomputeInsertBeforeInst = recomputeBlock->getFirstOrdinaryInst(); IRBuilder builder(func->getModule()); UIndex ii = 0; for (auto param : block->getParams()) { - builder.setInsertBefore(diffBlock->getFirstOrdinaryInst()); + builder.setInsertBefore(recomputeInsertBeforeInst); + bool isRecomputed = checkpointInfo->recomputeSet.contains(param); + bool isInverted = checkpointInfo->invertSet.contains(param); + + if (!isRecomputed && !isInverted) + continue; + + SLANG_RELEASE_ASSERT( + recomputeBlock != block && + "recomputed param should belong to block that has recompute block."); // Apply checkpoint rule to the parameter itself. applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, param); @@ -617,37 +603,36 @@ void applyCheckpointSet( { if (predecessorSet.contains(predecessor)) continue; - predecessorSet.add(predecessor); + + auto primalPhiArg = as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii); + auto recomputePredecessor = mapPrimalBlockToRecomputeBlock[predecessor].getValue(); - auto diffPredecessor = as<IRBlock>(diffBlockMap[block]); - - if (checkpointInfo->recomputeSet.contains(param)) + // For now, find the primal phi argument in this predecessor, + // and stick it into the recompute predecessor's branch inst. We + // will use a patch-up pass in the end to replace all these + // arguments to their recomputed versions if they exist. + + if (isRecomputed) { - IRInst* terminator = diffPredecessor->getTerminator(); + IRInst* terminator = recomputeBlock->getTerminator(); addPhiOutputArg(&builder, - diffPredecessor, + recomputePredecessor, terminator, - as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii)); + primalPhiArg); } - - if (checkpointInfo->invertSet.contains(param)) + else if (isInverted) { - IRInst* terminator = diffPredecessor->getTerminator(); - + IRInst* terminator = recomputeBlock->getTerminator(); addPhiOutputArg(&builder, - diffPredecessor, + recomputePredecessor, terminator, - as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii)); + primalPhiArg); } } - ii++; } - IRBlock* recomputeBlock = block; - mapPrimalBlockToRecomputeBlock.tryGetValue(block, recomputeBlock); - auto recomputeInsertBeforeInst = recomputeBlock->getFirstOrdinaryInst(); for (auto child : block->getChildren()) { @@ -670,6 +655,25 @@ void applyCheckpointSet( applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child); } } + + // Go through phi arguments in recompute blocks and replace them to + // recomputed insts if they exist. + for (auto block : func->getBlocks()) + { + if (!block->findDecoration<IRRecomputeBlockDecoration>()) + continue; + auto terminator = block->getTerminator(); + for (UInt i = 0; i < terminator->getOperandCount(); i++) + { + auto arg = terminator->getOperand(i); + if (as<IRBlock>(arg)) + continue; + if (auto recomputeArg = cloneCtx->cloneEnv.mapOldValToNew.tryGetValue(arg)) + { + terminator->setOperand(i, *recomputeArg); + } + } + } } IRType* getTypeForLocalStorage( @@ -1313,7 +1317,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func) void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func) { - domTree = computeDominatorTree(func); + SLANG_UNUSED(func) return; } @@ -1412,6 +1416,30 @@ static bool shouldStoreVar(IRVar* var) return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())); } +enum CheckpointPreference +{ + None, + PreferCheckpoint, + PreferRecompute +}; + +static CheckpointPreference getCheckpointPreference(IRInst* callee) +{ + callee = getResolvedInstForDecorations(callee, true); + for (auto decor : callee->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_PreferCheckpointDecoration: + return CheckpointPreference::PreferCheckpoint; + case kIROp_PreferRecomputeDecoration: + case kIROp_TargetIntrinsicDecoration: + return CheckpointPreference::PreferRecompute; + } + } + return CheckpointPreference::None; +} + static bool shouldStoreInst(IRInst* inst) { if (!inst->getDataType()) @@ -1422,9 +1450,9 @@ static bool shouldStoreInst(IRInst* inst) if (!canTypeBeStored(inst->getDataType())) return false; - // Never store certain opcodes. switch (inst->getOp()) { + // Never store these opcodes because they are not real computations. case kIROp_CastFloatToInt: case kIROp_CastIntToFloat: case kIROp_IntCast: @@ -1437,25 +1465,39 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_MakeStruct: case kIROp_MakeTuple: case kIROp_MakeArray: + case kIROp_MakeVector: + case kIROp_MakeMatrix: case kIROp_MakeArrayFromElement: case kIROp_MakeDifferentialPair: + case kIROp_MakeDifferentialPairUserCode: case kIROp_MakeOptionalNone: case kIROp_MakeOptionalValue: + case kIROp_MakeExistential: case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferentialUserCode: + case kIROp_DifferentialPairGetPrimalUserCode: case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialWitnessTable: case kIROp_undefined: case kIROp_GetSequentialID: + case kIROp_GetStringHash: case kIROp_Specialize: case kIROp_LookupWitness: + case kIROp_Param: + case kIROp_DetachDerivative: + return false; + + // Never store these op codes because they are trivial to compute. case kIROp_Add: case kIROp_Sub: case kIROp_Mul: case kIROp_Div: case kIROp_Neg: case kIROp_Geq: + case kIROp_FRem: + case kIROp_IRem: case kIROp_Leq: case kIROp_Neq: case kIROp_Eql: @@ -1471,6 +1513,7 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_Lsh: case kIROp_Rsh: return false; + case kIROp_GetElement: case kIROp_FieldExtract: case kIROp_swizzle: @@ -1479,8 +1522,12 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_GetOptionalValue: case kIROp_MatrixReshape: case kIROp_VectorReshape: - // If the operand is already stored, don't store the result of these insts. - if (inst->getOperand(0)->findDecoration<IRPrimalValueStructKeyDecoration>()) + case kIROp_GetTupleElement: + return false; + + case kIROp_Call: + // If the callee prefers recompute policy, don't store. + if (getCheckpointPreference(inst->getOperand(0)) == CheckpointPreference::PreferRecompute) { return false; } @@ -1495,9 +1542,8 @@ static bool shouldStoreInst(IRInst* inst) return true; } -bool canRecompute(IRDominatorTree* domTree, IRUse* use) +bool canRecompute(IRUse* use) { - SLANG_UNUSED(domTree); if (auto load = as<IRLoad>(use->get())) { // Generally, we cannot recompute a load(ptr), since ptr may be modified @@ -1518,7 +1564,18 @@ bool canRecompute(IRDominatorTree* domTree, IRUse* use) auto param = as<IRParam>(use->get()); if (!param) return true; - return false; + + // We can recompute a phi param if it is not in a loop start block. + auto parentBlock = as<IRBlock>(param->getParent()); + for (auto pred : parentBlock->getPredecessors()) + { + if (auto loop = as<IRLoop>(pred->getTerminator())) + { + if (loop->getTargetBlock() == parentBlock) + return false; + } + } + return true; } HoistResult DefaultCheckpointPolicy::classify(IRUse* use) @@ -1541,7 +1598,7 @@ HoistResult DefaultCheckpointPolicy::classify(IRUse* use) { // We may not be able to recompute due to limitations of // the unzip pass. If so we will store the result. - if (canRecompute(domTree, use)) + if (canRecompute(use)) return HoistResult::recompute(use->get()); // The fallback is to store. diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index fbac42c43..c0b56126d 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -286,8 +286,6 @@ namespace Slang virtual void preparePolicy(IRGlobalValueWithCode* func); virtual HoistResult classify(IRUse* use); - - RefPtr<IRDominatorTree> domTree; }; RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index e3575aceb..2994a8c31 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -142,6 +142,13 @@ namespace Slang builder->setInsertBefore(existingPrimalFunc); builder->setInsertInto(existingPrimalFunc); + + auto checkpointHint = udf->findDecoration<IRCheckpointHintDecoration>(); + if (!checkpointHint) + checkpointHint = originalFunc->findDecoration<IRCheckpointHintDecoration>(); + if (checkpointHint) + builder->addDecoration(existingPrimalFunc, checkpointHint->getOp()); + builder->emitBlock(); params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc)); params.removeLast(); @@ -759,6 +766,14 @@ namespace Slang auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc, true); builder->setInsertBefore(primalFunc); + + // Copy over checkpoint preference hints. + { + auto diffPrimalFunc = getResolvedInstForDecorations(primalFuncGeneric, true); + auto checkpointHint = primalFunc->findDecoration<IRCheckpointHintDecoration>(); + if (checkpointHint) + builder->addDecoration(diffPrimalFunc, checkpointHint->getOp()); + } if (auto existingDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>()) { diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index de806d653..78ea2e79a 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -140,6 +140,7 @@ struct CollectGlobalUniformParametersContext // auto wrapperStructType = builder->createStructType(); builder->addNameHintDecoration(wrapperStructType, UnownedTerminatedStringSlice("GlobalParams")); + builder->addBinaryInterfaceTypeDecoration(wrapperStructType); // If the computed layout used a bare `struct` type, then we will use // our `GlobalParams` struct as-is, but if the layout involved an diff --git a/source/slang/slang-ir-entry-point-uniforms.cpp b/source/slang/slang-ir-entry-point-uniforms.cpp index d10aa7ff2..518f6ae2c 100644 --- a/source/slang/slang-ir-entry-point-uniforms.cpp +++ b/source/slang/slang-ir-entry-point-uniforms.cpp @@ -387,6 +387,7 @@ struct CollectEntryPointUniformParams : PerEntryPointPass builder.setInsertBefore(m_entryPoint.func); paramStructType = builder.createStructType(); builder.addNameHintDecoration(paramStructType, UnownedTerminatedStringSlice("EntryPointParams")); + builder.addBinaryInterfaceTypeDecoration(paramStructType); if( needConstantBuffer ) { diff --git a/source/slang/slang-ir-init-local-var.cpp b/source/slang/slang-ir-init-local-var.cpp index 645344f2e..145d8d569 100644 --- a/source/slang/slang-ir-init-local-var.cpp +++ b/source/slang/slang-ir-init-local-var.cpp @@ -9,16 +9,42 @@ namespace Slang void initializeLocalVariables(IRModule* module, IRGlobalValueWithCode* func) { IRBuilder builder(module); + HashSet<IRInst*> userSet; for (auto block : func->getBlocks()) { for (auto inst : block->getChildren()) { if (inst->getOp() == kIROp_Var) { - auto firstUse = inst->firstUse; - bool initialized = - (firstUse && firstUse->getUser()->getOp() == kIROp_Store && - firstUse->getUser()->getParent() == inst->getParent()); + bool initialized = false; + userSet.clear(); + for (auto use = inst->firstUse; use; use = use->nextUse) + userSet.add(use->getUser()); + + // Check if the variable is initialized in the same block. + for (auto nextInst = inst->next; nextInst; nextInst = nextInst->next) + { + switch (nextInst->getOp()) + { + case kIROp_Store: + if (nextInst->getOperand(0) == inst) + initialized = true; + break; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + continue; + default: + if (userSet.contains(nextInst)) + { + // We encountered a user of the variable before it was initialized. + // Break out of the loop and insert the initialization code. + goto breakLabel; + } + } + if (initialized) + break; + } + breakLabel:; if (initialized) continue; builder.setInsertAfter(inst); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 11143cebb..b00972cce 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -635,6 +635,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(InterpolationModeDecoration, interpolationMode, 1, 0) INST(NameHintDecoration, nameHint, 1, 0) + // Marks a type as being used as binary interface (e.g. shader parameters). + // This prevents the legalizeEmptyType() pass from eliminating it on C++/CUDA targets. + INST(BinaryInterfaceTypeDecoration, BinaryInterfaceType, 0, 0) + /** The decorated _instruction_ is transitory. Such a decoration should NEVER be found on an output instruction a module. Typically used mark an instruction so can be specially handled - say when creating a IRConstant literal, and the payload of needs to be special cased for lookup. */ @@ -846,6 +850,15 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Treat a function as differentiable function, or an IRCall as a call to a differentiable function. INST(TreatAsDifferentiableDecoration, treatAsDifferentiableDecoration, 0, 0) + + /// Hint that the result from a call to the decorated function should be stored in backward prop function. + INST(PreferCheckpointDecoration, PreferCheckpointDecoration, 0, 0) + + /// Hint that the result from a call to the decorated function should be recomputed in backward prop function. + INST(PreferRecomputeDecoration, PreferRecomputeDecoration, 0, 0) + + INST_RANGE(CheckpointHintDecoration, PreferCheckpointDecoration, PreferRecomputeDecoration) + /// Marks a class type as a COM interface implementation, which enables /// the witness table to be easily picked up by emit. INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f515baf8d..4495d1a3d 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -712,6 +712,30 @@ struct IRBackwardDerivativeDecoration : IRDecoration IRInst* getBackwardDerivativeFunc() { return getOperand(0); } }; +struct IRCheckpointHintDecoration : public IRDecoration +{ + IR_PARENT_ISA(CheckpointHintDecoration) +}; + +struct IRPreferRecomputeDecoration : IRCheckpointHintDecoration +{ + enum + { + kOp = kIROp_PreferRecomputeDecoration + }; + IR_LEAF_ISA(PreferRecomputeDecoration) +}; + +struct IRPreferCheckpointDecoration : IRCheckpointHintDecoration +{ + enum + { + kOp = kIROp_PreferCheckpointDecoration + }; + IR_LEAF_ISA(PreferCheckpointDecoration) +}; + + struct IRLoopCounterDecoration : IRDecoration { enum @@ -3651,6 +3675,11 @@ public: addNameHintDecoration(value, getStringValue(text)); } + void addBinaryInterfaceTypeDecoration(IRInst* value) + { + addDecoration(value, kIROp_BinaryInterfaceTypeDecoration); + } + void addGLSLOuterArrayDecoration(IRInst* value, UnownedStringSlice const& text) { addDecoration(value, kIROp_GLSLOuterArrayDecoration, getStringValue(text)); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 5bfbfe994..75e4bdacf 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -3730,6 +3730,11 @@ struct IRResourceTypeLegalizationContext : IRTypeLegalizationContext return isResourceType(type); } + bool isSimpleType(IRType*) override + { + return false; + } + LegalType createLegalUniformBufferType( IROp op, LegalType legalElementType) override @@ -3761,6 +3766,11 @@ struct IRExistentialTypeLegalizationContext : IRTypeLegalizationContext return as<IRPseudoPtrType>(type) != nullptr; } + bool isSimpleType(IRType*) override + { + return false; + } + LegalType createLegalUniformBufferType( IROp op, LegalType legalElementType) override @@ -3779,6 +3789,9 @@ struct IRExistentialTypeLegalizationContext : IRTypeLegalizationContext } }; +// This customization of type legalization is used to remove empty +// structs from cpp/cuda programs if the empty type isn't used in +// a public function signature. struct IREmptyTypeLegalizationContext : IRTypeLegalizationContext { IREmptyTypeLegalizationContext(IRModule* module) @@ -3790,6 +3803,26 @@ struct IREmptyTypeLegalizationContext : IRTypeLegalizationContext return false; } + bool isSimpleType(IRType* type) override + { + // If type is used as public interface, then treat it as simple. + for (auto decor : type->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_LayoutDecoration: + case kIROp_PublicDecoration: + case kIROp_ExternCppDecoration: + case kIROp_DllImportDecoration: + case kIROp_DllExportDecoration: + case kIROp_HLSLExportDecoration: + case kIROp_BinaryInterfaceTypeDecoration: + return true; + } + } + return false; + } + LegalType createLegalUniformBufferType(IROp, LegalType) override { return LegalType(); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9225faf6b..e74a57424 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -7478,7 +7478,7 @@ namespace Slang return nullptr; } - IRInst* getResolvedInstForDecorations(IRInst* inst) + IRInst* getResolvedInstForDecorations(IRInst* inst, bool resolveThroughDifferentiation) { IRInst* candidate = inst; for(;;) @@ -7488,6 +7488,20 @@ namespace Slang candidate = specInst->getBase(); continue; } + if (resolveThroughDifferentiation) + { + switch (candidate->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + candidate = candidate->getOperand(0); + continue; + default: + break; + } + } if( auto genericInst = as<IRGeneric>(candidate) ) { if( auto returnVal = findGenericReturnVal(genericInst) ) diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 47aa5333a..3dbd0c773 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1938,7 +1938,7 @@ IRInst* findSpecializeReturnVal(IRSpecialize* specialize); // then try to chase down the generic being specialized, and what // it seems to return). // -IRInst* getResolvedInstForDecorations(IRInst* inst); +IRInst* getResolvedInstForDecorations(IRInst* inst, bool resolveThroughDifferentiation = false); // The IR module itself is represented as an instruction, which // serves at the root of the tree of all instructions in the module. diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index e17fa54bb..57735d459 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -1115,6 +1115,9 @@ LegalType legalizeTypeImpl( // if(type->findDecoration<IRTargetIntrinsicDecoration>()) return LegalType::simple(type); + + if (context->isSimpleType(type)) + return LegalType::simple(type); context->builder->setInsertBefore(type); diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h index fd8176889..17029b6b6 100644 --- a/source/slang/slang-legalize-types.h +++ b/source/slang/slang-legalize-types.h @@ -640,6 +640,12 @@ struct IRTypeLegalizationContext /// types will get moved out of the `struc` itself. virtual bool isSpecialType(IRType* type) = 0; + /// Customization point to decide what types are "simple." + /// + /// When a type is "simple" it means that it should not be changed + /// during legalization. + virtual bool isSimpleType(IRType* type) = 0; + /// Customization point to construct uniform-buffer/block types. /// /// This function will only be called if `legalElementType` is diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 1ec508bfa..d644d01c7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8848,6 +8848,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); } + else if (as<PreferCheckpointAttribute>(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_PreferCheckpointDecoration); + } + else if (as<PreferRecomputeAttribute>(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_PreferRecomputeDecoration); + } } // For convenience, ensure that any additional global // values that were emitted while outputting the function |
