From 41cb7c13e37ec32ffb6557d21da079d77151e136 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 24 Oct 2022 22:19:38 -0700 Subject: Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460) * wip: remove auto-diff for member access, add diff through property accessors. * Fix getter-setter test. * Fix getter-setter-multi test. * Fix nested-jvp test. * Use [DerivativeMember] attribute to differentiate through member access. * Clean up. * More cleanup. Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 104 +++++++++++++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 1 deletion(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 2d6e20622..356105e4f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -45,7 +45,10 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - + + void checkDerivativeMemberAttribute(VarDeclBase* varDecl, DerivativeMemberAttribute* attr); + void checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* m); + void checkVarDeclCommon(VarDeclBase* varDecl); void visitVarDecl(VarDecl* varDecl) @@ -78,6 +81,8 @@ namespace Slang void checkCallableDeclCommon(CallableDecl* decl); + void maybeCheckDifferentiableAccessorSignature(FuncDecl* funcDecl); + void visitFuncDecl(FuncDecl* funcDecl); void visitParamDecl(ParamDecl* paramDecl); @@ -636,6 +641,9 @@ namespace Slang bool SemanticsVisitor::isDeclUsableAsStaticMember( Decl* decl) { + if (m_allowStaticReferenceToNonStaticMember) + return true; + if(auto genericDecl = as(decl)) decl = genericDecl->inner; @@ -663,6 +671,9 @@ namespace Slang bool SemanticsVisitor::isUsableAsStaticMember( LookupResultItem const& item) { + if (m_allowStaticReferenceToNonStaticMember) + return true; + // There's a bit of a gotcha here, because a lookup result // item might include "breadcrumbs" that indicate more steps // along the lookup path. As a result it isn't always @@ -966,6 +977,87 @@ namespace Slang tryConstantFoldDeclRef(DeclRef(varDecl, nullptr), nullptr); } + void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttribute( + VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) + { + auto memberType = checkProperType(getLinkage(), varDecl->type, getSink()); + auto diffType = _getDifferential(m_astBuilder, memberType); + if (as(diffType)) + { + getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType); + } + auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl)); + if (!thisType) + { + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics:: + derivativeMemberAttributeCanOnlyBeUsedOnMembers); + } + auto diffThisType = _getDifferential(m_astBuilder, thisType); + if (!thisType) + { + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics::invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable); + } + SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1); + auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember()); + if (auto declRefExpr = as(checkedExpr)) + { + derivativeMemberAttr->memberDeclRef = declRefExpr; + if (!diffType->equals(declRefExpr->type)) + { + getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type); + } + if (!varDecl->parentDecl) + { + getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type); + } + if (auto memberExpr = as(declRefExpr)) + { + auto baseExprType = memberExpr->baseExpression->type.type; + if (auto typeType = as(baseExprType)) + { + if (diffThisType->equals(typeType->type)) + { + return; + } + } + + } + } + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics:: + derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, + diffThisType); + } + + void SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* extensionExternMemberModifier) + { + if (auto parentExtension = as(varDecl->parentDecl)) + { + if (auto originalVarDecl = extensionExternMemberModifier->originalDecl.as()) + { + auto originalType = GetTypeForDeclRef(originalVarDecl, originalVarDecl.getLoc()); + auto extVarType = varDecl->type; + if (!extVarType.type || !extVarType.type->equals(originalType)) + { + getSink()->diagnose(varDecl, Diagnostics::typeOfExternDeclMismatchesOriginalDefinition, varDecl, originalType); + } + else + { + return; + } + } + else + { + getSink()->diagnose(varDecl, Diagnostics::definitionOfExternDeclMismatchesOriginalDefinition, varDecl); + } + } + } + void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl) { // A variable that didn't have an explicit type written must @@ -1136,6 +1228,16 @@ namespace Slang getSink()->diagnose(varDecl, Diagnostics::valueRequirementMustBeCompileTimeConst); } } + + // Check modifiers that can't be checked earlier during modifier checking stage. + if (auto derivativeMemberAttr = varDecl->findModifier()) + { + checkDerivativeMemberAttribute(varDecl, derivativeMemberAttr); + } + if (auto extensionExternAttr = varDecl->findModifier()) + { + checkExtensionExternVarAttribute(varDecl, extensionExternAttr); + } } void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl) -- cgit v1.2.3