diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-09-19 03:10:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-19 00:10:28 -0700 |
| commit | ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch) | |
| tree | 435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/slang-check-expr.cpp | |
| parent | 1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff) | |
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface
* Initial support for `IDifferentiablePtrType`
* Fix unused vars
* More tests + fix switch case fallthrough.
* Update slang-ir-autodiff.cpp
* Update diff-ptr-type-loop.slang
* Add optimization to allow more complex pair types
* Update slang-ir-autodiff-primal-hoist.cpp
* Update diff-ptr-type-loop.slang
* Update slang-ir-autodiff-primal-hoist.cpp
* More fixes to address reviews
* Update slang-check-expr.cpp
* Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType`
* Move pair logic to ir-builder, unify the type dictionaries.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 8f24ec5b0..5233008fd 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1131,7 +1131,8 @@ namespace Slang { if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>()) { - if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType) + if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType + || builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType) { // We are trying to get differential type from a differential type. // The result is itself. @@ -1139,7 +1140,10 @@ namespace Slang } } type = resolveType(type); - if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()))) + auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())); + if (!witness) + witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType())); + if (witness) { auto diffTypeLookupResult = lookUpMember( getASTBuilder(), @@ -1367,6 +1371,13 @@ namespace Slang { addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); } + + if (auto subtypeWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType()))) + { + addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); + } + if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>()) { foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) @@ -2899,18 +2910,25 @@ namespace Slang return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); } } + // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); + auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType(); - auto conformanceWitness = as<Witness>(isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. - if (conformanceWitness) + if (auto conformanceWitness = isTypeDifferentiable(primalType)) { - return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + if (conformanceWitness->getSup() == differentiableInterface) + { + return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + } + else if (conformanceWitness->getSup() == differentiableRefInterface) + { + return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness); + } } - else - return primalType; + return primalType; } Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) |
