diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-07-10 13:53:35 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-10 10:53:35 -0700 |
| commit | 4a247244715e35872ab2359e9bc7cd55b5ea27d4 (patch) | |
| tree | ad5402cf83cd17cd923ad410a734d968c60def1b /source/slang/slang-check-decl.cpp | |
| parent | 8ed0f49d337338426c05aa643106098e755b8d9d (diff) | |
Various fixes around differentiable member associations `[DerivativeMember(<diff-member>)]` (#4525)
* Add diagnostic for missing diff-member associations
+ Automatically create diff member associations if differential type is the same as the primal type.
+ Move diff-member attribute checking to conformance-checking phase to avoid circularity issues.
Fixes #4103
* Update slang-check-decl.cpp
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 149 |
1 files changed, 103 insertions, 46 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 66bdbc18e..cb1c11d9c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -90,7 +90,7 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - void checkDerivativeMemberAttribute(VarDeclBase* varDecl, DerivativeMemberAttribute* attr); + void checkDerivativeMemberAttributeParent(VarDeclBase* varDecl, DerivativeMemberAttribute* attr); void checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* m); void checkMeshOutputDecl(VarDeclBase* varDecl); @@ -1461,7 +1461,7 @@ namespace Slang structDecl->buildMemberDictionary(); } - void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttribute( + void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttributeParent( VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) { auto memberType = checkProperType(getLinkage(), varDecl->type, getSink()); @@ -1479,43 +1479,12 @@ namespace Slang derivativeMemberAttributeCanOnlyBeUsedOnMembers); } auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); - if (!thisType) + if (!diffThisType) { getSink()->diagnose( derivativeMemberAttr, Diagnostics::invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable); } - SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1); - auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember()); - if (auto declRefExpr = as<DeclRefExpr>(checkedExpr)) - { - derivativeMemberAttr->memberDeclRef = declRefExpr; - if (!diffType->equals(declRefExpr->type)) - { - getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type); - } - if (!varDecl->parentDecl) - { - getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type); - } - if (auto memberExpr = as<StaticMemberExpr>(declRefExpr)) - { - auto baseExprType = memberExpr->baseExpression->type.type; - if (auto typeType = as<TypeType>(baseExprType)) - { - if (diffThisType->equals(typeType->getType())) - { - return; - } - } - - } - } - getSink()->diagnose( - derivativeMemberAttr, - Diagnostics:: - derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, - diffThisType); } void SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* extensionExternMemberModifier) @@ -1751,7 +1720,7 @@ namespace Slang // Check modifiers that can't be checked earlier during modifier checking stage. if (auto derivativeMemberAttr = varDecl->findModifier<DerivativeMemberAttribute>()) { - checkDerivativeMemberAttribute(varDecl, derivativeMemberAttr); + checkDerivativeMemberAttributeParent(varDecl, derivativeMemberAttr); } if (auto extensionExternAttr = varDecl->findModifier<ExtensionExternVarModifier>()) { @@ -2588,19 +2557,85 @@ namespace Slang auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType); if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue)) return; - // A type used as differential type must have itself as its own differential type. + if (witnessValue.getFlavor() != RequirementWitness::Flavor::val) return; auto differentialType = as<DeclRefType>(witnessValue.getVal()); if (!differentialType) return; + + // Check that the type used as differential type must have itself as its own differential type. auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType); if (!differentialType->equals(diffDiffType)) { SourceLoc sourceLoc = differentialType->getDeclRef().getDecl()->loc; - getSink()->diagnose(inheritanceDecl, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType, diffDiffType); + getSink()->diagnose( + inheritanceDecl, + Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, + differentialType, + diffDiffType); getSink()->diagnose(sourceLoc, Diagnostics::seeDefinitionOf, differentialType); } + + // Check that all [DerivativeMember(...)] attributes have their references checked. + for (auto member : inheritanceDecl->parentDecl->getMembersOfType<VarDeclBase>()) + { + if (member->findModifier<NoDiffModifier>()) + continue; + auto derivativeMemberAttr = member->findModifier<DerivativeMemberAttribute>(); + if (!derivativeMemberAttr) + continue; + checkDerivativeMemberAttributeReferences(member, derivativeMemberAttr); + } + + // Check that either the differential type is the same as the base type, or all fields of the base type that are differentiable + // have a corresponding field in the differential type through the [DerivativeMember(...)] attribute. + // + // We only need to check the fields of the base type that are differentiable. + auto baseDecl = as<AggTypeDecl>(inheritanceDecl->parentDecl); + if (!baseDecl) + return; + + auto thisType = calcThisType(getDefaultDeclRef(baseDecl)); + + bool typeIsSelfDifferential = thisType->equals(differentialType); + + for (auto member : baseDecl->getMembersOfType<VarDeclBase>()) + { + if (member->findModifier<NoDiffModifier>()) + continue; + auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); + if (!diffType) + continue; + + if (member->findModifier<DerivativeMemberAttribute>()) + continue; + else if (!typeIsSelfDifferential) + getSink()->diagnose( + member, + Diagnostics::differentiableMemberShouldHaveCorrespondingFieldInDiffType, + member->nameAndLoc.name, + differentialType); + else + { + // If the type is its own differential type, we can infer the differential + // members from the original type. + // + // Add a derivative member attribute referencing itself. + // + auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>(); + auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>(); + fieldLookupExpr->type.type = diffType; + auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(differentialType); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(member); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + } + } } }; @@ -5174,6 +5209,7 @@ namespace Slang auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); if (!derivativeAttr) continue; + auto varMember = as<VarDeclBase>(member); if (!varMember) continue; @@ -5183,6 +5219,9 @@ namespace Slang if (!diffMemberType) continue; + // Pull up the derivative member name from the attribute + auto derivMemberName = derivativeAttr->memberDeclRef->declRef.getName(); + // Construct reference exprs to the member's corresponding fields in each parameter. List<Expr*> paramFields; List<bool> inductiveArgMask; @@ -5195,9 +5234,9 @@ namespace Slang { auto memberExpr = m_astBuilder->create<MemberExpr>(); memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); + + memberExpr->name = derivMemberName; + paramFields.add(memberExpr); inductiveArgMask.add(true); } @@ -5219,9 +5258,8 @@ namespace Slang { auto memberExpr = m_astBuilder->create<MemberExpr>(); memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); + + memberExpr->name = derivMemberName; paramFields.add(memberExpr); inductiveArgMask.add(true); @@ -5236,9 +5274,7 @@ namespace Slang } // Invoke the method for the field and assign the value to resultVar. - // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` - // is Differential type. - auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); + auto leftVal = synth.emitMemberExpr(resultVarExpr, derivMemberName); if (!_synthesizeMemberAssignMemberHelper( synth, requirementDeclRef.getName(), @@ -5855,6 +5891,17 @@ namespace Slang } } + void SemanticsVisitor::checkDifferentiableMembersInType(AggTypeDecl* decl) + { + for (auto member : decl->getMembersOfType<VarDeclBase>()) + { + if (auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>()) + { + checkDerivativeMemberAttributeReferences(member, derivativeAttr); + } + } + } + void SemanticsVisitor::checkAggTypeConformance(AggTypeDecl* decl) { // After we've checked members, we need to go through @@ -5892,6 +5939,16 @@ namespace Slang auto inheritanceDecls = decl->getMembersOfType<InheritanceDecl>().toList(); for (auto inheritanceDecl : inheritanceDecls) { + // Special handling for when we check for conformance against `IDifferentiable` + // We will reference-checking for the [DerivativeMember(DiffType.member)] + // attributes here, since they have to be performed after types can be referenced + // and before conformance checking, where this information can be used to synthesize + // member methods (such as `dzero`, `dadd`, etc..) + // + if (inheritanceDecl->getSup().type->equals( + astBuilder->getDifferentiableInterfaceType())) + checkDifferentiableMembersInType(decl); + checkConformance(type, inheritanceDecl, decl); } |
