summaryrefslogtreecommitdiff
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/core.meta.slang6
-rw-r--r--source/slang/diff.meta.slang77
-rw-r--r--source/slang/hlsl.meta.slang4
-rw-r--r--source/slang/slang-ast-modifier.h10
-rw-r--r--source/slang/slang-emit.cpp9
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp167
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp15
-rw-r--r--source/slang/slang-ir-collect-global-uniforms.cpp1
-rw-r--r--source/slang/slang-ir-entry-point-uniforms.cpp1
-rw-r--r--source/slang/slang-ir-init-local-var.cpp34
-rw-r--r--source/slang/slang-ir-inst-defs.h13
-rw-r--r--source/slang/slang-ir-insts.h29
-rw-r--r--source/slang/slang-ir-legalize-types.cpp33
-rw-r--r--source/slang/slang-ir.cpp16
-rw-r--r--source/slang/slang-ir.h2
-rw-r--r--source/slang/slang-legalize-types.cpp3
-rw-r--r--source/slang/slang-legalize-types.h6
-rw-r--r--source/slang/slang-lower-to-ir.cpp8
20 files changed, 354 insertions, 85 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 5265d6cb6..c272da75d 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -3141,3 +3141,9 @@ attribute_syntax [payload] : PayloadAttribute;
__attributeTarget(DeclBase)
attribute_syntax [deprecated(message: String)] : DeprecatedAttribute;
+
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [PreferRecompute] : PreferRecomputeAttribute;
+
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [PreferCheckpoint] : PreferCheckpointAttribute;
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index cb87156f5..f8b36a3ac 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -362,6 +362,7 @@ extension Array<T, N> : IDifferentiable
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(transpose)]
+[PreferRecompute]
DifferentialPair<matrix<T, M, N>> __d_transpose(DifferentialPair<matrix<T, N, M>> m)
{
return DifferentialPair<matrix<T, M, N>>(transpose(m.p), transpose(m.d));
@@ -370,6 +371,7 @@ DifferentialPair<matrix<T, M, N>> __d_transpose(DifferentialPair<matrix<T, N, M>
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[BackwardDerivativeOf(transpose)]
+[PreferRecompute]
void __d_transpose(inout DifferentialPair<matrix<T, N, M>> m, matrix<T, M, N>.Differential dOut)
{
m = diffPair(m.p, transpose(dOut));
@@ -379,6 +381,7 @@ void __d_transpose(inout DifferentialPair<matrix<T, N, M>> m, matrix<T, M, N>.Di
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
+[PreferRecompute]
DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right)
{
let primal = mul(left.p, right.p);
@@ -388,6 +391,7 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDerivativeOf(mul)]
+[PreferRecompute]
void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut)
{
vector<T, N>.Differential left_d_result;
@@ -410,6 +414,7 @@ void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<m
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
+[PreferRecompute]
DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, DifferentialPair<vector<T,M>> right)
{
let primal = mul(left.p, right.p);
@@ -419,6 +424,7 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDerivativeOf(mul)]
+[PreferRecompute]
void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut)
{
matrix<T, N, M>.Differential left_d_result;
@@ -441,6 +447,7 @@ void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPai
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
+[PreferRecompute]
DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right)
{
let primal = mul(left.p, right.p);
@@ -450,6 +457,7 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, Differ
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
[BackwardDerivativeOf(mul)]
+[PreferRecompute]
void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut)
{
matrix<T, R, N>.Differential left_d_result;
@@ -480,6 +488,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
// Vector dot product
__generic<T : __BuiltinFloatingPointType, let N : int>
[ForwardDerivativeOf(dot)]
+[PreferRecompute]
DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy)
{
T result = T(0);
@@ -496,6 +505,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDerivativeOf(dot)]
+[PreferRecompute]
void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut)
{
vector<T, N>.Differential x_d_result, y_d_result;
@@ -512,6 +522,7 @@ void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<ve
// Cross product
__generic<T : __BuiltinFloatingPointType>
[ForwardDerivativeOf(cross)]
+[PreferRecompute]
DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, DifferentialPair<vector<T, 3>> b)
{
/*
@@ -539,6 +550,7 @@ DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, Diffe
__generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(cross)]
+[PreferRecompute]
void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<vector<T, 3>> b, vector<T, 3>.Differential dOut)
{
/*
@@ -560,7 +572,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
#define VECTOR_MATRIX_BINARY_DIFF_IMPL(NAME) \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector( \
DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) \
@@ -578,7 +590,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
DifferentialPair<matrix<T, M, N>> dpx, DifferentialPair<matrix<T, M, N>> dpy) \
@@ -597,7 +609,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<matrix<T, M, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
inout DifferentialPair<vector<T, N>> dpx, \
@@ -617,7 +629,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpy = diffPair(dpy.p, right_d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
inout DifferentialPair<matrix<T, M, N>> dpx, \
@@ -640,7 +652,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
#define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector( \
DifferentialPair<vector<T, N>> dpx, \
@@ -661,7 +673,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
DifferentialPair<matrix<T, M, N>> dpx, \
@@ -683,7 +695,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<matrix<T, M, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
inout DifferentialPair<vector<T, N>> dpx, \
@@ -708,7 +720,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpz = diffPair(dpz.p, right_d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
inout DifferentialPair<matrix<T, M, N>> dpx, \
@@ -736,7 +748,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
#define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \
__generic<T : __BuiltinFloatingPointType> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \
{ \
@@ -744,7 +756,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \
{ \
@@ -752,7 +764,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \
{ \
@@ -763,10 +775,10 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
var dpx = diffPair(dpm.p[i], dpm.d[i]); \
diff[i] = FWD_DIFF_FUNC; \
} \
- return diffPair(NAME(dpm.p), diff); \
+ return diffPair(NAME(dpm.p), diff); \
} \
__generic<T : __BuiltinFloatingPointType> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \
{ \
@@ -774,7 +786,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \
@@ -783,7 +795,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable] \
+ [BackwardDifferentiable][PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \
@@ -848,6 +860,7 @@ SIMPLE_UNARY_DERIVATIVE_IMPL(atan, T(1.0) / (T(1.0) + dpx.p * dpx.p))
// Atan2
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[ForwardDerivativeOf(atan2)]
DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx)
{
@@ -860,6 +873,7 @@ DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx)
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[BackwardDerivativeOf(atan2)]
void __d_atan2(inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpx, T.Differential dOut)
{
@@ -872,6 +886,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2)
// fmod
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[ForwardDerivativeOf(fmod)]
DifferentialPair<T> __d_fmod(DifferentialPair<T> x, DifferentialPair<T> y)
{
@@ -879,6 +894,7 @@ DifferentialPair<T> __d_fmod(DifferentialPair<T> x, DifferentialPair<T> y)
}
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[BackwardDerivativeOf(fmod)]
void __d_fmod(inout DifferentialPair<T> x, inout DifferentialPair<T> y, T.Differential dOut)
{
@@ -890,6 +906,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(fmod)
// Raise to a power
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[ForwardDerivativeOf(pow)]
DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
@@ -910,6 +927,7 @@ DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[BackwardDerivativeOf(pow)]
void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
@@ -936,6 +954,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(pow)
// Maximum
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[ForwardDerivativeOf(max)]
DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
@@ -947,6 +966,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[BackwardDerivativeOf(max)]
void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
@@ -959,6 +979,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(max)
// Minimum
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[ForwardDerivativeOf(min)]
DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
@@ -970,6 +991,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[BackwardDerivativeOf(min)]
void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
@@ -982,6 +1004,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(min)
// Lerp
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[ForwardDerivativeOf(lerp)]
DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dps)
{
@@ -992,6 +1015,7 @@ DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, D
}
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[BackwardDerivativeOf(lerp)]
void __d_lerp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dps, T.Differential dOut)
{
@@ -1004,6 +1028,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp)
// Clamp
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[ForwardDerivativeOf(clamp)]
DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin, DifferentialPair<T> dpMax)
{
@@ -1013,6 +1038,7 @@ DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin
}
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
[BackwardDerivativeOf(clamp)]
void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, inout DifferentialPair<T> dpMax, T.Differential dOut)
{
@@ -1025,6 +1051,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp)
// fma
[BackwardDifferentiable]
[ForwardDerivativeOf(fma)]
+[PreferRecompute]
DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair<double> dpy, DifferentialPair<double> dpz)
{
return DifferentialPair<double>(
@@ -1033,6 +1060,7 @@ DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair<
}
[BackwardDifferentiable]
[BackwardDerivativeOf(fma)]
+[PreferRecompute]
void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double> dpy, inout DifferentialPair<double> dpz, double dOut)
{
dpx = diffPair(dpx.p, dpy.p * dOut);
@@ -1042,6 +1070,7 @@ void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double>
__generic<let N : int>
[BackwardDifferentiable]
[ForwardDerivativeOf(fma)]
+[PreferRecompute]
DifferentialPair<vector<double, N>> __d_fma_vector(
DifferentialPair<vector<double, N>> dpx,
DifferentialPair<vector<double, N>> dpy,
@@ -1063,6 +1092,7 @@ DifferentialPair<vector<double, N>> __d_fma_vector(
__generic<let N : int>
[BackwardDifferentiable]
[BackwardDerivativeOf(fma)]
+[PreferRecompute]
void __d_fma_vector(
inout DifferentialPair<vector<double, N>> dpx,
inout DifferentialPair<vector<double, N>> dpy,
@@ -1089,6 +1119,7 @@ void __d_fma_vector(
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[ForwardDerivativeOf(mad)]
+[PreferRecompute]
DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dpz)
{
return DifferentialPair<T>(
@@ -1098,6 +1129,7 @@ DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, Di
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[BackwardDerivativeOf(mad)]
+[PreferRecompute]
void __d_mad(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut)
{
dpx = diffPair(dpx.p, T.dmul(dpy.p, dOut));
@@ -1109,6 +1141,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(mad)
// Smoothstep
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PreferRecompute]
T __smoothstep_impl(T minVal, T maxVal, T x)
{
let t = saturate((x - minVal) / (maxVal - minVal));
@@ -1117,6 +1150,7 @@ T __smoothstep_impl(T minVal, T maxVal, T x)
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[ForwardDerivativeOf(smoothstep)]
+[PreferRecompute]
DifferentialPair<T> __d_smoothstep(DifferentialPair<T> minVal, DifferentialPair<T> maxVal, DifferentialPair<T> x)
{
return __fwd_diff(__smoothstep_impl)(minVal, maxVal, x);
@@ -1124,6 +1158,7 @@ DifferentialPair<T> __d_smoothstep(DifferentialPair<T> minVal, DifferentialPair<
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[BackwardDerivativeOf(smoothstep)]
+[PreferRecompute]
void __d_smoothstep(inout DifferentialPair<T> minVal, inout DifferentialPair<T> maxVal, inout DifferentialPair<T> x, T.Differential dOut)
{
__bwd_diff(__smoothstep_impl)(minVal, maxVal, x, dOut);
@@ -1133,6 +1168,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(smoothstep)
// Vector length
__generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
+[PreferRecompute]
T __length_impl(vector<T, N> x)
{
T len = T(0.0);
@@ -1147,6 +1183,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
[ForwardDerivativeOf(length)]
[ForceInline]
+[PreferRecompute]
DifferentialPair<T> __d_length(DifferentialPair<vector<T, N>> x)
{
return __fwd_diff(__length_impl)(x);
@@ -1156,6 +1193,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
[BackwardDerivativeOf(length)]
[ForceInline]
+[PreferRecompute]
void __d_length(inout DifferentialPair<vector<T, N>> x, T.Differential dOut)
{
return __bwd_diff(__length_impl)(x, dOut);
@@ -1164,6 +1202,7 @@ void __d_length(inout DifferentialPair<vector<T, N>> x, T.Differential dOut)
// Vector distance
__generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
+[PreferRecompute]
T __distance_impl(vector<T, N> x, vector<T, N> y)
{
return length(y - x);
@@ -1172,6 +1211,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
[ForwardDerivativeOf(distance)]
[ForceInline]
+[PreferRecompute]
DifferentialPair<T> __d_distance(DifferentialPair<vector<T, N>> x, DifferentialPair<vector<T, N>> y)
{
return __fwd_diff(__distance_impl)(x, y);
@@ -1181,6 +1221,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
[BackwardDerivativeOf(distance)]
[ForceInline]
+[PreferRecompute]
void __d_distance(inout DifferentialPair<vector<T, N>> x, inout DifferentialPair<vector<T, N>> y, T.Differential dOut)
{
return __bwd_diff(__distance_impl)(x, y, dOut);
@@ -1189,6 +1230,7 @@ void __d_distance(inout DifferentialPair<vector<T, N>> x, inout DifferentialPair
// Vector normalize
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
+[PreferRecompute]
vector<T, N> __normalize_impl(vector<T, N> x)
{
let r = T(1.0) / length(x);
@@ -1198,6 +1240,7 @@ __generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
[ForwardDerivativeOf(normalize)]
[ForceInline]
+[PreferRecompute]
DifferentialPair<vector<T, N>> __d_normalize(DifferentialPair<vector<T, N>> x)
{
return __fwd_diff(__normalize_impl)(x);
@@ -1206,6 +1249,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
[BackwardDerivativeOf(normalize)]
[ForceInline]
+[PreferRecompute]
void __d_distance(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dOut)
{
return __bwd_diff(__normalize_impl)(x, dOut);
@@ -1264,6 +1308,7 @@ void __d_refract(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair<
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[PrimalSubstituteOf(sincos)]
+[PreferRecompute]
void __sincos_impl(T x, out T s, out T c)
{
s = sin(x);
@@ -1272,6 +1317,7 @@ void __sincos_impl(T x, out T s, out T c)
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
+[PreferRecompute]
[PrimalSubstituteOf(sincos)]
void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c)
{
@@ -1282,6 +1328,7 @@ void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c)
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDifferentiable]
[PrimalSubstituteOf(sincos)]
+[PreferRecompute]
void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M> s, out matrix<T, N, M> c)
{
s = sin(x);
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index a47c0cb25..1580a7a23 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -1744,6 +1744,7 @@ __target_intrinsic(hlsl)
__target_intrinsic(glsl)
__target_intrinsic(spirv_direct, "OpExtInst resultType resultId glsl450 Determinant _0")
[__readNone]
+[PreferCheckpoint]
T determinant(matrix<T,N,N> m);
// Barrier for device memory
@@ -3810,6 +3811,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
[__readNone]
+[PreferRecompute]
matrix<T, M, N> transpose(matrix<T, N, M> x)
{
matrix<T,M,N> result;
@@ -3822,6 +3824,7 @@ __generic<T : __BuiltinIntegerType, let N : int, let M : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
[__readNone]
+[PreferRecompute]
matrix<T, M, N> transpose(matrix<T, N, M> x)
{
matrix<T, M, N> result;
@@ -3834,6 +3837,7 @@ __generic<T : __BuiltinLogicalType, let N : int, let M : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
[__readNone]
+[PreferRecompute]
matrix<T, M, N> transpose(matrix<T, N, M> x)
{
matrix<T, M, N> result;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index caa439704..6f79e900f 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1078,6 +1078,16 @@ class CudaHostAttribute : public Attribute
SLANG_AST_CLASS(CudaHostAttribute)
};
+class PreferRecomputeAttribute : public Attribute
+{
+ SLANG_AST_CLASS(PreferRecomputeAttribute)
+};
+
+class PreferCheckpointAttribute : public Attribute
+{
+ SLANG_AST_CLASS(PreferCheckpointAttribute)
+};
+
class DerivativeMemberAttribute : public Attribute
{
SLANG_AST_CLASS(DerivativeMemberAttribute)
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 0b88bd057..4412eccb8 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -537,17 +537,12 @@ Result linkAndOptimizeIR(
}
else
{
-#if 0
- // On CPU/CUDA targets, we simply elminate any empty types.
- // TODO: disable for now, since the CPU compute shader
- // trampoline is still hard coded to assume there are
- // entrypoint and global parameters. renable when
- // we fix that logic.
+ // On CPU/CUDA targets, we simply elminate any empty types if
+ // they are not part of public interface.
legalizeEmptyTypes(
irModule,
sink);
eliminateDeadCode(irModule);
-#endif
}
// Once specialization and type legalization have been performed,
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 598a1dde9..6a87550d2 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -12,6 +12,7 @@
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-validate.h"
#include "slang-ir-inline.h"
+#include "slang-ir-init-local-var.h"
namespace Slang
{
@@ -1593,6 +1594,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
removeLinkageDecorations(func);
performForceInlining(func);
+
+ initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func);
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &differentiableTypeConformanceContext;
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 5d11b7fb3..af47ffbca 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -35,20 +35,6 @@ static bool isDifferentialBlock(IRBlock* block)
return block->findDecoration<IRDifferentialInstDecoration>();
}
-static Dictionary<IRBlock*, IRBlock*> reconstructDiffBlockMap(IRGlobalValueWithCode* func)
-{
- Dictionary<IRBlock*, IRBlock*> diffBlockMap;
- for (auto block : func->getBlocks())
- {
- if (auto diffDecor = block->findDecoration<IRDifferentialInstDecoration>())
- {
- if (diffDecor->getPrimalType())
- diffBlockMap[as<IRBlock>(diffDecor->getPrimalInst())] = block;
- }
- }
- return diffBlockMap;
-}
-
static IRBlock* getLoopRegionBodyBlock(IRLoop* loop)
{
auto condBlock = as<IRBlock>(loop->getTargetBlock());
@@ -397,7 +383,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
auto branchInst = as<IRUnconditionalBranch>(predecessor->getTerminator());
SLANG_ASSERT(branchInst->getOperandCount() > paramIndex);
- workList.add(&branchInst->getOperands()[paramIndex]);
+ workList.add(&branchInst->getArgs()[paramIndex]);
}
}
else
@@ -511,17 +497,10 @@ void applyToInst(
if (as<IRParam>(inst))
{
// Can completely ignore first block parameters
- if (getBlock(inst) != getBlock(inst)->getParent()->getFirstBlock())
+ if (getBlock(inst) == getBlock(inst)->getParent()->getFirstBlock())
{
- // TODO: We would need to clone in the control-flow for each region (without nested loops)
- // prior to this, and then hoist this parameter into the within-region block, otherwise
- // this parameter will not be visible to transposed insts.
- // This will also include adding an extra case to 'ensurePrimalAvailability': if both insts
- // are withing the _same_ indexed region, skip the indexed store/load and use a simple var.
- //
- SLANG_UNIMPLEMENTED_X("Parameter recompute is not currently supported");
+ return;
}
- return;
}
auto recomputeInst = cloneCtx->cloneInstOutOfOrder(builder, inst);
@@ -565,10 +544,6 @@ void applyCheckpointSet(
Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock,
IROutOfOrderCloneContext* cloneCtx)
{
- // Reconstruct diff block map.
- Dictionary<IRBlock*, IRBlock*> diffBlockMap = reconstructDiffBlockMap(func);
-
-
for (auto use : pendingUses)
cloneCtx->pendingUses.add(use);
@@ -600,13 +575,24 @@ void applyCheckpointSet(
if (block->findDecoration<IRRecomputeBlockDecoration>())
continue;
- auto diffBlock = as<IRBlock>(diffBlockMap[block]);
+ IRBlock* recomputeBlock = block;
+ mapPrimalBlockToRecomputeBlock.tryGetValue(block, recomputeBlock);
+ auto recomputeInsertBeforeInst = recomputeBlock->getFirstOrdinaryInst();
IRBuilder builder(func->getModule());
UIndex ii = 0;
for (auto param : block->getParams())
{
- builder.setInsertBefore(diffBlock->getFirstOrdinaryInst());
+ builder.setInsertBefore(recomputeInsertBeforeInst);
+ bool isRecomputed = checkpointInfo->recomputeSet.contains(param);
+ bool isInverted = checkpointInfo->invertSet.contains(param);
+
+ if (!isRecomputed && !isInverted)
+ continue;
+
+ SLANG_RELEASE_ASSERT(
+ recomputeBlock != block &&
+ "recomputed param should belong to block that has recompute block.");
// Apply checkpoint rule to the parameter itself.
applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, param);
@@ -617,37 +603,36 @@ void applyCheckpointSet(
{
if (predecessorSet.contains(predecessor))
continue;
-
predecessorSet.add(predecessor);
+
+ auto primalPhiArg = as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii);
+ auto recomputePredecessor = mapPrimalBlockToRecomputeBlock[predecessor].getValue();
- auto diffPredecessor = as<IRBlock>(diffBlockMap[block]);
-
- if (checkpointInfo->recomputeSet.contains(param))
+ // For now, find the primal phi argument in this predecessor,
+ // and stick it into the recompute predecessor's branch inst. We
+ // will use a patch-up pass in the end to replace all these
+ // arguments to their recomputed versions if they exist.
+
+ if (isRecomputed)
{
- IRInst* terminator = diffPredecessor->getTerminator();
+ IRInst* terminator = recomputeBlock->getTerminator();
addPhiOutputArg(&builder,
- diffPredecessor,
+ recomputePredecessor,
terminator,
- as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii));
+ primalPhiArg);
}
-
- if (checkpointInfo->invertSet.contains(param))
+ else if (isInverted)
{
- IRInst* terminator = diffPredecessor->getTerminator();
-
+ IRInst* terminator = recomputeBlock->getTerminator();
addPhiOutputArg(&builder,
- diffPredecessor,
+ recomputePredecessor,
terminator,
- as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii));
+ primalPhiArg);
}
}
-
ii++;
}
- IRBlock* recomputeBlock = block;
- mapPrimalBlockToRecomputeBlock.tryGetValue(block, recomputeBlock);
- auto recomputeInsertBeforeInst = recomputeBlock->getFirstOrdinaryInst();
for (auto child : block->getChildren())
{
@@ -670,6 +655,25 @@ void applyCheckpointSet(
applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child);
}
}
+
+ // Go through phi arguments in recompute blocks and replace them to
+ // recomputed insts if they exist.
+ for (auto block : func->getBlocks())
+ {
+ if (!block->findDecoration<IRRecomputeBlockDecoration>())
+ continue;
+ auto terminator = block->getTerminator();
+ for (UInt i = 0; i < terminator->getOperandCount(); i++)
+ {
+ auto arg = terminator->getOperand(i);
+ if (as<IRBlock>(arg))
+ continue;
+ if (auto recomputeArg = cloneCtx->cloneEnv.mapOldValToNew.tryGetValue(arg))
+ {
+ terminator->setOperand(i, *recomputeArg);
+ }
+ }
+ }
}
IRType* getTypeForLocalStorage(
@@ -1313,7 +1317,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
{
- domTree = computeDominatorTree(func);
+ SLANG_UNUSED(func)
return;
}
@@ -1412,6 +1416,30 @@ static bool shouldStoreVar(IRVar* var)
return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType()));
}
+enum CheckpointPreference
+{
+ None,
+ PreferCheckpoint,
+ PreferRecompute
+};
+
+static CheckpointPreference getCheckpointPreference(IRInst* callee)
+{
+ callee = getResolvedInstForDecorations(callee, true);
+ for (auto decor : callee->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_PreferCheckpointDecoration:
+ return CheckpointPreference::PreferCheckpoint;
+ case kIROp_PreferRecomputeDecoration:
+ case kIROp_TargetIntrinsicDecoration:
+ return CheckpointPreference::PreferRecompute;
+ }
+ }
+ return CheckpointPreference::None;
+}
+
static bool shouldStoreInst(IRInst* inst)
{
if (!inst->getDataType())
@@ -1422,9 +1450,9 @@ static bool shouldStoreInst(IRInst* inst)
if (!canTypeBeStored(inst->getDataType()))
return false;
- // Never store certain opcodes.
switch (inst->getOp())
{
+ // Never store these opcodes because they are not real computations.
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
case kIROp_IntCast:
@@ -1437,25 +1465,39 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_MakeStruct:
case kIROp_MakeTuple:
case kIROp_MakeArray:
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
case kIROp_MakeArrayFromElement:
case kIROp_MakeDifferentialPair:
+ case kIROp_MakeDifferentialPairUserCode:
case kIROp_MakeOptionalNone:
case kIROp_MakeOptionalValue:
+ case kIROp_MakeExistential:
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
+ case kIROp_DifferentialPairGetDifferentialUserCode:
+ case kIROp_DifferentialPairGetPrimalUserCode:
case kIROp_ExtractExistentialValue:
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialWitnessTable:
case kIROp_undefined:
case kIROp_GetSequentialID:
+ case kIROp_GetStringHash:
case kIROp_Specialize:
case kIROp_LookupWitness:
+ case kIROp_Param:
+ case kIROp_DetachDerivative:
+ return false;
+
+ // Never store these op codes because they are trivial to compute.
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
case kIROp_Div:
case kIROp_Neg:
case kIROp_Geq:
+ case kIROp_FRem:
+ case kIROp_IRem:
case kIROp_Leq:
case kIROp_Neq:
case kIROp_Eql:
@@ -1471,6 +1513,7 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_Lsh:
case kIROp_Rsh:
return false;
+
case kIROp_GetElement:
case kIROp_FieldExtract:
case kIROp_swizzle:
@@ -1479,8 +1522,12 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_GetOptionalValue:
case kIROp_MatrixReshape:
case kIROp_VectorReshape:
- // If the operand is already stored, don't store the result of these insts.
- if (inst->getOperand(0)->findDecoration<IRPrimalValueStructKeyDecoration>())
+ case kIROp_GetTupleElement:
+ return false;
+
+ case kIROp_Call:
+ // If the callee prefers recompute policy, don't store.
+ if (getCheckpointPreference(inst->getOperand(0)) == CheckpointPreference::PreferRecompute)
{
return false;
}
@@ -1495,9 +1542,8 @@ static bool shouldStoreInst(IRInst* inst)
return true;
}
-bool canRecompute(IRDominatorTree* domTree, IRUse* use)
+bool canRecompute(IRUse* use)
{
- SLANG_UNUSED(domTree);
if (auto load = as<IRLoad>(use->get()))
{
// Generally, we cannot recompute a load(ptr), since ptr may be modified
@@ -1518,7 +1564,18 @@ bool canRecompute(IRDominatorTree* domTree, IRUse* use)
auto param = as<IRParam>(use->get());
if (!param)
return true;
- return false;
+
+ // We can recompute a phi param if it is not in a loop start block.
+ auto parentBlock = as<IRBlock>(param->getParent());
+ for (auto pred : parentBlock->getPredecessors())
+ {
+ if (auto loop = as<IRLoop>(pred->getTerminator()))
+ {
+ if (loop->getTargetBlock() == parentBlock)
+ return false;
+ }
+ }
+ return true;
}
HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
@@ -1541,7 +1598,7 @@ HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
{
// We may not be able to recompute due to limitations of
// the unzip pass. If so we will store the result.
- if (canRecompute(domTree, use))
+ if (canRecompute(use))
return HoistResult::recompute(use->get());
// The fallback is to store.
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index fbac42c43..c0b56126d 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -286,8 +286,6 @@ namespace Slang
virtual void preparePolicy(IRGlobalValueWithCode* func);
virtual HoistResult classify(IRUse* use);
-
- RefPtr<IRDominatorTree> domTree;
};
RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index e3575aceb..2994a8c31 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -142,6 +142,13 @@ namespace Slang
builder->setInsertBefore(existingPrimalFunc);
builder->setInsertInto(existingPrimalFunc);
+
+ auto checkpointHint = udf->findDecoration<IRCheckpointHintDecoration>();
+ if (!checkpointHint)
+ checkpointHint = originalFunc->findDecoration<IRCheckpointHintDecoration>();
+ if (checkpointHint)
+ builder->addDecoration(existingPrimalFunc, checkpointHint->getOp());
+
builder->emitBlock();
params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc));
params.removeLast();
@@ -759,6 +766,14 @@ namespace Slang
auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc, true);
builder->setInsertBefore(primalFunc);
+
+ // Copy over checkpoint preference hints.
+ {
+ auto diffPrimalFunc = getResolvedInstForDecorations(primalFuncGeneric, true);
+ auto checkpointHint = primalFunc->findDecoration<IRCheckpointHintDecoration>();
+ if (checkpointHint)
+ builder->addDecoration(diffPrimalFunc, checkpointHint->getOp());
+ }
if (auto existingDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>())
{
diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp
index de806d653..78ea2e79a 100644
--- a/source/slang/slang-ir-collect-global-uniforms.cpp
+++ b/source/slang/slang-ir-collect-global-uniforms.cpp
@@ -140,6 +140,7 @@ struct CollectGlobalUniformParametersContext
//
auto wrapperStructType = builder->createStructType();
builder->addNameHintDecoration(wrapperStructType, UnownedTerminatedStringSlice("GlobalParams"));
+ builder->addBinaryInterfaceTypeDecoration(wrapperStructType);
// If the computed layout used a bare `struct` type, then we will use
// our `GlobalParams` struct as-is, but if the layout involved an
diff --git a/source/slang/slang-ir-entry-point-uniforms.cpp b/source/slang/slang-ir-entry-point-uniforms.cpp
index d10aa7ff2..518f6ae2c 100644
--- a/source/slang/slang-ir-entry-point-uniforms.cpp
+++ b/source/slang/slang-ir-entry-point-uniforms.cpp
@@ -387,6 +387,7 @@ struct CollectEntryPointUniformParams : PerEntryPointPass
builder.setInsertBefore(m_entryPoint.func);
paramStructType = builder.createStructType();
builder.addNameHintDecoration(paramStructType, UnownedTerminatedStringSlice("EntryPointParams"));
+ builder.addBinaryInterfaceTypeDecoration(paramStructType);
if( needConstantBuffer )
{
diff --git a/source/slang/slang-ir-init-local-var.cpp b/source/slang/slang-ir-init-local-var.cpp
index 645344f2e..145d8d569 100644
--- a/source/slang/slang-ir-init-local-var.cpp
+++ b/source/slang/slang-ir-init-local-var.cpp
@@ -9,16 +9,42 @@ namespace Slang
void initializeLocalVariables(IRModule* module, IRGlobalValueWithCode* func)
{
IRBuilder builder(module);
+ HashSet<IRInst*> userSet;
for (auto block : func->getBlocks())
{
for (auto inst : block->getChildren())
{
if (inst->getOp() == kIROp_Var)
{
- auto firstUse = inst->firstUse;
- bool initialized =
- (firstUse && firstUse->getUser()->getOp() == kIROp_Store &&
- firstUse->getUser()->getParent() == inst->getParent());
+ bool initialized = false;
+ userSet.clear();
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ userSet.add(use->getUser());
+
+ // Check if the variable is initialized in the same block.
+ for (auto nextInst = inst->next; nextInst; nextInst = nextInst->next)
+ {
+ switch (nextInst->getOp())
+ {
+ case kIROp_Store:
+ if (nextInst->getOperand(0) == inst)
+ initialized = true;
+ break;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ continue;
+ default:
+ if (userSet.contains(nextInst))
+ {
+ // We encountered a user of the variable before it was initialized.
+ // Break out of the loop and insert the initialization code.
+ goto breakLabel;
+ }
+ }
+ if (initialized)
+ break;
+ }
+ breakLabel:;
if (initialized)
continue;
builder.setInsertAfter(inst);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 11143cebb..b00972cce 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -635,6 +635,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(InterpolationModeDecoration, interpolationMode, 1, 0)
INST(NameHintDecoration, nameHint, 1, 0)
+ // Marks a type as being used as binary interface (e.g. shader parameters).
+ // This prevents the legalizeEmptyType() pass from eliminating it on C++/CUDA targets.
+ INST(BinaryInterfaceTypeDecoration, BinaryInterfaceType, 0, 0)
+
/** The decorated _instruction_ is transitory. Such a decoration should NEVER be found on an output instruction a module.
Typically used mark an instruction so can be specially handled - say when creating a IRConstant literal, and the payload of
needs to be special cased for lookup. */
@@ -846,6 +850,15 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// Treat a function as differentiable function, or an IRCall as a call to a differentiable function.
INST(TreatAsDifferentiableDecoration, treatAsDifferentiableDecoration, 0, 0)
+
+ /// Hint that the result from a call to the decorated function should be stored in backward prop function.
+ INST(PreferCheckpointDecoration, PreferCheckpointDecoration, 0, 0)
+
+ /// Hint that the result from a call to the decorated function should be recomputed in backward prop function.
+ INST(PreferRecomputeDecoration, PreferRecomputeDecoration, 0, 0)
+
+ INST_RANGE(CheckpointHintDecoration, PreferCheckpointDecoration, PreferRecomputeDecoration)
+
/// Marks a class type as a COM interface implementation, which enables
/// the witness table to be easily picked up by emit.
INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index f515baf8d..4495d1a3d 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -712,6 +712,30 @@ struct IRBackwardDerivativeDecoration : IRDecoration
IRInst* getBackwardDerivativeFunc() { return getOperand(0); }
};
+struct IRCheckpointHintDecoration : public IRDecoration
+{
+ IR_PARENT_ISA(CheckpointHintDecoration)
+};
+
+struct IRPreferRecomputeDecoration : IRCheckpointHintDecoration
+{
+ enum
+ {
+ kOp = kIROp_PreferRecomputeDecoration
+ };
+ IR_LEAF_ISA(PreferRecomputeDecoration)
+};
+
+struct IRPreferCheckpointDecoration : IRCheckpointHintDecoration
+{
+ enum
+ {
+ kOp = kIROp_PreferCheckpointDecoration
+ };
+ IR_LEAF_ISA(PreferCheckpointDecoration)
+};
+
+
struct IRLoopCounterDecoration : IRDecoration
{
enum
@@ -3651,6 +3675,11 @@ public:
addNameHintDecoration(value, getStringValue(text));
}
+ void addBinaryInterfaceTypeDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_BinaryInterfaceTypeDecoration);
+ }
+
void addGLSLOuterArrayDecoration(IRInst* value, UnownedStringSlice const& text)
{
addDecoration(value, kIROp_GLSLOuterArrayDecoration, getStringValue(text));
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 5bfbfe994..75e4bdacf 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -3730,6 +3730,11 @@ struct IRResourceTypeLegalizationContext : IRTypeLegalizationContext
return isResourceType(type);
}
+ bool isSimpleType(IRType*) override
+ {
+ return false;
+ }
+
LegalType createLegalUniformBufferType(
IROp op,
LegalType legalElementType) override
@@ -3761,6 +3766,11 @@ struct IRExistentialTypeLegalizationContext : IRTypeLegalizationContext
return as<IRPseudoPtrType>(type) != nullptr;
}
+ bool isSimpleType(IRType*) override
+ {
+ return false;
+ }
+
LegalType createLegalUniformBufferType(
IROp op,
LegalType legalElementType) override
@@ -3779,6 +3789,9 @@ struct IRExistentialTypeLegalizationContext : IRTypeLegalizationContext
}
};
+// This customization of type legalization is used to remove empty
+// structs from cpp/cuda programs if the empty type isn't used in
+// a public function signature.
struct IREmptyTypeLegalizationContext : IRTypeLegalizationContext
{
IREmptyTypeLegalizationContext(IRModule* module)
@@ -3790,6 +3803,26 @@ struct IREmptyTypeLegalizationContext : IRTypeLegalizationContext
return false;
}
+ bool isSimpleType(IRType* type) override
+ {
+ // If type is used as public interface, then treat it as simple.
+ for (auto decor : type->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_LayoutDecoration:
+ case kIROp_PublicDecoration:
+ case kIROp_ExternCppDecoration:
+ case kIROp_DllImportDecoration:
+ case kIROp_DllExportDecoration:
+ case kIROp_HLSLExportDecoration:
+ case kIROp_BinaryInterfaceTypeDecoration:
+ return true;
+ }
+ }
+ return false;
+ }
+
LegalType createLegalUniformBufferType(IROp, LegalType) override
{
return LegalType();
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 9225faf6b..e74a57424 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -7478,7 +7478,7 @@ namespace Slang
return nullptr;
}
- IRInst* getResolvedInstForDecorations(IRInst* inst)
+ IRInst* getResolvedInstForDecorations(IRInst* inst, bool resolveThroughDifferentiation)
{
IRInst* candidate = inst;
for(;;)
@@ -7488,6 +7488,20 @@ namespace Slang
candidate = specInst->getBase();
continue;
}
+ if (resolveThroughDifferentiation)
+ {
+ switch (candidate->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ candidate = candidate->getOperand(0);
+ continue;
+ default:
+ break;
+ }
+ }
if( auto genericInst = as<IRGeneric>(candidate) )
{
if( auto returnVal = findGenericReturnVal(genericInst) )
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 47aa5333a..3dbd0c773 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1938,7 +1938,7 @@ IRInst* findSpecializeReturnVal(IRSpecialize* specialize);
// then try to chase down the generic being specialized, and what
// it seems to return).
//
-IRInst* getResolvedInstForDecorations(IRInst* inst);
+IRInst* getResolvedInstForDecorations(IRInst* inst, bool resolveThroughDifferentiation = false);
// The IR module itself is represented as an instruction, which
// serves at the root of the tree of all instructions in the module.
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index e17fa54bb..57735d459 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -1115,6 +1115,9 @@ LegalType legalizeTypeImpl(
//
if(type->findDecoration<IRTargetIntrinsicDecoration>())
return LegalType::simple(type);
+
+ if (context->isSimpleType(type))
+ return LegalType::simple(type);
context->builder->setInsertBefore(type);
diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h
index fd8176889..17029b6b6 100644
--- a/source/slang/slang-legalize-types.h
+++ b/source/slang/slang-legalize-types.h
@@ -640,6 +640,12 @@ struct IRTypeLegalizationContext
/// types will get moved out of the `struc` itself.
virtual bool isSpecialType(IRType* type) = 0;
+ /// Customization point to decide what types are "simple."
+ ///
+ /// When a type is "simple" it means that it should not be changed
+ /// during legalization.
+ virtual bool isSimpleType(IRType* type) = 0;
+
/// Customization point to construct uniform-buffer/block types.
///
/// This function will only be called if `legalElementType` is
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 1ec508bfa..d644d01c7 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8848,6 +8848,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration);
}
+ else if (as<PreferCheckpointAttribute>(modifier))
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_PreferCheckpointDecoration);
+ }
+ else if (as<PreferRecomputeAttribute>(modifier))
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_PreferRecomputeDecoration);
+ }
}
// For convenience, ensure that any additional global
// values that were emitted while outputting the function