From 9913cfbf68dab8c3c8c418dd28b71c2a65a55ae0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 22 Nov 2024 18:55:47 -0500 Subject: [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred (#5630) * [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred * Fix failing tests * Update custom-derivative-generic.slang --- source/slang/slang-check-decl.cpp | 76 ++++++++++++- tests/autodiff/custom-derivative-enum-param.slang | 57 ++++++++++ tests/autodiff/custom-intrinsic-1.slang | 126 +++++++++++++++++++++ .../autodiff/custom-intrinsic-1.slang.expected.txt | 6 + tests/autodiff/custom-intrinsic.slang | 126 --------------------- tests/autodiff/custom-intrinsic.slang.expected.txt | 6 - tests/diagnostics/custom-derivative-generic.slang | 2 +- 7 files changed, 265 insertions(+), 134 deletions(-) create mode 100644 tests/autodiff/custom-derivative-enum-param.slang create mode 100644 tests/autodiff/custom-intrinsic-1.slang create mode 100644 tests/autodiff/custom-intrinsic-1.slang.expected.txt delete mode 100644 tests/autodiff/custom-intrinsic.slang delete mode 100644 tests/autodiff/custom-intrinsic.slang.expected.txt diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 251ce6a69..e4206827f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl( SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + + auto exprToCheck = attr->funcExpr; + + // If this is a generic, we want to wrap the call to the derivative method + // with the generic parameters of the source. + // + if (as(funcDecl->parentDecl) && !as(attr->funcExpr)) + { + auto genericDecl = as(funcDecl->parentDecl); + auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl); + auto appExpr = ctx.getASTBuilder()->create(); + + Index count = 0; + for (auto member : genericDecl->members) + { + if (as(member) || as(member) || + as(member)) + count++; + } + + appExpr->functionExpr = attr->funcExpr; + + for (auto arg : substArgs) + { + if (count == 0) + break; + + if (auto declRefType = as(arg)) + { + auto baseTypeExpr = ctx.getASTBuilder()->create(); + baseTypeExpr->base.type = declRefType; + auto baseTypeType = ctx.getASTBuilder()->getOrCreate(declRefType); + baseTypeExpr->type.type = baseTypeType; + + appExpr->arguments.add(baseTypeExpr); + } + else if (auto genericValParam = as(arg)) + { + auto declRef = genericValParam->getDeclRef(); + appExpr->arguments.add( + subVisitor + .ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr)); + } + else + { + SLANG_UNEXPECTED("Unhandled substitution arg type"); + } + + count--; + } + + exprToCheck = appExpr; + } + + auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx); attr->funcExpr = checkedFuncExpr; if (attr->args.getCount()) attr->args[0] = attr->funcExpr; @@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl( calleeDeclRef = calleeDeclRefExpr->declRef; auto calleeFunc = as(calleeDeclRef.getDecl()); + + if (!calleeFunc) + { + // If we couldn't find a direct function, it might be a generic. + if (auto genericDecl = as(calleeDeclRef.getDecl())) + { + calleeFunc = as(genericDecl->inner); + + if (as(resolved->type.type)) + { + // If we can't resolve a type, something went wrong. If we're working with a generic + // decl, the most likely cause is a failure of generic argument inference. + // + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + } + } + } + if (!calleeFunc) { visitor->getSink()->diagnose( diff --git a/tests/autodiff/custom-derivative-enum-param.slang b/tests/autodiff/custom-derivative-enum-param.slang new file mode 100644 index 000000000..aa6733873 --- /dev/null +++ b/tests/autodiff/custom-derivative-enum-param.slang @@ -0,0 +1,57 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +enum MyEnum { A, B, C }; + +[BackwardDerivative(mDiff)] +float m(float x) +{ + switch (M) + { + case MyEnum.A: + return x * x; + case MyEnum.B: + return x; + case MyEnum.C: + return 3 * x; + default: + return 0; + } +} + +void mDiff(inout DifferentialPair x, float dResult) +{ + switch (M) + { + case MyEnum.A: + updateDiff(x, 2 * dResult * x.p); + break; + case MyEnum.B: + updateDiff(x, dResult); + break; + case MyEnum.C: + updateDiff(x, 3 * dResult); + break; + default: + updateDiff(x, 0); + break; + } +} + +[Differentiable] +float test(float x) +{ + return m(x); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var a = diffPair(3.0); + __bwd_diff(test)(a, 1.0); + outputBuffer[dispatchThreadID.x] = a.d; + // CHECK: 6.0 +} diff --git a/tests/autodiff/custom-intrinsic-1.slang b/tests/autodiff/custom-intrinsic-1.slang new file mode 100644 index 000000000..1fe204b58 --- /dev/null +++ b/tests/autodiff/custom-intrinsic-1.slang @@ -0,0 +1,126 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +// slang-test/WGPU: IR opcode during code emit #5263 +//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; + +typealias IDFloat = __BuiltinFloatingPointType & IDifferentiable; + +namespace myintrinsiclib +{ + __generic + __target_intrinsic(hlsl, "exp($0)") + __target_intrinsic(glsl, "exp($0)") + __target_intrinsic(cuda, "$P_exp($0)") + __target_intrinsic(cpp, "$P_exp($0)") + __target_intrinsic(spirv, "12 resultType resultId glsl450 27 _0") + __target_intrinsic(metal, "exp($0)") + [ForwardDerivative(d_myexp)] + T myexp(T x); + + __generic + DifferentialPair d_myexp(DifferentialPair dpx) + { + return DifferentialPair( + myexp(dpx.p), + T.dmul(myexp(dpx.p), dpx.d)); + } + + + // Sine + __generic + __target_intrinsic(hlsl, "sin($0)") + __target_intrinsic(glsl, "sin($0)") + __target_intrinsic(metal, "sin($0)") + __target_intrinsic(cuda, "$P_sin($0)") + __target_intrinsic(cpp, "$P_sin($0)") + __target_intrinsic(spirv, "12 resultType resultId glsl450 13 _0") + [ForwardDerivative(d_mysin)] + T mysin(T x); + + __generic + DifferentialPair d_mysin(DifferentialPair dpx) + { + return DifferentialPair( + mysin(dpx.p), + T.dmul(mycos(dpx.p), dpx.d)); + } + + // Cosine + __generic + __target_intrinsic(hlsl, "cos($0)") + __target_intrinsic(glsl, "cos($0)") + __target_intrinsic(metal, "cos($0)") + __target_intrinsic(cuda, "$P_cos($0)") + __target_intrinsic(cpp, "$P_cos($0)") + __target_intrinsic(spirv, "12 resultType resultId glsl450 14 _0") + [ForwardDerivative(d_mycos)] + T mycos(T x); + + __generic + DifferentialPair d_mycos(DifferentialPair dpx) + { + return DifferentialPair( + mycos(dpx.p), + T.dmul(-sin(dpx.p), dpx.d)); + } + + // Sine and cosine + __generic + __target_intrinsic(hlsl, "sincos($0, $1, $2)") + __target_intrinsic(cuda, "$P_sincos($0, $1, $2)") + [ForwardDerivative(d_mysincos)] + void mysincos(T x, out T s, out T c) + { + s = sin(x); + c = cos(x); + } + + __generic + void d_mysincos(DifferentialPair x, out DifferentialPair s, out DifferentialPair c) + { + T _s; + T _c; + mysincos(x.p, _s, _c); + + s = DifferentialPair(_s, T.dmul(_c, x.d)); + c = DifferentialPair(_c, T.dmul(-_s, x.d)); + } +}; + +[ForwardDifferentiable] +float f(float x) +{ + return myintrinsiclib.myexp(x); +} + +[ForwardDifferentiable] +float g(float x) +{ + float s; + float t; + myintrinsiclib.mysincos(x, s, t); + + return s + t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + + outputBuffer[0] = f(dpa.p); // Expect: 7.389056 + outputBuffer[1] = __fwd_diff(f)(dpa).d; // Expect: 7.389056 + + // g() needs additional handling of IRMakeDifferentialPair(PtrType). This needs to + // generate a new var, load from the individual vars and store into the pair var. + + //outputBuffer[2] = g(dpa.p); // Expect: 1.381773 + //outputBuffer[3] = __fwd_diff(g)(dpa).d; // Expect: -0.301168 + } +} diff --git a/tests/autodiff/custom-intrinsic-1.slang.expected.txt b/tests/autodiff/custom-intrinsic-1.slang.expected.txt new file mode 100644 index 000000000..ce22a5b95 --- /dev/null +++ b/tests/autodiff/custom-intrinsic-1.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +7.389056 +7.389056 +0.0 +0.0 +0.0 \ No newline at end of file diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang deleted file mode 100644 index 1fe204b58..000000000 --- a/tests/autodiff/custom-intrinsic.slang +++ /dev/null @@ -1,126 +0,0 @@ -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -// slang-test/WGPU: IR opcode during code emit #5263 -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu - -//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer -RWStructuredBuffer outputBuffer; - -typedef DifferentialPair dpfloat; - -typealias IDFloat = __BuiltinFloatingPointType & IDifferentiable; - -namespace myintrinsiclib -{ - __generic - __target_intrinsic(hlsl, "exp($0)") - __target_intrinsic(glsl, "exp($0)") - __target_intrinsic(cuda, "$P_exp($0)") - __target_intrinsic(cpp, "$P_exp($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 27 _0") - __target_intrinsic(metal, "exp($0)") - [ForwardDerivative(d_myexp)] - T myexp(T x); - - __generic - DifferentialPair d_myexp(DifferentialPair dpx) - { - return DifferentialPair( - myexp(dpx.p), - T.dmul(myexp(dpx.p), dpx.d)); - } - - - // Sine - __generic - __target_intrinsic(hlsl, "sin($0)") - __target_intrinsic(glsl, "sin($0)") - __target_intrinsic(metal, "sin($0)") - __target_intrinsic(cuda, "$P_sin($0)") - __target_intrinsic(cpp, "$P_sin($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 13 _0") - [ForwardDerivative(d_mysin)] - T mysin(T x); - - __generic - DifferentialPair d_mysin(DifferentialPair dpx) - { - return DifferentialPair( - mysin(dpx.p), - T.dmul(mycos(dpx.p), dpx.d)); - } - - // Cosine - __generic - __target_intrinsic(hlsl, "cos($0)") - __target_intrinsic(glsl, "cos($0)") - __target_intrinsic(metal, "cos($0)") - __target_intrinsic(cuda, "$P_cos($0)") - __target_intrinsic(cpp, "$P_cos($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 14 _0") - [ForwardDerivative(d_mycos)] - T mycos(T x); - - __generic - DifferentialPair d_mycos(DifferentialPair dpx) - { - return DifferentialPair( - mycos(dpx.p), - T.dmul(-sin(dpx.p), dpx.d)); - } - - // Sine and cosine - __generic - __target_intrinsic(hlsl, "sincos($0, $1, $2)") - __target_intrinsic(cuda, "$P_sincos($0, $1, $2)") - [ForwardDerivative(d_mysincos)] - void mysincos(T x, out T s, out T c) - { - s = sin(x); - c = cos(x); - } - - __generic - void d_mysincos(DifferentialPair x, out DifferentialPair s, out DifferentialPair c) - { - T _s; - T _c; - mysincos(x.p, _s, _c); - - s = DifferentialPair(_s, T.dmul(_c, x.d)); - c = DifferentialPair(_c, T.dmul(-_s, x.d)); - } -}; - -[ForwardDifferentiable] -float f(float x) -{ - return myintrinsiclib.myexp(x); -} - -[ForwardDifferentiable] -float g(float x) -{ - float s; - float t; - myintrinsiclib.mysincos(x, s, t); - - return s + t; -} - -[numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) -{ - { - dpfloat dpa = dpfloat(2.0, 1.0); - - outputBuffer[0] = f(dpa.p); // Expect: 7.389056 - outputBuffer[1] = __fwd_diff(f)(dpa).d; // Expect: 7.389056 - - // g() needs additional handling of IRMakeDifferentialPair(PtrType). This needs to - // generate a new var, load from the individual vars and store into the pair var. - - //outputBuffer[2] = g(dpa.p); // Expect: 1.381773 - //outputBuffer[3] = __fwd_diff(g)(dpa).d; // Expect: -0.301168 - } -} diff --git a/tests/autodiff/custom-intrinsic.slang.expected.txt b/tests/autodiff/custom-intrinsic.slang.expected.txt deleted file mode 100644 index ce22a5b95..000000000 --- a/tests/autodiff/custom-intrinsic.slang.expected.txt +++ /dev/null @@ -1,6 +0,0 @@ -type: float -7.389056 -7.389056 -0.0 -0.0 -0.0 \ No newline at end of file diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang index 5f2cd9951..fb65dd2cc 100644 --- a/tests/diagnostics/custom-derivative-generic.slang +++ b/tests/diagnostics/custom-derivative-generic.slang @@ -34,7 +34,7 @@ DifferentialPair dd1(DifferentialPair x) } // CHECK-DAG: {{.*}}(37): error 31151 -[BackwardDerivative(f)] +[BackwardDerivativeOf(f)] DifferentialPair df(inout DifferentialPair x, float dOut) { var primal = x.p * x.p; -- cgit v1.2.3