summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ast-builder.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-30 19:24:09 -0800
committerGitHub <noreply@github.com>2023-01-30 19:24:09 -0800
commit499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (patch)
tree4c570a36d305c8909d633183694e0d1225f044c2 /source/slang/slang-ast-builder.cpp
parent134dd7eb26fc7988ae13559d276cbf337b4b9d27 (diff)
Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615)
* Allow array parameters in forward diff. * Use type canonicalization instead of coersion. * Reimplement array type. * Fix. * Update test case. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
-rw-r--r--source/slang/slang-ast-builder.cpp15
1 files changed, 10 insertions, 5 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index ab161065d..03725901e 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -290,13 +290,18 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount)
{
- ArrayExpressionType* arrayType = getOrCreateWithDefaultCtor<ArrayExpressionType>(elementType, elementCount);
- if (!arrayType->baseType)
+ if (!elementCount)
+ elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength);
+
+ auto result = getOrCreate<ArrayExpressionType>(elementType, elementCount);
+ if (!result->declRef.decl)
{
- arrayType->baseType = elementType;
- arrayType->arrayLength = elementCount;
+ auto arrayGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ArrayType"));
+ auto arrayTypeDecl = arrayGenericDecl->inner;
+ auto substitutions = getOrCreate<GenericSubstitution>(arrayGenericDecl, elementType, elementCount);
+ result->declRef = DeclRef<Decl>(arrayTypeDecl, substitutions);
}
- return arrayType;
+ return result;
}
VectorExpressionType* ASTBuilder::getVectorType(