summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-builder.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
-rw-r--r--source/slang/slang-ast-builder.cpp23
1 files changed, 21 insertions, 2 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 9879a4187..b66af34fa 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo
DifferentialPairType* ASTBuilder::getDifferentialPairType(
Type* valueType,
- Witness* primalIsDifferentialWitness)
+ Witness* diffTypeWitness)
{
- Val* args[] = { valueType, primalIsDifferentialWitness };
+ Val* args[] = { valueType, diffTypeWitness };
return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));
}
+DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType(
+ Type* valueType,
+ Witness* diffRefTypeWitness)
+{
+ Val* args[] = { valueType, diffRefTypeWitness };
+ return as<DifferentialPtrPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType"));
+}
+
DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl()
{
DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr));
return declRef;
}
+DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableRefInterfaceDecl()
+{
+ DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiablePtrType", nullptr));
+ return declRef;
+}
+
bool ASTBuilder::isDifferentiableInterfaceAvailable()
{
return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr);
@@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType()
return DeclRefType::create(this, getDifferentiableInterfaceDecl());
}
+Type* ASTBuilder::getDifferentiableRefInterfaceType()
+{
+ return DeclRefType::create(this, getDifferentiableRefInterfaceDecl());
+}
+
DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg)
{
auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);