diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2025-07-09 11:25:29 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-09 09:25:29 -0700 |
| commit | a670bafc121c20168624f70a388dbe8556402c7f (patch) | |
| tree | 79b48a80e7abc0744193716e400bb57a6c026bad /source/slang/slang-check-expr.cpp | |
| parent | a7cb36901ccaf8297136c58c1451d6e04420af73 (diff) | |
no_diff diagnostics improvement (#7655)
close #6286.
This PR is to improve the diagnostics for no_diff usage.
In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute.
This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR.
In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning.
This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses.
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b28b458da..c7e58a888 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1632,6 +1632,31 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* } } +// This checks that if a differentiable function access a non-diff type "This", in such case we +// want to provide a non-error diagnostic to the user to notify that there could be an unexpected +// behavior because every member access will not have derivative computed for it. User can use +// [NoDiffThis] to clarify that this is intended. +void SemanticsVisitor::maybeCheckMissingNoDiffThis(Expr* expr) +{ + if (auto memberExpr = as<MemberExpr>(expr)) + { + auto thisExpr = as<ThisExpr>(memberExpr->baseExpression); + if (thisExpr && isTypeDifferentiable(memberExpr->type.type)) + { + if (isTypeDifferentiable(calcThisType(thisExpr->type.type)) || + this->m_parentFunc->findModifier<NoDiffThisAttribute>()) + { + return; + } + + getSink()->diagnose( + memberExpr->loc, + Diagnostics::noDerivativeOnNonDifferentiableThisType, + memberExpr->declRef.getDecl(), + this->m_parentFunc); + } + } +} Expr* SemanticsVisitor::CheckTerm(Expr* term) { @@ -1649,7 +1674,13 @@ Expr* SemanticsVisitor::CheckTerm(Expr* term) if (this->m_parentFunc && this->m_parentFunc->findModifier<DifferentiableAttribute>()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + + if (!this->m_parentFunc->findModifier<TreatAsDifferentiableAttribute>()) + { + maybeCheckMissingNoDiffThis(checkedTerm); + } } + return checkedTerm; } |
