summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-27 23:42:06 -0500
committerGitHub <noreply@github.com>2023-02-27 23:42:06 -0500
commit10e2d9c7c532c204f26bb2c9f383f21b121b2ff2 (patch)
tree9ae0dd84b505a7ecd3fb45de9dbde74f8dd1ebe9 /tests
parenta3ba22b51c371d5a20d61aa4e35233ba4f4f68db (diff)
More fixes for reverse-mode on complicated loops (#2675)
* Multiple fixes to get various loop tests to pass. * Create reverse-nested-loop.slang * Fix for variables becoming inaccessible during cfg normalization * Removed comments and moved break-branch-normalization to eliminateMultiLevelBreaks * Fix. * Override liveness tests
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/reverse-continue-loop.slang43
-rw-r--r--tests/autodiff/reverse-continue-loop.slang.expected.txt6
-rw-r--r--tests/autodiff/reverse-hybrid-control-flow.slang47
-rw-r--r--tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt6
-rw-r--r--tests/autodiff/reverse-nested-loop.slang43
-rw-r--r--tests/autodiff/reverse-nested-loop.slang.expected.txt6
-rw-r--r--tests/experimental/liveness/liveness-2.slang.expected2
-rw-r--r--tests/experimental/liveness/liveness-3.slang.expected12
-rw-r--r--tests/experimental/liveness/liveness-4.slang.expected4
-rw-r--r--tests/experimental/liveness/liveness-5.slang.expected4
-rw-r--r--tests/experimental/liveness/liveness-6.slang.expected4
-rw-r--r--tests/experimental/liveness/liveness.slang.expected14
12 files changed, 171 insertions, 20 deletions
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang
new file mode 100644
index 000000000..0f9502673
--- /dev/null
+++ b/tests/autodiff/reverse-continue-loop.slang
@@ -0,0 +1,43 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDifferentiable]
+float test_loop_with_continue(float y)
+{
+ float t = y;
+
+ for (int i = 0; i < 3; i++)
+ {
+ if (t > 4.0)
+ continue;
+
+ t = t * t;
+ }
+
+ return t;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(2.0, 0.0);
+
+ __bwd_diff(test_loop_with_continue)(dpa, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 32.0
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_loop_with_continue)(dpa, 1.0f);
+ outputBuffer[1] = dpa.d; // Expect: 0.0131072
+ }
+}
diff --git a/tests/autodiff/reverse-continue-loop.slang.expected.txt b/tests/autodiff/reverse-continue-loop.slang.expected.txt
new file mode 100644
index 000000000..17dbb061d
--- /dev/null
+++ b/tests/autodiff/reverse-continue-loop.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+32.000000
+0.013107
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/reverse-hybrid-control-flow.slang b/tests/autodiff/reverse-hybrid-control-flow.slang
new file mode 100644
index 000000000..9379df4cf
--- /dev/null
+++ b/tests/autodiff/reverse-hybrid-control-flow.slang
@@ -0,0 +1,47 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDifferentiable]
+float test_simple_loop(float y)
+{
+ float t = y;
+
+ if (y > 0.5)
+ {
+ for (int i = 0; i < 3; i++)
+ {
+ t = t * t;
+ }
+ }
+ else
+ {
+ t = t * 10.f;
+ }
+
+ return t;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(1.0, 0.0);
+
+ __bwd_diff(test_simple_loop)(dpa, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 8.0
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_simple_loop)(dpa, 1.0f);
+ outputBuffer[1] = dpa.d; // Expect: 10.0
+ }
+}
diff --git a/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt b/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt
new file mode 100644
index 000000000..3cb76c394
--- /dev/null
+++ b/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+8.000000
+10.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/reverse-nested-loop.slang b/tests/autodiff/reverse-nested-loop.slang
new file mode 100644
index 000000000..08cde5230
--- /dev/null
+++ b/tests/autodiff/reverse-nested-loop.slang
@@ -0,0 +1,43 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDifferentiable]
+float test_simple_nested_loop(float y)
+{
+ float t = y;
+
+ for (int i = 0; i < 2; i++)
+ {
+ for (int j = 0; j < 2; j++)
+ {
+ t = t * (i + j + 1);
+ }
+ }
+
+ return t;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(1.0, 0.0);
+
+ __bwd_diff(test_simple_nested_loop)(dpa, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 12.0 * 1
+ }
+
+ {
+ dpfloat dpa = dpfloat(1.0, 0.0);
+
+ __bwd_diff(test_simple_nested_loop)(dpa, 0.4f);
+ outputBuffer[1] = dpa.d; // Expect: 12 * 0.4 = 4.8
+ }
+}
diff --git a/tests/autodiff/reverse-nested-loop.slang.expected.txt b/tests/autodiff/reverse-nested-loop.slang.expected.txt
new file mode 100644
index 000000000..59c14cf1d
--- /dev/null
+++ b/tests/autodiff/reverse-nested-loop.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+12.000000
+4.800000
+0.000000
+0.000000
+0.000000
diff --git a/tests/experimental/liveness/liveness-2.slang.expected b/tests/experimental/liveness/liveness-2.slang.expected
index 16883c1fd..e0486f8da 100644
--- a/tests/experimental/liveness/liveness-2.slang.expected
+++ b/tests/experimental/liveness/liveness-2.slang.expected
@@ -32,12 +32,12 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(i_0, 0);
break;
}
idx_0[i_0] = offset_0 + i_0;
i_0 = i_0 + 1;
}
- livenessEnd_0(i_0, 0);
int _S1 = idx_0[0] + idx_0[1];
int _S2 = idx_0[2];
livenessEnd_1(idx_0, 0);
diff --git a/tests/experimental/liveness/liveness-3.slang.expected b/tests/experimental/liveness/liveness-3.slang.expected
index cb093a640..d4b417082 100644
--- a/tests/experimental/liveness/liveness-3.slang.expected
+++ b/tests/experimental/liveness/liveness-3.slang.expected
@@ -12,12 +12,12 @@ void livenessStart_0(spirv_by_reference int _0[2], spirv_literal int _1);
spirv_instruction(id = 256)
void livenessStart_1(spirv_by_reference int _0, spirv_literal int _1);
-spirv_instruction(id = 256)
-void livenessStart_2(spirv_by_reference int _0[3], spirv_literal int _1);
-
spirv_instruction(id = 257)
void livenessEnd_0(spirv_by_reference int _0, spirv_literal int _1);
+spirv_instruction(id = 256)
+void livenessStart_2(spirv_by_reference int _0[3], spirv_literal int _1);
+
spirv_instruction(id = 257)
void livenessEnd_1(spirv_by_reference int _0[3], spirv_literal int _1);
@@ -46,6 +46,8 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(_S1, 0);
+ livenessEnd_0(k_0, 0);
break;
}
int idx_0[3];
@@ -69,6 +71,7 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(i_0, 0);
break;
}
int modRange_0 = i_0 % 3;
@@ -97,7 +100,6 @@ int calcThing_0(int offset_0)
livenessEnd_0(_S6, 0);
_S4 = _S10;
}
- livenessEnd_0(i_0, 0);
livenessEnd_0(_S1, 0);
livenessEnd_0(k_0, 0);
if(_S3)
@@ -123,8 +125,6 @@ int calcThing_0(int offset_0)
livenessStart_1(total_0, 0);
total_0 = total_1;
}
- livenessEnd_0(_S1, 0);
- livenessEnd_0(k_0, 0);
livenessEnd_2(another_0, 0);
int _S16 = total_0;
livenessEnd_0(total_0, 0);
diff --git a/tests/experimental/liveness/liveness-4.slang.expected b/tests/experimental/liveness/liveness-4.slang.expected
index efc2e3846..483247ecd 100644
--- a/tests/experimental/liveness/liveness-4.slang.expected
+++ b/tests/experimental/liveness/liveness-4.slang.expected
@@ -34,6 +34,7 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(k_0, 0);
break;
}
int _S1 = (k_0 + 7) % 5;
@@ -49,12 +50,12 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(i_0, 0);
break;
}
another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0);
i_0 = i_0 + 1;
}
- livenessEnd_0(i_0, 0);
livenessEnd_0(k_0, 0);
if(_S2)
{
@@ -64,7 +65,6 @@ int calcThing_0(int offset_0)
livenessStart_1(k_0, 0);
k_0 = k_1;
}
- livenessEnd_0(k_0, 0);
livenessEnd_1(another_0, 0);
return -2;
}
diff --git a/tests/experimental/liveness/liveness-5.slang.expected b/tests/experimental/liveness/liveness-5.slang.expected
index e9fe9d652..5e144a095 100644
--- a/tests/experimental/liveness/liveness-5.slang.expected
+++ b/tests/experimental/liveness/liveness-5.slang.expected
@@ -37,6 +37,7 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(k_0, 0);
break;
}
int _S1 = (k_0 + 7) % 5;
@@ -52,12 +53,12 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(i_0, 0);
break;
}
another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0);
i_0 = i_0 + 1;
}
- livenessEnd_0(i_0, 0);
livenessEnd_0(k_0, 0);
int _S3 = total_0;
livenessEnd_0(total_0, 0);
@@ -72,7 +73,6 @@ int calcThing_0(int offset_0)
livenessStart_1(total_0, 0);
total_0 = total_1;
}
- livenessEnd_0(k_0, 0);
livenessEnd_1(another_0, 0);
if(total_0 > 4)
{
diff --git a/tests/experimental/liveness/liveness-6.slang.expected b/tests/experimental/liveness/liveness-6.slang.expected
index b661c09bf..0d2e997b2 100644
--- a/tests/experimental/liveness/liveness-6.slang.expected
+++ b/tests/experimental/liveness/liveness-6.slang.expected
@@ -37,6 +37,7 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(k_0, 0);
break;
}
int arr_0[2];
@@ -57,13 +58,13 @@ int calcThing_0(int offset_0)
}
else
{
+ livenessEnd_0(i_0, 0);
break;
}
another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0);
arr_0[_S1] = arr_0[_S1] + i_0;
i_0 = i_0 + 1;
}
- livenessEnd_0(i_0, 0);
livenessEnd_0(k_0, 0);
int _S4 = total_0;
livenessEnd_0(total_0, 0);
@@ -81,7 +82,6 @@ int calcThing_0(int offset_0)
livenessStart_1(total_0, 0);
total_0 = total_2;
}
- livenessEnd_0(k_0, 0);
livenessEnd_1(another_0, 0);
if(total_0 > 4)
{
diff --git a/tests/experimental/liveness/liveness.slang.expected b/tests/experimental/liveness/liveness.slang.expected
index 06809ffc3..b0017ea9d 100644
--- a/tests/experimental/liveness/liveness.slang.expected
+++ b/tests/experimental/liveness/liveness.slang.expected
@@ -13,10 +13,10 @@ spirv_instruction(id = 256)
void livenessStart_1(spirv_by_reference int _0, spirv_literal int _1);
spirv_instruction(id = 257)
-void livenessEnd_0(spirv_by_reference uint _0, spirv_literal int _1);
+void livenessEnd_0(spirv_by_reference int _0, spirv_literal int _1);
spirv_instruction(id = 257)
-void livenessEnd_1(spirv_by_reference int _0, spirv_literal int _1);
+void livenessEnd_1(spirv_by_reference uint _0, spirv_literal int _1);
int someSlowFunc_0(int a_0)
{
@@ -35,18 +35,18 @@ int someSlowFunc_0(int a_0)
}
else
{
+ livenessEnd_0(i_0, 0);
break;
}
uint _S3 = v_0 >> 1;
uint _S4 = v_0;
- livenessEnd_0(v_0, 0);
+ livenessEnd_1(v_0, 0);
uint _S5 = (_S3 | _S4 << 31) * uint(i_0);
int i_1 = i_0 + 1;
livenessStart_0(v_0, 0);
v_0 = _S5;
i_0 = i_1;
}
- livenessEnd_1(i_0, 0);
return int(v_0);
}
@@ -111,6 +111,7 @@ void main()
}
else
{
+ livenessEnd_0(i_2, 0);
break;
}
SomeStruct_0 s_3;
@@ -153,15 +154,14 @@ void main()
livenessEnd_2(s_3, 0);
int _S22 = _S20 + _S21;
int _S23 = res_0;
- livenessEnd_1(res_0, 0);
+ livenessEnd_0(res_0, 0);
int res_1 = _S23 + _S22;
i_2 = i_2 + 1;
livenessStart_1(res_0, 0);
res_0 = res_1;
}
- livenessEnd_1(i_2, 0);
int _S24 = res_0;
- livenessEnd_1(res_0, 0);
+ livenessEnd_0(res_0, 0);
((outputBuffer_0)._data[(uint(index_0))]) = _S24;
return;
}