summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-07-10 13:53:35 -0400
committerGitHub <noreply@github.com>2024-07-10 10:53:35 -0700
commit4a247244715e35872ab2359e9bc7cd55b5ea27d4 (patch)
treead5402cf83cd17cd923ad410a734d968c60def1b /source
parent8ed0f49d337338426c05aa643106098e755b8d9d (diff)
Various fixes around differentiable member associations `[DerivativeMember(<diff-member>)]` (#4525)
* Add diagnostic for missing diff-member associations + Automatically create diff member associations if differential type is the same as the primal type. + Move diff-member attribute checking to conformance-checking phase to avoid circularity issues. Fixes #4103 * Update slang-check-decl.cpp --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp149
-rw-r--r--source/slang/slang-check-expr.cpp51
-rw-r--r--source/slang/slang-check-impl.h7
-rw-r--r--source/slang/slang-diagnostic-defs.h2
4 files changed, 163 insertions, 46 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 66bdbc18e..cb1c11d9c 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -90,7 +90,7 @@ namespace Slang
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
- void checkDerivativeMemberAttribute(VarDeclBase* varDecl, DerivativeMemberAttribute* attr);
+ void checkDerivativeMemberAttributeParent(VarDeclBase* varDecl, DerivativeMemberAttribute* attr);
void checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* m);
void checkMeshOutputDecl(VarDeclBase* varDecl);
@@ -1461,7 +1461,7 @@ namespace Slang
structDecl->buildMemberDictionary();
}
- void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttribute(
+ void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttributeParent(
VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr)
{
auto memberType = checkProperType(getLinkage(), varDecl->type, getSink());
@@ -1479,43 +1479,12 @@ namespace Slang
derivativeMemberAttributeCanOnlyBeUsedOnMembers);
}
auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc);
- if (!thisType)
+ if (!diffThisType)
{
getSink()->diagnose(
derivativeMemberAttr,
Diagnostics::invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable);
}
- SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1);
- auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember());
- if (auto declRefExpr = as<DeclRefExpr>(checkedExpr))
- {
- derivativeMemberAttr->memberDeclRef = declRefExpr;
- if (!diffType->equals(declRefExpr->type))
- {
- getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type);
- }
- if (!varDecl->parentDecl)
- {
- getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type);
- }
- if (auto memberExpr = as<StaticMemberExpr>(declRefExpr))
- {
- auto baseExprType = memberExpr->baseExpression->type.type;
- if (auto typeType = as<TypeType>(baseExprType))
- {
- if (diffThisType->equals(typeType->getType()))
- {
- return;
- }
- }
-
- }
- }
- getSink()->diagnose(
- derivativeMemberAttr,
- Diagnostics::
- derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType,
- diffThisType);
}
void SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* extensionExternMemberModifier)
@@ -1751,7 +1720,7 @@ namespace Slang
// Check modifiers that can't be checked earlier during modifier checking stage.
if (auto derivativeMemberAttr = varDecl->findModifier<DerivativeMemberAttribute>())
{
- checkDerivativeMemberAttribute(varDecl, derivativeMemberAttr);
+ checkDerivativeMemberAttributeParent(varDecl, derivativeMemberAttr);
}
if (auto extensionExternAttr = varDecl->findModifier<ExtensionExternVarModifier>())
{
@@ -2588,19 +2557,85 @@ namespace Slang
auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType);
if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue))
return;
- // A type used as differential type must have itself as its own differential type.
+
if (witnessValue.getFlavor() != RequirementWitness::Flavor::val)
return;
auto differentialType = as<DeclRefType>(witnessValue.getVal());
if (!differentialType)
return;
+
+ // Check that the type used as differential type must have itself as its own differential type.
auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType);
if (!differentialType->equals(diffDiffType))
{
SourceLoc sourceLoc = differentialType->getDeclRef().getDecl()->loc;
- getSink()->diagnose(inheritanceDecl, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType, diffDiffType);
+ getSink()->diagnose(
+ inheritanceDecl,
+ Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType,
+ differentialType,
+ diffDiffType);
getSink()->diagnose(sourceLoc, Diagnostics::seeDefinitionOf, differentialType);
}
+
+ // Check that all [DerivativeMember(...)] attributes have their references checked.
+ for (auto member : inheritanceDecl->parentDecl->getMembersOfType<VarDeclBase>())
+ {
+ if (member->findModifier<NoDiffModifier>())
+ continue;
+ auto derivativeMemberAttr = member->findModifier<DerivativeMemberAttribute>();
+ if (!derivativeMemberAttr)
+ continue;
+ checkDerivativeMemberAttributeReferences(member, derivativeMemberAttr);
+ }
+
+ // Check that either the differential type is the same as the base type, or all fields of the base type that are differentiable
+ // have a corresponding field in the differential type through the [DerivativeMember(...)] attribute.
+ //
+ // We only need to check the fields of the base type that are differentiable.
+ auto baseDecl = as<AggTypeDecl>(inheritanceDecl->parentDecl);
+ if (!baseDecl)
+ return;
+
+ auto thisType = calcThisType(getDefaultDeclRef(baseDecl));
+
+ bool typeIsSelfDifferential = thisType->equals(differentialType);
+
+ for (auto member : baseDecl->getMembersOfType<VarDeclBase>())
+ {
+ if (member->findModifier<NoDiffModifier>())
+ continue;
+ auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type);
+ if (!diffType)
+ continue;
+
+ if (member->findModifier<DerivativeMemberAttribute>())
+ continue;
+ else if (!typeIsSelfDifferential)
+ getSink()->diagnose(
+ member,
+ Diagnostics::differentiableMemberShouldHaveCorrespondingFieldInDiffType,
+ member->nameAndLoc.name,
+ differentialType);
+ else
+ {
+ // If the type is its own differential type, we can infer the differential
+ // members from the original type.
+ //
+ // Add a derivative member attribute referencing itself.
+ //
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = diffType;
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = differentialType;
+ auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(differentialType);
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+ fieldLookupExpr->declRef = makeDeclRef(member);
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ }
+ }
}
};
@@ -5174,6 +5209,7 @@ namespace Slang
auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>();
if (!derivativeAttr)
continue;
+
auto varMember = as<VarDeclBase>(member);
if (!varMember)
continue;
@@ -5183,6 +5219,9 @@ namespace Slang
if (!diffMemberType)
continue;
+ // Pull up the derivative member name from the attribute
+ auto derivMemberName = derivativeAttr->memberDeclRef->declRef.getName();
+
// Construct reference exprs to the member's corresponding fields in each parameter.
List<Expr*> paramFields;
List<bool> inductiveArgMask;
@@ -5195,9 +5234,9 @@ namespace Slang
{
auto memberExpr = m_astBuilder->create<MemberExpr>();
memberExpr->baseExpression = arg;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
- // Differential type.
- memberExpr->name = varMember->getName();
+
+ memberExpr->name = derivMemberName;
+
paramFields.add(memberExpr);
inductiveArgMask.add(true);
}
@@ -5219,9 +5258,8 @@ namespace Slang
{
auto memberExpr = m_astBuilder->create<MemberExpr>();
memberExpr->baseExpression = arg;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
- // Differential type.
- memberExpr->name = varMember->getName();
+
+ memberExpr->name = derivMemberName;
paramFields.add(memberExpr);
inductiveArgMask.add(true);
@@ -5236,9 +5274,7 @@ namespace Slang
}
// Invoke the method for the field and assign the value to resultVar.
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
- // is Differential type.
- auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName());
+ auto leftVal = synth.emitMemberExpr(resultVarExpr, derivMemberName);
if (!_synthesizeMemberAssignMemberHelper(
synth,
requirementDeclRef.getName(),
@@ -5855,6 +5891,17 @@ namespace Slang
}
}
+ void SemanticsVisitor::checkDifferentiableMembersInType(AggTypeDecl* decl)
+ {
+ for (auto member : decl->getMembersOfType<VarDeclBase>())
+ {
+ if (auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>())
+ {
+ checkDerivativeMemberAttributeReferences(member, derivativeAttr);
+ }
+ }
+ }
+
void SemanticsVisitor::checkAggTypeConformance(AggTypeDecl* decl)
{
// After we've checked members, we need to go through
@@ -5892,6 +5939,16 @@ namespace Slang
auto inheritanceDecls = decl->getMembersOfType<InheritanceDecl>().toList();
for (auto inheritanceDecl : inheritanceDecls)
{
+ // Special handling for when we check for conformance against `IDifferentiable`
+ // We will reference-checking for the [DerivativeMember(DiffType.member)]
+ // attributes here, since they have to be performed after types can be referenced
+ // and before conformance checking, where this information can be used to synthesize
+ // member methods (such as `dzero`, `dadd`, etc..)
+ //
+ if (inheritanceDecl->getSup().type->equals(
+ astBuilder->getDifferentiableInterfaceType()))
+ checkDifferentiableMembersInType(decl);
+
checkConformance(type, inheritanceDecl, decl);
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index f79e23b42..ee36a21fb 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1213,6 +1213,57 @@ namespace Slang
}
}
+ void SemanticsVisitor::checkDerivativeMemberAttributeReferences(
+ VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr)
+ {
+ if (derivativeMemberAttr->memberDeclRef)
+ {
+ // Already checked! This usually happens if this attribute is synthesized by the compiler.
+ return;
+ }
+
+ SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1);
+ auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember());
+
+ auto memberType = varDecl->type.type; // All types must be fully checked by now.
+ auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc);
+ auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl));
+ if (!thisType) return; // Diagnostic should have been emitted previously.
+
+ auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc);
+ if (!diffThisType) return; // Diagnostic should have been emitted previously.
+
+ if (auto declRefExpr = as<DeclRefExpr>(checkedExpr))
+ {
+ derivativeMemberAttr->memberDeclRef = declRefExpr;
+ if (!diffType->equals(declRefExpr->type))
+ {
+ getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type);
+ }
+ if (!varDecl->parentDecl)
+ {
+ getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type);
+ }
+ if (auto memberExpr = as<StaticMemberExpr>(declRefExpr))
+ {
+ auto baseExprType = memberExpr->baseExpression->type.type;
+ if (auto typeType = as<TypeType>(baseExprType))
+ {
+ if (diffThisType->equals(typeType->getType()))
+ {
+ return;
+ }
+ }
+
+ }
+ }
+ getSink()->diagnose(
+ derivativeMemberAttr,
+ Diagnostics::
+ derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType,
+ diffThisType);
+ }
+
Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc)
{
auto result = tryGetDifferentialType(builder, type);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 232cb623c..39f5f46b3 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1393,6 +1393,9 @@ namespace Slang
// Helper function to check if a struct can be used as its own differential type.
bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl);
void markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type);
+
+ void checkDerivativeMemberAttributeReferences(
+ VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr);
public:
@@ -1811,6 +1814,10 @@ namespace Slang
RefPtr<WitnessTable> witnessTable,
BuiltinRequirementKind requirementKind);
+ /// Check references from`[DerivativeMember(...)]` attributes on members of the agg-decl.
+ /// this is typically deferred until after types are ready for reference.
+ void checkDifferentiableMembersInType(AggTypeDecl* decl);
+
struct DifferentiableMemberInfo
{
Decl* memberDecl;
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 7ebe77a8f..98af8a228 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -361,6 +361,8 @@ DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-st
DIAGNOSTIC(30099, Error, sizeOfArgumentIsInvalid, "argument to sizeof is invalid")
DIAGNOSTIC(30101, Error, readingFromWriteOnly, "cannot read from writeonly, check modifiers.")
+DIAGNOSTIC(30102, Error, differentiableMemberShouldHaveCorrespondingFieldInDiffType, "differentiable member '$0' should have a corresponding field in '$1'. Use [DerivativeMember($1.<field-name>)] or mark as no_diff")
+
// Include
DIAGNOSTIC(30500, Error, includedFileMissingImplementing, "missing 'implementing' declaration in the included source file '$0'.")