summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-07-09 11:25:29 -0500
committerGitHub <noreply@github.com>2025-07-09 09:25:29 -0700
commita670bafc121c20168624f70a388dbe8556402c7f (patch)
tree79b48a80e7abc0744193716e400bb57a6c026bad /source/slang/slang-check-expr.cpp
parenta7cb36901ccaf8297136c58c1451d6e04420af73 (diff)
no_diff diagnostics improvement (#7655)
close #6286. This PR is to improve the diagnostics for no_diff usage. In a differentiable function, any calls to a non-diff function with constant arguments should not require no_diff attribute. This PR adds this extra check at `checkAutoDiffUsages` where it checks the differentiability on IR. In a differentiable method, we will force to use `[NoDiffThis]` attribute if there is access to non-differentiable `This` type. Once this access is detected we will report a warning to bring users attention that this access won't generate any derivative, they have to use `[NoDiffThis]` to suppress that warning. This PR adds this check at type checking stage, because it's the easiest way to find out all the `This` accesses.
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp31
1 files changed, 31 insertions, 0 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b28b458da..c7e58a888 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1632,6 +1632,31 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder*
}
}
+// This checks that if a differentiable function access a non-diff type "This", in such case we
+// want to provide a non-error diagnostic to the user to notify that there could be an unexpected
+// behavior because every member access will not have derivative computed for it. User can use
+// [NoDiffThis] to clarify that this is intended.
+void SemanticsVisitor::maybeCheckMissingNoDiffThis(Expr* expr)
+{
+ if (auto memberExpr = as<MemberExpr>(expr))
+ {
+ auto thisExpr = as<ThisExpr>(memberExpr->baseExpression);
+ if (thisExpr && isTypeDifferentiable(memberExpr->type.type))
+ {
+ if (isTypeDifferentiable(calcThisType(thisExpr->type.type)) ||
+ this->m_parentFunc->findModifier<NoDiffThisAttribute>())
+ {
+ return;
+ }
+
+ getSink()->diagnose(
+ memberExpr->loc,
+ Diagnostics::noDerivativeOnNonDifferentiableThisType,
+ memberExpr->declRef.getDecl(),
+ this->m_parentFunc);
+ }
+ }
+}
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
@@ -1649,7 +1674,13 @@ Expr* SemanticsVisitor::CheckTerm(Expr* term)
if (this->m_parentFunc && this->m_parentFunc->findModifier<DifferentiableAttribute>())
{
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
+
+ if (!this->m_parentFunc->findModifier<TreatAsDifferentiableAttribute>())
+ {
+ maybeCheckMissingNoDiffThis(checkedTerm);
+ }
}
+
return checkedTerm;
}