summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-10-26 10:45:08 -0700
committerGitHub <noreply@github.com>2023-10-26 10:45:08 -0700
commit927d176be9ba03be161375b8695de1f0a37f1785 (patch)
tree1054acb0552721b6c296a5b9a5ce4ca507d06e78
parent4572976fd60817b9e2644b6fcadbf34511e770a9 (diff)
Fix generic specialization bug. (#3290)
* Fix generic specialization bug. * Update test. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ir-link.cpp2
-rw-r--r--tests/autodiff/bug-1.slang61
-rw-r--r--tests/pipeline/rasterization/mesh/hello.slang13
3 files changed, 73 insertions, 3 deletions
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index b8f43c5f2..87b4f3fde 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -307,7 +307,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
registerClonedValue(this, clonedValue, originalValue);
cloneDecorationsAndChildren(this, clonedValue, originalValue);
- builder->addInst(clonedValue);
+ addHoistableInst(builder, clonedValue);
return clonedValue;
}
diff --git a/tests/autodiff/bug-1.slang b/tests/autodiff/bug-1.slang
new file mode 100644
index 000000000..deb7e8461
--- /dev/null
+++ b/tests/autodiff/bug-1.slang
@@ -0,0 +1,61 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type
+
+#define DO_FLOOR
+#define MANUAL_DERIVATIVE
+
+#ifndef MANUAL_DERIVATIVE
+[BackwardDifferentiable]
+#endif
+float unusual_norm<let N : uint>(Array<float, N> x)
+{
+ float result = 0.f;
+ #ifndef MANUAL_DERIVATIVE
+ [ForceUnroll]
+ #endif
+ for(uint i = 0; i < N; i++)
+ {
+ #ifdef DO_FLOOR
+ result += pow(floor(x[i]), 4);
+ #else
+ result += pow(x[i], 4);
+ #endif
+ }
+ return result;
+}
+
+#ifdef MANUAL_DERIVATIVE
+[BackwardDerivativeOf(unusual_norm)]
+void unusual_norm_bwd<let N : uint>(inout DifferentialPair<Array<float, N>> x, float dResult)
+{
+ Array<float, N> derivatives;
+ for(uint i = 0; i < N; i++)
+ {
+ derivatives[i] = 4.f * dResult * pow(x.p[i], 3);
+ }
+ x = diffPair(x.p, derivatives);
+}
+#endif
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=g_out
+
+RWStructuredBuffer<float> g_out;
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dtid : SV_DispatchThreadID)
+{
+ Array<float, 5> x;
+ for(uint i = 0; i < 5; i++)
+ {
+ x[i] = float(i + dtid.x);
+ }
+
+ DifferentialPair<Array<float, 5>> x_pd = diffPair(x, {});
+ bwd_diff(unusual_norm)(x_pd, 1.0f);
+ for (int i = 0; i < 5; i++)
+ g_out[i] = x_pd.d[i];
+ // CHECK: 0.0
+ // CHECK: 4.0
+ // CHECK: 32.0
+ // CHECK: 108.0
+ // CHECK: 256.0
+}
diff --git a/tests/pipeline/rasterization/mesh/hello.slang b/tests/pipeline/rasterization/mesh/hello.slang
index 5eea900d3..54754c42e 100644
--- a/tests/pipeline/rasterization/mesh/hello.slang
+++ b/tests/pipeline/rasterization/mesh/hello.slang
@@ -2,8 +2,17 @@
// Test that a simple mesh shader compiles
-//TEST:CROSS_COMPILE:-target spirv-assembly -entry main -stage mesh -profile glsl_450+spirv_1_4
-//TEST:CROSS_COMPILE:-target dxil-assembly -entry main -stage mesh -profile sm_6_6
+//TEST:CROSS_COMPILE(filecheck=SPIRV):-target spirv-assembly -entry main -stage mesh -profile glsl_450+spirv_1_4
+//TEST:CROSS_COMPILE(filecheck=DXIL):-target dxil-assembly -entry main -stage mesh -profile sm_6_6
+
+// DXIL: call void @dx.op.setMeshOutputCounts
+// DXIL: call void @dx.op.storeVertexOutput.f32
+// DXIL: call void @dx.op.emitIndices
+// SPIRV: OpEntryPoint MeshEXT %main
+// SPIRV: OpExecutionMode %main OutputVertices 3
+// SPIRV: OpExecutionMode %main OutputPrimitivesNV 1
+// SPIRV: OpExecutionMode %main OutputTrianglesNV
+// SPIRV: OpSetMeshOutputsEXT
const static float2 positions[3] = {
float2(0.0, -0.5),