summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-03-23 22:27:30 -0400
committerGitHub <noreply@github.com>2023-03-23 19:27:30 -0700
commite004511b5f75bb24df1adec71b005146917afb39 (patch)
treeb9d8a41dd2dfb92e5808d643c5a1693c78af839f
parent6e4eae1050ab9282b460a33a013652c387c1e585 (diff)
AD: Fix type checking for higher-order custom derivatives definitions (#2729)
* Fixed type coercion issue with higher-order user defined methods * Placed associated type lookup method in a loop * Update high-order-user-defined-derivative.slang * Revert changes to associated type lookup method
-rw-r--r--source/slang/slang-check-decl.cpp4
-rw-r--r--tests/autodiff/high-order-user-defined-derivative.slang77
-rw-r--r--tests/autodiff/high-order-user-defined-derivative.slang.expected.txt10
3 files changed, 91 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index eaab43ef8..6083ce9c0 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -238,6 +238,9 @@ namespace Slang
void visitCallableDecl(CallableDecl* decl)
{
+ for (auto paramDecl : decl->getMembersOfType<ParamDecl>())
+ visitTypeExp(paramDecl->type);
+
visitTypeExp(decl->returnType);
visitTypeExp(decl->errorType);
}
@@ -6916,6 +6919,7 @@ namespace Slang
auto ctx = visitor->withExprLocalScope(&scope);
auto subVisitor = SemanticsVisitor(ctx);
auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);
+ visitor->ensureDecl(as<DeclRefExpr>(checkedFuncExpr)->declRef, DeclCheckState::TypesFullyResolved);
auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
auto resolved = subVisitor.ResolveInvoke(invokeExpr);
if (auto resolvedInvoke = as<InvokeExpr>(resolved))
diff --git a/tests/autodiff/high-order-user-defined-derivative.slang b/tests/autodiff/high-order-user-defined-derivative.slang
new file mode 100644
index 000000000..4ad4aad12
--- /dev/null
+++ b/tests/autodiff/high-order-user-defined-derivative.slang
@@ -0,0 +1,77 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+//TEST_INPUT:ubuffer(data=[0.0 1.0 2.0 3.0], stride=4):name=endpointBuffer
+RWStructuredBuffer<float> endpointBuffer;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=endpointDifferentialBuffer
+RWStructuredBuffer<float> endpointDifferentialBuffer;
+
+struct LineSegment : IDifferentiable
+{
+ float x0;
+ float x1;
+
+ [BackwardDifferentiable]
+ __init(float _x0, float _x1)
+ {
+ x0 = _x0;
+ x1 = _x1;
+ }
+};
+
+[BackwardDerivative(d_loadLineSegment)]
+[ForwardDerivative(fwd_loadLineSegment)]
+LineSegment loadLineSegment(uint id)
+{
+ return {endpointBuffer[id * 2], endpointBuffer[id * 2 + 1]};
+}
+
+[BackwardDerivative(d_fwd_loadLineSegment)]
+DifferentialPair<LineSegment> fwd_loadLineSegment(uint id)
+{
+ return DifferentialPair<LineSegment>(loadLineSegment(id), LineSegment.dzero());
+}
+
+void d_loadLineSegment(uint id, LineSegment.Differential d_ls)
+{
+ endpointDifferentialBuffer[id * 2] += d_ls.x0;
+ endpointDifferentialBuffer[id * 2 + 1] += d_ls.x1;
+}
+
+void d_fwd_loadLineSegment(uint id, DifferentialPair<LineSegment>.Differential dp_ls)
+{
+ endpointDifferentialBuffer[id * 2] += dp_ls.p.x0;
+ endpointDifferentialBuffer[id * 2 + 1] += dp_ls.p.x1;
+}
+
+[BackwardDifferentiable]
+float something()
+{
+ LineSegment ls = __fwd_diff(loadLineSegment)(1).p;
+ return ls.x0 + ls.x1;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ LineSegment ls = __fwd_diff(loadLineSegment)(0).p;
+ outputBuffer[0] = ls.x0; // Expect: 0
+ outputBuffer[1] = ls.x1; // Expect: 1
+ }
+
+ {
+ LineSegment.Differential d_ls = __fwd_diff(loadLineSegment)(0).d;
+ outputBuffer[2] = d_ls.x1; // Expect: 0
+ }
+
+ {
+ // Expect: 2.0 in endpointDifferentialBuffer[2]
+ // Expect: 2.0 in endpointDifferentialBuffer[3]
+ __bwd_diff(something)(2.0);
+ }
+} \ No newline at end of file
diff --git a/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt b/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt
new file mode 100644
index 000000000..b5f259133
--- /dev/null
+++ b/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt
@@ -0,0 +1,10 @@
+type: float
+0.000000
+1.000000
+0.000000
+0.000000
+type: float
+0.000000
+0.000000
+2.000000
+2.000000