summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-conformance.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-01 08:46:57 -0700
committerGitHub <noreply@github.com>2022-11-01 08:46:57 -0700
commitcbc1eff56057f199183bb7c17d8a360326512367 (patch)
tree487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/slang-check-conformance.cpp
parentb707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff)
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/slang-check-conformance.cpp')
-rw-r--r--source/slang/slang-check-conformance.cpp42
1 files changed, 40 insertions, 2 deletions
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index 2c9977082..eb072e9dd 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -33,7 +33,6 @@ namespace Slang
else
return conjunction->rightWitness;
}
-
ExtractFromConjunctionSubtypeWitness* simplExtractFromConjunction = builder->create<ExtractFromConjunctionSubtypeWitness>();
simplExtractFromConjunction->sub = extractFromConjunction->sub;
simplExtractFromConjunction->sup = extractFromConjunction->sup;
@@ -145,7 +144,7 @@ namespace Slang
m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions);
- TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness);
+ TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>();
transitiveWitness->sub = subType;
transitiveWitness->sup = bb->sup;
transitiveWitness->midToSup = declaredWitness;
@@ -379,6 +378,45 @@ namespace Slang
}
}
}
+
+ // If a generic type parameter does not declare itself to conform to `IDifferentiable`,
+ // we treat it as a subtype of `DifferentialBottom` to make it conform to `IDifferentiable`.
+ // Note: we only consider this option for `originalSubType` so a type that implements `IDifferential` but
+ // inherits from some other non differentiable types don't get to inherit `DifferentialBottom`.
+ if (m_astBuilder->isDifferentiableInterfaceAvailable() &&
+ subType == originalSubType &&
+ superTypeDeclRef.getDecl() == m_astBuilder->getDifferentiableInterface())
+ {
+ if (as<GenericTypeParamDecl>(declRefType->declRef.getDecl()) ||
+ as<AssocTypeDecl>(declRefType->declRef.getDecl()))
+ {
+ auto sup = DeclRefType::create(m_astBuilder, superTypeDeclRef);
+ auto differentialBottomType = as<DeclRefType>(m_astBuilder->getDifferentialBottomType());
+ auto container = differentialBottomType->declRef.as<ContainerDecl>().getDecl();
+ SLANG_RELEASE_ASSERT(container);
+ auto inheritanceDecl = container->getMembersOfType<InheritanceDecl>().getFirst();
+ auto witnessDifferentialBottomIsIDifferentiable =
+ m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
+ m_astBuilder->getDifferentialBottomType(),
+ sup,
+ inheritanceDecl,
+ nullptr);
+
+ auto witnessSubIsDifferentialBottom =
+ m_astBuilder->getOrCreate<DifferentialBottomSubtypeWitness>(
+ subType, differentialBottomType);
+
+ TransitiveSubtypeWitness* transitiveWitness =
+ m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(
+ witnessSubIsDifferentialBottom, witnessDifferentialBottomIsIDifferentiable);
+ transitiveWitness->sub = subType;
+ transitiveWitness->sup = sup;
+ transitiveWitness->midToSup = witnessDifferentialBottomIsIDifferentiable;
+ transitiveWitness->subToMid = witnessSubIsDifferentialBottom;
+ *outWitness = transitiveWitness;
+ return true;
+ }
+ }
}
else if (auto extractExistentialType = as<ExtractExistentialType>(subType))
{