diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-30 19:24:09 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-30 19:24:09 -0800 |
| commit | 499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (patch) | |
| tree | 4c570a36d305c8909d633183694e0d1225f044c2 /source/slang/slang-ast-type.cpp | |
| parent | 134dd7eb26fc7988ae13559d276cbf337b4b9d27 (diff) | |
Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615)
* Allow array parameters in forward diff.
* Use type canonicalization instead of coersion.
* Reimplement array type.
* Fix.
* Update test case.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-type.cpp')
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 61 |
1 files changed, 17 insertions, 44 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 3801b99b0..362503a64 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -314,64 +314,37 @@ Type* MatrixExpressionType::getRowType() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArrayExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ArrayExpressionType::_equalsImplOverride(Type* type) +Type* ArrayExpressionType::getElementType() { - auto arrType = as<ArrayExpressionType>(type); - if (!arrType) - return false; - return (areValsEqual(arrayLength, arrType->arrayLength) && baseType->equals(arrType->baseType)); -} - -Val* ArrayExpressionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - auto elementType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff)); - IntVal* newArrayLength = nullptr; - if (arrayLength) - { - newArrayLength = as<IntVal>(arrayLength->substituteImpl(astBuilder, subst, &diff)); - SLANG_ASSERT(newArrayLength); - } - if (diff) - { - *ioDiff = 1; - auto rsType = getArrayType( - astBuilder, - elementType, - newArrayLength); - return rsType; - } - return this; -} - -Type* ArrayExpressionType::_createCanonicalTypeOverride() -{ - auto canonicalElementType = baseType->getCanonicalType(); - auto canonicalArrayType = getASTBuilder()->getArrayType( - canonicalElementType, - arrayLength); - return canonicalArrayType; + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); } -HashCode ArrayExpressionType::_getHashCodeOverride() +IntVal* ArrayExpressionType::getElementCount() { - if (arrayLength) - return (baseType->getHashCode() * 16777619) ^ arrayLength->getHashCode(); - else - return baseType->getHashCode(); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); } void ArrayExpressionType::_toTextOverride(StringBuilder& out) { - out << baseType; + out << getElementType(); out.appendChar('['); - if (arrayLength) + if (!isUnsized()) { - out << arrayLength; + out << getElementCount(); } out.appendChar(']'); } +bool ArrayExpressionType::isUnsized() +{ + if (auto constSize = as<ConstantIntVal>(getElementCount())) + { + if (constSize->value == kUnsizedArrayMagicLength) + return true; + } + return false; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void TypeType::_toTextOverride(StringBuilder& out) |
