summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-type.cpp11
-rw-r--r--source/slang/slang-check-decl.cpp50
-rw-r--r--source/slang/slang-check-expr.cpp28
-rw-r--r--source/slang/slang-ir-autodiff.cpp12
-rw-r--r--source/slang/slang-syntax.cpp37
5 files changed, 111 insertions, 27 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 362503a64..fdbd56377 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -208,8 +208,19 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(declRef.getDecl()))
{
if (auto result = maybeSubstituteGenericParam(this, genericTypeParamDecl, subst, ioDiff))
+ {
+ if (auto substDeclRefType = as<DeclRefType>(result))
+ {
+ // After generic substitution, we may be able to further simplify
+ // by looking up the actual type of an associated type.
+ if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(
+ astBuilder, substDeclRefType->declRef))
+ return satisfyingVal;
+ }
return result;
+ }
}
+
int diff = 0;
DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 580ad8402..c0253fd2c 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1543,18 +1543,48 @@ namespace Slang
};
// Make the Differential type itself conform to `IDifferential` interface.
- auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>();
- inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType();
- inheritanceIDiffernetiable->parentDecl = aggTypeDecl;
- aggTypeDecl->members.add(inheritanceIDiffernetiable);
+ bool hasDifferentialConformance = false;
+ for (auto inheritanceDecl : aggTypeDecl->getMembersOfType<InheritanceDecl>())
+ {
+ if (auto declRefType = as<DeclRefType>(inheritanceDecl->base.type))
+ {
+ if (declRefType->declRef == m_astBuilder->getDifferentiableInterface())
+ {
+ hasDifferentialConformance = true;
+ break;
+ }
+ }
+ }
+ if (!hasDifferentialConformance)
+ {
+ auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>();
+ inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType();
+ inheritanceIDiffernetiable->parentDecl = aggTypeDecl;
+ aggTypeDecl->members.add(inheritanceIDiffernetiable);
+ }
// The `Differential` type of a `Differential` type is always itself.
- auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
- assocTypeDef->nameAndLoc.name = getName("Differential");
- assocTypeDef->type.type = satisfyingType;
- assocTypeDef->parentDecl = aggTypeDecl;
- assocTypeDef->setCheckState(DeclCheckState::Checked);
- aggTypeDecl->members.add(assocTypeDef);
+ bool hasDifferentialTypeDef = false;
+ for (auto member : aggTypeDecl->members)
+ {
+ if (auto name = member->getName())
+ {
+ if (name->text == "Differential")
+ {
+ hasDifferentialTypeDef = true;
+ break;
+ }
+ }
+ }
+ if (!hasDifferentialTypeDef)
+ {
+ auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
+ assocTypeDef->nameAndLoc.name = getName("Differential");
+ assocTypeDef->type.type = satisfyingType;
+ assocTypeDef->parentDecl = aggTypeDecl;
+ assocTypeDef->setCheckState(DeclCheckState::Checked);
+ aggTypeDecl->members.add(assocTypeDef);
+ }
// Go through all members and collect their differential types.
// Go through super types.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 8d8a72dd6..bfad1dbfe 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -487,12 +487,25 @@ namespace Slang
switch (builtinAssocTypeAttr->kind)
{
case BuiltinRequirementKind::DifferentialType:
- synthesizedDecl = m_astBuilder->create<StructDecl>();
+ {
+ auto structDecl = m_astBuilder->create<StructDecl>();
+ auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
+ conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
+ conformanceDecl->parentDecl = structDecl;
+ structDecl->members.add(conformanceDecl);
+
+ synthesizedDecl = structDecl;
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = getName("Differential");
+ auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, DeclRef<Decl>(structDecl, nullptr));
+ typeDef->type.type = m_astBuilder->getOrCreateDeclRefType(declRef.decl, declRef.substitutions);
+ typeDef->parentDecl = structDecl;
+ structDecl->members.add(typeDef);
+ }
break;
default:
- break;
+ return nullptr;
}
- synthesizedDecl = m_astBuilder->create<StructDecl>();
synthesizedDecl->parentDecl = parent;
synthesizedDecl->nameAndLoc.name = item.declRef.getName();
synthesizedDecl->loc = parent->loc;
@@ -645,6 +658,15 @@ namespace Slang
default:
SLANG_UNREACHABLE("all cases handle");
}
+ if (getShared()->isInLanguageServer())
+ {
+ // Don't make breadcrumb nodes carry any source loc info,
+ // as they may confuse language server functionalities.
+ if (bb)
+ {
+ bb->loc = SourceLoc();
+ }
+ }
}
return ConstructDeclRefExpr(item.declRef, bb, loc, originalExpr);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index e9b78696e..224cca9e0 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -384,6 +384,18 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
else
{
differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
+ if (auto diffPairType = as<IRDifferentialPairTypeBase>(item->getConcreteType()))
+ {
+ // For differential pair types, register the differential type as well.
+ IRBuilder builder(diffPairType);
+ builder.setInsertAfter(diffPairType->getWitness());
+ auto diffType = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey);
+ auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey);
+ if (diffType && diffWitness)
+ {
+ differentiableWitnessDictionary.AddIfNotExists((IRType*)diffType, diffWitness);
+ }
+ }
}
}
}
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 470f5f983..27aba435f 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -1234,23 +1234,32 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
}
// Hard code implementation of T.Differential.Differential == T.Differential rule.
- if (auto builtinReq = substDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>())
+ auto foldResult = [&]() -> Val*
{
- if (builtinReq->kind == BuiltinRequirementKind::DifferentialType)
+ auto builtinReq = substDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>();
+ if (!builtinReq)
+ return nullptr;
+ if (builtinReq->kind != BuiltinRequirementKind::DifferentialType)
+ return nullptr;
+ // Is the concrete type a Differential associated type?
+ auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub);
+ if (!innerDeclRefType)
+ return nullptr;
+ auto innerBuiltinReq = innerDeclRefType->declRef.decl->findModifier<BuiltinRequirementModifier>();
+ if (!innerBuiltinReq)
+ return nullptr;
+ if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType)
+ return nullptr;
+ if (!innerDeclRefType->declRef.equals(declRef))
{
- // Is the concrete type a Differential associated type?
- if (auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub))
- {
- if (auto innerBuiltinReq = innerDeclRefType->declRef.decl->findModifier<BuiltinRequirementModifier>())
- {
- if (innerBuiltinReq->kind == BuiltinRequirementKind::DifferentialType)
- {
- return innerDeclRefType;
- }
- }
- }
+ auto result = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(builder, innerDeclRefType->declRef);
+ if (result)
+ return result;
}
- }
+ return innerDeclRefType;
+ }();
+ if (foldResult)
+ return foldResult;
}
}
}