summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-syntax.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-30 19:24:09 -0800
committerGitHub <noreply@github.com>2023-01-30 19:24:09 -0800
commit499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (patch)
tree4c570a36d305c8909d633183694e0d1225f044c2 /source/slang/slang-syntax.cpp
parent134dd7eb26fc7988ae13559d276cbf337b4b9d27 (diff)
Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615)
* Allow array parameters in forward diff. * Use type canonicalization instead of coersion. * Reimplement array type. * Fix. * Update test case. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-syntax.cpp')
-rw-r--r--source/slang/slang-syntax.cpp16
1 files changed, 9 insertions, 7 deletions
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 4e5db17c0..6076a41ca 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -513,6 +513,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
vecType->elementCount = ExtractGenericArgInteger(subst->getArgs()[1]);
return vecType;
}
+ else if (magicMod->magicName == "ArrayType")
+ {
+ SLANG_ASSERT(subst && subst->getArgs().getCount() == 2);
+ auto vecType = astBuilder->getOrCreate<ArrayExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1]));
+ vecType->declRef = declRef;
+ return vecType;
+ }
else if (magicMod->magicName == "Matrix")
{
SLANG_ASSERT(subst && subst->getArgs().getCount() == 3);
@@ -1097,19 +1104,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
Type* elementType,
IntVal* elementCount)
{
- auto arrayType = astBuilder->create<ArrayExpressionType>();
- arrayType->baseType = elementType;
- arrayType->arrayLength = elementCount;
- return arrayType;
+ return astBuilder->getArrayType(elementType, elementCount);
}
ArrayExpressionType* getArrayType(
ASTBuilder* astBuilder,
Type* elementType)
{
- auto arrayType = astBuilder->create<ArrayExpressionType>();
- arrayType->baseType = elementType;
- return arrayType;
+ return astBuilder->getArrayType(elementType, nullptr);
}
NamedExpressionType* getNamedType(