diff options
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | tests/autodiff/overloaded-custom-deriv.slang | 41 | ||||
| -rw-r--r-- | tests/autodiff/overloaded-custom-deriv.slang.expected.txt | 2 |
4 files changed, 69 insertions, 11 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 3ffb6c100..63c7d9741 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6923,20 +6923,34 @@ namespace Slang auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); - if (auto derivFuncDeclRef = as<DeclRefExpr>(checkedFuncExpr)->declRef) + if (auto declRefExpr = as<DeclRefExpr>(checkedFuncExpr)) { - visitor->ensureDecl(derivFuncDeclRef, DeclCheckState::TypesFullyResolved); - auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); - auto resolved = subVisitor.ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::TypesFullyResolved); + } + else if (auto overloadedExpr = as<OverloadedExpr>(checkedFuncExpr)) + { + for (auto candidate : overloadedExpr->lookupResult2.items) { - if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) - { - attr->funcExpr = calleeDeclRef; - return; - } + visitor->ensureDecl(candidate.declRef, DeclCheckState::TypesFullyResolved); } } + else + { + visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveDerivativeFunction); + return; + } + + auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 4a9b83c6c..128142d84 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -356,8 +356,9 @@ DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original defi DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.") DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.") DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.") +DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function") -DIAGNOSTIC(31148, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.") +DIAGNOSTIC(31149, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.") // Enums diff --git a/tests/autodiff/overloaded-custom-deriv.slang b/tests/autodiff/overloaded-custom-deriv.slang new file mode 100644 index 000000000..81f91974f --- /dev/null +++ b/tests/autodiff/overloaded-custom-deriv.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[ForwardDerivative(diff_f)] +float f(float v) +{ + return v * v; +} + +DifferentialPair<float> diff_f(DifferentialPair<float> v) +{ + return diffPair(v.p * v.p, 2 * v.d * v.p); +} + +[ForwardDerivative(diff_f)] +float2 f(float2 v) +{ + return v * v; +} + +DifferentialPair<float2> diff_f(DifferentialPair<float2> v) +{ + return diffPair(v.p * v.p, float2(2 * v.d.x * v.p.x, 0.0)); +} + +[ForwardDifferentiable] +float test(float v) +{ + return f(v) + f(float2(v, v)).x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var p = diffPair(3.0, 1.0); + let rs = __fwd_diff(test)(p); + outputBuffer[0] = rs.d; +} diff --git a/tests/autodiff/overloaded-custom-deriv.slang.expected.txt b/tests/autodiff/overloaded-custom-deriv.slang.expected.txt new file mode 100644 index 000000000..7da9c9037 --- /dev/null +++ b/tests/autodiff/overloaded-custom-deriv.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +12.000000 |
