diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-03-23 22:27:30 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-23 19:27:30 -0700 |
| commit | e004511b5f75bb24df1adec71b005146917afb39 (patch) | |
| tree | b9d8a41dd2dfb92e5808d643c5a1693c78af839f | |
| parent | 6e4eae1050ab9282b460a33a013652c387c1e585 (diff) | |
AD: Fix type checking for higher-order custom derivatives definitions (#2729)
* Fixed type coercion issue with higher-order user defined methods
* Placed associated type lookup method in a loop
* Update high-order-user-defined-derivative.slang
* Revert changes to associated type lookup method
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 4 | ||||
| -rw-r--r-- | tests/autodiff/high-order-user-defined-derivative.slang | 77 | ||||
| -rw-r--r-- | tests/autodiff/high-order-user-defined-derivative.slang.expected.txt | 10 |
3 files changed, 91 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index eaab43ef8..6083ce9c0 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -238,6 +238,9 @@ namespace Slang void visitCallableDecl(CallableDecl* decl) { + for (auto paramDecl : decl->getMembersOfType<ParamDecl>()) + visitTypeExp(paramDecl->type); + visitTypeExp(decl->returnType); visitTypeExp(decl->errorType); } @@ -6916,6 +6919,7 @@ namespace Slang auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + visitor->ensureDecl(as<DeclRefExpr>(checkedFuncExpr)->declRef, DeclCheckState::TypesFullyResolved); auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); auto resolved = subVisitor.ResolveInvoke(invokeExpr); if (auto resolvedInvoke = as<InvokeExpr>(resolved)) diff --git a/tests/autodiff/high-order-user-defined-derivative.slang b/tests/autodiff/high-order-user-defined-derivative.slang new file mode 100644 index 000000000..4ad4aad12 --- /dev/null +++ b/tests/autodiff/high-order-user-defined-derivative.slang @@ -0,0 +1,77 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +//TEST_INPUT:ubuffer(data=[0.0 1.0 2.0 3.0], stride=4):name=endpointBuffer +RWStructuredBuffer<float> endpointBuffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=endpointDifferentialBuffer +RWStructuredBuffer<float> endpointDifferentialBuffer; + +struct LineSegment : IDifferentiable +{ + float x0; + float x1; + + [BackwardDifferentiable] + __init(float _x0, float _x1) + { + x0 = _x0; + x1 = _x1; + } +}; + +[BackwardDerivative(d_loadLineSegment)] +[ForwardDerivative(fwd_loadLineSegment)] +LineSegment loadLineSegment(uint id) +{ + return {endpointBuffer[id * 2], endpointBuffer[id * 2 + 1]}; +} + +[BackwardDerivative(d_fwd_loadLineSegment)] +DifferentialPair<LineSegment> fwd_loadLineSegment(uint id) +{ + return DifferentialPair<LineSegment>(loadLineSegment(id), LineSegment.dzero()); +} + +void d_loadLineSegment(uint id, LineSegment.Differential d_ls) +{ + endpointDifferentialBuffer[id * 2] += d_ls.x0; + endpointDifferentialBuffer[id * 2 + 1] += d_ls.x1; +} + +void d_fwd_loadLineSegment(uint id, DifferentialPair<LineSegment>.Differential dp_ls) +{ + endpointDifferentialBuffer[id * 2] += dp_ls.p.x0; + endpointDifferentialBuffer[id * 2 + 1] += dp_ls.p.x1; +} + +[BackwardDifferentiable] +float something() +{ + LineSegment ls = __fwd_diff(loadLineSegment)(1).p; + return ls.x0 + ls.x1; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + LineSegment ls = __fwd_diff(loadLineSegment)(0).p; + outputBuffer[0] = ls.x0; // Expect: 0 + outputBuffer[1] = ls.x1; // Expect: 1 + } + + { + LineSegment.Differential d_ls = __fwd_diff(loadLineSegment)(0).d; + outputBuffer[2] = d_ls.x1; // Expect: 0 + } + + { + // Expect: 2.0 in endpointDifferentialBuffer[2] + // Expect: 2.0 in endpointDifferentialBuffer[3] + __bwd_diff(something)(2.0); + } +}
\ No newline at end of file diff --git a/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt b/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt new file mode 100644 index 000000000..b5f259133 --- /dev/null +++ b/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt @@ -0,0 +1,10 @@ +type: float +0.000000 +1.000000 +0.000000 +0.000000 +type: float +0.000000 +0.000000 +2.000000 +2.000000 |
