diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-09-19 03:10:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-19 00:10:28 -0700 |
| commit | ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch) | |
| tree | 435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/slang-ast-builder.cpp | |
| parent | 1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff) | |
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface
* Initial support for `IDifferentiablePtrType`
* Fix unused vars
* More tests + fix switch case fallthrough.
* Update slang-ir-autodiff.cpp
* Update diff-ptr-type-loop.slang
* Add optimization to allow more complex pair types
* Update slang-ir-autodiff-primal-hoist.cpp
* Update diff-ptr-type-loop.slang
* Update slang-ir-autodiff-primal-hoist.cpp
* More fixes to address reviews
* Update slang-check-expr.cpp
* Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType`
* Move pair logic to ir-builder, unify the type dictionaries.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 9879a4187..b66af34fa 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, - Witness* primalIsDifferentialWitness) + Witness* diffTypeWitness) { - Val* args[] = { valueType, primalIsDifferentialWitness }; + Val* args[] = { valueType, diffTypeWitness }; return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); } +DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType( + Type* valueType, + Witness* diffRefTypeWitness) +{ + Val* args[] = { valueType, diffRefTypeWitness }; + return as<DifferentialPtrPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType")); +} + DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl() { DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr)); return declRef; } +DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableRefInterfaceDecl() +{ + DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiablePtrType", nullptr)); + return declRef; +} + bool ASTBuilder::isDifferentiableInterfaceAvailable() { return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr); @@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType() return DeclRefType::create(this, getDifferentiableInterfaceDecl()); } +Type* ASTBuilder::getDifferentiableRefInterfaceType() +{ + return DeclRefType::create(this, getDifferentiableRefInterfaceDecl()); +} + DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); |
