summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-19 03:10:28 -0400
committerGitHub <noreply@github.com>2024-09-19 00:10:28 -0700
commitccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch)
tree435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/slang-check-expr.cpp
parent1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff)
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp32
1 files changed, 25 insertions, 7 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 8f24ec5b0..5233008fd 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1131,7 +1131,8 @@ namespace Slang
{
if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>())
{
- if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType)
+ if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType
+ || builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType)
{
// We are trying to get differential type from a differential type.
// The result is itself.
@@ -1139,7 +1140,10 @@ namespace Slang
}
}
type = resolveType(type);
- if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())))
+ auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()));
+ if (!witness)
+ witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType()));
+ if (witness)
{
auto diffTypeLookupResult = lookUpMember(
getASTBuilder(),
@@ -1367,6 +1371,13 @@ namespace Slang
{
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
}
+
+ if (auto subtypeWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType())))
+ {
+ addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ }
+
if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
{
foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
@@ -2899,18 +2910,25 @@ namespace Slang
return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType));
}
}
+
// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType();
+ auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType();
- auto conformanceWitness = as<Witness>(isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None));
// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
- if (conformanceWitness)
+ if (auto conformanceWitness = isTypeDifferentiable(primalType))
{
- return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
+ if (conformanceWitness->getSup() == differentiableInterface)
+ {
+ return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
+ }
+ else if (conformanceWitness->getSup() == differentiableRefInterface)
+ {
+ return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness);
+ }
}
- else
- return primalType;
+ return primalType;
}
Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType)