summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-01 08:46:57 -0700
committerGitHub <noreply@github.com>2022-11-01 08:46:57 -0700
commitcbc1eff56057f199183bb7c17d8a360326512367 (patch)
tree487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/slang-check-decl.cpp
parentb707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff)
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp438
1 files changed, 311 insertions, 127 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 457ae229b..f60fbcc2c 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -11,9 +11,8 @@
// and when things get checked.
#include "slang-lookup.h"
-
#include "slang-syntax.h"
-
+#include "slang-ast-synthesis.h"
#include <limits>
namespace Slang
@@ -166,6 +165,65 @@ namespace Slang
void visitExtensionDecl(ExtensionDecl* decl);
};
+ struct SemanticsDeclTypeResolutionVisitor
+ : public SemanticsDeclVisitorBase
+ , public DeclVisitor<SemanticsDeclTypeResolutionVisitor>
+ {
+ SemanticsDeclTypeResolutionVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+
+ void visitDecl(Decl*) {}
+ void visitDeclGroup(DeclGroup*) {}
+
+ Val* resolveVal(Val* val);
+ Type* resolveType(Type* type)
+ {
+ return (Type*)resolveVal(type);
+ }
+
+ void visitTypeExp(TypeExp& exp)
+ {
+ exp.type = resolveType(exp.type);
+ }
+
+ void visitVarDeclBase(VarDeclBase* varDecl)
+ {
+ visitTypeExp(varDecl->type);
+ }
+
+ void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl)
+ {
+ visitTypeExp(decl->sup);
+ }
+
+ void visitTypeDefDecl(TypeDefDecl* decl)
+ {
+ visitTypeExp(decl->type);
+ }
+
+ void visitGenericTypeParamDecl(GenericTypeParamDecl* paramDecl)
+ {
+ visitTypeExp(paramDecl->initType);
+ }
+
+ void visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
+ {
+ visitTypeExp(inheritanceDecl->base);
+ }
+
+ void visitCallableDecl(CallableDecl* decl)
+ {
+ visitTypeExp(decl->returnType);
+ visitTypeExp(decl->errorType);
+ }
+
+ void visitPropertyDecl(PropertyDecl* decl)
+ {
+ visitTypeExp(decl->type);
+ }
+ };
+
struct SemanticsDeclBodyVisitor
: public SemanticsDeclVisitorBase
, public DeclVisitor<SemanticsDeclBodyVisitor>
@@ -1363,27 +1421,30 @@ namespace Slang
bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness(
ConformanceCheckingContext* context,
- DeclRef<Decl> requirementDeclRef,
+ DeclRef<AssocTypeDecl> requirementDeclRef,
RefPtr<WitnessTable> witnessTable)
{
- // We currently can't handle generic types.
- if (GetOuterGeneric(context->parentDecl) != nullptr)
- {
- return false;
- }
-
+ ASTSynthesizer synth(m_astBuilder, getNamePool());
Decl* existingDecl = nullptr;
AggTypeDecl* aggTypeDecl = nullptr;
if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl))
{
- aggTypeDecl = as<AggTypeDecl>(existingDecl);
- SLANG_RELEASE_ASSERT(aggTypeDecl);
-
// Remove the `ToBeSynthesizedModifier`.
- if (as<ToBeSynthesizedModifier>(aggTypeDecl->modifiers.first))
+ if (as<ToBeSynthesizedModifier>(existingDecl->modifiers.first))
{
- aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next;
+ existingDecl->modifiers.first = existingDecl->modifiers.first->next;
}
+ else
+ {
+ // The user has defined an associatedtype explicitly but that we reach here because
+ // that type failed to satisfy the `IDifferential` requirement.
+ // We stop the synthesis and let the follow-up logic to report a diagnostic.
+ return false;
+ }
+
+ aggTypeDecl = as<AggTypeDecl>(existingDecl);
+ SLANG_RELEASE_ASSERT(aggTypeDecl);
+ synth.pushContainerScope(aggTypeDecl);
}
else
{
@@ -1393,15 +1454,12 @@ namespace Slang
aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName();
aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc;
context->parentDecl->invalidateMemberDictionary();
+ synth.pushScopeForContainer(aggTypeDecl);
}
- // TODO: if we want to make the synthesized type itself to be differentiable,
- // add an inheritance decl here. Need to be careful to avoid infinite recursion
- // trying to synthesize the higher order differential types.
-
// Helper function to add a `diffType` field into the synthesized type for the original
// `member`.
- auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc);
+ auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl));
auto addDiffMember = [&](Decl* member, Type* diffMemberType)
{
// If the field is differentiable, add a corresponding field in the associated Differential type.
@@ -1452,12 +1510,35 @@ namespace Slang
addDiffMember(member, diffType);
}
- // In the future when the Differential type itself needs to conform to some interface,
- // this is the place to synthesize requirements for them.
addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>());
- auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr);
- witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
- return true;
+
+ // If `This` is nested inside a generic, we need to form a complete declref type to the
+ // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution
+ // from requirementDeclRef to get the generic substitution for outer generic parameters, and
+ // apply it to the newly synthesized decl.
+ SubstitutionSet substSet;
+ if (auto thisTypeSusbt = findThisTypeSubstitution(
+ requirementDeclRef.substitutions,
+ as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
+ {
+ if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ {
+ substSet = declRefType->declRef.substitutions;
+ }
+ }
+
+ auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet);
+
+ if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable))
+ {
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
+ return true;
+ }
+
+ // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed.
+ // If not, there is something wrong with the code synthesis logic. For now we just return false
+ // instead of crashing so the user can work around the issues.
+ return false;
}
void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*)
@@ -2242,22 +2323,8 @@ namespace Slang
witnessTable);
}
- bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement(
- Type* satisfyingType,
- DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef,
- RefPtr<WitnessTable> witnessTable)
+ bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, RefPtr<WitnessTable> witnessTable)
{
- if (auto declRefType = as<DeclRefType>(satisfyingType))
- {
- // If we are seeing a placeholder that awaits synthesis, return false now to trigger
- // auto synthesis.
- if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>())
- return false;
- }
- // We need to confirm that the chosen type `satisfyingType`,
- // meets all the constraints placed on the associated type
- // requirement `requiredAssociatedTypeDeclRef`.
- //
// We will enumerate the type constraints placed on the
// associated type and see if they can be satisfied.
//
@@ -2269,7 +2336,7 @@ namespace Slang
// Perform a search for a witness to the subtype relationship.
auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType);
- if(witness)
+ if (witness)
{
// If a subtype witness was found, then the conformance
// appears to hold, and we can satisfy that requirement.
@@ -2282,6 +2349,30 @@ namespace Slang
conformance = false;
}
}
+ return conformance;
+ }
+
+ bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement(
+ Type* satisfyingType,
+ DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ if (auto declRefType = as<DeclRefType>(satisfyingType))
+ {
+ // If we are seeing a placeholder that awaits synthesis, return false now to trigger
+ // auto synthesis.
+ if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>())
+ return false;
+ }
+ // We need to confirm that the chosen type `satisfyingType`,
+ // meets all the constraints placed on the associated type
+ // requirement `requiredAssociatedTypeDeclRef`.
+ //
+ // We will enumerate the type constraints placed on the
+ // associated type and see if they can be satisfied.
+ //
+ bool conformance = doesTypeSatisfyAssociatedTypeConstraintRequirement(
+ satisfyingType, requiredAssociatedTypeDeclRef, witnessTable);
// TODO: if any conformance check failed, we should probably include
// that in an error message produced about not satisfying the requirement.
@@ -3122,12 +3213,43 @@ namespace Slang
return false;
}
+ Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List<Expr*>&& args, int nestingLevel = 0)
+ {
+ if (nestingLevel > 16)
+ return nullptr;
+
+ // If field type is an array, assign each element individually.
+ if (auto arrayType = as<ArrayExpressionType>(leftType))
+ {
+ VarDecl* indexVar = nullptr;
+ auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar);
+ auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar));
+ for (auto& arg : args)
+ {
+ arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar));
+ }
+ auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->baseType, innerLeft, _Move(args), nestingLevel + 1);
+ synth.popScope();
+ if (!assignStmt)
+ return nullptr;
+ forStmt->statement = assignStmt;
+ return forStmt;
+ }
+
+ auto callee = synth.emitMemberExpr(leftType, funcName);
+ return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args)));
+ }
+
bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<Decl> requirementDeclRef,
RefPtr<WitnessTable> witnessTable)
{
- // This method implements a general code synthesis pattern.
+ // We support two cases of synthesis here.
+ // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`.
+ // In this case we just trivially return `DifferentialBottom` in all synthesized methods.
+ // Case 2 is that the `Differential` type contains members corresponding to each primal member.
+ // We will apply a general code synthesis pattern to reflect that structure.
// For requirement of the form:
// ```
// static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
@@ -3145,104 +3267,123 @@ namespace Slang
// return result;
// }
// ```
+
+ // First we need to make sure the associated `Differential` type requirement is satisfied.
+ bool hasDifferentialAssocType = false;
+ for (auto existingEntry : witnessTable->requirementList)
+ {
+ if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementAttribute>())
+ {
+ if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType &&
+ existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none)
+ {
+ hasDifferentialAssocType = true;
+ }
+ }
+ }
+ if (!hasDifferentialAssocType)
+ return false;
+
+ ASTSynthesizer synth(m_astBuilder, getNamePool());
List<Expr*> synArgs;
ThisExpr* synThis = nullptr;
auto synFunc = synthesizeMethodSignatureForRequirementWitness(
context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis);
-
+ synFunc->parentDecl = context->parentDecl;
+ synth.pushContainerScope(synFunc);
auto blockStmt = m_astBuilder->create<BlockStmt>();
synFunc->body = blockStmt;
- auto seqStmt = m_astBuilder->create<SeqStmt>();
+ auto seqStmt = synth.pushSeqStmtScope();
blockStmt->body = seqStmt;
- // Create a variable for return value.
- auto scopeDecl = m_astBuilder->create<ScopeDecl>();
- synFunc->members.add(scopeDecl);
- scopeDecl->parentDecl = synFunc;
- auto varStmt = m_astBuilder->create<DeclStmt>();
- seqStmt->stmts.add(varStmt);
-
- auto returnVar = m_astBuilder->create<VarDecl>();
- returnVar->parentDecl = scopeDecl;
- scopeDecl->members.add(returnVar);
-
- returnVar->type.type = synFunc->returnType.type;
- returnVar->nameAndLoc.name = getName("result");
- varStmt->decl = returnVar;
- auto resultVarExpr = m_astBuilder->create<VarExpr>();
- resultVarExpr->declRef = makeDeclRef(returnVar);
- resultVarExpr->type.type = synFunc->returnType.type;
- resultVarExpr->type.isLeftValue = true;
-
- for (auto member : context->parentDecl->members)
- {
- auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>();
- if (!derivativeAttr)
- continue;
- auto varMember = as<VarDeclBase>(member);
- if (!varMember)
- continue;
- ensureDecl(varMember, DeclCheckState::ReadyForReference);
- auto memberType = varMember->getType();
- auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType);
- if (!diffMemberType)
- continue;
+ if (synFunc->returnType.type->equals(m_astBuilder->getDifferentialBottomType()))
+ {
+ // Trivial case, the `Differential` type is `DifferentialBottom`.
+ // We will just return `DifferentialBottom.dzero()`.
+ auto resultExpr = m_astBuilder->create<InvokeExpr>();
+ auto dzeroMember = m_astBuilder->create<StaticMemberExpr>();
+ auto base = m_astBuilder->create<SharedTypeExpr>();
+ auto typetype = m_astBuilder->create<TypeType>();
+ typetype->type = m_astBuilder->getDifferentialBottomType();
+ base->type.type = typetype;
+ dzeroMember->baseExpression = base;
+ dzeroMember->name = getName("dzero");
+ resultExpr->functionExpr = dzeroMember;
+ auto synReturn = m_astBuilder->create<ReturnStmt>();
+ synReturn->expression = resultExpr;
+ seqStmt->stmts.add(synReturn);
+ }
+ else
+ {
+ // The general case.
+ // Create a variable for return value.
+ synth.pushVarScope();
+ auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result"));
+ auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type);
- // Construct reference exprs to the member's corresponding fields in each parameter.
- List<Expr*> paramFields;
- int paramIndex = 0;
- for (auto arg : synArgs)
- {
- auto memberExpr = m_astBuilder->create<MemberExpr>();
- memberExpr->baseExpression = arg;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
- // Differential type.
- memberExpr->name = varMember->getName();
- paramFields.add(memberExpr);
- paramIndex++;
- }
-
- // Invoke the method for the field.
- auto callee = m_astBuilder->create<StaticMemberExpr>();
- auto baseSharedType = m_astBuilder->create<SharedTypeExpr>();
- auto baseSharedTypeType = m_astBuilder->create<TypeType>();
- baseSharedTypeType->type = memberType;
- baseSharedType->type = baseSharedTypeType;
- baseSharedType->base.type = memberType;
- callee->baseExpression = baseSharedType;
- callee->name = requirementDeclRef.getName();
- callee->loc = synFunc->loc;
- auto invokeExpr = m_astBuilder->create<InvokeExpr>();
- invokeExpr->functionExpr = callee;
- invokeExpr->arguments = _Move(paramFields);
-
- // Assign the value to resultVar.
- auto leftVal = m_astBuilder->create<MemberExpr>();
- leftVal->baseExpression = resultVarExpr;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
- // is Differential type.
- leftVal->name = varMember->getName();
-
- auto assignExpr = m_astBuilder->create<AssignExpr>();
- assignExpr->left = leftVal;
- assignExpr->right = invokeExpr;
- auto assignStmt = m_astBuilder->create<ExpressionStmt>();
- assignStmt->expression = assignExpr;
- seqStmt->stmts.add(assignStmt);
- }
-
- // TODO: synthesize assignments for inherited members here.
-
- auto synReturn = m_astBuilder->create<ReturnStmt>();
- synReturn->expression = resultVarExpr;
- seqStmt->stmts.add(synReturn);
+ for (auto member : context->parentDecl->members)
+ {
+ auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>();
+ if (!derivativeAttr)
+ continue;
+ auto varMember = as<VarDeclBase>(member);
+ if (!varMember)
+ continue;
+ ensureDecl(varMember, DeclCheckState::ReadyForReference);
+ auto memberType = varMember->getType();
+ auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType);
+ if (!diffMemberType)
+ continue;
- synFunc->parentDecl = context->parentDecl;
+ // Construct reference exprs to the member's corresponding fields in each parameter.
+ List<Expr*> paramFields;
+ int paramIndex = 0;
+ for (auto arg : synArgs)
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
+ // Differential type.
+ memberExpr->name = varMember->getName();
+ paramFields.add(memberExpr);
+ paramIndex++;
+ }
+
+ // Invoke the method for the field and assign the value to resultVar.
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
+ // is Differential type.
+ auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName());
+ if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields)))
+ return false;
+ }
+
+ // TODO: synthesize assignments for inherited members here.
+
+ auto synReturn = m_astBuilder->create<ReturnStmt>();
+ synReturn->expression = resultVarExpr;
+ seqStmt->stmts.add(synReturn);
+ }
+
context->parentDecl->members.add(synFunc);
context->parentDecl->invalidateMemberDictionary();
addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>());
- witnessTable->add(requirementDeclRef, RequirementWitness(makeDeclRef(synFunc)));
+ // If `This` is nested inside a generic, we need to form a complete declref type to the
+ // newly synthesized method here in order to fill into the witness table.
+ // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the
+ // generic substitution for outer generic parameters, and apply it here.
+ SubstitutionSet substSet;
+ if (auto thisTypeSusbt = findThisTypeSubstitution(
+ requirementDeclRef.substitutions,
+ as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
+ {
+ if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ {
+ substSet = declRefType->declRef.substitutions;
+ }
+ }
+
+ witnessTable->add(requirementDeclRef, RequirementWitness(DeclRef<Decl>(synFunc, substSet)));
return true;
}
@@ -3801,7 +3942,10 @@ namespace Slang
// be required to implement all interface requirements,
// just with `abstract` methods that replicate things?
// (That's what C# does).
- for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
+
+ // Make a copy of inhertanceDecls firstsince `checkConformance` may modify decl->members.
+ auto inheritanceDecls = decl->getMembersOfType<InheritanceDecl>().toList();
+ for (auto inheritanceDecl : inheritanceDecls)
{
checkConformance(type, inheritanceDecl, decl);
}
@@ -5230,7 +5374,7 @@ namespace Slang
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
- if (decl->findModifier<ForwardDifferentiableAttribute>())
+ if (decl->findModifier<DifferentiableAttribute>())
{
this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
}
@@ -6274,6 +6418,10 @@ namespace Slang
SemanticsDeclConformancesVisitor(shared).dispatch(decl);
break;
+ case DeclCheckState::TypesFullyResolved:
+ SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl);
+ break;
+
case DeclCheckState::Checked:
SemanticsDeclBodyVisitor(shared).dispatch(decl);
break;
@@ -6325,4 +6473,40 @@ namespace Slang
return result;
}
+ Val* SemanticsDeclTypeResolutionVisitor::resolveVal(Val* val)
+ {
+ if (auto declRefType = as<DeclRefType>(val))
+ {
+ if (auto concreteType = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(m_astBuilder, declRefType->declRef))
+ return as<Type>(concreteType);
+ for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst=subst->outer)
+ {
+ if (auto genericSubst = as<GenericSubstitution>(subst))
+ {
+ ShortList<Val*> newArgs;
+ for (auto& arg : genericSubst->getArgs())
+ {
+ arg = resolveVal(arg);
+ SLANG_RELEASE_ASSERT(arg);
+ }
+ }
+ }
+ }
+ else if (auto subtypeWitness = as<SubtypeWitness>(val))
+ {
+ auto sub = as<Type>(resolveVal(subtypeWitness->sub));
+ auto sup = as<Type>(resolveVal(subtypeWitness->sup));
+ if (sub && sup)
+ {
+ if (sub != subtypeWitness->sub || sup != subtypeWitness->sup)
+ {
+ auto newVal = tryGetSubtypeWitness(as<Type>(sub), as<Type>(sup));
+ if (newVal)
+ val = newVal;
+ }
+ }
+ }
+ return val;
+ }
+
}