diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-17 16:22:47 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-17 16:22:47 -0500 |
| commit | 79049bc7617be0d20f6ed5d9d1dfe75006aa675a (patch) | |
| tree | 882bf8c76e36ee770a8645f2b200bd90d22ec728 | |
| parent | f253d15a3b2681dfa40491451fcb3f21f1dbe412 (diff) | |
Cleaned up legacy differential type handling + type casting bugfixes (#2660)
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa.cpp | 5 | ||||
| -rw-r--r-- | tests/autodiff/reverse-uint-vector.slang | 40 | ||||
| -rw-r--r-- | tests/autodiff/reverse-uint-vector.slang.expected.txt | 3 |
5 files changed, 47 insertions, 7 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 214e97ff6..26d84720f 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1363,8 +1363,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_Neq: return transcribeBinaryLogic(builder, origInst); - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: case kIROp_MakeVector: case kIROp_MakeMatrix: case kIROp_MakeMatrixFromScalar: @@ -1470,6 +1468,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_PackAnyValue: case kIROp_UnpackAnyValue: case kIROp_GetNativePtr: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, // so we treat this inst as non differentiable. // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 38e4636ac..7bcd4c90b 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -173,8 +173,8 @@ struct DifferentiableTypeConformanceContext case kIROp_FloatType: case kIROp_HalfType: case kIROp_DoubleType: - case kIROp_VectorType: return origType; + case kIROp_ArrayType: { auto diffElementType = (IRType*)getDifferentialForType( diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 20a8d7d13..d8246edae 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -1056,10 +1056,7 @@ bool constructSSA(ConstructSSAContext* context) // Figure out what variables we can promote to // SSA temporaries. - if (!(context->promotableVars.getCount() > 0)) - { - identifyPromotableVars(context); - } + identifyPromotableVars(context); // If none of the variables are promote-able, // then we can exit without making any changes diff --git a/tests/autodiff/reverse-uint-vector.slang b/tests/autodiff/reverse-uint-vector.slang new file mode 100644 index 000000000..3c940c5de --- /dev/null +++ b/tests/autodiff/reverse-uint-vector.slang @@ -0,0 +1,40 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float3> dpfloat3; +typedef float3.Differential dfloat3; + +typedef DifferentialPair<float2> dpfloat2; +typedef float2.Differential dfloat2; + +// This test case should hit a lot of conversion insts +// kIROp_MakeVectorFromScalar (both uint and float), +// kIROp_CastIntToFloat, etc.. +// +[BackwardDifferentiable] +float3 test_uint_offset(float3 x, float3 y) +{ + uint3 u4 = 1; + + uint3 u5 = u4 + 2; + + return x + y + u5; +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat3 dpx = dpfloat3(float3(2.0, 3.0, 4.0), float3(0.0, 0.0, 0.0)); + dpfloat3 dpy = dpfloat3(float3(1.5, 2.5, 3.5), float3(0.0, 0.0, 0.0)); + + __bwd_diff(test_uint_offset)(dpx, dpy, dfloat3(1.0, 2.0, 3.0)); + outputBuffer[0] = dpx.d.y; // Expect: 2 + outputBuffer[1] = dpy.d.y; // Expect: 2 + } + +}
\ No newline at end of file diff --git a/tests/autodiff/reverse-uint-vector.slang.expected.txt b/tests/autodiff/reverse-uint-vector.slang.expected.txt new file mode 100644 index 000000000..4ca096282 --- /dev/null +++ b/tests/autodiff/reverse-uint-vector.slang.expected.txt @@ -0,0 +1,3 @@ +type: float +2.000000 +2.000000
\ No newline at end of file |
