From 499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 30 Jan 2023 19:24:09 -0800 Subject: Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615) * Allow array parameters in forward diff. * Use type canonicalization instead of coersion. * Reimplement array type. * Fix. * Update test case. --------- Co-authored-by: Yong He --- source/slang/slang-ast-type.cpp | 61 ++++++++++++----------------------------- 1 file changed, 17 insertions(+), 44 deletions(-) (limited to 'source/slang/slang-ast-type.cpp') diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 3801b99b0..362503a64 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -314,64 +314,37 @@ Type* MatrixExpressionType::getRowType() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArrayExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ArrayExpressionType::_equalsImplOverride(Type* type) +Type* ArrayExpressionType::getElementType() { - auto arrType = as(type); - if (!arrType) - return false; - return (areValsEqual(arrayLength, arrType->arrayLength) && baseType->equals(arrType->baseType)); -} - -Val* ArrayExpressionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - auto elementType = as(baseType->substituteImpl(astBuilder, subst, &diff)); - IntVal* newArrayLength = nullptr; - if (arrayLength) - { - newArrayLength = as(arrayLength->substituteImpl(astBuilder, subst, &diff)); - SLANG_ASSERT(newArrayLength); - } - if (diff) - { - *ioDiff = 1; - auto rsType = getArrayType( - astBuilder, - elementType, - newArrayLength); - return rsType; - } - return this; -} - -Type* ArrayExpressionType::_createCanonicalTypeOverride() -{ - auto canonicalElementType = baseType->getCanonicalType(); - auto canonicalArrayType = getASTBuilder()->getArrayType( - canonicalElementType, - arrayLength); - return canonicalArrayType; + return as(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); } -HashCode ArrayExpressionType::_getHashCodeOverride() +IntVal* ArrayExpressionType::getElementCount() { - if (arrayLength) - return (baseType->getHashCode() * 16777619) ^ arrayLength->getHashCode(); - else - return baseType->getHashCode(); + return as(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); } void ArrayExpressionType::_toTextOverride(StringBuilder& out) { - out << baseType; + out << getElementType(); out.appendChar('['); - if (arrayLength) + if (!isUnsized()) { - out << arrayLength; + out << getElementCount(); } out.appendChar(']'); } +bool ArrayExpressionType::isUnsized() +{ + if (auto constSize = as(getElementCount())) + { + if (constSize->value == kUnsizedArrayMagicLength) + return true; + } + return false; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void TypeType::_toTextOverride(StringBuilder& out) -- cgit v1.2.3