summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp32
1 files changed, 25 insertions, 7 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 8f24ec5b0..5233008fd 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1131,7 +1131,8 @@ namespace Slang
{
if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>())
{
- if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType)
+ if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType
+ || builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType)
{
// We are trying to get differential type from a differential type.
// The result is itself.
@@ -1139,7 +1140,10 @@ namespace Slang
}
}
type = resolveType(type);
- if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())))
+ auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()));
+ if (!witness)
+ witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType()));
+ if (witness)
{
auto diffTypeLookupResult = lookUpMember(
getASTBuilder(),
@@ -1367,6 +1371,13 @@ namespace Slang
{
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
}
+
+ if (auto subtypeWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType())))
+ {
+ addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ }
+
if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
{
foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
@@ -2899,18 +2910,25 @@ namespace Slang
return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType));
}
}
+
// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType();
+ auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType();
- auto conformanceWitness = as<Witness>(isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None));
// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
- if (conformanceWitness)
+ if (auto conformanceWitness = isTypeDifferentiable(primalType))
{
- return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
+ if (conformanceWitness->getSup() == differentiableInterface)
+ {
+ return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
+ }
+ else if (conformanceWitness->getSup() == differentiableRefInterface)
+ {
+ return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness);
+ }
}
- else
- return primalType;
+ return primalType;
}
Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType)