From ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 19 Sep 2024 03:10:28 -0400 Subject: Support `IDifferentiablePtrType` (#5031) * initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He --- source/slang/slang-ast-builder.cpp | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-ast-builder.cpp') diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 9879a4187..b66af34fa 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, - Witness* primalIsDifferentialWitness) + Witness* diffTypeWitness) { - Val* args[] = { valueType, primalIsDifferentialWitness }; + Val* args[] = { valueType, diffTypeWitness }; return as(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); } +DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType( + Type* valueType, + Witness* diffRefTypeWitness) +{ + Val* args[] = { valueType, diffRefTypeWitness }; + return as(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType")); +} + DeclRef ASTBuilder::getDifferentiableInterfaceDecl() { DeclRef declRef = DeclRef(getBuiltinDeclRef("DifferentiableType", nullptr)); return declRef; } +DeclRef ASTBuilder::getDifferentiableRefInterfaceDecl() +{ + DeclRef declRef = DeclRef(getBuiltinDeclRef("DifferentiablePtrType", nullptr)); + return declRef; +} + bool ASTBuilder::isDifferentiableInterfaceAvailable() { return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr); @@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType() return DeclRefType::create(this, getDifferentiableInterfaceDecl()); } +Type* ASTBuilder::getDifferentiableRefInterfaceType() +{ + return DeclRefType::create(this, getDifferentiableRefInterfaceDecl()); +} + DeclRef ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); -- cgit v1.2.3