From a670bafc121c20168624f70a388dbe8556402c7f Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Wed, 9 Jul 2025 11:25:29 -0500 Subject: no_diff diagnostics improvement (#7655) close #6286. This PR is to improve the diagnostics for no_diff usage. In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute. This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR. In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning. This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses. --- source/slang/slang-check-expr.cpp | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b28b458da..c7e58a888 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1632,6 +1632,31 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* } } +// This checks that if a differentiable function access a non-diff type "This", in such case we +// want to provide a non-error diagnostic to the user to notify that there could be an unexpected +// behavior because every member access will not have derivative computed for it. User can use +// [NoDiffThis] to clarify that this is intended. +void SemanticsVisitor::maybeCheckMissingNoDiffThis(Expr* expr) +{ + if (auto memberExpr = as(expr)) + { + auto thisExpr = as(memberExpr->baseExpression); + if (thisExpr && isTypeDifferentiable(memberExpr->type.type)) + { + if (isTypeDifferentiable(calcThisType(thisExpr->type.type)) || + this->m_parentFunc->findModifier()) + { + return; + } + + getSink()->diagnose( + memberExpr->loc, + Diagnostics::noDerivativeOnNonDifferentiableThisType, + memberExpr->declRef.getDecl(), + this->m_parentFunc); + } + } +} Expr* SemanticsVisitor::CheckTerm(Expr* term) { @@ -1649,7 +1674,13 @@ Expr* SemanticsVisitor::CheckTerm(Expr* term) if (this->m_parentFunc && this->m_parentFunc->findModifier()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + + if (!this->m_parentFunc->findModifier()) + { + maybeCheckMissingNoDiffThis(checkedTerm); + } } + return checkedTerm; } -- cgit v1.2.3