summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-26 08:32:24 -0700
committerGitHub <noreply@github.com>2022-10-26 08:32:24 -0700
commit939be44ca23476e622dfb24a592383fe2a1da61f (patch)
tree7f45645897fe5735d58a7687290552d479e4d6fc /source/slang
parent4fc34b18da2f83ee6b4f094067503a66cab3d0b5 (diff)
Auto synthesis of Differential type (#2466)
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/core.meta.slang7
-rw-r--r--source/slang/diff.meta.slang4
-rw-r--r--source/slang/slang-ast-decl.cpp57
-rw-r--r--source/slang/slang-ast-decl.h15
-rw-r--r--source/slang/slang-ast-dump.cpp4
-rw-r--r--source/slang/slang-ast-modifier.h14
-rw-r--r--source/slang/slang-ast-support-types.h6
-rw-r--r--source/slang/slang-check-decl.cpp126
-rw-r--r--source/slang/slang-check-expr.cpp137
-rw-r--r--source/slang/slang-check-impl.h24
-rw-r--r--source/slang/slang-check-modifier.cpp53
-rw-r--r--source/slang/slang-check-shader.cpp10
-rw-r--r--source/slang/slang-doc-markdown-writer.cpp10
-rw-r--r--source/slang/slang-language-server-semantic-tokens.cpp1
-rw-r--r--source/slang/slang-lookup.cpp67
-rw-r--r--source/slang/slang-lookup.h4
-rw-r--r--source/slang/slang-parser.cpp7
-rw-r--r--source/slang/slang-syntax.cpp3
18 files changed, 421 insertions, 128 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 769a1091d..a25ce03bd 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2742,6 +2742,13 @@ attribute_syntax [Differentiable] : DifferentiableAttribute;
__attributeTarget(DeclBase)
attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
+enum _BuiltinAssociatedTypeRequirementKind
+{
+ Differential = $( (int) BuiltinAssociatedTypeRequirementKind::Differential),
+};
+__attributeTarget(DeclBase)
+attribute_syntax [__BuiltinAssociatedTypeRequirementAttribute(kind: _BuiltinAssociatedTypeRequirementKind)] : BuiltinAssociatedTypeRequirementAttribute;
+
__attributeTarget(DeclBase)
attribute_syntax [builtin] : BuiltinAttribute;
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 26fec224c..f314e0487 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -18,6 +18,10 @@ attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
__magic_type(DifferentiableType)
interface IDifferentiable
{
+ // Note: the compiler implementation requires the `Differential` associated type to be defined
+ // before anything else.
+
+ [__BuiltinAssociatedTypeRequirementAttribute(_BuiltinAssociatedTypeRequirementKind.Differential)]
associatedtype Differential;
static Differential zero();
diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp
index 2df9164fb..b2802e304 100644
--- a/source/slang/slang-ast-decl.cpp
+++ b/source/slang/slang-ast-decl.cpp
@@ -1,5 +1,6 @@
// slang-ast-decl.cpp
#include "slang-ast-builder.h"
+#include "slang-syntax.h"
#include <assert.h>
#include "slang-generated-ast-macro.h"
@@ -32,4 +33,60 @@ bool isInterfaceRequirement(Decl* decl)
return false;
}
+void ContainerDecl::buildMemberDictionary()
+{
+ // Don't rebuild if already built
+ if (isMemberDictionaryValid())
+ return;
+
+ // If it's < 0 it means that the dictionaries are entirely invalid
+ if (dictionaryLastCount < 0)
+ {
+ dictionaryLastCount = 0;
+ memberDictionary.Clear();
+ transparentMembers.clear();
+ }
+
+ // are we a generic?
+ GenericDecl* genericDecl = as<GenericDecl>(this);
+
+ const Index membersCount = members.getCount();
+
+ SLANG_ASSERT(dictionaryLastCount >= 0 && dictionaryLastCount <= membersCount);
+
+ for (Index i = dictionaryLastCount; i < membersCount; ++i)
+ {
+ Decl* m = members[i];
+
+ auto name = m->getName();
+
+ // Add any transparent members to a separate list for lookup
+ if (m->hasModifier<TransparentModifier>())
+ {
+ TransparentMemberInfo info;
+ info.decl = m;
+ transparentMembers.add(info);
+ }
+
+ // Ignore members with no name
+ if (!name)
+ continue;
+
+ // Ignore the "inner" member of a generic declaration
+ if (genericDecl && m == genericDecl->inner)
+ continue;
+
+ m->nextInContainerWithSameName = nullptr;
+
+ Decl* next = nullptr;
+ if (memberDictionary.TryGetValue(name, next))
+ m->nextInContainerWithSameName = next;
+
+ memberDictionary[name] = m;
+ }
+
+ dictionaryLastCount = membersCount;
+ SLANG_ASSERT(isMemberDictionaryValid());
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index b1b20dc93..87d696927 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -35,12 +35,27 @@ class ContainerDecl: public Decl
return FilteredMemberList<T>(members);
}
+ void buildMemberDictionary();
+
bool isMemberDictionaryValid() const { return dictionaryLastCount == members.getCount(); }
void invalidateMemberDictionary() { dictionaryLastCount = -1; }
+ Dictionary<Name*, Decl*>& getMemberDictionary()
+ {
+ buildMemberDictionary();
+ return memberDictionary;
+ }
+
+ List<TransparentMemberInfo>& getTransparentMembers()
+ {
+ buildMemberDictionary();
+ return transparentMembers;
+ }
+
SLANG_UNREFLECTED // We don't want to reflect the following fields
+private:
// Denotes how much of Members has been placed into the dictionary/transparentMembers.
// If this value equals the Members.getCount(), the dictionary is completely full and valid.
// If it's >= 0, then the Members after dictionaryLastCount are all that need to be added.
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index d67a35174..32f9dd16f 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -345,6 +345,10 @@ struct ASTDumpContext
{
m_writer->emit(getTryClauseTypeName(clauseType));
}
+ void dump(BuiltinAssociatedTypeRequirementKind kind)
+ {
+ m_writer->emit((int)kind);
+ }
void dump(const String& string)
{
dump(string.getUnownedSlice());
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index b019953cb..c439c7437 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -32,6 +32,12 @@ class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoher
class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)};
class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)};
+// Marks that the definition of a decl is not yet synthesized.
+class ToBeSynthesizedModifier : public Modifier {SLANG_AST_CLASS(ToBeSynthesizedModifier)};
+
+// Marks that the definition of a decl is synthesized.
+class SynthesizedModifier : public Modifier { SLANG_AST_CLASS(SynthesizedModifier) };
+
// An `extern` variable in an extension is used to introduce additional attributes on an existing
// field.
class ExtensionExternVarModifier : public Modifier
@@ -584,6 +590,14 @@ class Attribute : public AttributeBase
AttributeArgumentValueDict intArgVals;
};
+// A modifier that indicates a built-in associated type requirement (e.g., `Differential`)
+class BuiltinAssociatedTypeRequirementAttribute : public Attribute
+{
+ SLANG_AST_CLASS(BuiltinAssociatedTypeRequirementAttribute);
+
+ BuiltinAssociatedTypeRequirementKind kind;
+};
+
class UserDefinedAttribute : public Attribute
{
SLANG_AST_CLASS(UserDefinedAttribute)
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 39ca71267..9a32d816c 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1490,6 +1490,12 @@ namespace Slang
kParameterDirection_Ref, ///< By-reference
};
+ /// The type of a builtin associated type requirement.
+ enum class BuiltinAssociatedTypeRequirementKind
+ {
+ Differential
+ };
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 356105e4f..fa05dde11 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -981,7 +981,7 @@ namespace Slang
VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr)
{
auto memberType = checkProperType(getLinkage(), varDecl->type, getSink());
- auto diffType = _getDifferential(m_astBuilder, memberType);
+ auto diffType = getDifferentialType(m_astBuilder, memberType, varDecl->loc);
if (as<ErrorType>(diffType))
{
getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType);
@@ -994,7 +994,7 @@ namespace Slang
Diagnostics::
derivativeMemberAttributeCanOnlyBeUsedOnMembers);
}
- auto diffThisType = _getDifferential(m_astBuilder, thisType);
+ auto diffThisType = getDifferentialType(m_astBuilder, thisType, derivativeMemberAttr->loc);
if (!thisType)
{
getSink()->diagnose(
@@ -1359,6 +1359,104 @@ namespace Slang
}
}
+ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<Decl> requirementDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ // We currently can't handle generic types.
+ if (GetOuterGeneric(context->parentDecl) != nullptr)
+ {
+ return false;
+ }
+
+ Decl* existingDecl = nullptr;
+ AggTypeDecl* aggTypeDecl = nullptr;
+ if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl))
+ {
+ aggTypeDecl = as<AggTypeDecl>(existingDecl);
+ SLANG_RELEASE_ASSERT(aggTypeDecl);
+
+ // Remove the `ToBeSynthesizedModifier`.
+ if (as<ToBeSynthesizedModifier>(aggTypeDecl->modifiers.first))
+ {
+ aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next;
+ }
+ }
+ else
+ {
+ aggTypeDecl = m_astBuilder->create<StructDecl>();
+ aggTypeDecl->parentDecl = context->parentDecl;
+ context->parentDecl->members.add((aggTypeDecl));
+ aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName();
+ aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc;
+ context->parentDecl->getMemberDictionary().Add(aggTypeDecl->getName(), aggTypeDecl);
+ }
+
+ // TODO: if we want to make the synthesized type itself to be differentiable,
+ // add an inheritance decl here. Need to be careful to avoid infinite recursion
+ // trying to synthesize the higher order differential types.
+
+ // Helper function to add a `diffType` field into the synthesized type for the original
+ // `member`.
+ auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc);
+ auto addDiffMember = [&](Decl* member, Type* diffMemberType)
+ {
+ // If the field is differentiable, add a corresponding field in the associated Differential type.
+ auto diffField = m_astBuilder->create<VarDecl>();
+ diffField->nameAndLoc = member->nameAndLoc;
+ diffField->type.type = diffMemberType;
+ diffField->checkState = DeclCheckState::SignatureChecked;
+ diffField->parentDecl = aggTypeDecl;
+ aggTypeDecl->members.add(diffField);
+
+ // Inject a `DerivativeMember` modifier on the original decl.
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = diffMemberType;
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = differentialType;
+ auto baseTypeType = m_astBuilder->create<TypeType>();
+ baseTypeType->type = differentialType;
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+ fieldLookupExpr->declRef = makeDeclRef(diffField);
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ };
+
+ // Go through super types.
+ for (auto inheritance : context->parentDecl->getMembersOfType<InheritanceDecl>())
+ {
+ if (auto baseDeclRefType = as<DeclRefType>(inheritance->base.type))
+ {
+ // Skip interface super types.
+ if (baseDeclRefType->declRef.as<InterfaceDecl>())
+ continue;
+ if (auto superDiffType = tryGetDifferentialType(m_astBuilder, baseDeclRefType))
+ {
+ addDiffMember(inheritance, superDiffType);
+ }
+ }
+ }
+
+ // We go through all members and generate their differential counterparts.
+ for (auto member : context->parentDecl->getMembersOfType<VarDeclBase>())
+ {
+ auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type);
+ if (!diffType)
+ continue;
+ addDiffMember(member, diffType);
+ }
+
+ // In the future when the Differential type itself needs to conform to some interface,
+ // this is the place to synthesize requirements for them.
+ addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>());
+ auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr);
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
+ return true;
+ }
+
void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*)
{
// If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any
@@ -2146,6 +2244,13 @@ namespace Slang
DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef,
RefPtr<WitnessTable> witnessTable)
{
+ if (auto declRefType = as<DeclRefType>(satisfyingType))
+ {
+ // If we are seeing a placeholder that awaits synthesis, return false now to trigger
+ // auto synthesis.
+ if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>())
+ return false;
+ }
// We need to confirm that the chosen type `satisfyingType`,
// meets all the constraints placed on the associated type
// requirement `requiredAssociatedTypeDeclRef`.
@@ -2947,6 +3052,21 @@ namespace Slang
witnessTable);
}
+ if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>())
+ {
+ if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>())
+ {
+ switch (builtinAttr->kind)
+ {
+ case BuiltinAssociatedTypeRequirementKind::Differential:
+ return trySynthesizeDifferentialAssociatedTypeRequirementWitness(
+ context,
+ requiredAssocTypeDeclRef,
+ witnessTable);
+ }
+ }
+ }
+
// TODO: There are other kinds of requirements for which synthesis should
// be possible:
//
@@ -4876,7 +4996,7 @@ namespace Slang
// We will now look for other declarations with
// the same name in the same parent/container.
//
- buildMemberDictionary(parentDecl);
+ parentDecl->buildMemberDictionary();
for (auto oldDecl = newDecl->nextInContainerWithSameName; oldDecl; oldDecl = oldDecl->nextInContainerWithSameName)
{
// For each matching declaration, we will check
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 29b44e726..d69cd39ed 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -393,12 +393,107 @@ namespace Slang
return derefExpr;
}
+ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr)
+ {
+ // If the only result from lookup is an entry in an interface decl, it could be that
+ // the user is leaving out an explicit definition for the requirement and depending on
+ // the compiler to synthesis the definition.
+ // In this case, if the lookup is triggered from a location such that the satisfying
+ // definition should be returned should it existed, we should create a placeholder decl for
+ // the definition and return a reference to to newly created decl instead of the requirement
+ // decl in the interface.
+ switch (item.declRef.getDecl()->astNodeType)
+ {
+ case ASTNodeType::AssocTypeDecl:
+ return maybeUseSynthesizedTypeDeclForLookupResult(item, originalExpr);
+ default:
+ return nullptr;
+ }
+ }
+
+ Expr* SemanticsVisitor::maybeUseSynthesizedTypeDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr)
+ {
+ // We need to check if the lookup should resolve to a definition in an implementation type
+ // if it existed.
+ // This will be the case when the lookup is initiated from the concrete implementation type instead of
+ // directly from the Interface decl. The breadcrumbs of the lookup should provide this information.
+
+ // If no breadcrumbs existed, then the lookup should just resolve to the interface requirement.
+
+ if (!item.breadcrumbs)
+ return nullptr;
+
+ // We will only ever need to synthesis a type to satisfy an associatedtype requirement.
+ // In this case the lookup should have resolved to a known associatedtype decl.
+ auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>();
+ if (!builtinAssocTypeAttr)
+ return nullptr;
+
+ DeclRefType* subType = nullptr;
+
+ // Check if we are reaching the associated type decl through inheritance from a concrete type.
+ for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next)
+ {
+ switch (breadcrumb->kind)
+ {
+ case LookupResultItem::Breadcrumb::Kind::SuperType:
+ {
+ auto witness = as<SubtypeWitness>(breadcrumb->val);
+ if (auto subDeclRefType = as<DeclRefType>(witness->sub))
+ {
+ if (!as<InterfaceDecl>(subDeclRefType->declRef.getDecl()))
+ {
+ // Store the inner most concrete super type.
+ subType = subDeclRefType;
+ }
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ if (!subType)
+ return nullptr;
+
+ subType = as<DeclRefType>(subType->getCanonicalType());
+ if (!subType)
+ return nullptr;
+
+ // Don't synthesize for generic parameters.
+ auto parent = as<AggTypeDecl>(subType->declRef.getDecl());
+ if (!parent)
+ return nullptr;
+
+ // If we reach here, we are expecting a synthesized associated type defined in `subType`.
+ // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder type
+ // in `subType` and return a DeclRefExpr to the synthesized decl.
+ auto assocType = m_astBuilder->create<StructDecl>();
+ assocType->parentDecl = parent;
+ assocType->nameAndLoc.name = item.declRef.getName();
+ assocType->loc = parent->loc;
+ parent->members.add(assocType);
+ parent->invalidateMemberDictionary();
+
+ // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
+ // from user-provided definitions, and proceed to fill in its definition.
+ auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
+ addModifier(assocType, toBeSynthesized);
+
+ return ConstructDeclRefExpr(makeDeclRef(assocType), nullptr, originalExpr->loc, originalExpr);
+ }
+
Expr* SemanticsVisitor::ConstructLookupResultExpr(
LookupResultItem const& item,
Expr* baseExpr,
SourceLoc loc,
Expr* originalExpr)
{
+ // We could be referencing a decl that will be synthesized. If so create a placeholder
+ // and return a DeclRefExpr to it.
+ if (auto lookupResultExpr = maybeUseSynthesizedDeclForLookupResult(item, originalExpr))
+ return lookupResultExpr;
+
// If we collected any breadcrumbs, then these represent
// additional segments of the lookup path that we need
// to expand here.
@@ -719,21 +814,25 @@ namespace Slang
return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink());
}
- Type* SemanticsVisitor::_getDifferential(ASTBuilder* builder, Type* type)
+ Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type)
{
if (auto ptrType = as<PtrTypeBase>(type))
{
+ auto baseDiffType = tryGetDifferentialType(builder, ptrType->getValueType());
+ if (!baseDiffType) return nullptr;
return builder->getPtrType(
- _getDifferential(builder, ptrType->getValueType()),
+ baseDiffType,
ptrType->getClassInfo().m_name);
}
else if (auto arrayType = as<ArrayExpressionType>(type))
{
+ auto baseDiffType = tryGetDifferentialType(builder, arrayType->baseType);
+ if (!baseDiffType) return nullptr;
return builder->getArrayType(
- _getDifferential(builder, arrayType->baseType),
+ baseDiffType,
arrayType->arrayLength);
}
-
+
if (auto declRefType = as<DeclRefType>(type))
{
if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface())))
@@ -745,17 +844,16 @@ namespace Slang
type,
Slang::LookupMask::type,
Slang::LookupOptions::None);
-
+
diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult);
if (!diffTypeLookupResult.isValid())
{
- // Diagnose no 'Differential' member.
- getSink()->diagnose(declRefType->declRef, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential"));
+ return nullptr;
}
else if (diffTypeLookupResult.isOverloaded())
{
- getSink()->diagnose(declRefType->declRef, Diagnostics::ambiguousReference, getName("Differential"));
+ return nullptr;
}
else
{
@@ -764,17 +862,28 @@ namespace Slang
baseTypeExpr->type.type = m_astBuilder->getTypeType(type);
auto diffTypeExpr = ConstructLookupResultExpr(
- diffTypeLookupResult.item,
- baseTypeExpr,
- declRefType->declRef.getLoc(),
- baseTypeExpr);
-
+ diffTypeLookupResult.item,
+ baseTypeExpr,
+ declRefType->declRef.getLoc(),
+ baseTypeExpr);
+
return ExtractTypeFromTypeRepr(diffTypeExpr);
}
}
}
- return m_astBuilder->getErrorType();
+ return nullptr;
+ }
+
+ Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc)
+ {
+ auto result = tryGetDifferentialType(builder, type);
+ if (!result)
+ {
+ getSink()->diagnose(loc, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential"));
+ return m_astBuilder->getErrorType();
+ }
+ return result;
}
void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type)
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 0877f2d6e..ac1d624c2 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -630,6 +630,14 @@ namespace Slang
Expr* base,
SourceLoc loc);
+ Expr* maybeUseSynthesizedTypeDeclForLookupResult(
+ LookupResultItem const& item,
+ Expr* orignalExpr);
+
+ Expr* maybeUseSynthesizedDeclForLookupResult(
+ LookupResultItem const& item,
+ Expr* orignalExpr);
+
Expr* ConstructLookupResultExpr(
LookupResultItem const& item,
Expr* baseExpr,
@@ -804,7 +812,9 @@ namespace Slang
void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type);
// Construct the differential for 'type', if it exists.
- Type* _getDifferential(ASTBuilder* builder, Type* type);
+ Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc);
+ Type* tryGetDifferentialType(ASTBuilder* builder, Type* type);
+
public:
@@ -1094,6 +1104,18 @@ namespace Slang
DeclRef<Decl> requiredMemberDeclRef,
RefPtr<WitnessTable> witnessTable);
+ /// Attempt to synthesize an associated `Differential` type for a type that conforms to
+ /// `IDifferentiable`.
+ ///
+ /// On success, installs the syntethesized type in `witnessTable`, injects `[DerivativeMember]`
+ /// modifiers on differentiable fields to point to the corresponding field in the synthesized
+ /// differential type, and returns `true`.
+ /// Otherwise, returns `false`.
+ bool trySynthesizeDifferentialAssociatedTypeRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<Decl> requirementDeclRef,
+ RefPtr<WitnessTable> witnessTable);
+
/// Registers a type as differentiable in the currrent semantic context, if the declaration represents
/// a subtype of IDifferentable. Does nothing otherwise.
void tryAddDifferentiableConformanceToContext(
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index f977721dd..7e11ee3ca 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -292,7 +292,7 @@ namespace Slang
bool SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget)
{
- if(auto numThreadsAttr = as<NumThreadsAttribute>(attr))
+ if (auto numThreadsAttr = as<NumThreadsAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 3);
@@ -320,9 +320,9 @@ namespace Slang
values[i] = value;
}
- numThreadsAttr->x = values[0];
- numThreadsAttr->y = values[1];
- numThreadsAttr->z = values[2];
+ numThreadsAttr->x = values[0];
+ numThreadsAttr->y = values[1];
+ numThreadsAttr->z = values[2];
}
else if (auto anyValueSizeAttr = as<AnyValueSizeAttribute>(attr))
{
@@ -368,7 +368,7 @@ namespace Slang
{
return false;
}
-
+
bindingAttr->binding = int32_t(binding->value);
bindingAttr->set = int32_t(set->value);
}
@@ -395,31 +395,31 @@ namespace Slang
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if(!val) return false;
+ if (!val) return false;
maxVertexCountAttr->value = (int32_t)val->value;
}
- else if(auto instanceAttr = as<InstanceAttribute>(attr))
+ else if (auto instanceAttr = as<InstanceAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if(!val) return false;
+ if (!val) return false;
instanceAttr->value = (int32_t)val->value;
}
- else if(auto entryPointAttr = as<EntryPointAttribute>(attr))
+ else if (auto entryPointAttr = as<EntryPointAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
String stageName;
- if(!checkLiteralStringVal(attr->args[0], &stageName))
+ if (!checkLiteralStringVal(attr->args[0], &stageName))
{
return false;
}
auto stage = findStageByName(stageName);
- if(stage == Stage::Unknown)
+ if (stage == Stage::Unknown)
{
getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName);
}
@@ -427,10 +427,10 @@ namespace Slang
entryPointAttr->stage = stage;
}
else if ((as<DomainAttribute>(attr)) ||
- (as<MaxTessFactorAttribute>(attr)) ||
- (as<OutputTopologyAttribute>(attr)) ||
- (as<PartitioningAttribute>(attr)) ||
- (as<PatchConstantFuncAttribute>(attr)))
+ (as<MaxTessFactorAttribute>(attr)) ||
+ (as<OutputTopologyAttribute>(attr)) ||
+ (as<PartitioningAttribute>(attr)) ||
+ (as<PatchConstantFuncAttribute>(attr)))
{
// Let it go thru iff single string attribute
if (!hasStringArgs(attr, 1))
@@ -439,7 +439,7 @@ namespace Slang
}
}
else if (as<OutputControlPointsAttribute>(attr) ||
- as<SPIRVInstructionOpAttribute>(attr))
+ as<SPIRVInstructionOpAttribute>(attr))
{
// Let it go thru iff single integral attribute
if (!hasIntArgs(attr, 1))
@@ -484,6 +484,27 @@ namespace Slang
return false;
}
}
+ else if (auto builtinAssocTypeAttr = as<BuiltinAssociatedTypeRequirementAttribute>(attr))
+ {
+ if (attr->args.getCount() == 1)
+ {
+ //IntVal* outIntVal;
+ if (auto cInt = checkConstantEnumVal(attr->args[0]))
+ {
+ builtinAssocTypeAttr->kind = (BuiltinAssociatedTypeRequirementKind)(cInt->value);
+ }
+ else
+ {
+ getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName);
+ return false;
+ }
+ }
+ else
+ {
+ getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName);
+ return false;
+ }
+ }
else if (auto unrollAttr = as<UnrollAttribute>(attr))
{
// Check has an argument. We need this because default behavior is to give an error
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index d7200d47c..a84e40768 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -236,13 +236,10 @@ namespace Slang
{
auto translationUnitSyntax = translationUnit->getModuleDecl();
- // Make sure we've got a query-able member dictionary
- buildMemberDictionary(translationUnitSyntax);
-
// We will look up any global-scope declarations in the translation
// unit that match the name of our entry point.
Decl* firstDeclWithName = nullptr;
- if (!translationUnitSyntax->memberDictionary.TryGetValue(name, firstDeclWithName))
+ if (!translationUnitSyntax->getMemberDictionary().TryGetValue(name, firstDeclWithName))
{
// If there doesn't appear to be any such declaration, then we are done.
@@ -454,13 +451,10 @@ namespace Slang
auto entryPointName = entryPointReq->getName();
- // Make sure we've got a query-able member dictionary
- buildMemberDictionary(translationUnitSyntax);
-
// We will look up any global-scope declarations in the translation
// unit that match the name of our entry point.
Decl* firstDeclWithName = nullptr;
- if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPointName, firstDeclWithName) )
+ if( !translationUnitSyntax->getMemberDictionary().TryGetValue(entryPointName, firstDeclWithName))
{
// If there doesn't appear to be any such declaration, then
// we need to diagnose it as an error, and then bail out.
diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp
index 4d8afd763..9130c05ed 100644
--- a/source/slang/slang-doc-markdown-writer.cpp
+++ b/source/slang/slang-doc-markdown-writer.cpp
@@ -667,13 +667,10 @@ static bool _isFirstOverridden(Decl* decl)
ContainerDecl* parentDecl = decl->parentDecl;
- // Make sure we have the member dictionary.
- buildMemberDictionary(parentDecl);
-
Name* declName = decl->getName();
if (declName)
{
- Decl** firstDeclPtr = parentDecl->memberDictionary.TryGetValue(declName);
+ Decl** firstDeclPtr = parentDecl->getMemberDictionary().TryGetValue(declName);
return (firstDeclPtr && *firstDeclPtr == decl) || (firstDeclPtr == nullptr);
}
@@ -1061,11 +1058,10 @@ void DocMarkdownWriter::writeAggType(const ASTMarkup::Entry& entry, AggTypeDeclB
{
// Make sure we've got a query-able member dictionary
- buildMemberDictionary(aggTypeDecl);
- SLANG_ASSERT(aggTypeDecl->isMemberDictionaryValid());
+ auto& memberDict = aggTypeDecl->getMemberDictionary();
List<Decl*> uniqueMethods;
- for (const auto& pair : aggTypeDecl->memberDictionary)
+ for (const auto& pair : memberDict)
{
CallableDecl* callableDecl = as<CallableDecl>(pair.Value);
if (callableDecl && isVisible(callableDecl))
diff --git a/source/slang/slang-language-server-semantic-tokens.cpp b/source/slang/slang-language-server-semantic-tokens.cpp
index 3754c46aa..485dd7a44 100644
--- a/source/slang/slang-language-server-semantic-tokens.cpp
+++ b/source/slang/slang-language-server-semantic-tokens.cpp
@@ -60,7 +60,6 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS
.pathInfo.foundPath.getUnownedSlice()
.endsWithCaseInsensitive(fileName))
return;
-
SemanticToken token =
_createSemanticToken(manager, declRef->loc, declRef->name);
auto target = declRef->declRef.decl;
diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp
index c574be4ea..c560b67f9 100644
--- a/source/slang/slang-lookup.cpp
+++ b/source/slang/slang-lookup.cpp
@@ -30,63 +30,6 @@ struct BreadcrumbInfo
//
-void buildMemberDictionary(ContainerDecl* decl)
-{
- // Don't rebuild if already built
- if (decl->isMemberDictionaryValid())
- return;
-
- // If it's < 0 it means that the dictionaries are entirely invalid
- if (decl->dictionaryLastCount < 0)
- {
- decl->dictionaryLastCount = 0;
- decl->memberDictionary.Clear();
- decl->transparentMembers.clear();
- }
-
- // are we a generic?
- GenericDecl* genericDecl = as<GenericDecl>(decl);
-
- const Index membersCount = decl->members.getCount();
-
- SLANG_ASSERT(decl->dictionaryLastCount >= 0 && decl->dictionaryLastCount <= membersCount);
-
- for (Index i = decl->dictionaryLastCount; i < membersCount; ++i)
- {
- Decl* m = decl->members[i];
-
- auto name = m->getName();
-
- // Add any transparent members to a separate list for lookup
- if (m->hasModifier<TransparentModifier>())
- {
- TransparentMemberInfo info;
- info.decl = m;
- decl->transparentMembers.add(info);
- }
-
- // Ignore members with no name
- if (!name)
- continue;
-
- // Ignore the "inner" member of a generic declaration
- if (genericDecl && m == genericDecl->inner)
- continue;
-
- m->nextInContainerWithSameName = nullptr;
-
- Decl* next = nullptr;
- if (decl->memberDictionary.TryGetValue(name, next))
- m->nextInContainerWithSameName = next;
-
- decl->memberDictionary[name] = m;
- }
-
- decl->dictionaryLastCount = membersCount;
- SLANG_ASSERT(decl->isMemberDictionaryValid());
-}
-
-
bool DeclPassesLookupMask(Decl* decl, LookupMask mask)
{
// Always exclude extern members from lookup result.
@@ -229,15 +172,9 @@ static void _lookUpDirectAndTransparentMembers(
}
else
{
- // Ensure that the lookup dictionary in the container is up to date
- if (!containerDecl->isMemberDictionaryValid())
- {
- buildMemberDictionary(containerDecl);
- }
-
// Look up the declarations with the chosen name in the container.
Decl* firstDecl = nullptr;
- containerDecl->memberDictionary.TryGetValue(name, firstDecl);
+ containerDecl->getMemberDictionary().TryGetValue(name, firstDecl);
// Now iterate over those declarations (if any) and see if
// we find any that meet our filtering criteria.
@@ -255,7 +192,7 @@ static void _lookUpDirectAndTransparentMembers(
// TODO(tfoley): should we look up in the transparent decls
// if we already has a hit in the current container?
- for(auto transparentInfo : containerDecl->transparentMembers)
+ for(auto transparentInfo : containerDecl->getTransparentMembers())
{
// The reference to the transparent member should use whatever
// substitutions we used in referring to its outer container
diff --git a/source/slang/slang-lookup.h b/source/slang/slang-lookup.h
index 0f034d100..7a9346498 100644
--- a/source/slang/slang-lookup.h
+++ b/source/slang/slang-lookup.h
@@ -11,10 +11,6 @@ struct SemanticsVisitor;
// results that pass the given `LookupMask`.
LookupResult refineLookup(LookupResult const& inResult, LookupMask mask);
-// Ensure that the dictionary for name-based member lookup has been
-// built for the given container declaration.
-void buildMemberDictionary(ContainerDecl* decl);
-
// Look up a name in the given scope, proceeding up through
// parent scopes as needed.
LookupResult lookUp(
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index c0c035211..f2284a121 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -3078,11 +3078,6 @@ namespace Slang
// would trigger a rebuild of the member dictionary that
// would take O(N) time.
//
- // Eventually we should make `builtMemberDictionary()`
- // incremental, so that it only has to process members
- // added since the last time it was invoked.
- //
- buildMemberDictionary(parentDecl);
// There might be multiple members of the same name
// (if we define a namespace `foo` after an overloaded
@@ -3090,7 +3085,7 @@ namespace Slang
// lookup will only give us the first.
//
Decl* firstDecl = nullptr;
- parentDecl->memberDictionary.TryGetValue(nameAndLoc.name, firstDecl);
+ parentDecl->getMemberDictionary().TryGetValue(nameAndLoc.name, firstDecl);
//
// We will search through the declarations of the name
// and find the first that is a namespace (if any).
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 5b4b61849..c779b4510 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -682,9 +682,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
return Slang::as<Type>(type->substitute(astBuilder, substs));
}
-
- void buildMemberDictionary(ContainerDecl* decl);
-
InterfaceDecl* findOuterInterfaceDecl(Decl* decl)
{
Decl* dd = decl;