diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-07-10 13:53:35 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-10 10:53:35 -0700 |
| commit | 4a247244715e35872ab2359e9bc7cd55b5ea27d4 (patch) | |
| tree | ad5402cf83cd17cd923ad410a734d968c60def1b /source | |
| parent | 8ed0f49d337338426c05aa643106098e755b8d9d (diff) | |
Various fixes around differentiable member associations `[DerivativeMember(<diff-member>)]` (#4525)
* Add diagnostic for missing diff-member associations
+ Automatically create diff member associations if differential type is the same as the primal type.
+ Move diff-member attribute checking to conformance-checking phase to avoid circularity issues.
Fixes #4103
* Update slang-check-decl.cpp
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 149 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 51 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 |
4 files changed, 163 insertions, 46 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 66bdbc18e..cb1c11d9c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -90,7 +90,7 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - void checkDerivativeMemberAttribute(VarDeclBase* varDecl, DerivativeMemberAttribute* attr); + void checkDerivativeMemberAttributeParent(VarDeclBase* varDecl, DerivativeMemberAttribute* attr); void checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* m); void checkMeshOutputDecl(VarDeclBase* varDecl); @@ -1461,7 +1461,7 @@ namespace Slang structDecl->buildMemberDictionary(); } - void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttribute( + void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttributeParent( VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) { auto memberType = checkProperType(getLinkage(), varDecl->type, getSink()); @@ -1479,43 +1479,12 @@ namespace Slang derivativeMemberAttributeCanOnlyBeUsedOnMembers); } auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); - if (!thisType) + if (!diffThisType) { getSink()->diagnose( derivativeMemberAttr, Diagnostics::invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable); } - SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1); - auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember()); - if (auto declRefExpr = as<DeclRefExpr>(checkedExpr)) - { - derivativeMemberAttr->memberDeclRef = declRefExpr; - if (!diffType->equals(declRefExpr->type)) - { - getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type); - } - if (!varDecl->parentDecl) - { - getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type); - } - if (auto memberExpr = as<StaticMemberExpr>(declRefExpr)) - { - auto baseExprType = memberExpr->baseExpression->type.type; - if (auto typeType = as<TypeType>(baseExprType)) - { - if (diffThisType->equals(typeType->getType())) - { - return; - } - } - - } - } - getSink()->diagnose( - derivativeMemberAttr, - Diagnostics:: - derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, - diffThisType); } void SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* extensionExternMemberModifier) @@ -1751,7 +1720,7 @@ namespace Slang // Check modifiers that can't be checked earlier during modifier checking stage. if (auto derivativeMemberAttr = varDecl->findModifier<DerivativeMemberAttribute>()) { - checkDerivativeMemberAttribute(varDecl, derivativeMemberAttr); + checkDerivativeMemberAttributeParent(varDecl, derivativeMemberAttr); } if (auto extensionExternAttr = varDecl->findModifier<ExtensionExternVarModifier>()) { @@ -2588,19 +2557,85 @@ namespace Slang auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType); if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue)) return; - // A type used as differential type must have itself as its own differential type. + if (witnessValue.getFlavor() != RequirementWitness::Flavor::val) return; auto differentialType = as<DeclRefType>(witnessValue.getVal()); if (!differentialType) return; + + // Check that the type used as differential type must have itself as its own differential type. auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType); if (!differentialType->equals(diffDiffType)) { SourceLoc sourceLoc = differentialType->getDeclRef().getDecl()->loc; - getSink()->diagnose(inheritanceDecl, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType, diffDiffType); + getSink()->diagnose( + inheritanceDecl, + Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, + differentialType, + diffDiffType); getSink()->diagnose(sourceLoc, Diagnostics::seeDefinitionOf, differentialType); } + + // Check that all [DerivativeMember(...)] attributes have their references checked. + for (auto member : inheritanceDecl->parentDecl->getMembersOfType<VarDeclBase>()) + { + if (member->findModifier<NoDiffModifier>()) + continue; + auto derivativeMemberAttr = member->findModifier<DerivativeMemberAttribute>(); + if (!derivativeMemberAttr) + continue; + checkDerivativeMemberAttributeReferences(member, derivativeMemberAttr); + } + + // Check that either the differential type is the same as the base type, or all fields of the base type that are differentiable + // have a corresponding field in the differential type through the [DerivativeMember(...)] attribute. + // + // We only need to check the fields of the base type that are differentiable. + auto baseDecl = as<AggTypeDecl>(inheritanceDecl->parentDecl); + if (!baseDecl) + return; + + auto thisType = calcThisType(getDefaultDeclRef(baseDecl)); + + bool typeIsSelfDifferential = thisType->equals(differentialType); + + for (auto member : baseDecl->getMembersOfType<VarDeclBase>()) + { + if (member->findModifier<NoDiffModifier>()) + continue; + auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); + if (!diffType) + continue; + + if (member->findModifier<DerivativeMemberAttribute>()) + continue; + else if (!typeIsSelfDifferential) + getSink()->diagnose( + member, + Diagnostics::differentiableMemberShouldHaveCorrespondingFieldInDiffType, + member->nameAndLoc.name, + differentialType); + else + { + // If the type is its own differential type, we can infer the differential + // members from the original type. + // + // Add a derivative member attribute referencing itself. + // + auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>(); + auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>(); + fieldLookupExpr->type.type = diffType; + auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(differentialType); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(member); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + } + } } }; @@ -5174,6 +5209,7 @@ namespace Slang auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); if (!derivativeAttr) continue; + auto varMember = as<VarDeclBase>(member); if (!varMember) continue; @@ -5183,6 +5219,9 @@ namespace Slang if (!diffMemberType) continue; + // Pull up the derivative member name from the attribute + auto derivMemberName = derivativeAttr->memberDeclRef->declRef.getName(); + // Construct reference exprs to the member's corresponding fields in each parameter. List<Expr*> paramFields; List<bool> inductiveArgMask; @@ -5195,9 +5234,9 @@ namespace Slang { auto memberExpr = m_astBuilder->create<MemberExpr>(); memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); + + memberExpr->name = derivMemberName; + paramFields.add(memberExpr); inductiveArgMask.add(true); } @@ -5219,9 +5258,8 @@ namespace Slang { auto memberExpr = m_astBuilder->create<MemberExpr>(); memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); + + memberExpr->name = derivMemberName; paramFields.add(memberExpr); inductiveArgMask.add(true); @@ -5236,9 +5274,7 @@ namespace Slang } // Invoke the method for the field and assign the value to resultVar. - // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` - // is Differential type. - auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); + auto leftVal = synth.emitMemberExpr(resultVarExpr, derivMemberName); if (!_synthesizeMemberAssignMemberHelper( synth, requirementDeclRef.getName(), @@ -5855,6 +5891,17 @@ namespace Slang } } + void SemanticsVisitor::checkDifferentiableMembersInType(AggTypeDecl* decl) + { + for (auto member : decl->getMembersOfType<VarDeclBase>()) + { + if (auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>()) + { + checkDerivativeMemberAttributeReferences(member, derivativeAttr); + } + } + } + void SemanticsVisitor::checkAggTypeConformance(AggTypeDecl* decl) { // After we've checked members, we need to go through @@ -5892,6 +5939,16 @@ namespace Slang auto inheritanceDecls = decl->getMembersOfType<InheritanceDecl>().toList(); for (auto inheritanceDecl : inheritanceDecls) { + // Special handling for when we check for conformance against `IDifferentiable` + // We will reference-checking for the [DerivativeMember(DiffType.member)] + // attributes here, since they have to be performed after types can be referenced + // and before conformance checking, where this information can be used to synthesize + // member methods (such as `dzero`, `dadd`, etc..) + // + if (inheritanceDecl->getSup().type->equals( + astBuilder->getDifferentiableInterfaceType())) + checkDifferentiableMembersInType(decl); + checkConformance(type, inheritanceDecl, decl); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f79e23b42..ee36a21fb 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1213,6 +1213,57 @@ namespace Slang } } + void SemanticsVisitor::checkDerivativeMemberAttributeReferences( + VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) + { + if (derivativeMemberAttr->memberDeclRef) + { + // Already checked! This usually happens if this attribute is synthesized by the compiler. + return; + } + + SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1); + auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember()); + + auto memberType = varDecl->type.type; // All types must be fully checked by now. + auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc); + auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl)); + if (!thisType) return; // Diagnostic should have been emitted previously. + + auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); + if (!diffThisType) return; // Diagnostic should have been emitted previously. + + if (auto declRefExpr = as<DeclRefExpr>(checkedExpr)) + { + derivativeMemberAttr->memberDeclRef = declRefExpr; + if (!diffType->equals(declRefExpr->type)) + { + getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type); + } + if (!varDecl->parentDecl) + { + getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type); + } + if (auto memberExpr = as<StaticMemberExpr>(declRefExpr)) + { + auto baseExprType = memberExpr->baseExpression->type.type; + if (auto typeType = as<TypeType>(baseExprType)) + { + if (diffThisType->equals(typeType->getType())) + { + return; + } + } + + } + } + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics:: + derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, + diffThisType); + } + Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc) { auto result = tryGetDifferentialType(builder, type); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 232cb623c..39f5f46b3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1393,6 +1393,9 @@ namespace Slang // Helper function to check if a struct can be used as its own differential type. bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl); void markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type); + + void checkDerivativeMemberAttributeReferences( + VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr); public: @@ -1811,6 +1814,10 @@ namespace Slang RefPtr<WitnessTable> witnessTable, BuiltinRequirementKind requirementKind); + /// Check references from`[DerivativeMember(...)]` attributes on members of the agg-decl. + /// this is typically deferred until after types are ready for reference. + void checkDifferentiableMembersInType(AggTypeDecl* decl); + struct DifferentiableMemberInfo { Decl* memberDecl; diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 7ebe77a8f..98af8a228 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -361,6 +361,8 @@ DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-st DIAGNOSTIC(30099, Error, sizeOfArgumentIsInvalid, "argument to sizeof is invalid") DIAGNOSTIC(30101, Error, readingFromWriteOnly, "cannot read from writeonly, check modifiers.") +DIAGNOSTIC(30102, Error, differentiableMemberShouldHaveCorrespondingFieldInDiffType, "differentiable member '$0' should have a corresponding field in '$1'. Use [DerivativeMember($1.<field-name>)] or mark as no_diff") + // Include DIAGNOSTIC(30500, Error, includedFileMissingImplementing, "missing 'implementing' declaration in the included source file '$0'.") |
