diff options
| -rw-r--r-- | docs/user-guide/07-autodiff.md | 7 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 195 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 15 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 192 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 2 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang | 12 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt | 5 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-sqrt.slang | 11 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt | 6 |
15 files changed, 306 insertions, 192 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index 244ebb47b..6b21ee28c 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -483,10 +483,9 @@ The following builtin functions are backward differentiable and both their forwa - Hyperbolic functions: `sinh`, `cosh`, `tanh` - Exponential and logarithmic functions: `exp`, `exp2`, `pow`, `log`, `log2`, `log10` - Vector functions: `dot`, `cross`, `length`, `distance`, `normalize`, `reflect`, `refract` -- Matrix transform: `mul(matrix, vector)`, `mul(vector, matrix)`, `mul(matrix, matrix)`, `transpose` - -Derivatives for the following legacy HLSL intrinsic functions are not implemented: -- `dst`, `lit`, +- Matrix transforms: `mul(matrix, vector)`, `mul(vector, matrix)`, `mul(matrix, matrix)` +- Matrix operations: `transpose`, `determinant` +- Legacy blending and lighting intrinsics: `dst`, `lit` ## Excluding Parameters From Differentiation diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c303b39d9..54f927816 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -380,6 +380,24 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve } \ return DifferentialPair<vector<T, N>>(result, d_result); \ } \ + __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [ForwardDerivativeOf(NAME)] \ + DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ + DifferentialPair<matrix<T, M, N>> dpx, DifferentialPair<matrix<T, M, N>> dpy) \ + { \ + matrix<T, M, N> result; \ + matrix<T, M, N>.Differential d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair<T> dp_elem = __d_##NAME( \ + DifferentialPair<T>(dpx.p[i][j], __slang_noop_cast<T.Differential>(dpx.d[i][j])), \ + DifferentialPair<T>(dpy.p[i][j], __slang_noop_cast<T.Differential>(dpy.d[i][j]))); \ + result[i][j] = dp_elem.p; \ + d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \ + } \ + return DifferentialPair<matrix<T, M, N>>(result, d_result); \ + } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ @@ -398,6 +416,26 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve } \ dpx = diffPair(dpx.p, left_d_result); \ dpy = diffPair(dpy.p, right_d_result); \ + } \ + __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDerivativeOf(NAME)] \ + void __d_##NAME##_matrix( \ + inout DifferentialPair<matrix<T, M, N>> dpx, \ + inout DifferentialPair<matrix<T, M, N>> dpy, \ + matrix<T, M, N>.Differential dOut) \ + { \ + matrix<T, M, N>.Differential left_d_result, right_d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair<T> left_dp = diffPair(dpx.p[i][j], T.dzero()); \ + DifferentialPair<T> right_dp = diffPair(dpy.p[i][j], T.dzero()); \ + __d_##NAME(left_dp, right_dp, __slang_noop_cast<T.Differential>(dOut[i][j])); \ + left_d_result[i][j] = __slang_noop_cast<T>(left_dp.d); \ + right_d_result[i][j] = __slang_noop_cast<T>(right_dp.d); \ + } \ + dpx = diffPair(dpx.p, left_d_result); \ + dpy = diffPair(dpy.p, right_d_result); \ } #define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \ @@ -407,7 +445,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve DifferentialPair<vector<T, N>> dpx, \ DifferentialPair<vector<T, N>> dpy, \ DifferentialPair<vector<T, N>> dpz) \ -{ \ + { \ vector<T, N> result; \ vector<T, N>.Differential d_result; \ [ForceUnroll] for (int i = 0; i < N; ++i) \ @@ -421,6 +459,27 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve } \ return DifferentialPair<vector<T, N>>(result, d_result); \ } \ + __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [ForwardDerivativeOf(NAME)] \ + DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ + DifferentialPair<matrix<T, M, N>> dpx, \ + DifferentialPair<matrix<T, M, N>> dpy, \ + DifferentialPair<matrix<T, M, N>> dpz) \ + { \ + matrix<T, M, N> result; \ + matrix<T, M, N>.Differential d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair<T> dp_elem = __d_##NAME( \ + DifferentialPair<T>(dpx.p[i][j], __slang_noop_cast<T.Differential>(dpx.d[i][j])), \ + DifferentialPair<T>(dpy.p[i][j], __slang_noop_cast<T.Differential>(dpy.d[i][j])), \ + DifferentialPair<T>(dpz.p[i][j], __slang_noop_cast<T.Differential>(dpz.d[i][j]))); \ + result[i][j] = dp_elem.p; \ + d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \ + } \ + return DifferentialPair<matrix<T, M, N>>(result, d_result); \ + } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ @@ -444,6 +503,31 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpx = diffPair(dpx.p, left_d_result); \ dpy = diffPair(dpy.p, middle_d_result); \ dpz = diffPair(dpz.p, right_d_result); \ + } \ + __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDerivativeOf(NAME)] \ + void __d_##NAME##_matrix( \ + inout DifferentialPair<matrix<T, M, N>> dpx, \ + inout DifferentialPair<matrix<T, M, N>> dpy, \ + inout DifferentialPair<matrix<T, M, N>> dpz, \ + matrix<T, M, N>.Differential dOut) \ + { \ + matrix<T, M, N>.Differential left_d_result, middle_d_result, right_d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair<T> left_dp = diffPair(dpx.p[i][j], T.dzero()); \ + DifferentialPair<T> middle_dp = diffPair(dpy.p[i][j], T.dzero()); \ + DifferentialPair<T> right_dp = diffPair(dpz.p[i][j], T.dzero()); \ + __d_##NAME(left_dp, middle_dp, right_dp, \ + __slang_noop_cast<T.Differential>(dOut[i][j])); \ + left_d_result[i][j] = __slang_noop_cast<T>(left_dp.d); \ + middle_d_result[i][j] = __slang_noop_cast<T>(middle_dp.d); \ + right_d_result[i][j] = __slang_noop_cast<T>(right_dp.d); \ + } \ + dpx = diffPair(dpx.p, left_d_result); \ + dpy = diffPair(dpy.p, middle_d_result); \ + dpz = diffPair(dpz.p, right_d_result); \ } #define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \ @@ -999,24 +1083,19 @@ void __d_sincos(DifferentialPair<matrix<T, N, M>> x, out DifferentialPair<matrix __fwd_diff(__sincos_impl)(x, s, c); } -#if 0 -// TODO: this is not working right now since our type system can't resolve -// the overload to `sincos` in `[BackwardDerivativeOf]` attribute. We need to implement -// a proper overload resolver for custom backward derivatives. - __generic<T: __BuiltinFloatingPointType> [BackwardDerivativeOf(sincos)] [ForceInline] void __d_sincos(inout DifferentialPair<T> x, T.Differential dS, T.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); } __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDerivativeOf(sincos)] [ForceInline] void __d_sincos(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dS, vector<T, N>.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); } __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> @@ -1024,7 +1103,103 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] void __d_sincos(inout DifferentialPair<matrix<T, N, M>> x, matrix<T, N, M>.Differential dS, matrix<T, N, M>.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); +} + +// dst (obsolete) +__generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] +vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1) +{ + vector<T, 4> dest; + dest.x = T(1.0); + dest.y = src0.y * src1.y; + dest.z = src0.z; + dest.w = src1.w; ; + return dest; +} +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(dst)] +[ForceInline] +DifferentialPair<vector<T, 4>> __d_dst(DifferentialPair<vector<T, 4>> src0, DifferentialPair<vector<T, 4>> src1) +{ + return __fwd_diff(__dst_impl)(src0, src1); +} +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(dst)] +[ForceInline] +void __d_dst(inout DifferentialPair<vector<T, 4>> src0, inout DifferentialPair<vector<T, 4>> src1, vector<T, 4>.Differential dOut) +{ + __bwd_diff(__dst_impl)(src0, src1, dOut); +} + +// Legacy lighting function (obsolete) +__target_intrinsic(hlsl) +[__readNone] +[BackwardDifferentiable] +float4 __lit_impl(float n_dot_l, float n_dot_h, float m) +{ + let ambient = 1.0f; + let diffuse = max(n_dot_l, 0.0f); + let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m)); + return float4(ambient, diffuse, specular, 1.0f); +} +[ForwardDerivativeOf(lit)] +[ForceInline] +DifferentialPair<float4> __d_lit(DifferentialPair<float> n_dot_l, DifferentialPair<float> n_dot_h, DifferentialPair<float> m) +{ + return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m); +} +[BackwardDerivativeOf(lit)] +[ForceInline] +void __d_lit(inout DifferentialPair<float> n_dot_l, inout DifferentialPair<float> n_dot_h, inout DifferentialPair<float> m, float4 dOut) +{ + __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut); } -#endif
\ No newline at end of file +// Matrix determinant +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] +[__readNone] +T __determinant_impl(matrix<T,N,N> m) +{ + if (N == 1) + return m[0][0]; + else if (N == 2) + return m[0][0] * m[1][1] - m[0][1] * m[1][0]; + else if (N == 3) + { + return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) + - m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2]) + + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]); + } + else if (N == 4) + { + T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; + T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; + T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; + T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; + T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; + + return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02) + - m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04) + + m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05) + - m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05); + } + return T(0.0); +} +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(determinant)] +[ForceInline] +DifferentialPair<T> __determinant_impl(DifferentialPair<matrix<T,N,N>> m) +{ + return __fwd_diff(__determinant_impl)(m); +} +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(determinant)] +[ForceInline] +void __d_determinant(inout DifferentialPair<matrix<T,N,N>> m, T.Differential dOut) +{ + __bwd_diff(__determinant_impl)(m, dOut); +} diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 5a01bc132..633c3d87e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2619,7 +2619,7 @@ float4 lit(float n_dot_l, float n_dot_h, float m) { let ambient = 1.0f; let diffuse = max(n_dot_l, 0.0f); - let specular = step(0.0f, n_dot_l) * max(n_dot_h * m, 0.0f); + let specular = step(0.0f, n_dot_l) * max(pow(n_dot_h, m), 0.0f); return float4(ambient, diffuse, specular, 1.0f); } diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index d49db89a2..ba0b4ce7a 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -448,7 +448,7 @@ class OpenRefExpr : public Expr class HigherOrderInvokeExpr : public Expr { SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr) - Expr* baseFunction; + Expr* baseFunction; List<Name*> newParameterNames; }; diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index aa0513569..3feed7541 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -36,11 +36,16 @@ void removeModifier(ModifiableSyntaxNode* syntax, Modifier* toRemove) } } -Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) +Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outLevel) { HashSet<Expr*> workListSet; + outLevel = FunctionDifferentiableLevel::None; while (auto higherOrder = as<HigherOrderInvokeExpr>(expr)) { + if (as<BackwardDifferentiateExpr>(expr)) + outLevel = FunctionDifferentiableLevel::Backward; + else if (as<ForwardDifferentiateExpr>(expr) && outLevel == FunctionDifferentiableLevel::None) + outLevel = FunctionDifferentiableLevel::Forward; if (workListSet.Add(higherOrder)) { expr = higherOrder->baseFunction; diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 5fd9df400..07b3a5eac 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1537,9 +1537,22 @@ namespace Slang DMulFunc, ///< The `IDifferentiable.dmul` function requirement }; + enum class FunctionDifferentiableLevel + { + None, + Forward, + Backward + }; + /// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s /// inner most expr is `f`. - Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr); + Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outDiffLevel); + inline Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) + { + FunctionDifferentiableLevel level; + return getInnerMostExprFromHigherOrderExpr(expr, level); + } + /// Get the operator name from the higher order invoke expr. UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index a1d5acfb0..7c42c1892 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4663,7 +4663,8 @@ namespace Slang TDerivativeAttr* attr, const List<Expr*>& imaginaryArguments) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); auto resolved = visitor->ResolveInvoke(invokeExpr); if (auto resolvedInvoke = as<InvokeExpr>(resolved)) { @@ -4690,38 +4691,34 @@ namespace Slang return "BackwardDerivative"; } - List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) { List<Expr*> imaginaryArguments; - for (auto param : originalFuncDecl->getParameters()) + for (auto param : func->getParameters()) { - auto arg = visitor->getASTBuilder()->create<VarExpr>(); + auto arg = astBuilder->create<VarExpr>(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) - { - arg->type.type = pairType; - } imaginaryArguments.add(arg); } return imaginaryArguments; } - List<Expr*> getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc) + List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List<Expr*> imaginaryArguments; - for (auto param : fwdDiffFunc->getParameters()) + for (auto param : originalFuncDecl->getParameters()) { - auto arg = astBuilder->create<VarExpr>(); + auto arg = visitor->getASTBuilder()->create<VarExpr>(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = as<DifferentialPairType>(param->getType())) + if (auto pairType = visitor->getDifferentialPairType(param->getType())) { - arg->type.type = pairType->getPrimalType(); + arg->type.type = pairType; } imaginaryArguments.add(arg); } @@ -4731,6 +4728,11 @@ namespace Slang List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List<Expr*> imaginaryArguments; + auto isOutParam = [&](ParamDecl* param) + { + return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr; + }; + for (auto param : originalFuncDecl->getParameters()) { auto arg = visitor->getASTBuilder()->create<VarExpr>(); @@ -4738,16 +4740,23 @@ namespace Slang arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) + if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) { arg->type.type = pairType; - if (auto diffPairType = as<DifferentialPairType>(pairType)) + if (isOutParam(param)) { - if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr) - { - arg->type.isLeftValue = false; - arg->type.type = diffPairType->getPrimalType(); - } + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + if (isOutParam(param)) + { + // Skip non-differentiable out params. + continue; } } imaginaryArguments.add(arg); @@ -4763,38 +4772,6 @@ namespace Slang return imaginaryArguments; } - List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc) - { - // Note: it isn't always possible to construct original arguments from - // backward propagation arguments because backward propagation function - // may drop certain parameters. - List<Expr*> imaginaryArguments; - for (auto param : bwdDiffFunc->getParameters()) - { - auto arg = astBuilder->create<VarExpr>(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = as<DifferentialPairType>(param->getType())) - { - if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr) - { - arg->type.isLeftValue = false; - } - arg->type.type = pairType->getPrimalType(); - } - imaginaryArguments.add(arg); - } - // Assume the last parameter is `dOut`. - // This is not true if the function returns a non-differentiable value. - // However in that uncommon case we just fail the overload resolution - // and require the user to provide disambiguate themselves. - if (imaginaryArguments.getCount()) - imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1); - return imaginaryArguments; - } - // This helper function is needed to workaround a gcc bug. // Remove when we upgrade to a newer version of gcc. template <typename T> @@ -4803,76 +4780,41 @@ namespace Slang return decl->findModifier<T>(); } - template <typename TDerivativeAttr, typename TDerivativeOfAttr> + template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr> void checkDerivativeOfAttributeImpl( SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind, - const List<Expr*>& imaginaryArgsToOriginal) + DeclAssociationKind assocKind) { DeclRef<Decl> calleeDeclRef; - auto calleeDeclRefExpr = as<DeclRefExpr>(derivativeOfAttr->funcExpr); - if (!calleeDeclRefExpr) + DeclRefExpr* calleeDeclRefExpr = nullptr; + DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); + diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr; + diffFuncExpr->loc = derivativeOfAttr->loc; + Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor); + if (!checkedDiffFuncExpr) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) - { - calleeDeclRefExpr = as<DeclRefExpr>(resolvedInvoke->functionExpr); - } + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs); + auto resolved = visitor->ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + auto resolvedDiffFuncExpr = as<DifferentiateExpr>(resolvedInvoke->functionExpr); + if (resolvedDiffFuncExpr) + calleeDeclRefExpr = as<DeclRefExpr>(resolvedDiffFuncExpr->baseFunction); } + if (!calleeDeclRefExpr) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } calleeDeclRef = calleeDeclRefExpr->declRef; - if (auto calleeGenDecl = as<GenericDecl>(calleeDeclRef.getDecl())) - { - auto parentGenericDecl = as<GenericDecl>(funcDecl->parentDecl); - if (!parentGenericDecl) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - FunctionDeclBase* funcReturnVal = nullptr; - List<Val*> args; - for (auto mm : parentGenericDecl->members) - { - if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm)) - { - args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef<Decl>(genericTypeParamDecl, nullptr))); - } - else if (auto genericValueParamDecl = as<GenericValueParamDecl>(mm)) - { - args.add(visitor->getASTBuilder()->getOrCreate<GenericParamIntVal>( - genericValueParamDecl->getType(), - genericValueParamDecl, nullptr)); - } - } - auto funcs = calleeGenDecl->getMembersOfType<FunctionDeclBase>(); - if (funcs.isEmpty()) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - funcReturnVal = funcs.getFirst(); - if (funcReturnVal) - { - auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr); - calleeDeclRef.decl = funcReturnVal; - calleeDeclRef.substitutions = subst; - calleeDeclRefExpr = as<DeclRefExpr>(visitor->ConstructDeclRefExpr( - calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr)); - } - else - { - calleeDeclRef = DeclRef<Decl>(); - calleeDeclRefExpr = nullptr; - } - } - + auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); if (!calleeFunc) { @@ -4953,9 +4895,8 @@ namespace Slang if (!attr) return; - List<Expr*> imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); - checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute>( - this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal); + checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>( + this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); } void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) @@ -4964,33 +4905,8 @@ namespace Slang if (!attr) return; - List<Expr*> imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); - - // The tricky part here is that we can't easily derive the arguments to original func just - // from the definition of a backward derivative function, because we don't know if the last - // parameter is just a normal parameter of the original func, or if it is the additional - // derivative of the return value. The solution here is to try to resolve the original - // function with or without the last argument. However if the type of the last argument - // isn't differentiable, we know that it can't possibly be the result derivative. - - if (imaginaryArguments.getCount() == 0 || - !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type)) - { - checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); - return; - } - - // Otherwise, try resolve with all the arguments, if failed, resolve without the last - // argument. - if (tryCheckDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments)) - { - return; - } - - imaginaryArguments.removeLast(); - checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); + checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); } void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2803b5959..bebaa63a2 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1971,31 +1971,26 @@ namespace Slang if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr)) { - if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke))) + FunctionDifferentiableLevel requiredLevel; + if (auto funcDeclExpr = as<DeclRefExpr>( + getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel))) { auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl()); if (funcDecl) { - DifferentiateExpr* forwardDiff = nullptr; - DifferentiateExpr* backwardDiff = nullptr; - for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction)) + if (requiredLevel == FunctionDifferentiableLevel::Forward && + !getShared()->isDifferentiableFunc(funcDecl)) { - if (auto fwd = as<ForwardDifferentiateExpr>(node)) - forwardDiff = fwd; - if (auto bwd = as<BackwardDifferentiateExpr>(node)) - backwardDiff = bwd; + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); } - if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl)) + if (requiredLevel == FunctionDifferentiableLevel::Backward && + !getShared()->isBackwardDifferentiableFunc(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); - } - if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl)) - { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); } if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); + getSink()->diagnose(invoke->functionExpr, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); } } } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 11aacd255..fc1b622cc 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -11,12 +11,6 @@ namespace Slang { - enum class FunctionDifferentiableLevel - { - None, - Forward, - Backward - }; /// Should the given `decl` be treated as a static rather than instance declaration? bool isEffectivelyStatic( Decl* decl); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 520d85971..e6a524645 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -680,16 +680,7 @@ namespace Slang { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as<Decl>(attrTarget)); - - // Ensure that the argument is a reference to a function definition or declaration. - auto primalFunc = CheckTerm(attr->args[0]); - if (primalFunc->type == getASTBuilder()->getErrorType()) - { - // Could not resolve the term. - getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); - return false; - } - derivativeOfAttr->funcExpr = primalFunc; + derivativeOfAttr->funcExpr = attr->args[0]; } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 09ddde2de..91af731ad 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1622,10 +1622,10 @@ namespace Slang else { // Unhandled case for the inner expr. - funcExpr->type = this->getASTBuilder()->getErrorType(); getSink()->diagnose(funcExpr->loc, Diagnostics::expectedFunction, funcExpr->type); + funcExpr->type = this->getASTBuilder()->getErrorType(); } } diff --git a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang index 379e2c3ef..53972ac2c 100644 --- a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang +++ b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang @@ -2,7 +2,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; @@ -43,4 +43,14 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) __bwd_diff(diffSin)(dpx, 1.0); outputBuffer[4] = dpx.d; // Expect: -1.000000 } + + { + dpfloat dpx = dpfloat(float.getPi() / 3.0, 1.0); + __bwd_diff(sincos)(dpx, 1.0, 0.0); + outputBuffer[5] = dpx.d; // Expect: 0.5 + __bwd_diff(sincos)(dpx, 0.0, 1.0); + outputBuffer[6] = dpx.d; // Expect: -0.8660254 + __bwd_diff(sincos)(dpx, 1.0, 1.0); + outputBuffer[7] = dpx.d; // Expect: -0.3660254 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt index a4b804cb8..17627df68 100644 --- a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt @@ -3,4 +3,7 @@ type: float 7.389056 0.000000 1.000000 --1.00000
\ No newline at end of file +-1.000000 +0.500000 +-0.866025 +-0.366025 diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang b/tests/autodiff-dstdlib/dstdlib-sqrt.slang index 15573c4ef..d68a2697c 100644 --- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang +++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang @@ -1,7 +1,7 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; @@ -50,4 +50,13 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) outputBuffer[6] = dpx.d[0]; // Expect: 0.158114 outputBuffer[7] = dpx.d[1]; // Expect: 0.577350 } + + { + var dpx = diffPair(float2x2(4.0, 9.0, 16.0, 25.0), float2x2(0.0, 0.0, 0.0, 0.0)); + __bwd_diff(sqrt)(dpx, float2x2(1.0, 2.0, 3.0, 4.0)); + outputBuffer[8] = dpx.d[0][0]; // Expect: 0.25 + outputBuffer[9] = dpx.d[0][1]; // Expect: 0.3333 + outputBuffer[10] = dpx.d[1][0]; // Expect: 0.375 + outputBuffer[11] = dpx.d[1][1]; // Expect: 0.4 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt index fe6487fef..7e0fdf02f 100644 --- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt @@ -6,4 +6,8 @@ type: float 0.000000 0.000000 0.158114 -0.577350
\ No newline at end of file +0.577350 +0.250000 +0.333333 +0.375000 +0.400000
\ No newline at end of file |
