summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp302
1 files changed, 229 insertions, 73 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index fa05dde11..f28f46deb 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -81,8 +81,6 @@ namespace Slang
void checkCallableDeclCommon(CallableDecl* decl);
- void maybeCheckDifferentiableAccessorSignature(FuncDecl* funcDecl);
-
void visitFuncDecl(FuncDecl* funcDecl);
void visitParamDecl(ParamDecl* paramDecl);
@@ -889,6 +887,7 @@ namespace Slang
auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder());
diffTypeDict->parentDecl = containerDecl;
containerDecl->members.add(diffTypeDict);
+ containerDecl->invalidateMemberDictionary();
}
}
}
@@ -1072,8 +1071,11 @@ namespace Slang
auto initExpr = varDecl->initExpr;
if(!initExpr)
{
- getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer);
- varDecl->type.type = m_astBuilder->getErrorType();
+ if (!varDecl->type.type)
+ {
+ getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer);
+ varDecl->type.type = m_astBuilder->getErrorType();
+ }
}
else
{
@@ -1390,7 +1392,7 @@ namespace Slang
context->parentDecl->members.add((aggTypeDecl));
aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName();
aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc;
- context->parentDecl->getMemberDictionary().Add(aggTypeDecl->getName(), aggTypeDecl);
+ context->parentDecl->invalidateMemberDictionary();
}
// TODO: if we want to make the synthesized type itself to be differentiable,
@@ -1409,6 +1411,7 @@ namespace Slang
diffField->checkState = DeclCheckState::SignatureChecked;
diffField->parentDecl = aggTypeDecl;
aggTypeDecl->members.add(diffField);
+ aggTypeDecl->invalidateMemberDictionary();
// Inject a `DerivativeMember` modifier on the original decl.
auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
@@ -2412,54 +2415,12 @@ namespace Slang
return false;
}
- bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
+ FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness(
ConformanceCheckingContext* context,
- LookupResult const& lookupResult,
- DeclRef<FuncDecl> requiredMemberDeclRef,
- RefPtr<WitnessTable> witnessTable)
+ DeclRef<FuncDecl> requiredMemberDeclRef,
+ List<Expr*>& synArgs,
+ ThisExpr*& synThis)
{
- // The situation here is that the context of an inheritance
- // declaration didn't provide an exact match for a required
- // method. E.g.:
- //
- // interface ICounter { [mutating] int increment(); }
- // struct MyCounter : ICounter
- // {
- // [murtating] int increment(int val = 1) { ... }
- // }
- //
- // It is clear in this case that the `MyCounter` type *can*
- // satisfy the signature required by `ICounter`, but it has
- // no explicit method declaration that is a perfect match.
- //
- // The approach in this function will be to construct a
- // synthesized method along the lines of:
- //
- // struct MyCounter ...
- // {
- // ...
- // [murtating] int synthesized()
- // {
- // return this.increment();
- // }
- // }
- //
- // That is, we construct a method with the exact signature
- // of the requirement (same parameter and result types),
- // and then provide it with a body that simple `return`s
- // the result of applying the desired requirement name
- // (`increment` in this case) to those parameters.
- //
- // If the synthesized method type-checks, then we can say
- // that the type must satisfy the requirement structurally,
- // even if there isn't an exact signature match. More
- // importantly, the method we just synthesized can be
- // used as a witness to the fact that the requirement is
- // satisfied.
-
- // With the big picture spelled out, we can settle into
- // the work of constructing our synthesized method.
- //
auto synFuncDecl = m_astBuilder->create<FuncDecl>();
// For now our synthesized method will use the name and source
@@ -2497,8 +2458,7 @@ namespace Slang
// that reference those parametesr as arguments for the call expresison
// that makes up the body.
//
- List<Expr*> synArgs;
- for( auto paramDeclRef : getParameters(requiredMemberDeclRef) )
+ for (auto paramDeclRef : getParameters(requiredMemberDeclRef))
{
auto paramType = getType(m_astBuilder, paramDeclRef);
@@ -2524,13 +2484,13 @@ namespace Slang
synArgs.add(synArg);
}
+
// Required interface methods can be `static` or non-`static`,
// and non-`static` methods can be `[mutating]` or non-`[mutating]`.
// All of these details affect how we introduce our `this` parameter,
// if any.
//
- ThisExpr* synThis = nullptr;
- if( requiredMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>() )
+ if (requiredMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>())
{
auto synStaticModifier = m_astBuilder->create<HLSLStaticModifier>();
synFuncDecl->modifiers.first = synStaticModifier;
@@ -2546,7 +2506,7 @@ namespace Slang
//
synThis->type.type = context->conformingType;
- if( requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>() )
+ if (requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>())
{
// If the interface requirement is `[mutating]` then our
// synthesized method should be too, and also the `this`
@@ -2559,6 +2519,64 @@ namespace Slang
}
}
+ return synFuncDecl;
+ }
+
+ bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
+ ConformanceCheckingContext* context,
+ LookupResult const& lookupResult,
+ DeclRef<FuncDecl> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ // The situation here is that the context of an inheritance
+ // declaration didn't provide an exact match for a required
+ // method. E.g.:
+ //
+ // interface ICounter { [mutating] int increment(); }
+ // struct MyCounter : ICounter
+ // {
+ // [murtating] int increment(int val = 1) { ... }
+ // }
+ //
+ // It is clear in this case that the `MyCounter` type *can*
+ // satisfy the signature required by `ICounter`, but it has
+ // no explicit method declaration that is a perfect match.
+ //
+ // The approach in this function will be to construct a
+ // synthesized method along the lines of:
+ //
+ // struct MyCounter ...
+ // {
+ // ...
+ // [murtating] int synthesized()
+ // {
+ // return this.increment();
+ // }
+ // }
+ //
+ // That is, we construct a method with the exact signature
+ // of the requirement (same parameter and result types),
+ // and then provide it with a body that simple `return`s
+ // the result of applying the desired requirement name
+ // (`increment` in this case) to those parameters.
+ //
+ // If the synthesized method type-checks, then we can say
+ // that the type must satisfy the requirement structurally,
+ // even if there isn't an exact signature match. More
+ // importantly, the method we just synthesized can be
+ // used as a witness to the fact that the requirement is
+ // satisfied.
+
+ // With the big picture spelled out, we can settle into
+ // the work of constructing our synthesized method.
+ //
+ ThisExpr* synThis = nullptr;
+ List<Expr*> synArgs;
+ auto synFuncDecl = synthesizeMethodSignatureForRequirementWitness(
+ context, requiredMemberDeclRef, synArgs, synThis);
+
+ auto resultType = synFuncDecl->returnType.type;
+
// The body of our synthesized method is going to try to
// make a call using the name of the method requirement (e.g.,
// the name `increment` in our example at the top of this function).
@@ -3036,11 +3054,27 @@ namespace Slang
if (auto requiredFuncDeclRef = requiredMemberDeclRef.as<FuncDecl>())
{
// Check signature match.
- return trySynthesizeMethodRequirementWitness(
+ if (trySynthesizeMethodRequirementWitness(
context,
lookupResult,
requiredFuncDeclRef,
- witnessTable);
+ witnessTable))
+ return true;
+
+ if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>())
+ {
+ switch (builtinAttr->kind)
+ {
+ case BuiltinRequirementKind::DAddFunc:
+ case BuiltinRequirementKind::DMulFunc:
+ case BuiltinRequirementKind::DZeroFunc:
+ return trySynthesizeDifferentialMethodRequirementWitness(
+ context,
+ requiredFuncDeclRef,
+ witnessTable);
+ }
+ }
+ return false;
}
if( auto requiredPropertyDeclRef = requiredMemberDeclRef.as<PropertyDecl>() )
@@ -3054,11 +3088,11 @@ namespace Slang
if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>())
{
- if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>())
+ if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>())
{
switch (builtinAttr->kind)
{
- case BuiltinAssociatedTypeRequirementKind::Differential:
+ case BuiltinRequirementKind::DifferentialType:
return trySynthesizeDifferentialAssociatedTypeRequirementWitness(
context,
requiredAssocTypeDeclRef,
@@ -3088,6 +3122,130 @@ namespace Slang
return false;
}
+ bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<Decl> requirementDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ // This method implements a general code synthesis pattern.
+ // For requirement of the form:
+ // ```
+ // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
+ // ```
+ // Where TResult, TParam1, TParam2 is either `This` or `Differential`,
+ // We synthesize a memberwise dispatch to compute each field of `TResult`,
+ // resulting an implementation of the form:
+ // ```
+ // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
+ // {
+ // TResult result;
+ // result.member0 = decltype(result.member0).requiredMethod(p0.member0, p1.member0);
+ // result.member1 = decltype(result.member1).requiredMethod(p0.member1, p1.member1);
+ // ...
+ // return result;
+ // }
+ // ```
+ List<Expr*> synArgs;
+ ThisExpr* synThis = nullptr;
+ auto synFunc = synthesizeMethodSignatureForRequirementWitness(
+ context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis);
+
+ auto blockStmt = m_astBuilder->create<BlockStmt>();
+ synFunc->body = blockStmt;
+ auto seqStmt = m_astBuilder->create<SeqStmt>();
+ blockStmt->body = seqStmt;
+
+ // Create a variable for return value.
+ auto scopeDecl = m_astBuilder->create<ScopeDecl>();
+ synFunc->members.add(scopeDecl);
+ scopeDecl->parentDecl = synFunc;
+ auto varStmt = m_astBuilder->create<DeclStmt>();
+ seqStmt->stmts.add(varStmt);
+
+ auto returnVar = m_astBuilder->create<VarDecl>();
+ returnVar->parentDecl = scopeDecl;
+ scopeDecl->members.add(returnVar);
+
+ returnVar->type.type = synFunc->returnType.type;
+ returnVar->nameAndLoc.name = getName("result");
+ varStmt->decl = returnVar;
+ auto resultVarExpr = m_astBuilder->create<VarExpr>();
+ resultVarExpr->declRef = makeDeclRef(returnVar);
+ resultVarExpr->type.type = synFunc->returnType.type;
+ resultVarExpr->type.isLeftValue = true;
+
+ for (auto member : context->parentDecl->members)
+ {
+ auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>();
+ if (!derivativeAttr)
+ continue;
+ auto varMember = as<VarDeclBase>(member);
+ if (!varMember)
+ continue;
+ ensureDecl(varMember, DeclCheckState::ReadyForReference);
+ auto memberType = varMember->getType();
+ auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType);
+ if (!diffMemberType)
+ continue;
+
+ // Construct reference exprs to the member's corresponding fields in each parameter.
+ List<Expr*> paramFields;
+ int paramIndex = 0;
+ for (auto arg : synArgs)
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
+ // Differential type.
+ memberExpr->name = varMember->getName();
+ paramFields.add(memberExpr);
+ paramIndex++;
+ }
+
+ // Invoke the method for the field.
+ auto callee = m_astBuilder->create<StaticMemberExpr>();
+ auto baseSharedType = m_astBuilder->create<SharedTypeExpr>();
+ auto baseSharedTypeType = m_astBuilder->create<TypeType>();
+ baseSharedTypeType->type = memberType;
+ baseSharedType->type = baseSharedTypeType;
+ baseSharedType->base.type = memberType;
+ callee->baseExpression = baseSharedType;
+ callee->name = requirementDeclRef.getName();
+ callee->loc = synFunc->loc;
+ auto invokeExpr = m_astBuilder->create<InvokeExpr>();
+ invokeExpr->functionExpr = callee;
+ invokeExpr->arguments = _Move(paramFields);
+
+ // Assign the value to resultVar.
+ auto leftVal = m_astBuilder->create<MemberExpr>();
+ leftVal->baseExpression = resultVarExpr;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
+ // is Differential type.
+ leftVal->name = varMember->getName();
+
+ auto assignExpr = m_astBuilder->create<AssignExpr>();
+ assignExpr->left = leftVal;
+ assignExpr->right = invokeExpr;
+ auto assignStmt = m_astBuilder->create<ExpressionStmt>();
+ assignStmt->expression = assignExpr;
+ seqStmt->stmts.add(assignStmt);
+ }
+
+ // TODO: synthesize assignments for inherited members here.
+
+ auto synReturn = m_astBuilder->create<ReturnStmt>();
+ synReturn->expression = resultVarExpr;
+ seqStmt->stmts.add(synReturn);
+
+ synFunc->parentDecl = context->parentDecl;
+ context->parentDecl->members.add(synFunc);
+ context->parentDecl->invalidateMemberDictionary();
+ addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>());
+
+ witnessTable->add(requirementDeclRef, RequirementWitness(makeDeclRef(synFunc)));
+ return true;
+ }
+
bool SemanticsVisitor::findWitnessForInterfaceRequirement(
ConformanceCheckingContext* context,
Type* subType,
@@ -3210,18 +3368,16 @@ namespace Slang
if(!lookupResult.isValid())
{
- // If we failed to even look up a member with the name of the
- // requirement, then we can be certain that the type doesn't
- // satisfy the requirement.
- //
- // TODO: If we ever allowed certain kinds of requirements to
- // be inferred (e.g., inferring associated types from the
- // signatures of methods, as is done for Swift), we'd
- // need to revisit this step.
- //
- getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
- getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);
- return false;
+ // If we failed to look up a member with the name of the
+ // requirement, it may be possible that we can still synthesis the
+ // implementation if this is one of the known builtin requirements.
+ // Otherwise, report diagnostic now.
+ if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementAttribute>())
+ {
+ getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
+ getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);
+ return false;
+ }
}
// Iterate over the members and look for one that matches
@@ -5103,7 +5259,7 @@ namespace Slang
{
resultType = CheckProperType(resultType);
}
- else
+ else if (!funcDecl->returnType.type)
{
resultType = TypeExp(m_astBuilder->getVoidType());
}