diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-26 08:32:24 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-26 08:32:24 -0700 |
| commit | 939be44ca23476e622dfb24a592383fe2a1da61f (patch) | |
| tree | 7f45645897fe5735d58a7687290552d479e4d6fc /source/slang | |
| parent | 4fc34b18da2f83ee6b4f094067503a66cab3d0b5 (diff) | |
Auto synthesis of Differential type (#2466)
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/core.meta.slang | 7 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.cpp | 57 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 15 | ||||
| -rw-r--r-- | source/slang/slang-ast-dump.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 126 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 137 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 24 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 53 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-doc-markdown-writer.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-language-server-semantic-tokens.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-lookup.cpp | 67 | ||||
| -rw-r--r-- | source/slang/slang-lookup.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 3 |
18 files changed, 421 insertions, 128 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 769a1091d..a25ce03bd 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2742,6 +2742,13 @@ attribute_syntax [Differentiable] : DifferentiableAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; +enum _BuiltinAssociatedTypeRequirementKind +{ + Differential = $( (int) BuiltinAssociatedTypeRequirementKind::Differential), +}; +__attributeTarget(DeclBase) +attribute_syntax [__BuiltinAssociatedTypeRequirementAttribute(kind: _BuiltinAssociatedTypeRequirementKind)] : BuiltinAssociatedTypeRequirementAttribute; + __attributeTarget(DeclBase) attribute_syntax [builtin] : BuiltinAttribute; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 26fec224c..f314e0487 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -18,6 +18,10 @@ attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; __magic_type(DifferentiableType) interface IDifferentiable { + // Note: the compiler implementation requires the `Differential` associated type to be defined + // before anything else. + + [__BuiltinAssociatedTypeRequirementAttribute(_BuiltinAssociatedTypeRequirementKind.Differential)] associatedtype Differential; static Differential zero(); diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index 2df9164fb..b2802e304 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -1,5 +1,6 @@ // slang-ast-decl.cpp #include "slang-ast-builder.h" +#include "slang-syntax.h" #include <assert.h> #include "slang-generated-ast-macro.h" @@ -32,4 +33,60 @@ bool isInterfaceRequirement(Decl* decl) return false; } +void ContainerDecl::buildMemberDictionary() +{ + // Don't rebuild if already built + if (isMemberDictionaryValid()) + return; + + // If it's < 0 it means that the dictionaries are entirely invalid + if (dictionaryLastCount < 0) + { + dictionaryLastCount = 0; + memberDictionary.Clear(); + transparentMembers.clear(); + } + + // are we a generic? + GenericDecl* genericDecl = as<GenericDecl>(this); + + const Index membersCount = members.getCount(); + + SLANG_ASSERT(dictionaryLastCount >= 0 && dictionaryLastCount <= membersCount); + + for (Index i = dictionaryLastCount; i < membersCount; ++i) + { + Decl* m = members[i]; + + auto name = m->getName(); + + // Add any transparent members to a separate list for lookup + if (m->hasModifier<TransparentModifier>()) + { + TransparentMemberInfo info; + info.decl = m; + transparentMembers.add(info); + } + + // Ignore members with no name + if (!name) + continue; + + // Ignore the "inner" member of a generic declaration + if (genericDecl && m == genericDecl->inner) + continue; + + m->nextInContainerWithSameName = nullptr; + + Decl* next = nullptr; + if (memberDictionary.TryGetValue(name, next)) + m->nextInContainerWithSameName = next; + + memberDictionary[name] = m; + } + + dictionaryLastCount = membersCount; + SLANG_ASSERT(isMemberDictionaryValid()); +} + } // namespace Slang diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index b1b20dc93..87d696927 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -35,12 +35,27 @@ class ContainerDecl: public Decl return FilteredMemberList<T>(members); } + void buildMemberDictionary(); + bool isMemberDictionaryValid() const { return dictionaryLastCount == members.getCount(); } void invalidateMemberDictionary() { dictionaryLastCount = -1; } + Dictionary<Name*, Decl*>& getMemberDictionary() + { + buildMemberDictionary(); + return memberDictionary; + } + + List<TransparentMemberInfo>& getTransparentMembers() + { + buildMemberDictionary(); + return transparentMembers; + } + SLANG_UNREFLECTED // We don't want to reflect the following fields +private: // Denotes how much of Members has been placed into the dictionary/transparentMembers. // If this value equals the Members.getCount(), the dictionary is completely full and valid. // If it's >= 0, then the Members after dictionaryLastCount are all that need to be added. diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index d67a35174..32f9dd16f 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -345,6 +345,10 @@ struct ASTDumpContext { m_writer->emit(getTryClauseTypeName(clauseType)); } + void dump(BuiltinAssociatedTypeRequirementKind kind) + { + m_writer->emit((int)kind); + } void dump(const String& string) { dump(string.getUnownedSlice()); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index b019953cb..c439c7437 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -32,6 +32,12 @@ class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoher class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)}; class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)}; +// Marks that the definition of a decl is not yet synthesized. +class ToBeSynthesizedModifier : public Modifier {SLANG_AST_CLASS(ToBeSynthesizedModifier)}; + +// Marks that the definition of a decl is synthesized. +class SynthesizedModifier : public Modifier { SLANG_AST_CLASS(SynthesizedModifier) }; + // An `extern` variable in an extension is used to introduce additional attributes on an existing // field. class ExtensionExternVarModifier : public Modifier @@ -584,6 +590,14 @@ class Attribute : public AttributeBase AttributeArgumentValueDict intArgVals; }; +// A modifier that indicates a built-in associated type requirement (e.g., `Differential`) +class BuiltinAssociatedTypeRequirementAttribute : public Attribute +{ + SLANG_AST_CLASS(BuiltinAssociatedTypeRequirementAttribute); + + BuiltinAssociatedTypeRequirementKind kind; +}; + class UserDefinedAttribute : public Attribute { SLANG_AST_CLASS(UserDefinedAttribute) diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 39ca71267..9a32d816c 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1490,6 +1490,12 @@ namespace Slang kParameterDirection_Ref, ///< By-reference }; + /// The type of a builtin associated type requirement. + enum class BuiltinAssociatedTypeRequirementKind + { + Differential + }; + } // namespace Slang #endif diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 356105e4f..fa05dde11 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -981,7 +981,7 @@ namespace Slang VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) { auto memberType = checkProperType(getLinkage(), varDecl->type, getSink()); - auto diffType = _getDifferential(m_astBuilder, memberType); + auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc); if (as<ErrorType>(diffType)) { getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType); @@ -994,7 +994,7 @@ namespace Slang Diagnostics:: derivativeMemberAttributeCanOnlyBeUsedOnMembers); } - auto diffThisType = _getDifferential(m_astBuilder, thisType); + auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc); if (!thisType) { getSink()->diagnose( @@ -1359,6 +1359,104 @@ namespace Slang } } + bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<Decl> requirementDeclRef, + RefPtr<WitnessTable> witnessTable) + { + // We currently can't handle generic types. + if (GetOuterGeneric(context->parentDecl) != nullptr) + { + return false; + } + + Decl* existingDecl = nullptr; + AggTypeDecl* aggTypeDecl = nullptr; + if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl)) + { + aggTypeDecl = as<AggTypeDecl>(existingDecl); + SLANG_RELEASE_ASSERT(aggTypeDecl); + + // Remove the `ToBeSynthesizedModifier`. + if (as<ToBeSynthesizedModifier>(aggTypeDecl->modifiers.first)) + { + aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next; + } + } + else + { + aggTypeDecl = m_astBuilder->create<StructDecl>(); + aggTypeDecl->parentDecl = context->parentDecl; + context->parentDecl->members.add((aggTypeDecl)); + aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); + aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; + context->parentDecl->getMemberDictionary().Add(aggTypeDecl->getName(), aggTypeDecl); + } + + // TODO: if we want to make the synthesized type itself to be differentiable, + // add an inheritance decl here. Need to be careful to avoid infinite recursion + // trying to synthesize the higher order differential types. + + // Helper function to add a `diffType` field into the synthesized type for the original + // `member`. + auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc); + auto addDiffMember = [&](Decl* member, Type* diffMemberType) + { + // If the field is differentiable, add a corresponding field in the associated Differential type. + auto diffField = m_astBuilder->create<VarDecl>(); + diffField->nameAndLoc = member->nameAndLoc; + diffField->type.type = diffMemberType; + diffField->checkState = DeclCheckState::SignatureChecked; + diffField->parentDecl = aggTypeDecl; + aggTypeDecl->members.add(diffField); + + // Inject a `DerivativeMember` modifier on the original decl. + auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>(); + auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>(); + fieldLookupExpr->type.type = diffMemberType; + auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->create<TypeType>(); + baseTypeType->type = differentialType; + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(diffField); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + }; + + // Go through super types. + for (auto inheritance : context->parentDecl->getMembersOfType<InheritanceDecl>()) + { + if (auto baseDeclRefType = as<DeclRefType>(inheritance->base.type)) + { + // Skip interface super types. + if (baseDeclRefType->declRef.as<InterfaceDecl>()) + continue; + if (auto superDiffType = tryGetDifferentialType(m_astBuilder, baseDeclRefType)) + { + addDiffMember(inheritance, superDiffType); + } + } + } + + // We go through all members and generate their differential counterparts. + for (auto member : context->parentDecl->getMembersOfType<VarDeclBase>()) + { + auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); + if (!diffType) + continue; + addDiffMember(member, diffType); + } + + // In the future when the Differential type itself needs to conform to some interface, + // this is the place to synthesize requirements for them. + addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>()); + auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + return true; + } + void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*) { // If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any @@ -2146,6 +2244,13 @@ namespace Slang DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, RefPtr<WitnessTable> witnessTable) { + if (auto declRefType = as<DeclRefType>(satisfyingType)) + { + // If we are seeing a placeholder that awaits synthesis, return false now to trigger + // auto synthesis. + if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>()) + return false; + } // We need to confirm that the chosen type `satisfyingType`, // meets all the constraints placed on the associated type // requirement `requiredAssociatedTypeDeclRef`. @@ -2947,6 +3052,21 @@ namespace Slang witnessTable); } + if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>()) + { + if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>()) + { + switch (builtinAttr->kind) + { + case BuiltinAssociatedTypeRequirementKind::Differential: + return trySynthesizeDifferentialAssociatedTypeRequirementWitness( + context, + requiredAssocTypeDeclRef, + witnessTable); + } + } + } + // TODO: There are other kinds of requirements for which synthesis should // be possible: // @@ -4876,7 +4996,7 @@ namespace Slang // We will now look for other declarations with // the same name in the same parent/container. // - buildMemberDictionary(parentDecl); + parentDecl->buildMemberDictionary(); for (auto oldDecl = newDecl->nextInContainerWithSameName; oldDecl; oldDecl = oldDecl->nextInContainerWithSameName) { // For each matching declaration, we will check diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 29b44e726..d69cd39ed 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -393,12 +393,107 @@ namespace Slang return derefExpr; } + Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr) + { + // If the only result from lookup is an entry in an interface decl, it could be that + // the user is leaving out an explicit definition for the requirement and depending on + // the compiler to synthesis the definition. + // In this case, if the lookup is triggered from a location such that the satisfying + // definition should be returned should it existed, we should create a placeholder decl for + // the definition and return a reference to to newly created decl instead of the requirement + // decl in the interface. + switch (item.declRef.getDecl()->astNodeType) + { + case ASTNodeType::AssocTypeDecl: + return maybeUseSynthesizedTypeDeclForLookupResult(item, originalExpr); + default: + return nullptr; + } + } + + Expr* SemanticsVisitor::maybeUseSynthesizedTypeDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr) + { + // We need to check if the lookup should resolve to a definition in an implementation type + // if it existed. + // This will be the case when the lookup is initiated from the concrete implementation type instead of + // directly from the Interface decl. The breadcrumbs of the lookup should provide this information. + + // If no breadcrumbs existed, then the lookup should just resolve to the interface requirement. + + if (!item.breadcrumbs) + return nullptr; + + // We will only ever need to synthesis a type to satisfy an associatedtype requirement. + // In this case the lookup should have resolved to a known associatedtype decl. + auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>(); + if (!builtinAssocTypeAttr) + return nullptr; + + DeclRefType* subType = nullptr; + + // Check if we are reaching the associated type decl through inheritance from a concrete type. + for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) + { + switch (breadcrumb->kind) + { + case LookupResultItem::Breadcrumb::Kind::SuperType: + { + auto witness = as<SubtypeWitness>(breadcrumb->val); + if (auto subDeclRefType = as<DeclRefType>(witness->sub)) + { + if (!as<InterfaceDecl>(subDeclRefType->declRef.getDecl())) + { + // Store the inner most concrete super type. + subType = subDeclRefType; + } + } + } + break; + default: + break; + } + } + if (!subType) + return nullptr; + + subType = as<DeclRefType>(subType->getCanonicalType()); + if (!subType) + return nullptr; + + // Don't synthesize for generic parameters. + auto parent = as<AggTypeDecl>(subType->declRef.getDecl()); + if (!parent) + return nullptr; + + // If we reach here, we are expecting a synthesized associated type defined in `subType`. + // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder type + // in `subType` and return a DeclRefExpr to the synthesized decl. + auto assocType = m_astBuilder->create<StructDecl>(); + assocType->parentDecl = parent; + assocType->nameAndLoc.name = item.declRef.getName(); + assocType->loc = parent->loc; + parent->members.add(assocType); + parent->invalidateMemberDictionary(); + + // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it + // from user-provided definitions, and proceed to fill in its definition. + auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); + addModifier(assocType, toBeSynthesized); + + return ConstructDeclRefExpr(makeDeclRef(assocType), nullptr, originalExpr->loc, originalExpr); + } + Expr* SemanticsVisitor::ConstructLookupResultExpr( LookupResultItem const& item, Expr* baseExpr, SourceLoc loc, Expr* originalExpr) { + // We could be referencing a decl that will be synthesized. If so create a placeholder + // and return a DeclRefExpr to it. + if (auto lookupResultExpr = maybeUseSynthesizedDeclForLookupResult(item, originalExpr)) + return lookupResultExpr; + // If we collected any breadcrumbs, then these represent // additional segments of the lookup path that we need // to expand here. @@ -719,21 +814,25 @@ namespace Slang return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink()); } - Type* SemanticsVisitor::_getDifferential(ASTBuilder* builder, Type* type) + Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type) { if (auto ptrType = as<PtrTypeBase>(type)) { + auto baseDiffType = tryGetDifferentialType(builder, ptrType->getValueType()); + if (!baseDiffType) return nullptr; return builder->getPtrType( - _getDifferential(builder, ptrType->getValueType()), + baseDiffType, ptrType->getClassInfo().m_name); } else if (auto arrayType = as<ArrayExpressionType>(type)) { + auto baseDiffType = tryGetDifferentialType(builder, arrayType->baseType); + if (!baseDiffType) return nullptr; return builder->getArrayType( - _getDifferential(builder, arrayType->baseType), + baseDiffType, arrayType->arrayLength); } - + if (auto declRefType = as<DeclRefType>(type)) { if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface()))) @@ -745,17 +844,16 @@ namespace Slang type, Slang::LookupMask::type, Slang::LookupOptions::None); - + diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult); if (!diffTypeLookupResult.isValid()) { - // Diagnose no 'Differential' member. - getSink()->diagnose(declRefType->declRef, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential")); + return nullptr; } else if (diffTypeLookupResult.isOverloaded()) { - getSink()->diagnose(declRefType->declRef, Diagnostics::ambiguousReference, getName("Differential")); + return nullptr; } else { @@ -764,17 +862,28 @@ namespace Slang baseTypeExpr->type.type = m_astBuilder->getTypeType(type); auto diffTypeExpr = ConstructLookupResultExpr( - diffTypeLookupResult.item, - baseTypeExpr, - declRefType->declRef.getLoc(), - baseTypeExpr); - + diffTypeLookupResult.item, + baseTypeExpr, + declRefType->declRef.getLoc(), + baseTypeExpr); + return ExtractTypeFromTypeRepr(diffTypeExpr); } } } - return m_astBuilder->getErrorType(); + return nullptr; + } + + Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc) + { + auto result = tryGetDifferentialType(builder, type); + if (!result) + { + getSink()->diagnose(loc, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential")); + return m_astBuilder->getErrorType(); + } + return result; } void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 0877f2d6e..ac1d624c2 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -630,6 +630,14 @@ namespace Slang Expr* base, SourceLoc loc); + Expr* maybeUseSynthesizedTypeDeclForLookupResult( + LookupResultItem const& item, + Expr* orignalExpr); + + Expr* maybeUseSynthesizedDeclForLookupResult( + LookupResultItem const& item, + Expr* orignalExpr); + Expr* ConstructLookupResultExpr( LookupResultItem const& item, Expr* baseExpr, @@ -804,7 +812,9 @@ namespace Slang void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); // Construct the differential for 'type', if it exists. - Type* _getDifferential(ASTBuilder* builder, Type* type); + Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc); + Type* tryGetDifferentialType(ASTBuilder* builder, Type* type); + public: @@ -1094,6 +1104,18 @@ namespace Slang DeclRef<Decl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable); + /// Attempt to synthesize an associated `Differential` type for a type that conforms to + /// `IDifferentiable`. + /// + /// On success, installs the syntethesized type in `witnessTable`, injects `[DerivativeMember]` + /// modifiers on differentiable fields to point to the corresponding field in the synthesized + /// differential type, and returns `true`. + /// Otherwise, returns `false`. + bool trySynthesizeDifferentialAssociatedTypeRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<Decl> requirementDeclRef, + RefPtr<WitnessTable> witnessTable); + /// Registers a type as differentiable in the currrent semantic context, if the declaration represents /// a subtype of IDifferentable. Does nothing otherwise. void tryAddDifferentiableConformanceToContext( diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index f977721dd..7e11ee3ca 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -292,7 +292,7 @@ namespace Slang bool SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget) { - if(auto numThreadsAttr = as<NumThreadsAttribute>(attr)) + if (auto numThreadsAttr = as<NumThreadsAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 3); @@ -320,9 +320,9 @@ namespace Slang values[i] = value; } - numThreadsAttr->x = values[0]; - numThreadsAttr->y = values[1]; - numThreadsAttr->z = values[2]; + numThreadsAttr->x = values[0]; + numThreadsAttr->y = values[1]; + numThreadsAttr->z = values[2]; } else if (auto anyValueSizeAttr = as<AnyValueSizeAttribute>(attr)) { @@ -368,7 +368,7 @@ namespace Slang { return false; } - + bindingAttr->binding = int32_t(binding->value); bindingAttr->set = int32_t(set->value); } @@ -395,31 +395,31 @@ namespace Slang SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if(!val) return false; + if (!val) return false; maxVertexCountAttr->value = (int32_t)val->value; } - else if(auto instanceAttr = as<InstanceAttribute>(attr)) + else if (auto instanceAttr = as<InstanceAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if(!val) return false; + if (!val) return false; instanceAttr->value = (int32_t)val->value; } - else if(auto entryPointAttr = as<EntryPointAttribute>(attr)) + else if (auto entryPointAttr = as<EntryPointAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); String stageName; - if(!checkLiteralStringVal(attr->args[0], &stageName)) + if (!checkLiteralStringVal(attr->args[0], &stageName)) { return false; } auto stage = findStageByName(stageName); - if(stage == Stage::Unknown) + if (stage == Stage::Unknown) { getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName); } @@ -427,10 +427,10 @@ namespace Slang entryPointAttr->stage = stage; } else if ((as<DomainAttribute>(attr)) || - (as<MaxTessFactorAttribute>(attr)) || - (as<OutputTopologyAttribute>(attr)) || - (as<PartitioningAttribute>(attr)) || - (as<PatchConstantFuncAttribute>(attr))) + (as<MaxTessFactorAttribute>(attr)) || + (as<OutputTopologyAttribute>(attr)) || + (as<PartitioningAttribute>(attr)) || + (as<PatchConstantFuncAttribute>(attr))) { // Let it go thru iff single string attribute if (!hasStringArgs(attr, 1)) @@ -439,7 +439,7 @@ namespace Slang } } else if (as<OutputControlPointsAttribute>(attr) || - as<SPIRVInstructionOpAttribute>(attr)) + as<SPIRVInstructionOpAttribute>(attr)) { // Let it go thru iff single integral attribute if (!hasIntArgs(attr, 1)) @@ -484,6 +484,27 @@ namespace Slang return false; } } + else if (auto builtinAssocTypeAttr = as<BuiltinAssociatedTypeRequirementAttribute>(attr)) + { + if (attr->args.getCount() == 1) + { + //IntVal* outIntVal; + if (auto cInt = checkConstantEnumVal(attr->args[0])) + { + builtinAssocTypeAttr->kind = (BuiltinAssociatedTypeRequirementKind)(cInt->value); + } + else + { + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); + return false; + } + } + else + { + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); + return false; + } + } else if (auto unrollAttr = as<UnrollAttribute>(attr)) { // Check has an argument. We need this because default behavior is to give an error diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index d7200d47c..a84e40768 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -236,13 +236,10 @@ namespace Slang { auto translationUnitSyntax = translationUnit->getModuleDecl(); - // Make sure we've got a query-able member dictionary - buildMemberDictionary(translationUnitSyntax); - // We will look up any global-scope declarations in the translation // unit that match the name of our entry point. Decl* firstDeclWithName = nullptr; - if (!translationUnitSyntax->memberDictionary.TryGetValue(name, firstDeclWithName)) + if (!translationUnitSyntax->getMemberDictionary().TryGetValue(name, firstDeclWithName)) { // If there doesn't appear to be any such declaration, then we are done. @@ -454,13 +451,10 @@ namespace Slang auto entryPointName = entryPointReq->getName(); - // Make sure we've got a query-able member dictionary - buildMemberDictionary(translationUnitSyntax); - // We will look up any global-scope declarations in the translation // unit that match the name of our entry point. Decl* firstDeclWithName = nullptr; - if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPointName, firstDeclWithName) ) + if( !translationUnitSyntax->getMemberDictionary().TryGetValue(entryPointName, firstDeclWithName)) { // If there doesn't appear to be any such declaration, then // we need to diagnose it as an error, and then bail out. diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index 4d8afd763..9130c05ed 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -667,13 +667,10 @@ static bool _isFirstOverridden(Decl* decl) ContainerDecl* parentDecl = decl->parentDecl; - // Make sure we have the member dictionary. - buildMemberDictionary(parentDecl); - Name* declName = decl->getName(); if (declName) { - Decl** firstDeclPtr = parentDecl->memberDictionary.TryGetValue(declName); + Decl** firstDeclPtr = parentDecl->getMemberDictionary().TryGetValue(declName); return (firstDeclPtr && *firstDeclPtr == decl) || (firstDeclPtr == nullptr); } @@ -1061,11 +1058,10 @@ void DocMarkdownWriter::writeAggType(const ASTMarkup::Entry& entry, AggTypeDeclB { // Make sure we've got a query-able member dictionary - buildMemberDictionary(aggTypeDecl); - SLANG_ASSERT(aggTypeDecl->isMemberDictionaryValid()); + auto& memberDict = aggTypeDecl->getMemberDictionary(); List<Decl*> uniqueMethods; - for (const auto& pair : aggTypeDecl->memberDictionary) + for (const auto& pair : memberDict) { CallableDecl* callableDecl = as<CallableDecl>(pair.Value); if (callableDecl && isVisible(callableDecl)) diff --git a/source/slang/slang-language-server-semantic-tokens.cpp b/source/slang/slang-language-server-semantic-tokens.cpp index 3754c46aa..485dd7a44 100644 --- a/source/slang/slang-language-server-semantic-tokens.cpp +++ b/source/slang/slang-language-server-semantic-tokens.cpp @@ -60,7 +60,6 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS .pathInfo.foundPath.getUnownedSlice() .endsWithCaseInsensitive(fileName)) return; - SemanticToken token = _createSemanticToken(manager, declRef->loc, declRef->name); auto target = declRef->declRef.decl; diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index c574be4ea..c560b67f9 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -30,63 +30,6 @@ struct BreadcrumbInfo // -void buildMemberDictionary(ContainerDecl* decl) -{ - // Don't rebuild if already built - if (decl->isMemberDictionaryValid()) - return; - - // If it's < 0 it means that the dictionaries are entirely invalid - if (decl->dictionaryLastCount < 0) - { - decl->dictionaryLastCount = 0; - decl->memberDictionary.Clear(); - decl->transparentMembers.clear(); - } - - // are we a generic? - GenericDecl* genericDecl = as<GenericDecl>(decl); - - const Index membersCount = decl->members.getCount(); - - SLANG_ASSERT(decl->dictionaryLastCount >= 0 && decl->dictionaryLastCount <= membersCount); - - for (Index i = decl->dictionaryLastCount; i < membersCount; ++i) - { - Decl* m = decl->members[i]; - - auto name = m->getName(); - - // Add any transparent members to a separate list for lookup - if (m->hasModifier<TransparentModifier>()) - { - TransparentMemberInfo info; - info.decl = m; - decl->transparentMembers.add(info); - } - - // Ignore members with no name - if (!name) - continue; - - // Ignore the "inner" member of a generic declaration - if (genericDecl && m == genericDecl->inner) - continue; - - m->nextInContainerWithSameName = nullptr; - - Decl* next = nullptr; - if (decl->memberDictionary.TryGetValue(name, next)) - m->nextInContainerWithSameName = next; - - decl->memberDictionary[name] = m; - } - - decl->dictionaryLastCount = membersCount; - SLANG_ASSERT(decl->isMemberDictionaryValid()); -} - - bool DeclPassesLookupMask(Decl* decl, LookupMask mask) { // Always exclude extern members from lookup result. @@ -229,15 +172,9 @@ static void _lookUpDirectAndTransparentMembers( } else { - // Ensure that the lookup dictionary in the container is up to date - if (!containerDecl->isMemberDictionaryValid()) - { - buildMemberDictionary(containerDecl); - } - // Look up the declarations with the chosen name in the container. Decl* firstDecl = nullptr; - containerDecl->memberDictionary.TryGetValue(name, firstDecl); + containerDecl->getMemberDictionary().TryGetValue(name, firstDecl); // Now iterate over those declarations (if any) and see if // we find any that meet our filtering criteria. @@ -255,7 +192,7 @@ static void _lookUpDirectAndTransparentMembers( // TODO(tfoley): should we look up in the transparent decls // if we already has a hit in the current container? - for(auto transparentInfo : containerDecl->transparentMembers) + for(auto transparentInfo : containerDecl->getTransparentMembers()) { // The reference to the transparent member should use whatever // substitutions we used in referring to its outer container diff --git a/source/slang/slang-lookup.h b/source/slang/slang-lookup.h index 0f034d100..7a9346498 100644 --- a/source/slang/slang-lookup.h +++ b/source/slang/slang-lookup.h @@ -11,10 +11,6 @@ struct SemanticsVisitor; // results that pass the given `LookupMask`. LookupResult refineLookup(LookupResult const& inResult, LookupMask mask); -// Ensure that the dictionary for name-based member lookup has been -// built for the given container declaration. -void buildMemberDictionary(ContainerDecl* decl); - // Look up a name in the given scope, proceeding up through // parent scopes as needed. LookupResult lookUp( diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index c0c035211..f2284a121 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -3078,11 +3078,6 @@ namespace Slang // would trigger a rebuild of the member dictionary that // would take O(N) time. // - // Eventually we should make `builtMemberDictionary()` - // incremental, so that it only has to process members - // added since the last time it was invoked. - // - buildMemberDictionary(parentDecl); // There might be multiple members of the same name // (if we define a namespace `foo` after an overloaded @@ -3090,7 +3085,7 @@ namespace Slang // lookup will only give us the first. // Decl* firstDecl = nullptr; - parentDecl->memberDictionary.TryGetValue(nameAndLoc.name, firstDecl); + parentDecl->getMemberDictionary().TryGetValue(nameAndLoc.name, firstDecl); // // We will search through the declarations of the name // and find the first that is a namespace (if any). diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 5b4b61849..c779b4510 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -682,9 +682,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return Slang::as<Type>(type->substitute(astBuilder, substs)); } - - void buildMemberDictionary(ContainerDecl* decl); - InterfaceDecl* findOuterInterfaceDecl(Decl* decl) { Decl* dd = decl; |
