summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ast-type.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-30 19:24:09 -0800
committerGitHub <noreply@github.com>2023-01-30 19:24:09 -0800
commit499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (patch)
tree4c570a36d305c8909d633183694e0d1225f044c2 /source/slang/slang-ast-type.cpp
parent134dd7eb26fc7988ae13559d276cbf337b4b9d27 (diff)
Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615)
* Allow array parameters in forward diff. * Use type canonicalization instead of coersion. * Reimplement array type. * Fix. * Update test case. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-type.cpp')
-rw-r--r--source/slang/slang-ast-type.cpp61
1 files changed, 17 insertions, 44 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 3801b99b0..362503a64 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -314,64 +314,37 @@ Type* MatrixExpressionType::getRowType()
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArrayExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-bool ArrayExpressionType::_equalsImplOverride(Type* type)
+Type* ArrayExpressionType::getElementType()
{
- auto arrType = as<ArrayExpressionType>(type);
- if (!arrType)
- return false;
- return (areValsEqual(arrayLength, arrType->arrayLength) && baseType->equals(arrType->baseType));
-}
-
-Val* ArrayExpressionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
- auto elementType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff));
- IntVal* newArrayLength = nullptr;
- if (arrayLength)
- {
- newArrayLength = as<IntVal>(arrayLength->substituteImpl(astBuilder, subst, &diff));
- SLANG_ASSERT(newArrayLength);
- }
- if (diff)
- {
- *ioDiff = 1;
- auto rsType = getArrayType(
- astBuilder,
- elementType,
- newArrayLength);
- return rsType;
- }
- return this;
-}
-
-Type* ArrayExpressionType::_createCanonicalTypeOverride()
-{
- auto canonicalElementType = baseType->getCanonicalType();
- auto canonicalArrayType = getASTBuilder()->getArrayType(
- canonicalElementType,
- arrayLength);
- return canonicalArrayType;
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]);
}
-HashCode ArrayExpressionType::_getHashCodeOverride()
+IntVal* ArrayExpressionType::getElementCount()
{
- if (arrayLength)
- return (baseType->getHashCode() * 16777619) ^ arrayLength->getHashCode();
- else
- return baseType->getHashCode();
+ return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]);
}
void ArrayExpressionType::_toTextOverride(StringBuilder& out)
{
- out << baseType;
+ out << getElementType();
out.appendChar('[');
- if (arrayLength)
+ if (!isUnsized())
{
- out << arrayLength;
+ out << getElementCount();
}
out.appendChar(']');
}
+bool ArrayExpressionType::isUnsized()
+{
+ if (auto constSize = as<ConstantIntVal>(getElementCount()))
+ {
+ if (constSize->value == kUnsizedArrayMagicLength)
+ return true;
+ }
+ return false;
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void TypeType::_toTextOverride(StringBuilder& out)