From ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 19 Sep 2024 03:10:28 -0400 Subject: Support `IDifferentiablePtrType` (#5031) * initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index deb8c55eb..8e78ff084 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10204,7 +10204,8 @@ namespace Slang bool isDiffParam = (!param->findModifier()); if (isDiffParam) { - if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) + auto diffPair = visitor->getDifferentialPairType(param->getType()); + if (auto pairType = as(diffPair)) { arg->type.type = pairType; arg->type.isLeftValue = true; @@ -10225,6 +10226,11 @@ namespace Slang direction = ParameterDirection::kParameterDirection_InOut; } } + else if (auto refPairType = as(diffPair)) + { + // no need to change direction of ref-pairs. + arg->type.type = refPairType; + } else { isDiffParam = false; -- cgit v1.2.3