From 927d176be9ba03be161375b8695de1f0a37f1785 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 26 Oct 2023 10:45:08 -0700 Subject: Fix generic specialization bug. (#3290) * Fix generic specialization bug. * Update test. --------- Co-authored-by: Yong He --- source/slang/slang-ir-link.cpp | 2 +- tests/autodiff/bug-1.slang | 61 +++++++++++++++++++++++++++ tests/pipeline/rasterization/mesh/hello.slang | 13 +++++- 3 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 tests/autodiff/bug-1.slang diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index b8f43c5f2..87b4f3fde 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -307,7 +307,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) registerClonedValue(this, clonedValue, originalValue); cloneDecorationsAndChildren(this, clonedValue, originalValue); - builder->addInst(clonedValue); + addHoistableInst(builder, clonedValue); return clonedValue; } diff --git a/tests/autodiff/bug-1.slang b/tests/autodiff/bug-1.slang new file mode 100644 index 000000000..deb7e8461 --- /dev/null +++ b/tests/autodiff/bug-1.slang @@ -0,0 +1,61 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type + +#define DO_FLOOR +#define MANUAL_DERIVATIVE + +#ifndef MANUAL_DERIVATIVE +[BackwardDifferentiable] +#endif +float unusual_norm(Array x) +{ + float result = 0.f; + #ifndef MANUAL_DERIVATIVE + [ForceUnroll] + #endif + for(uint i = 0; i < N; i++) + { + #ifdef DO_FLOOR + result += pow(floor(x[i]), 4); + #else + result += pow(x[i], 4); + #endif + } + return result; +} + +#ifdef MANUAL_DERIVATIVE +[BackwardDerivativeOf(unusual_norm)] +void unusual_norm_bwd(inout DifferentialPair> x, float dResult) +{ + Array derivatives; + for(uint i = 0; i < N; i++) + { + derivatives[i] = 4.f * dResult * pow(x.p[i], 3); + } + x = diffPair(x.p, derivatives); +} +#endif + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=g_out + +RWStructuredBuffer g_out; +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dtid : SV_DispatchThreadID) +{ + Array x; + for(uint i = 0; i < 5; i++) + { + x[i] = float(i + dtid.x); + } + + DifferentialPair> x_pd = diffPair(x, {}); + bwd_diff(unusual_norm)(x_pd, 1.0f); + for (int i = 0; i < 5; i++) + g_out[i] = x_pd.d[i]; + // CHECK: 0.0 + // CHECK: 4.0 + // CHECK: 32.0 + // CHECK: 108.0 + // CHECK: 256.0 +} diff --git a/tests/pipeline/rasterization/mesh/hello.slang b/tests/pipeline/rasterization/mesh/hello.slang index 5eea900d3..54754c42e 100644 --- a/tests/pipeline/rasterization/mesh/hello.slang +++ b/tests/pipeline/rasterization/mesh/hello.slang @@ -2,8 +2,17 @@ // Test that a simple mesh shader compiles -//TEST:CROSS_COMPILE:-target spirv-assembly -entry main -stage mesh -profile glsl_450+spirv_1_4 -//TEST:CROSS_COMPILE:-target dxil-assembly -entry main -stage mesh -profile sm_6_6 +//TEST:CROSS_COMPILE(filecheck=SPIRV):-target spirv-assembly -entry main -stage mesh -profile glsl_450+spirv_1_4 +//TEST:CROSS_COMPILE(filecheck=DXIL):-target dxil-assembly -entry main -stage mesh -profile sm_6_6 + +// DXIL: call void @dx.op.setMeshOutputCounts +// DXIL: call void @dx.op.storeVertexOutput.f32 +// DXIL: call void @dx.op.emitIndices +// SPIRV: OpEntryPoint MeshEXT %main +// SPIRV: OpExecutionMode %main OutputVertices 3 +// SPIRV: OpExecutionMode %main OutputPrimitivesNV 1 +// SPIRV: OpExecutionMode %main OutputTrianglesNV +// SPIRV: OpSetMeshOutputsEXT const static float2 positions[3] = { float2(0.0, -0.5), -- cgit v1.2.3