diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 540 |
1 files changed, 466 insertions, 74 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7d083e53b..b03126512 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -28,6 +28,7 @@ static List<ConstructorDecl*> _getCtorList( SemanticsVisitor* visitor, StructDecl* structDecl, ConstructorDecl** defaultCtorOut); +static Expr* constructDefaultInitExprForType(SemanticsVisitor* visitor, VarDeclBase* varDecl); /// Visitor to transition declarations to `DeclCheckState::CheckedModifiers` struct SemanticsDeclModifiersVisitor : public SemanticsDeclVisitorBase, @@ -94,6 +95,18 @@ struct SemanticsDeclAttributesVisitor : public SemanticsDeclVisitorBase, void checkVarDeclCommon(VarDeclBase* varDecl); void visitVarDecl(VarDecl* varDecl) { checkVarDeclCommon(varDecl); } + + // Synthesize the constructor declaration for a struct during header visit, as we + // need to have such declaration first such that the overloading resolution can lookup + // such constructor and complete the initialize list to constructor translation. + // + // We will defer the actual implementation of the constructor to the body visit, because + // we will have full information about each field in the struct during that stage. + bool _synthesizeCtorSignature(StructDecl* structDecl); + bool collectInitializableMembers( + StructDecl* structDecl, + const DeclVisibility ctorVisibility, + List<VarDeclBase*>& resultMembers); }; struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase, @@ -319,6 +332,7 @@ struct SemanticsDeclBodyVisitor : public SemanticsDeclVisitorBase, SemanticsContext registerDifferentiableTypesForFunc(FunctionDeclBase* funcDecl); +private: struct DeclAndCtorInfo { StructDecl* parent = nullptr; @@ -350,13 +364,21 @@ struct SemanticsDeclBodyVisitor : public SemanticsDeclVisitorBase, ConstructorDecl* ctor, List<DeclAndCtorInfo>& inheritanceDefaultCtorList, ThisExpr* thisExpr, - SeqStmt* seqStmtChild); + SeqStmt* seqStmtChild, + bool isMemberInitCtor, + Index& paramIndex); + void synthesizeCtorBodyForMember( ConstructorDecl* ctor, Decl* member, ThisExpr* thisExpr, Dictionary<Decl*, Expr*>& cachedDeclToCheckedVar, - SeqStmt* seqStmtChild); + SeqStmt* seqStmtChild, + bool isMemberInitCtor, + Index& paramIndex); + + MemberExpr* createMemberExpr(ThisExpr* thisExpr, Scope* scope, Decl* member); + Expr* createCtorParamExpr(ConstructorDecl* ctor, Index paramIndex); }; template<typename VisitorType> @@ -2025,45 +2047,80 @@ void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl) checkVisibility(varDecl); } -static ConstructorDecl* _createCtor( +static void addAutoDiffModifiersToFunc( SemanticsDeclVisitorBase* visitor, ASTBuilder* m_astBuilder, - AggTypeDecl* decl) + FunctionDeclBase* func) +{ + if (visitor->isTypeDifferentiable(func->returnType.type)) + { + addModifier(func, m_astBuilder->create<BackwardDifferentiableAttribute>()); + addModifier(func, m_astBuilder->create<ForwardDifferentiableAttribute>()); + } + else + addModifier(func, m_astBuilder->create<TreatAsDifferentiableAttribute>()); +} + +ConstructorDecl* SemanticsDeclVisitorBase::createCtor( + AggTypeDecl* decl, + DeclVisibility ctorVisibility) { auto ctor = m_astBuilder->create<ConstructorDecl>(); addModifier(ctor, m_astBuilder->create<SynthesizedModifier>()); - auto ctorName = visitor->getName("$init"); + auto ctorName = getName("$init"); ctor->ownedScope = m_astBuilder->create<Scope>(); ctor->ownedScope->containerDecl = ctor; - ctor->ownedScope->parent = visitor->getScope(decl); + ctor->ownedScope->parent = getScope(decl); ctor->parentDecl = decl; ctor->loc = decl->loc; ctor->closingSourceLoc = ctor->loc; ctor->nameAndLoc.name = ctorName; ctor->nameAndLoc.loc = ctor->loc; - ctor->returnType.type = visitor->calcThisType(makeDeclRef(decl)); + ctor->returnType.type = calcThisType(makeDeclRef(decl)); auto body = m_astBuilder->create<BlockStmt>(); body->scopeDecl = m_astBuilder->create<ScopeDecl>(); body->scopeDecl->ownedScope = m_astBuilder->create<Scope>(); - body->scopeDecl->ownedScope->parent = visitor->getScope(ctor); + body->scopeDecl->ownedScope->parent = getScope(ctor); body->scopeDecl->parentDecl = ctor; body->scopeDecl->loc = ctor->loc; body->scopeDecl->closingSourceLoc = ctor->loc; body->closingSourceLoc = ctor->closingSourceLoc; ctor->body = body; body->body = m_astBuilder->create<SeqStmt>(); - ctor->isSynthesized = true; + ctor->addFlavor(ConstructorDecl::ConstructorFlavor::SynthesizedDefault); decl->addMember(ctor); + addAutoDiffModifiersToFunc(this, m_astBuilder, ctor); + addVisibilityModifier(ctor, ctorVisibility); return ctor; } +static inline bool _isDefaultCtor(ConstructorDecl* ctor) +{ + auto allParamHaveInitExpr = [](ConstructorDecl* ctor) + { + for (auto i : ctor->getParameters()) + if (!i->initExpr) + return false; + return true; + }; + + // 1. default ctor must have no parameters + // 2. default ctor can have parameters, but all parameters have init expr (Because we won't + // differentiate this case from 2.) + if (ctor->members.getCount() == 0 || allParamHaveInitExpr(ctor)) + { + return true; + } + + return false; +} + static ConstructorDecl* _getDefaultCtor(StructDecl* structDecl) { for (auto ctor : structDecl->getMembersOfType<ConstructorDecl>()) { - if (!ctor->body || ctor->members.getCount() != 0) - continue; - return ctor; + if (_isDefaultCtor(ctor)) + return ctor; } return nullptr; } @@ -2095,9 +2152,8 @@ static List<ConstructorDecl*> _getCtorList( if (!ctor || !ctor->body) return; ctorList.add(ctor); - if (ctor->members.getCount() != 0) - return; - *defaultCtorOut = ctor; + if (_isDefaultCtor(ctor)) + *defaultCtorOut = ctor; }; if (ctorLookupResult.items.getCount() == 0) { @@ -2208,16 +2264,10 @@ bool isDefaultInitializable(VarDeclBase* varDecl) return true; } -static Expr* constructDefaultInitExprForVar(SemanticsVisitor* visitor, VarDeclBase* varDecl) +static Expr* constructDefaultConstructorForType(SemanticsVisitor* visitor, Type* type) { - if (!varDecl->type || !varDecl->type.type) - return nullptr; - - if (!isDefaultInitializable(varDecl)) - return nullptr; - ConstructorDecl* defaultCtor = nullptr; - auto declRefType = as<DeclRefType>(varDecl->type.type); + auto declRefType = as<DeclRefType>(type); if (declRefType) { if (auto structDecl = as<StructDecl>(declRefType->getDeclRef().getDecl())) @@ -2225,7 +2275,6 @@ static Expr* constructDefaultInitExprForVar(SemanticsVisitor* visitor, VarDeclBa defaultCtor = _getDefaultCtor(structDecl); } } - if (defaultCtor) { auto* invoke = visitor->getASTBuilder()->create<InvokeExpr>(); @@ -2239,6 +2288,22 @@ static Expr* constructDefaultInitExprForVar(SemanticsVisitor* visitor, VarDeclBa nullptr); return invoke; } + + return nullptr; +} + +static Expr* constructDefaultInitExprForType(SemanticsVisitor* visitor, VarDeclBase* varDecl) +{ + if (!varDecl->type || !varDecl->type.type) + return nullptr; + + if (!isDefaultInitializable(varDecl)) + return nullptr; + + if (auto defaultInitExpr = constructDefaultConstructorForType(visitor, varDecl->type.type)) + { + return defaultInitExpr; + } else { auto* defaultCall = visitor->getASTBuilder()->create<DefaultConstructExpr>(); @@ -2255,7 +2320,7 @@ void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) if (getOptionSet().hasOption(CompilerOptionName::ZeroInitialize) && !varDecl->initExpr && as<VarDecl>(varDecl)) { - varDecl->initExpr = constructDefaultInitExprForVar(this, varDecl); + varDecl->initExpr = constructDefaultInitExprForType(this, varDecl); } if (auto initExpr = varDecl->initExpr) @@ -2274,6 +2339,7 @@ void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) if (initExpr->type.isWriteOnly) getSink()->diagnose(initExpr, Diagnostics::readingFromWriteOnly); + initExpr = coerce(CoercionSite::Initializer, varDecl->type.Ptr(), initExpr); varDecl->initExpr = initExpr; @@ -2361,11 +2427,21 @@ void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) // for the variable, that will be used for all downstream // code generation. // - varDecl->initExpr = - CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); - getShared()->cacheImplicitCastMethod( - key, - ImplicitCastMethod{*overloadContext.bestCandidate, 0}); + auto constructorDecl = + as<ConstructorDecl>(overloadContext.bestCandidate->item.declRef).getDecl(); + // We don't allow implicit initialization of struct only have synthesized default + // ctor. + if ((constructorDecl && + !constructorDecl->containsFlavor( + ConstructorDecl::ConstructorFlavor::SynthesizedDefault)) || + !constructorDecl) + { + varDecl->initExpr = + CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); + getShared()->cacheImplicitCastMethod( + key, + ImplicitCastMethod{*overloadContext.bestCandidate, 0}); + } } } } @@ -2478,18 +2554,18 @@ void SemanticsVisitor::CheckConstraintSubType(TypeExp& typeExp) } } -void addVisibilityModifier(ASTBuilder* builder, Decl* decl, DeclVisibility vis) +void SemanticsVisitor::addVisibilityModifier(Decl* decl, DeclVisibility vis) { switch (vis) { case DeclVisibility::Public: - addModifier(decl, builder->create<PublicModifier>()); + addModifier(decl, m_astBuilder->create<PublicModifier>()); break; case DeclVisibility::Internal: - addModifier(decl, builder->create<InternalModifier>()); + addModifier(decl, m_astBuilder->create<InternalModifier>()); break; case DeclVisibility::Private: - addModifier(decl, builder->create<PrivateModifier>()); + addModifier(decl, m_astBuilder->create<PrivateModifier>()); break; default: break; @@ -2635,7 +2711,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness aggTypeDecl->members.add(diffField); auto visibility = getDeclVisibility(member); - addVisibilityModifier(m_astBuilder, diffField, visibility); + addVisibilityModifier(diffField, visibility); aggTypeDecl->invalidateMemberDictionary(); @@ -2748,7 +2824,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl()); auto thisVisibility = getDeclVisibility(context->parentDecl); auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, aggTypeDecl, visibility); + addVisibilityModifier(aggTypeDecl, visibility); } // Synthesize the rest of IDifferential method conformances by recursively checking @@ -4411,7 +4487,7 @@ void SemanticsVisitor::addModifiersToSynthesizedDecl( auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); auto thisVisibility = getDeclVisibility(context->parentDecl); auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synthesized, visibility); + addVisibilityModifier(synthesized, visibility); } } @@ -5318,7 +5394,7 @@ bool SemanticsVisitor::trySynthesizePropertyRequirementWitness( auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); auto thisVisibility = getDeclVisibility(context->parentDecl); auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synPropertyDecl, visibility); + addVisibilityModifier(synPropertyDecl, visibility); } return true; } @@ -5487,7 +5563,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( if (innerProperty.getDecl()->findModifier<VisibilityModifier>()) { auto vis = getDeclVisibility(innerProperty.getDecl()); - addVisibilityModifier(m_astBuilder, synPropertyDecl, vis); + addVisibilityModifier(synPropertyDecl, vis); } context->parentDecl->addMember(synPropertyDecl); @@ -5874,7 +5950,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypeSubscriptRequirementWitness( auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); auto thisVisibility = getDeclVisibility(context->parentDecl); auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synSubscriptDecl, visibility); + addVisibilityModifier(synSubscriptDecl, visibility); } return true; @@ -5995,7 +6071,7 @@ bool SemanticsVisitor::trySynthesizeSubscriptRequirementWitness( auto requirementVisibility = getDeclVisibility(requiredMemberDeclRef.getDecl()); auto thisVisibility = getDeclVisibility(context->parentDecl); auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(m_astBuilder, synSubscriptDecl, visibility); + addVisibilityModifier(synSubscriptDecl, visibility); } return true; @@ -7670,6 +7746,43 @@ void SemanticsVisitor::validateEnumTagType(Type* type, SourceLoc const& loc) getSink()->diagnose(loc, Diagnostics::invalidEnumTagType, type); } +bool SemanticsVisitor::_hasExplicitConstructor(StructDecl* structDecl, bool checkBaseType) +{ + if (!structDecl) + return false; + + auto _hasExplicitCtor = [](AggTypeDecl* aggDecl) -> bool + { + // First check if the extension of this struct defines an explicit constructor. + for (auto ctor : aggDecl->getMembersOfType<ConstructorDecl>()) + { + // constructor that is not synthesized must be user defined. + if (ctor->findModifier<SynthesizedModifier>() == nullptr) + { + return true; + } + } + return false; + }; + + if (_hasExplicitCtor(structDecl)) + return true; + + if (!checkBaseType) + return false; + + for (auto inheritanceMember : structDecl->getMembersOfType<InheritanceDecl>()) + { + auto baseTypeDecl = isDeclRefTypeOf<AggTypeDecl>(inheritanceMember->base.type); + if (baseTypeDecl && !as<InterfaceDecl>(baseTypeDecl)) + { + if (_hasExplicitCtor(baseTypeDecl.getDecl())) + return true; + } + } + return false; +} + void SemanticsDeclBasesVisitor::visitEnumDecl(EnumDecl* decl) { SLANG_OUTER_SCOPE_CONTEXT_DECL_RAII(this, decl); @@ -9064,31 +9177,97 @@ static SeqStmt* _ensureCtorBodyIsSeqStmt(ASTBuilder* m_astBuilder, ConstructorDe return as<SeqStmt>(stmt->body); } +MemberExpr* SemanticsDeclBodyVisitor::createMemberExpr( + ThisExpr* thisExpr, + Scope* scope, + Decl* member) +{ + MemberExpr* memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = thisExpr; + memberExpr->declRef = member->getDefaultDeclRef(); + memberExpr->scope = scope; + memberExpr->loc = member->loc; + memberExpr->name = member->getName(); + memberExpr->type = GetTypeForDeclRef(member->getDefaultDeclRef(), member->loc); + + return memberExpr; +} + +Expr* SemanticsDeclBodyVisitor::createCtorParamExpr(ConstructorDecl* ctor, Index paramIndex) +{ + if (paramIndex < ctor->members.getCount()) + { + if (auto param = as<ParamDecl>(ctor->members[paramIndex])) + { + auto paramType = param->getType(); + auto paramExpr = m_astBuilder->create<VarExpr>(); + paramExpr->scope = ctor->ownedScope; + paramExpr->declRef = param; + paramExpr->type = paramType; + paramExpr->loc = param->loc; + return paramExpr; + } + } + return nullptr; +} + void SemanticsDeclBodyVisitor::synthesizeCtorBodyForBases( ConstructorDecl* ctor, List<DeclAndCtorInfo>& inheritanceDefaultCtorList, ThisExpr* thisExpr, - SeqStmt* seqStmtChild) + SeqStmt* seqStmtChild, + bool isMemberInitCtor, + Index& ioParamIndex) { - // e.g. this->base = BaseType(); for (auto& declInfo : inheritanceDefaultCtorList) { - if (!declInfo.defaultCtor) - continue; + ConstructorDecl* baseCtor = nullptr; + List<Expr*> argumentList; + + if (isMemberInitCtor) + { + // Pick the parameters from the member initialize ctor, and use them to invoke the + // base's member initialize ctor. e.g. base->init(...); + baseCtor = _getSynthesizedConstructor( + declInfo.parent, + ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit); + if (baseCtor) + { + Index idx = 0; + for (; idx < baseCtor->getParameters().getCount(); idx++) + { + auto paramExpr = createCtorParamExpr(ctor, idx); + argumentList.add(paramExpr); + } + ioParamIndex += idx; + } + } + + // It's possible that the base type doesn't have a member initialize ctor, in this case, we + // should use the default ctor. + if (!baseCtor) + { + // If the base type has no default constructor, it means that it's not default + // initializable, e.g. unsized array, resource type, etc. We will not synthesize code to + // initialize it. + if (!declInfo.defaultCtor) + continue; + baseCtor = declInfo.defaultCtor; + } auto declRefType = as<DeclRefType>(declInfo.type); auto ctorToInvoke = m_astBuilder->create<VarExpr>(); ctorToInvoke->declRef = declRefType->getDeclRef(); - ctorToInvoke->name = declInfo.defaultCtor->getName(); - ctorToInvoke->loc = declInfo.defaultCtor->loc; + ctorToInvoke->name = baseCtor->getName(); + ctorToInvoke->loc = baseCtor->loc; ctorToInvoke->type = m_astBuilder->getFuncType(ArrayView<Type*>(), declRefType); auto invoke = m_astBuilder->create<InvokeExpr>(); invoke->functionExpr = ctorToInvoke; + invoke->arguments.addRange(argumentList); auto assign = m_astBuilder->create<AssignExpr>(); - assign->left = coerce(CoercionSite::Initializer, declRefType, thisExpr); assign->right = invoke; @@ -9105,25 +9284,57 @@ void SemanticsDeclBodyVisitor::synthesizeCtorBodyForMember( Decl* member, ThisExpr* thisExpr, Dictionary<Decl*, Expr*>& cachedDeclToCheckedVar, - SeqStmt* seqStmtChild) + SeqStmt* seqStmtChild, + bool isMemberInitCtor, + Index& paramIndex) { auto varDeclBase = as<VarDeclBase>(member); // Static variables are initialized at start of runtime, not inside a constructor - if (!varDeclBase || !varDeclBase->initExpr || varDeclBase->hasModifier<HLSLStaticModifier>()) + // Once thing to notice is that if a member variable doesn't have name, it must be synthesized + // instead of defined by user, we should not put it into the constructor because it's not a real + // member. + if (!varDeclBase || varDeclBase->hasModifier<HLSLStaticModifier>() || + varDeclBase->getName() == nullptr) return; - MemberExpr* memberExpr = m_astBuilder->create<MemberExpr>(); - memberExpr->baseExpression = thisExpr; - memberExpr->declRef = member->getDefaultDeclRef(); - memberExpr->scope = ctor->ownedScope; - memberExpr->loc = member->loc; - memberExpr->name = member->getName(); - memberExpr->type = DeclRefType::create(getASTBuilder(), member->getDefaultDeclRef()); + Expr* initExpr = nullptr; + auto structDecl = as<StructDecl>(member->parentDecl); + bool useParamList = isMemberInitCtor; + useParamList = isMemberInitCtor && structDecl->m_membersVisibleInCtor.contains(varDeclBase); + + if (!useParamList) + { + // If this is not a synthesized constructor (e.g. explicit ctor), or + // the member has no visibility, we can only use it's init expression to initialize it. + if (!varDeclBase->initExpr) + return; + initExpr = varDeclBase->initExpr; + } + else + { + // Find the corresponding parameter, if we can't find it, there + // must be something wrong, it indicates that the ctor signature + // is incorrect that the parameter list doesn't match the member list. + initExpr = createCtorParamExpr(ctor, paramIndex++); + if (!initExpr) + { + const char* structName = + (structDecl->getName() ? structDecl->getName()->text.begin() : "unknown"); + StringBuilder msg; + msg << "Fail to synthesize the member initialize constructor for struct '" << structName + << "', the parameter list doesn't match the member list."; + SLANG_ABORT_COMPILATION(msg.produceString().begin()); + } + } + + MemberExpr* memberExpr = createMemberExpr(thisExpr, ctor->ownedScope, member); + if (!memberExpr->type.isLeftValue) + return; auto assign = m_astBuilder->create<AssignExpr>(); assign->left = memberExpr; - assign->right = varDeclBase->initExpr; + assign->right = initExpr; assign->loc = member->loc; auto stmt = m_astBuilder->create<ExpressionStmt>(); @@ -9139,9 +9350,6 @@ void SemanticsDeclBodyVisitor::synthesizeCtorBodyForMember( cachedDeclToCheckedVar.add({member, checkedMemberVarExpr}); } - if (!checkedMemberVarExpr->type.isLeftValue) - return; - seqStmtChild->stmts.add(stmt); } @@ -9163,14 +9371,37 @@ void SemanticsDeclBodyVisitor::synthesizeCtorBody( thisExpr->scope = ctor->ownedScope; thisExpr->type = ctor->returnType.type; - // Initialize base type by using its default constructor if it has one. - synthesizeCtorBodyForBases(ctor, inheritanceDefaultCtorList, thisExpr, seqStmtChild); - - // Initialize member variables by using their default value if they have one - // e.g. this->member = default_value + // We treat the ctor with parameters and all parameters have default value as default ctor + // as well, but the method to synthesize them are totally different, therefore, we need to + // differentiate them here. + bool isMemberInitCtor = + ctor->containsFlavor(ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit); + + // When we synthesize the member initialize constructor, we need to use the parameters in + // the function body, so this inout parameter is used to keep track of the index of the + // parameters. + Index ioParamIndex = 0; + + // The first step is to synthesize the initialization of the base member. + synthesizeCtorBodyForBases( + ctor, + inheritanceDefaultCtorList, + thisExpr, + seqStmtChild, + isMemberInitCtor, + ioParamIndex); + + // Then synthesize the initialization of the other members. for (auto& m : structDecl->members) { - synthesizeCtorBodyForMember(ctor, m, thisExpr, cachedDeclToCheckedVar, seqStmtChild); + synthesizeCtorBodyForMember( + ctor, + m, + thisExpr, + cachedDeclToCheckedVar, + seqStmtChild, + isMemberInitCtor, + ioParamIndex); } if (seqStmtChild->stmts.getCount() != 0) @@ -9224,7 +9455,7 @@ void SemanticsDeclBodyVisitor::visitAggTypeDecl(AggTypeDecl* aggTypeDecl) ensureDecl(m->getDefaultDeclRef(), DeclCheckState::DefaultConstructorReadyForUse); if (!isDefaultInitializableType || varDeclBase->initExpr) continue; - varDeclBase->initExpr = constructDefaultInitExprForVar(this, varDeclBase); + varDeclBase->initExpr = constructDefaultInitExprForType(this, varDeclBase); } synthesizeCtorBody(structDeclInfo, inheritanceDefaultCtorList, structDecl); @@ -9776,7 +10007,7 @@ Type* SemanticsVisitor::findResultTypeForConstructorDecl(ConstructorDecl* decl) void SemanticsDeclHeaderVisitor::visitConstructorDecl(ConstructorDecl* decl) { - // We need to compute the result tyep for this declaration, + // We need to compute the result type for this declaration, // since it wasn't filled in for us. decl->returnType.type = findResultTypeForConstructorDecl(decl); @@ -11976,13 +12207,174 @@ void SemanticsDeclAttributesVisitor::checkPrimalSubstituteOfAttribute( DeclAssociationKind::PrimalSubstituteFunc); } +bool SemanticsDeclAttributesVisitor::collectInitializableMembers( + StructDecl* structDecl, + const DeclVisibility ctorVisibility, + List<VarDeclBase*>& resultMembers) +{ + auto findMembers = [&](StructDecl* structDecl) + { + for (auto varDeclRef : getMembersOfType<VarDeclBase>( + getASTBuilder(), + structDecl, + MemberFilterStyle::Instance)) + { + auto varDecl = varDeclRef.getDecl(); + if (getDeclVisibility(varDecl) < ctorVisibility) + continue; + + auto type = GetTypeForDeclRef(varDeclRef, varDecl->loc); + if (!type.isLeftValue) + continue; + + resultMembers.add(varDecl); + structDecl->m_membersVisibleInCtor.add(varDecl); + } + }; + + // Find the base type's members first + for (auto inheritanceMember : structDecl->getMembersOfType<InheritanceDecl>()) + { + // For base types, we need to pick their parameters of the constructor to the derived type's + // constructor + if (auto baseTypeDeclRef = isDeclRefTypeOf<StructDecl>(inheritanceMember->base.type)) + { + // We should only find the member initialization constructor because it is the + // constructor has parameters + ConstructorDecl* ctor = _getSynthesizedConstructor( + baseTypeDeclRef.getDecl(), + ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit); + + // The constructor has to have higher or equal visibility level than the struct itself, + // otherwise, it's not accessible so we will not pick up. + if (ctor && getDeclVisibility(ctor) >= ctorVisibility) + { + for (ParamDecl* param : ctor->getParameters()) + { + // Because the parameters in the ctor must have the higher or equal visibility + // than the ctor itself, we don't need to check the visibility level of the + // parameter. + resultMembers.add(param); + } + } + } + } + + // Find the struct's members + findMembers(structDecl); + return (resultMembers.getCount() > 0); +} + +// If a struct's member: +// 1. has an initialize expression: Struct S {int a = 1;}; or +// 2. is a default initializable type +// Note, If a type is not default initializable, it doesn't have default value. +// it can be associated with default value expression in the constructor signature. +// This function helps to check whether either of those 2 conditions are met and create +// a default value for the parameter. +// It's totally fine that there is no default value for the parameter, in this case, user +// code has to provide the argument for this parameter. +static Expr* _getParamDefaultValue(SemanticsVisitor* visitor, VarDeclBase* varDecl) +{ + // 1st condition is easy, we can just use the init expression as the default value. + if (varDecl->initExpr) + { + return varDecl->initExpr; + } + + if (!varDecl->type || !varDecl->type.type) + return nullptr; + + if (!isDefaultInitializable(varDecl)) + return nullptr; + + return constructDefaultConstructorForType(visitor, varDecl->type.type); +} + +bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* structDecl) +{ + // If a type or its base type already defines any explicit constructors, do not synthesize any + // constructors. see: + // https://github.com/shader-slang/slang/blob/master/docs/proposals/004-initialization.md#inheritance-initialization + if (_hasExplicitConstructor(structDecl, true)) + return false; + + // synthesize the signature first. + // The constructor's visibility level is the same as the struct itself. + // See: + // https://github.com/shader-slang/slang/blob/master/docs/proposals/004-initialization.md#synthesis-of-constructors-for-member-initialization + DeclVisibility ctorVisibility = getDeclVisibility(structDecl); + + // Only the members whose visibility level is higher or equal than the + // constructor's visibility level will appear in the constructor's parameter list. + List<VarDeclBase*> resultMembers; + if (!collectInitializableMembers(structDecl, ctorVisibility, resultMembers)) + return false; + + // synthesize the constructor signature: + // 1. The constructor's name is always `$init`, we create one without parameters now. + ConstructorDecl* ctor = createCtor(structDecl, ctorVisibility); + ctor->addFlavor(ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit); + + ctor->members.reserve(resultMembers.getCount()); + + // 2. Add the parameter list + bool stopProcessingDefaultValues = false; + for (SlangInt i = resultMembers.getCount() - 1; i >= 0; i--) + { + auto member = resultMembers[i]; + auto parentAggDecl = getParentAggTypeDecl(member); + ; + + auto ctorParam = m_astBuilder->create<ParamDecl>(); + ctorParam->type = (TypeExp)member->type; + + if (!stopProcessingDefaultValues) + ctorParam->initExpr = _getParamDefaultValue(this, member); + + if (!ctorParam->initExpr) + stopProcessingDefaultValues = true; + + ctorParam->parentDecl = ctor; + + Name* paramName = + (parentAggDecl == structDecl) + ? member->getName() + : getName(parentAggDecl->getName()->text + "_" + member->getName()->text); + + ctorParam->nameAndLoc = NameLoc(paramName, ctor->loc); + + ctorParam->loc = ctor->loc; + ctor->members.add(ctorParam); + + // We need to ensure member is `no_diff` if it cannot be differentiated, `ctor` modifiers do + // not matter in this case since member-wise ctor is always differentiable or "treat as + // differentiable". + if (!isTypeDifferentiable(member->getType()) || member->hasModifier<NoDiffModifier>()) + { + auto noDiffMod = m_astBuilder->create<NoDiffModifier>(); + noDiffMod->loc = ctorParam->loc; + addModifier(ctorParam, noDiffMod); + } + } + ctor->members.reverse(); + return true; +} + void SemanticsDeclAttributesVisitor::visitStructDecl(StructDecl* structDecl) { - // add a empty deault CTor if missing; checking in attributes - // to avoid circular checking logic - auto defaultCtor = _getDefaultCtor(structDecl); - if (!defaultCtor) - _createCtor(this, m_astBuilder, structDecl); + // add the member initialize constructor here to avoid circular checking logic + if (!_synthesizeCtorSignature(structDecl)) + { + // add a default CTor if missing; checking in attributes + // to avoid circular checking logic + auto defaultCtor = _getDefaultCtor(structDecl); + if (!defaultCtor) + { + DeclVisibility ctorVisibility = getDeclVisibility(structDecl); + createCtor(structDecl, ctorVisibility); + } + } int backingWidth = 0; [[maybe_unused]] int totalWidth = 0; |
