summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-builder.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-19 03:10:28 -0400
committerGitHub <noreply@github.com>2024-09-19 00:10:28 -0700
commitccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch)
tree435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/slang-ast-builder.cpp
parent1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff)
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
-rw-r--r--source/slang/slang-ast-builder.cpp23
1 files changed, 21 insertions, 2 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 9879a4187..b66af34fa 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo
DifferentialPairType* ASTBuilder::getDifferentialPairType(
Type* valueType,
- Witness* primalIsDifferentialWitness)
+ Witness* diffTypeWitness)
{
- Val* args[] = { valueType, primalIsDifferentialWitness };
+ Val* args[] = { valueType, diffTypeWitness };
return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));
}
+DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType(
+ Type* valueType,
+ Witness* diffRefTypeWitness)
+{
+ Val* args[] = { valueType, diffRefTypeWitness };
+ return as<DifferentialPtrPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType"));
+}
+
DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl()
{
DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr));
return declRef;
}
+DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableRefInterfaceDecl()
+{
+ DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiablePtrType", nullptr));
+ return declRef;
+}
+
bool ASTBuilder::isDifferentiableInterfaceAvailable()
{
return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr);
@@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType()
return DeclRefType::create(this, getDifferentiableInterfaceDecl());
}
+Type* ASTBuilder::getDifferentiableRefInterfaceType()
+{
+ return DeclRefType::create(this, getDifferentiableRefInterfaceDecl());
+}
+
DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg)
{
auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);