diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-11-22 18:55:47 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-22 23:55:47 +0000 |
| commit | 9913cfbf68dab8c3c8c418dd28b71c2a65a55ae0 (patch) | |
| tree | 735743bce54f0a4faf99925bc4582e3d0de8d5ea | |
| parent | 95125f280a3ee6cad08866baedc41fee8585b91e (diff) | |
[AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred (#5630)
* [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred
* Fix failing tests
* Update custom-derivative-generic.slang
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 76 | ||||
| -rw-r--r-- | tests/autodiff/custom-derivative-enum-param.slang | 57 | ||||
| -rw-r--r-- | tests/autodiff/custom-intrinsic-1.slang (renamed from tests/autodiff/custom-intrinsic.slang) | 0 | ||||
| -rw-r--r-- | tests/autodiff/custom-intrinsic-1.slang.expected.txt (renamed from tests/autodiff/custom-intrinsic.slang.expected.txt) | 0 | ||||
| -rw-r--r-- | tests/diagnostics/custom-derivative-generic.slang | 2 |
5 files changed, 133 insertions, 2 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 251ce6a69..e4206827f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl( SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + + auto exprToCheck = attr->funcExpr; + + // If this is a generic, we want to wrap the call to the derivative method + // with the generic parameters of the source. + // + if (as<GenericDecl>(funcDecl->parentDecl) && !as<GenericAppExpr>(attr->funcExpr)) + { + auto genericDecl = as<GenericDecl>(funcDecl->parentDecl); + auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl); + auto appExpr = ctx.getASTBuilder()->create<GenericAppExpr>(); + + Index count = 0; + for (auto member : genericDecl->members) + { + if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member) || + as<GenericTypePackParamDecl>(member)) + count++; + } + + appExpr->functionExpr = attr->funcExpr; + + for (auto arg : substArgs) + { + if (count == 0) + break; + + if (auto declRefType = as<DeclRefType>(arg)) + { + auto baseTypeExpr = ctx.getASTBuilder()->create<SharedTypeExpr>(); + baseTypeExpr->base.type = declRefType; + auto baseTypeType = ctx.getASTBuilder()->getOrCreate<TypeType>(declRefType); + baseTypeExpr->type.type = baseTypeType; + + appExpr->arguments.add(baseTypeExpr); + } + else if (auto genericValParam = as<GenericParamIntVal>(arg)) + { + auto declRef = genericValParam->getDeclRef(); + appExpr->arguments.add( + subVisitor + .ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr)); + } + else + { + SLANG_UNEXPECTED("Unhandled substitution arg type"); + } + + count--; + } + + exprToCheck = appExpr; + } + + auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx); attr->funcExpr = checkedFuncExpr; if (attr->args.getCount()) attr->args[0] = attr->funcExpr; @@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl( calleeDeclRef = calleeDeclRefExpr->declRef; auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); + + if (!calleeFunc) + { + // If we couldn't find a direct function, it might be a generic. + if (auto genericDecl = as<GenericDecl>(calleeDeclRef.getDecl())) + { + calleeFunc = as<FunctionDeclBase>(genericDecl->inner); + + if (as<ErrorType>(resolved->type.type)) + { + // If we can't resolve a type, something went wrong. If we're working with a generic + // decl, the most likely cause is a failure of generic argument inference. + // + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + } + } + } + if (!calleeFunc) { visitor->getSink()->diagnose( diff --git a/tests/autodiff/custom-derivative-enum-param.slang b/tests/autodiff/custom-derivative-enum-param.slang new file mode 100644 index 000000000..aa6733873 --- /dev/null +++ b/tests/autodiff/custom-derivative-enum-param.slang @@ -0,0 +1,57 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +enum MyEnum { A, B, C }; + +[BackwardDerivative(mDiff)] +float m<let M : MyEnum>(float x) +{ + switch (M) + { + case MyEnum.A: + return x * x; + case MyEnum.B: + return x; + case MyEnum.C: + return 3 * x; + default: + return 0; + } +} + +void mDiff<let M : MyEnum>(inout DifferentialPair<float> x, float dResult) +{ + switch (M) + { + case MyEnum.A: + updateDiff(x, 2 * dResult * x.p); + break; + case MyEnum.B: + updateDiff(x, dResult); + break; + case MyEnum.C: + updateDiff(x, 3 * dResult); + break; + default: + updateDiff(x, 0); + break; + } +} + +[Differentiable] +float test(float x) +{ + return m<MyEnum.A>(x); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var a = diffPair(3.0); + __bwd_diff(test)(a, 1.0); + outputBuffer[dispatchThreadID.x] = a.d; + // CHECK: 6.0 +} diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic-1.slang index 1fe204b58..1fe204b58 100644 --- a/tests/autodiff/custom-intrinsic.slang +++ b/tests/autodiff/custom-intrinsic-1.slang diff --git a/tests/autodiff/custom-intrinsic.slang.expected.txt b/tests/autodiff/custom-intrinsic-1.slang.expected.txt index ce22a5b95..ce22a5b95 100644 --- a/tests/autodiff/custom-intrinsic.slang.expected.txt +++ b/tests/autodiff/custom-intrinsic-1.slang.expected.txt diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang index 5f2cd9951..fb65dd2cc 100644 --- a/tests/diagnostics/custom-derivative-generic.slang +++ b/tests/diagnostics/custom-derivative-generic.slang @@ -34,7 +34,7 @@ DifferentialPair<float> dd1(DifferentialPair<float> x) } // CHECK-DAG: {{.*}}(37): error 31151 -[BackwardDerivative(f)] +[BackwardDerivativeOf(f)] DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut) { var primal = x.p * x.p; |
