From 257733f328f38a763c8b0c8830ff4c0d34ec9491 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 7 Mar 2023 11:22:32 -0800 Subject: Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688) * Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs. * Add diff implementation matrix versions of binary and ternary intrinsics. * Add diff impl for legacy intrinsics. * Fix diagnostics of using non-differentiable function in a diff operator. * Add diff implementation for `determinant`. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 195 +++++++++++++++++++++++++++++-- source/slang/hlsl.meta.slang | 2 +- source/slang/slang-ast-expr.h | 2 +- source/slang/slang-ast-support-types.cpp | 7 +- source/slang/slang-ast-support-types.h | 15 ++- source/slang/slang-check-decl.cpp | 192 +++++++++--------------------- source/slang/slang-check-expr.cpp | 25 ++-- source/slang/slang-check-impl.h | 6 - source/slang/slang-check-modifier.cpp | 11 +- source/slang/slang-check-overload.cpp | 2 +- 10 files changed, 273 insertions(+), 184 deletions(-) (limited to 'source') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c303b39d9..54f927816 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -380,6 +380,24 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair>(result, d_result); \ } \ + __generic \ + [ForwardDerivativeOf(NAME)] \ + DifferentialPair> __d_##NAME##_matrix( \ + DifferentialPair> dpx, DifferentialPair> dpy) \ + { \ + matrix result; \ + matrix.Differential d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair dp_elem = __d_##NAME( \ + DifferentialPair(dpx.p[i][j], __slang_noop_cast(dpx.d[i][j])), \ + DifferentialPair(dpy.p[i][j], __slang_noop_cast(dpy.d[i][j]))); \ + result[i][j] = dp_elem.p; \ + d_result[i][j] = __slang_noop_cast(dp_elem.d); \ + } \ + return DifferentialPair>(result, d_result); \ + } \ __generic \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ @@ -398,6 +416,26 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair \ + [BackwardDerivativeOf(NAME)] \ + void __d_##NAME##_matrix( \ + inout DifferentialPair> dpx, \ + inout DifferentialPair> dpy, \ + matrix.Differential dOut) \ + { \ + matrix.Differential left_d_result, right_d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair left_dp = diffPair(dpx.p[i][j], T.dzero()); \ + DifferentialPair right_dp = diffPair(dpy.p[i][j], T.dzero()); \ + __d_##NAME(left_dp, right_dp, __slang_noop_cast(dOut[i][j])); \ + left_d_result[i][j] = __slang_noop_cast(left_dp.d); \ + right_d_result[i][j] = __slang_noop_cast(right_dp.d); \ + } \ + dpx = diffPair(dpx.p, left_d_result); \ + dpy = diffPair(dpy.p, right_d_result); \ } #define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \ @@ -407,7 +445,7 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair> dpx, \ DifferentialPair> dpy, \ DifferentialPair> dpz) \ -{ \ + { \ vector result; \ vector.Differential d_result; \ [ForceUnroll] for (int i = 0; i < N; ++i) \ @@ -421,6 +459,27 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair>(result, d_result); \ } \ + __generic \ + [ForwardDerivativeOf(NAME)] \ + DifferentialPair> __d_##NAME##_matrix( \ + DifferentialPair> dpx, \ + DifferentialPair> dpy, \ + DifferentialPair> dpz) \ + { \ + matrix result; \ + matrix.Differential d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair dp_elem = __d_##NAME( \ + DifferentialPair(dpx.p[i][j], __slang_noop_cast(dpx.d[i][j])), \ + DifferentialPair(dpy.p[i][j], __slang_noop_cast(dpy.d[i][j])), \ + DifferentialPair(dpz.p[i][j], __slang_noop_cast(dpz.d[i][j]))); \ + result[i][j] = dp_elem.p; \ + d_result[i][j] = __slang_noop_cast(dp_elem.d); \ + } \ + return DifferentialPair>(result, d_result); \ + } \ __generic \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ @@ -444,6 +503,31 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair \ + [BackwardDerivativeOf(NAME)] \ + void __d_##NAME##_matrix( \ + inout DifferentialPair> dpx, \ + inout DifferentialPair> dpy, \ + inout DifferentialPair> dpz, \ + matrix.Differential dOut) \ + { \ + matrix.Differential left_d_result, middle_d_result, right_d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair left_dp = diffPair(dpx.p[i][j], T.dzero()); \ + DifferentialPair middle_dp = diffPair(dpy.p[i][j], T.dzero()); \ + DifferentialPair right_dp = diffPair(dpz.p[i][j], T.dzero()); \ + __d_##NAME(left_dp, middle_dp, right_dp, \ + __slang_noop_cast(dOut[i][j])); \ + left_d_result[i][j] = __slang_noop_cast(left_dp.d); \ + middle_d_result[i][j] = __slang_noop_cast(middle_dp.d); \ + right_d_result[i][j] = __slang_noop_cast(right_dp.d); \ + } \ + dpx = diffPair(dpx.p, left_d_result); \ + dpy = diffPair(dpy.p, middle_d_result); \ + dpz = diffPair(dpz.p, right_d_result); \ } #define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \ @@ -999,24 +1083,19 @@ void __d_sincos(DifferentialPair> x, out DifferentialPair [BackwardDerivativeOf(sincos)] [ForceInline] void __d_sincos(inout DifferentialPair x, T.Differential dS, T.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); } __generic [BackwardDerivativeOf(sincos)] [ForceInline] void __d_sincos(inout DifferentialPair> x, vector.Differential dS, vector.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); } __generic @@ -1024,7 +1103,103 @@ __generic [ForceInline] void __d_sincos(inout DifferentialPair> x, matrix.Differential dS, matrix.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); +} + +// dst (obsolete) +__generic +[BackwardDifferentiable] +vector __dst_impl(vector src0, vector src1) +{ + vector dest; + dest.x = T(1.0); + dest.y = src0.y * src1.y; + dest.z = src0.z; + dest.w = src1.w; ; + return dest; +} +__generic +[ForwardDerivativeOf(dst)] +[ForceInline] +DifferentialPair> __d_dst(DifferentialPair> src0, DifferentialPair> src1) +{ + return __fwd_diff(__dst_impl)(src0, src1); +} +__generic +[BackwardDerivativeOf(dst)] +[ForceInline] +void __d_dst(inout DifferentialPair> src0, inout DifferentialPair> src1, vector.Differential dOut) +{ + __bwd_diff(__dst_impl)(src0, src1, dOut); +} + +// Legacy lighting function (obsolete) +__target_intrinsic(hlsl) +[__readNone] +[BackwardDifferentiable] +float4 __lit_impl(float n_dot_l, float n_dot_h, float m) +{ + let ambient = 1.0f; + let diffuse = max(n_dot_l, 0.0f); + let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m)); + return float4(ambient, diffuse, specular, 1.0f); +} +[ForwardDerivativeOf(lit)] +[ForceInline] +DifferentialPair __d_lit(DifferentialPair n_dot_l, DifferentialPair n_dot_h, DifferentialPair m) +{ + return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m); +} +[BackwardDerivativeOf(lit)] +[ForceInline] +void __d_lit(inout DifferentialPair n_dot_l, inout DifferentialPair n_dot_h, inout DifferentialPair m, float4 dOut) +{ + __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut); } -#endif \ No newline at end of file +// Matrix determinant +__generic +[BackwardDifferentiable] +[__readNone] +T __determinant_impl(matrix m) +{ + if (N == 1) + return m[0][0]; + else if (N == 2) + return m[0][0] * m[1][1] - m[0][1] * m[1][0]; + else if (N == 3) + { + return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) + - m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2]) + + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]); + } + else if (N == 4) + { + T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; + T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; + T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; + T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; + T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; + + return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02) + - m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04) + + m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05) + - m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05); + } + return T(0.0); +} +__generic +[ForwardDerivativeOf(determinant)] +[ForceInline] +DifferentialPair __determinant_impl(DifferentialPair> m) +{ + return __fwd_diff(__determinant_impl)(m); +} +__generic +[BackwardDerivativeOf(determinant)] +[ForceInline] +void __d_determinant(inout DifferentialPair> m, T.Differential dOut) +{ + __bwd_diff(__determinant_impl)(m, dOut); +} diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 5a01bc132..633c3d87e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2619,7 +2619,7 @@ float4 lit(float n_dot_l, float n_dot_h, float m) { let ambient = 1.0f; let diffuse = max(n_dot_l, 0.0f); - let specular = step(0.0f, n_dot_l) * max(n_dot_h * m, 0.0f); + let specular = step(0.0f, n_dot_l) * max(pow(n_dot_h, m), 0.0f); return float4(ambient, diffuse, specular, 1.0f); } diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index d49db89a2..ba0b4ce7a 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -448,7 +448,7 @@ class OpenRefExpr : public Expr class HigherOrderInvokeExpr : public Expr { SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr) - Expr* baseFunction; + Expr* baseFunction; List newParameterNames; }; diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index aa0513569..3feed7541 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -36,11 +36,16 @@ void removeModifier(ModifiableSyntaxNode* syntax, Modifier* toRemove) } } -Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) +Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outLevel) { HashSet workListSet; + outLevel = FunctionDifferentiableLevel::None; while (auto higherOrder = as(expr)) { + if (as(expr)) + outLevel = FunctionDifferentiableLevel::Backward; + else if (as(expr) && outLevel == FunctionDifferentiableLevel::None) + outLevel = FunctionDifferentiableLevel::Forward; if (workListSet.Add(higherOrder)) { expr = higherOrder->baseFunction; diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 5fd9df400..07b3a5eac 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1537,9 +1537,22 @@ namespace Slang DMulFunc, ///< The `IDifferentiable.dmul` function requirement }; + enum class FunctionDifferentiableLevel + { + None, + Forward, + Backward + }; + /// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s /// inner most expr is `f`. - Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr); + Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outDiffLevel); + inline Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) + { + FunctionDifferentiableLevel level; + return getInnerMostExprFromHigherOrderExpr(expr, level); + } + /// Get the operator name from the higher order invoke expr. UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index a1d5acfb0..7c42c1892 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4663,7 +4663,8 @@ namespace Slang TDerivativeAttr* attr, const List& imaginaryArguments) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); auto resolved = visitor->ResolveInvoke(invokeExpr); if (auto resolvedInvoke = as(resolved)) { @@ -4690,38 +4691,34 @@ namespace Slang return "BackwardDerivative"; } - List getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + List getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) { List imaginaryArguments; - for (auto param : originalFuncDecl->getParameters()) + for (auto param : func->getParameters()) { - auto arg = visitor->getASTBuilder()->create(); + auto arg = astBuilder->create(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) - { - arg->type.type = pairType; - } imaginaryArguments.add(arg); } return imaginaryArguments; } - List getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc) + List getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List imaginaryArguments; - for (auto param : fwdDiffFunc->getParameters()) + for (auto param : originalFuncDecl->getParameters()) { - auto arg = astBuilder->create(); + auto arg = visitor->getASTBuilder()->create(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = as(param->getType())) + if (auto pairType = visitor->getDifferentialPairType(param->getType())) { - arg->type.type = pairType->getPrimalType(); + arg->type.type = pairType; } imaginaryArguments.add(arg); } @@ -4731,6 +4728,11 @@ namespace Slang List getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) { List imaginaryArguments; + auto isOutParam = [&](ParamDecl* param) + { + return param->findModifier() != nullptr && param->findModifier() == nullptr; + }; + for (auto param : originalFuncDecl->getParameters()) { auto arg = visitor->getASTBuilder()->create(); @@ -4738,16 +4740,23 @@ namespace Slang arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) + if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) { arg->type.type = pairType; - if (auto diffPairType = as(pairType)) + if (isOutParam(param)) { - if (param->findModifier() != nullptr && param->findModifier() == nullptr) - { - arg->type.isLeftValue = false; - arg->type.type = diffPairType->getPrimalType(); - } + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + if (isOutParam(param)) + { + // Skip non-differentiable out params. + continue; } } imaginaryArguments.add(arg); @@ -4763,38 +4772,6 @@ namespace Slang return imaginaryArguments; } - List getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc) - { - // Note: it isn't always possible to construct original arguments from - // backward propagation arguments because backward propagation function - // may drop certain parameters. - List imaginaryArguments; - for (auto param : bwdDiffFunc->getParameters()) - { - auto arg = astBuilder->create(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = as(param->getType())) - { - if (param->findModifier() != nullptr && param->findModifier() == nullptr) - { - arg->type.isLeftValue = false; - } - arg->type.type = pairType->getPrimalType(); - } - imaginaryArguments.add(arg); - } - // Assume the last parameter is `dOut`. - // This is not true if the function returns a non-differentiable value. - // However in that uncommon case we just fail the overload resolution - // and require the user to provide disambiguate themselves. - if (imaginaryArguments.getCount()) - imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1); - return imaginaryArguments; - } - // This helper function is needed to workaround a gcc bug. // Remove when we upgrade to a newer version of gcc. template @@ -4803,76 +4780,41 @@ namespace Slang return decl->findModifier(); } - template + template void checkDerivativeOfAttributeImpl( SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind, - const List& imaginaryArgsToOriginal) + DeclAssociationKind assocKind) { DeclRef calleeDeclRef; - auto calleeDeclRefExpr = as(derivativeOfAttr->funcExpr); - if (!calleeDeclRefExpr) + DeclRefExpr* calleeDeclRefExpr = nullptr; + DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create(); + diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr; + diffFuncExpr->loc = derivativeOfAttr->loc; + Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor); + if (!checkedDiffFuncExpr) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as(resolved)) - { - calleeDeclRefExpr = as(resolvedInvoke->functionExpr); - } + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + List imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs); + auto resolved = visitor->ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + auto resolvedDiffFuncExpr = as(resolvedInvoke->functionExpr); + if (resolvedDiffFuncExpr) + calleeDeclRefExpr = as(resolvedDiffFuncExpr->baseFunction); } + if (!calleeDeclRefExpr) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } calleeDeclRef = calleeDeclRefExpr->declRef; - if (auto calleeGenDecl = as(calleeDeclRef.getDecl())) - { - auto parentGenericDecl = as(funcDecl->parentDecl); - if (!parentGenericDecl) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - FunctionDeclBase* funcReturnVal = nullptr; - List args; - for (auto mm : parentGenericDecl->members) - { - if (auto genericTypeParamDecl = as(mm)) - { - args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef(genericTypeParamDecl, nullptr))); - } - else if (auto genericValueParamDecl = as(mm)) - { - args.add(visitor->getASTBuilder()->getOrCreate( - genericValueParamDecl->getType(), - genericValueParamDecl, nullptr)); - } - } - auto funcs = calleeGenDecl->getMembersOfType(); - if (funcs.isEmpty()) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - funcReturnVal = funcs.getFirst(); - if (funcReturnVal) - { - auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr); - calleeDeclRef.decl = funcReturnVal; - calleeDeclRef.substitutions = subst; - calleeDeclRefExpr = as(visitor->ConstructDeclRefExpr( - calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr)); - } - else - { - calleeDeclRef = DeclRef(); - calleeDeclRefExpr = nullptr; - } - } - + auto calleeFunc = as(calleeDeclRef.getDecl()); if (!calleeFunc) { @@ -4953,9 +4895,8 @@ namespace Slang if (!attr) return; - List imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal); + checkDerivativeOfAttributeImpl( + this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); } void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) @@ -4964,33 +4905,8 @@ namespace Slang if (!attr) return; - List imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); - - // The tricky part here is that we can't easily derive the arguments to original func just - // from the definition of a backward derivative function, because we don't know if the last - // parameter is just a normal parameter of the original func, or if it is the additional - // derivative of the return value. The solution here is to try to resolve the original - // function with or without the last argument. However if the type of the last argument - // isn't differentiable, we know that it can't possibly be the result derivative. - - if (imaginaryArguments.getCount() == 0 || - !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type)) - { - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); - return; - } - - // Otherwise, try resolve with all the arguments, if failed, resolve without the last - // argument. - if (tryCheckDerivativeOfAttributeImpl(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments)) - { - return; - } - - imaginaryArguments.removeLast(); - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); + checkDerivativeOfAttributeImpl( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); } void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2803b5959..bebaa63a2 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1971,31 +1971,26 @@ namespace Slang if (auto higherOrderInvoke = as(invoke->functionExpr)) { - if (auto funcDeclExpr = as(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke))) + FunctionDifferentiableLevel requiredLevel; + if (auto funcDeclExpr = as( + getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel))) { auto funcDecl = as(funcDeclExpr->declRef.getDecl()); if (funcDecl) { - DifferentiateExpr* forwardDiff = nullptr; - DifferentiateExpr* backwardDiff = nullptr; - for (auto node = as(invoke->functionExpr); node; node = as(node->baseFunction)) + if (requiredLevel == FunctionDifferentiableLevel::Forward && + !getShared()->isDifferentiableFunc(funcDecl)) { - if (auto fwd = as(node)) - forwardDiff = fwd; - if (auto bwd = as(node)) - backwardDiff = bwd; + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); } - if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl)) + if (requiredLevel == FunctionDifferentiableLevel::Backward && + !getShared()->isBackwardDifferentiableFunc(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); - } - if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl)) - { - getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); + getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); } if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl)) { - getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); + getSink()->diagnose(invoke->functionExpr, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); } } } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 11aacd255..fc1b622cc 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -11,12 +11,6 @@ namespace Slang { - enum class FunctionDifferentiableLevel - { - None, - Forward, - Backward - }; /// Should the given `decl` be treated as a static rather than instance declaration? bool isEffectivelyStatic( Decl* decl); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 520d85971..e6a524645 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -680,16 +680,7 @@ namespace Slang { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as(attrTarget)); - - // Ensure that the argument is a reference to a function definition or declaration. - auto primalFunc = CheckTerm(attr->args[0]); - if (primalFunc->type == getASTBuilder()->getErrorType()) - { - // Could not resolve the term. - getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as(attrTarget)); - return false; - } - derivativeOfAttr->funcExpr = primalFunc; + derivativeOfAttr->funcExpr = attr->args[0]; } else if (auto comInterfaceAttr = as(attr)) { diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 09ddde2de..91af731ad 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1622,10 +1622,10 @@ namespace Slang else { // Unhandled case for the inner expr. - funcExpr->type = this->getASTBuilder()->getErrorType(); getSink()->diagnose(funcExpr->loc, Diagnostics::expectedFunction, funcExpr->type); + funcExpr->type = this->getASTBuilder()->getErrorType(); } } -- cgit v1.2.3