diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-30 19:24:09 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-30 19:24:09 -0800 |
| commit | 499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (patch) | |
| tree | 4c570a36d305c8909d633183694e0d1225f044c2 /source | |
| parent | 134dd7eb26fc7988ae13559d276cbf337b4b9d27 (diff) | |
Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615)
* Allow array parameters in forward diff.
* Use type canonicalization instead of coersion.
* Reimplement array type.
* Fix.
* Update test case.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
26 files changed, 231 insertions, 240 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 97b9a227d..31dd5ed29 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -220,6 +220,12 @@ interface __FlagsEnumType : __EnumType { }; +__generic<T, let N:int> +__magic_type(ArrayType) +struct Array +{ +} + // The "comma operator" is effectively just a generic function that returns its second // argument. The left-to-right evaluation order guaranteed by Slang then ensures that // `left` is evaluated before `right`. diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c732d1a5e..adbf8ae48 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -127,6 +127,41 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T p = DifferentialPair<T>(newPrimal, newDiff); } +__generic<T, let N:int> +__intrinsic_op($(kIROp_MakeArrayFromElement)) +Array<T,N> makeArrayFromElement(T element); + + +__generic<T:IDifferentiable, let N:int> +extension Array<T, N> : IDifferentiable +{ + typedef Array<T.Differential, N> Differential; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return makeArrayFromElement<T.Differential, N>(T.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + Array<T.Differential, N> result; + for (int i = 0; i < N; i++) + result[i] = T.dadd(a[i], b[i]); + return result; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + Array<T.Differential, N> result; + for (int i = 0; i < N; i++) + result[i] = T.dmul(a[i], b[i]); + return result; + } +} + // vector-matrix __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index ab161065d..03725901e 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -290,13 +290,18 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount) { - ArrayExpressionType* arrayType = getOrCreateWithDefaultCtor<ArrayExpressionType>(elementType, elementCount); - if (!arrayType->baseType) + if (!elementCount) + elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength); + + auto result = getOrCreate<ArrayExpressionType>(elementType, elementCount); + if (!result->declRef.decl) { - arrayType->baseType = elementType; - arrayType->arrayLength = elementCount; + auto arrayGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ArrayType")); + auto arrayTypeDecl = arrayGenericDecl->inner; + auto substitutions = getOrCreate<GenericSubstitution>(arrayGenericDecl, elementType, elementCount); + result->declRef = DeclRef<Decl>(arrayTypeDecl, substitutions); } - return arrayType; + return result; } VectorExpressionType* ASTBuilder::getVectorType( diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index bc95ed63d..d44a15813 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -230,6 +230,11 @@ public: }); } + ConstantIntVal* getIntVal(Type* type, IntegerLiteralValue value) + { + return getOrCreate<ConstantIntVal>(type, value); + } + DeclRefType* getOrCreateDeclRefType(Decl* decl, Substitutions* outer) { NodeDesc desc; diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 3801b99b0..362503a64 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -314,64 +314,37 @@ Type* MatrixExpressionType::getRowType() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArrayExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -bool ArrayExpressionType::_equalsImplOverride(Type* type) +Type* ArrayExpressionType::getElementType() { - auto arrType = as<ArrayExpressionType>(type); - if (!arrType) - return false; - return (areValsEqual(arrayLength, arrType->arrayLength) && baseType->equals(arrType->baseType)); -} - -Val* ArrayExpressionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - auto elementType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff)); - IntVal* newArrayLength = nullptr; - if (arrayLength) - { - newArrayLength = as<IntVal>(arrayLength->substituteImpl(astBuilder, subst, &diff)); - SLANG_ASSERT(newArrayLength); - } - if (diff) - { - *ioDiff = 1; - auto rsType = getArrayType( - astBuilder, - elementType, - newArrayLength); - return rsType; - } - return this; -} - -Type* ArrayExpressionType::_createCanonicalTypeOverride() -{ - auto canonicalElementType = baseType->getCanonicalType(); - auto canonicalArrayType = getASTBuilder()->getArrayType( - canonicalElementType, - arrayLength); - return canonicalArrayType; + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); } -HashCode ArrayExpressionType::_getHashCodeOverride() +IntVal* ArrayExpressionType::getElementCount() { - if (arrayLength) - return (baseType->getHashCode() * 16777619) ^ arrayLength->getHashCode(); - else - return baseType->getHashCode(); + return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[1]); } void ArrayExpressionType::_toTextOverride(StringBuilder& out) { - out << baseType; + out << getElementType(); out.appendChar('['); - if (arrayLength) + if (!isUnsized()) { - out << arrayLength; + out << getElementCount(); } out.appendChar(']'); } +bool ArrayExpressionType::isUnsized() +{ + if (auto constSize = as<ConstantIntVal>(getElementCount())) + { + if (constSize->value == kUnsizedArrayMagicLength) + return true; + } + return false; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void TypeType::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 0e7614dd6..47608405a 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -432,19 +432,18 @@ class ParameterBlockType : public UniformParameterGroupType SLANG_AST_CLASS(ParameterBlockType) }; -class ArrayExpressionType : public Type +class ArrayExpressionType : public DeclRefType { SLANG_AST_CLASS(ArrayExpressionType) - - Type* baseType = nullptr; - IntVal* arrayLength = nullptr; - - // Overrides should be public so base classes can access + ArrayExpressionType(Type* inElementType, IntVal* inElementCount) + { + SLANG_UNUSED(inElementType); + SLANG_UNUSED(inElementCount); + } + bool isUnsized(); void _toTextOverride(StringBuilder& out); - Type* _createCanonicalTypeOverride(); - bool _equalsImplOverride(Type* type); - Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - HashCode _getHashCodeOverride(); + Type* getElementType(); + IntVal* getElementCount(); }; // The "type" of an expression that resolves to a type. @@ -500,7 +499,7 @@ class VectorExpressionType : public ArithmeticExpressionType BasicExpressionType* _getScalarTypeOverride(); VectorExpressionType(Type* inElementType, IntVal* inElementCount) - : elementType(inElementType), elementCount(inElementCount) + : elementType(inElementType), elementCount(inElementCount) {} }; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index e60c963a8..fde31c730 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -544,7 +544,6 @@ Val* ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBu return substValue; } - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! bool TaggedUnionSubtypeWitness::_equalsValOverride(Val* val) @@ -618,41 +617,6 @@ Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, return substWitness; } -bool DifferentialBottomSubtypeWitness::_equalsValOverride(Val* val) -{ - auto otherDiffBottomWitness = as<DifferentialBottomSubtypeWitness>(val); - if (!otherDiffBottomWitness) - return false; - - return otherDiffBottomWitness->sub && otherDiffBottomWitness->sub->equals(sub); -} - -void DifferentialBottomSubtypeWitness::_toTextOverride(StringBuilder& out) -{ - out << "DifferentialBottomSubtypeWitness(" << sub << ")"; -} - -HashCode DifferentialBottomSubtypeWitness::_getHashCodeOverride() -{ - return combineHash(3892, sub->getHashCode()); -} - -Val* DifferentialBottomSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - if (!diff) - return this; - - *ioDiff += diff; - - DifferentialBottomSubtypeWitness* substWitness = - astBuilder->create<DifferentialBottomSubtypeWitness>(substSub, substSup); - return substWitness; -} - bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val) { if (auto other = as<ConjunctionSubtypeWitness>(val)) @@ -940,7 +904,7 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut *ioDiff += diff; if (evaluatedTerms.getCount() == 0) - return astBuilder->getOrCreate<ConstantIntVal>(type, evaluatedConstantTerm); + return astBuilder->getIntVal(type, evaluatedConstantTerm); if (diff != 0) { auto newPolynomial = astBuilder->create<PolynomialIntVal>(type); @@ -1253,7 +1217,7 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) return terms[0]->paramFactors[0]->param; } if (terms.getCount() == 0) - return builder->getOrCreate<ConstantIntVal>(type, constantTerm); + return builder->getIntVal(type, constantTerm); return this; } @@ -1425,7 +1389,7 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR { SLANG_UNREACHABLE("constant folding of FuncCallIntVal"); } - return astBuilder->getOrCreate<ConstantIntVal>(resultType, resultValue); + return astBuilder->getIntVal(resultType, resultValue); } return nullptr; } diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 8e5192536..49aec8c5e 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -399,24 +399,6 @@ class DynamicSubtypeWitness : public SubtypeWitness SLANG_AST_CLASS(DynamicSubtypeWitness) }; - /// A witness of the fact that any type can be viewed as a subtype of DifferentialBottom. -class DifferentialBottomSubtypeWitness : public SubtypeWitness -{ - SLANG_AST_CLASS(DifferentialBottomSubtypeWitness) - - DifferentialBottomSubtypeWitness(Type* inSub, Type* inSup) - { - sub = inSub; - sup = inSup; - } - - // Overrides should be public so base classes can access - bool _equalsValOverride(Val* val); - void _toTextOverride(StringBuilder& out); - HashCode _getHashCodeOverride(); - Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); -}; - /// A witness that `T : L & R` because `T : L` and `T : R` class ConjunctionSubtypeWitness : public SubtypeWitness { diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 3a50897de..9bf343876 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -176,7 +176,6 @@ namespace Slang *link = extractWitness; link = (SubtypeWitness**) &extractWitness->conjunctionWitness; } - // Move on with the list. bb = bb->prev; } diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index c6daf5e86..6decce625 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -275,10 +275,11 @@ namespace Slang // TODO(tfoley): If we can compute the size of the array statically, // then we want to check that there aren't too many initializers present - auto toElementType = toArrayType->baseType; - - if(auto toElementCount = toArrayType->arrayLength) + auto toElementType = toArrayType->getElementType(); + if(!toArrayType->isUnsized()) { + auto toElementCount = toArrayType->getElementCount(); + // In the case of a sized array, we need to check that the number // of elements being initialized matches what was declared. // @@ -349,7 +350,7 @@ namespace Slang // We have a new type for the conversion, based on what // we learned. toType = m_astBuilder->getArrayType(toElementType, - m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)); + m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount)); } } else if(auto toMatrixType = as<MatrixExpressionType>(toType)) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5add89312..9bda6c3e7 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -955,17 +955,14 @@ namespace Slang } } - static bool isUnsizedArrayType(Type* type) + bool isUnsizedArrayType(Type* type) { // Not an array? auto arrayType = as<ArrayExpressionType>(type); if (!arrayType) return false; // Explicit element count given? - auto elementCount = arrayType->arrayLength; - if (elementCount) return true; - - return true; + return arrayType->isUnsized(); } bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state) @@ -3304,7 +3301,7 @@ namespace Slang { arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); } - auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->baseType, innerLeft, _Move(args), nestingLevel + 1); + auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->getElementType(), innerLeft, _Move(args), nestingLevel + 1); synth.popScope(); if (!assignStmt) return nullptr; @@ -5986,17 +5983,19 @@ namespace Slang if (!arrayType) return; // Explicit element count given? - auto elementCount = arrayType->arrayLength; - if (elementCount) return; + if (!isUnsizedArrayType(arrayType)) + return; // No initializer? auto initExpr = varDecl->initExpr; if(!initExpr) return; + IntVal* elementCount = nullptr; + // Is the type of the initializer an array type? if(auto arrayInitType = as<ArrayExpressionType>(initExpr->type)) { - elementCount = arrayInitType->arrayLength; + elementCount = arrayInitType->getElementCount(); } else { @@ -6008,7 +6007,7 @@ namespace Slang // and install it into our type. varDecl->type.type = getArrayType( m_astBuilder, - arrayType->baseType, + arrayType->getElementType(), elementCount); } @@ -6017,8 +6016,7 @@ namespace Slang auto arrayType = as<ArrayExpressionType>(varDecl->type); if (!arrayType) return; - auto elementCount = arrayType->arrayLength; - if (!elementCount) + if (arrayType->isUnsized()) { // Note(tfoley): For now we allow arrays of unspecified size // everywhere, because some source languages (e.g., GLSL) @@ -6030,6 +6028,7 @@ namespace Slang } // TODO(tfoley): How to handle the case where bound isn't known? + auto elementCount = arrayType->getElementCount(); if (GetMinBound(elementCount) <= 0) { getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2853c1eb9..d99114e4f 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -853,11 +853,11 @@ namespace Slang } else if (auto arrayType = as<ArrayExpressionType>(type)) { - auto baseDiffType = tryGetDifferentialType(builder, arrayType->baseType); + auto baseDiffType = tryGetDifferentialType(builder, arrayType->getElementType()); if (!baseDiffType) return nullptr; return builder->getArrayType( baseDiffType, - arrayType->arrayLength); + arrayType->getElementCount()); } if (auto declRefType = as<DeclRefType>(type)) @@ -946,8 +946,8 @@ namespace Slang if (auto arrayType = as<ArrayExpressionType>(type)) { - maybeRegisterDifferentiableType(builder, arrayType->baseType); - return; + maybeRegisterDifferentiableType(builder, arrayType->getElementType()); + // Fall through to register the array type itself. } if (auto declRefType = as<DeclRefType>(type)) @@ -990,8 +990,8 @@ namespace Slang if (auto arrayType = as<ArrayExpressionType>(type)) { - maybeRegisterDifferentiableTypeRecursive(builder, arrayType->baseType, workingSet); - return; + maybeRegisterDifferentiableTypeRecursive(builder, arrayType->getElementType(), workingSet); + // Fall through to register the array type itself. } if (auto declRefType = as<DeclRefType>(type)) @@ -1204,7 +1204,7 @@ namespace Slang IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr) { - return m_astBuilder->getOrCreate<ConstantIntVal>(expr->type.type, expr->value); + return m_astBuilder->getIntVal(expr->type.type, expr->value); } IntVal* SemanticsVisitor::tryConstantFoldExpr( @@ -1433,7 +1433,7 @@ namespace Slang } } - IntVal* result = m_astBuilder->getOrCreate<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue); + IntVal* result = m_astBuilder->getIntVal(invokeExpr.getExpr()->type.type, resultValue); return result; } @@ -1517,7 +1517,7 @@ namespace Slang { // If it's a boolean, we allow promotion to int. const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value); - return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value); + return m_astBuilder->getIntVal(m_astBuilder->getBoolType(), value); } if (auto arrayLengthExpr = expr.as<GetArrayLengthExpr>()) @@ -1527,8 +1527,11 @@ namespace Slang auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts()); if (auto arrayType = as<ArrayExpressionType>(type)) { - if (auto val = as<IntVal>(arrayType->arrayLength)) - return val; + if (!arrayType->isUnsized()) + { + if (auto val = as<IntVal>(arrayType->getElementCount())) + return val; + } } } } @@ -1734,7 +1737,7 @@ namespace Slang { return CheckSimpleSubscriptExpr( subscriptExpr, - baseArrayType->baseType); + baseArrayType->getElementType()); } else if (auto vecType = as<VectorExpressionType>(baseType)) { @@ -2146,12 +2149,14 @@ namespace Slang // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = m_astBuilder->getDifferentiableInterface(); - - auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)); + + SubtypeWitness* conformanceWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. if (conformanceWitness) + { return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + } else return primalType; } @@ -2200,15 +2205,24 @@ namespace Slang for (UInt i = 0; i < originalType->getParamCount(); i++) { - if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + if (auto outType = as<OutType>(originalType->getParamType(i))) { - // Using inout type on all the derivative parameters - if (auto outType = as<OutType>(derivType)) + auto diffElementType = + tryGetDifferentialType(m_astBuilder, outType->getValueType()); + if (diffElementType) + { + type->paramTypes.add(diffElementType); + } + else { - derivType = outType->getValueType(); + continue; } - else if (as<DifferentialPairType>(derivType)) + } + else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + { + if (as<DifferentialPairType>(derivType)) { + // Using inout type on all the derivative parameters derivType = m_astBuilder->getInOutType(derivType); } type->paramTypes.add(derivType); @@ -2216,7 +2230,9 @@ namespace Slang } // Last parameter is the initial derivative of the original return type - type->paramTypes.add(getDifferentialType(m_astBuilder, originalType->resultType, SourceLoc())); + auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->resultType); + if (dOutType) + type->paramTypes.add(dOutType); return type; } @@ -2407,7 +2423,7 @@ namespace Slang if (auto arrType = as<ArrayExpressionType>(expr->arrayExpr->type)) { expr->type = m_astBuilder->getIntType(); - if (!arrType->arrayLength) + if (arrType->isUnsized()) { getSink()->diagnose(expr, Diagnostics::invalidArraySize); } @@ -2823,7 +2839,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); + m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount))); } // A swizzle can be used as an l-value as long as there @@ -2948,7 +2964,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); + m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount))); } // A swizzle can be used as an l-value as long as there diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 1b59094e2..ac5fc8392 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2038,4 +2038,6 @@ namespace Slang void checkModule(ModuleDecl* programNode); }; + + bool isUnsizedArrayType(Type* type); } diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 3a64f3c8f..46e39e4c0 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -62,7 +62,7 @@ namespace Slang // while( auto arrayType = as<ArrayExpressionType>(type) ) { - type = arrayType->baseType; + type = arrayType->getElementType(); } if( auto parameterGroupType = as<ParameterGroupType>(type) ) @@ -1125,7 +1125,7 @@ namespace Slang if(!intVal) { sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl); - intVal = getLinkage()->getASTBuilder()->getOrCreate<ConstantIntVal>(m_astBuilder->getIntType(), 0); + intVal = getLinkage()->getASTBuilder()->getIntVal(m_astBuilder->getIntType(), 0); } ModuleSpecializationInfo::GenericArgInfo expandedArg; diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 0f450340f..8049c1230 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -206,8 +206,7 @@ namespace Slang } else { - ConstantIntVal* rangeBeginConst = m_astBuilder->getOrCreate<ConstantIntVal>(); - rangeBeginConst->value = 0; + ConstantIntVal* rangeBeginConst = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 0); rangeBeginVal = rangeBeginConst; } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 05a5f8f56..91374e006 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -194,27 +194,19 @@ IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* b return table; } -IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) +IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) { - return builder->getDifferentialPairType( - (IRType*)primalType, - witness); -} - -IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType) -{ - auto primalType = lookupPrimalInst(builder, originalType, nullptr); - SLANG_RELEASE_ASSERT(primalType); - - IRInst* witness = + IRInst* witness = differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType); if (witness) { witness = lookupPrimalInst(builder, witness, nullptr); - SLANG_RELEASE_ASSERT(witness); + SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(originalType)); } if (!witness) { + auto primalType = lookupPrimalInst(builder, originalType, nullptr); + SLANG_RELEASE_ASSERT(primalType); if (auto primalPairType = as<IRDifferentialPairType>(primalType)) { witness = getDifferentialPairWitness(builder, originalType, primalPairType); @@ -224,6 +216,23 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI differentiateExtractExistentialType(builder, extractExistential, witness); } } + return witness; +} + +IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) +{ + return builder->getDifferentialPairType( + (IRType*)primalType, + witness); +} + +IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType) +{ + auto primalType = lookupPrimalInst(builder, originalType, nullptr); + SLANG_RELEASE_ASSERT(primalType); + + IRInst* witness = tryGetDifferentiableWitness(builder, originalType); + SLANG_RELEASE_ASSERT(witness); return builder->getDifferentialPairType( (IRType*)primalType, diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index a870dc815..e6a525dee 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -101,6 +101,8 @@ struct AutoDiffTranscriberBase // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. IRWitnessTable* getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType); + IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); + IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index ae8359251..8952f9756 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -534,31 +534,7 @@ void stripNoDiffTypeAttribute(IRModule* module) bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) { - HashSet<IRInst*> processedSet; - for (;typeInst;) - { - if (as<IRArrayTypeBase>(typeInst) || as<IRPtrTypeBase>(typeInst)) - { - typeInst = typeInst->getOperand(0); - if (!processedSet.Add(typeInst)) - return false; - } - else - { - break; - } - } - if (!typeInst) - return false; - switch (typeInst->getOp()) - { - case kIROp_FloatType: - case kIROp_DifferentialPairType: - return true; - default: - break; - } - if (context.lookUpConformanceForType(typeInst)) + if (context.isDifferentiableType((IRType*)typeInst)) return true; // Look for equivalent types. for (auto type : context.differentiableWitnessDictionary) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 6da4ea6a6..2258ff753 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -190,6 +190,30 @@ struct DifferentiableTypeConformanceContext } } + bool isDifferentiableType(IRType* origType) + { + for (; origType;) + { + switch (origType->getOp()) + { + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + return true; + case kIROp_VectorType: + case kIROp_ArrayType: + case kIROp_PtrType: + case kIROp_OutType: + case kIROp_InOutType: + origType = (IRType*)origType->getOperand(0); + continue; + default: + return lookUpConformanceForType(origType) != nullptr; + } + } + return false; + } + IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) { auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 605ac62db..149f5f6b9 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1443,11 +1443,6 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(diff); } - LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*) - { - return LoweredValInfo(); - } - LoweredValInfo visitTaggedUnionSubtypeWitness( TaggedUnionSubtypeWitness* val) { @@ -1861,10 +1856,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower IRType* visitArrayExpressionType(ArrayExpressionType* type) { - auto elementType = lowerType(context, type->baseType); - if (type->arrayLength) + auto elementType = lowerType(context, type->getElementType()); + if (!type->isUnsized()) { - auto elementCount = lowerSimpleVal(context, type->arrayLength); + auto elementCount = lowerSimpleVal(context, type->getElementCount()); return getBuilder()->getArrayType( elementType, elementCount); @@ -3390,18 +3385,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } else if (auto arrayType = as<ArrayExpressionType>(type)) { - UInt elementCount = (UInt) getIntVal(arrayType->arrayLength); - - auto irDefaultElement = getSimpleVal(context, getDefaultVal(arrayType->baseType)); - - List<IRInst*> args; - for(UInt ee = 0; ee < elementCount; ++ee) - { - args.add(irDefaultElement); - } + auto irDefaultElement = getSimpleVal(context, getDefaultVal(arrayType->getElementType())); return LoweredValInfo::simple( - getBuilder()->emitMakeArray(irType, args.getCount(), args.getBuffer())); + getBuilder()->emitMakeArrayFromElement(irType, irDefaultElement)); } else if (auto declRefType = as<DeclRefType>(type)) { @@ -3470,7 +3457,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // fill in the appropriate field of the result if (auto arrayType = as<ArrayExpressionType>(type)) { - UInt elementCount = (UInt) getIntVal(arrayType->arrayLength); + UInt elementCount = (UInt) getIntVal(arrayType->getElementCount()); for (UInt ee = 0; ee < argCount; ++ee) { @@ -3480,7 +3467,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } if(elementCount > argCount) { - auto irDefaultValue = getSimpleVal(context, getDefaultVal(arrayType->baseType)); + auto irDefaultValue = getSimpleVal(context, getDefaultVal(arrayType->getElementType())); for(UInt ee = argCount; ee < elementCount; ++ee) { args.add(irDefaultValue); diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 58d6aaae3..da5099934 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -217,8 +217,8 @@ namespace Slang else if (auto arrType = dynamicCast<ArrayExpressionType>(type)) { emitRaw(context, "a"); - emitSimpleIntVal(context, arrType->arrayLength); - emitType(context, arrType->baseType); + emitSimpleIntVal(context, arrType->getElementCount()); + emitType(context, arrType->getElementType()); } else if( auto taggedUnionType = dynamicCast<TaggedUnionType>(type) ) { diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index db323ff6e..f3eabf613 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -1889,16 +1889,18 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( // Note: Bad Things will happen if we have an array input // without a semantic already being enforced. - auto elementCount = (UInt) getIntVal(arrayType->arrayLength); + auto elementCount = (UInt) getIntVal(arrayType->getElementCount()); + if (arrayType->isUnsized()) + elementCount = 0; // We use the first element to derive the layout for the element type - auto elementTypeLayout = processEntryPointVaryingParameter(context, arrayType->baseType, state, varLayout); + auto elementTypeLayout = processEntryPointVaryingParameter(context, arrayType->getElementType(), state, varLayout); // We still walk over subsequent elements to make sure they consume resources // as needed for( UInt ii = 1; ii < elementCount; ++ii ) { - processEntryPointVaryingParameter(context, arrayType->baseType, state, nullptr); + processEntryPointVaryingParameter(context, arrayType->getElementType(), state, nullptr); } RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 5c5773fec..9c1d48a28 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -149,7 +149,7 @@ static SlangParameterCategory maybeRemapParameterCategory( // of this variable? Type* type = typeLayout->getType(); while (auto arrayType = as<ArrayExpressionType>(type)) - type = arrayType->baseType; + type = arrayType->getElementType(); switch (spReflectionType_GetKind(convert(type))) { case SLANG_TYPE_KIND_CONSTANT_BUFFER: @@ -462,7 +462,7 @@ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) if(auto arrayType = as<ArrayExpressionType>(type)) { - return arrayType->arrayLength ? (size_t) getIntVal(arrayType->arrayLength) : 0; + return !arrayType->isUnsized() ? (size_t)getIntVal(arrayType->getElementCount()) : 0; } else if( auto vectorType = as<VectorExpressionType>(type)) { @@ -479,7 +479,7 @@ SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionTy if(auto arrayType = as<ArrayExpressionType>(type)) { - return (SlangReflectionType*) arrayType->baseType; + return (SlangReflectionType*) arrayType->getElementType(); } else if( auto parameterGroupType = as<ParameterGroupType>(type)) { @@ -631,7 +631,7 @@ SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionTy while(auto arrayType = as<ArrayExpressionType>(type)) { - type = arrayType->baseType; + type = arrayType->getElementType(); } if(auto textureType = as<TextureTypeBase>(type)) @@ -667,7 +667,7 @@ SLANG_API SlangResourceAccess spReflectionType_GetResourceAccess(SlangReflection while(auto arrayType = as<ArrayExpressionType>(type)) { - type = arrayType->baseType; + type = arrayType->getElementType(); } if(auto textureType = as<TextureTypeBase>(type)) @@ -763,7 +763,7 @@ SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangRefle while(auto arrayType = as<ArrayExpressionType>(type)) { - type = arrayType->baseType; + type = arrayType->getElementType(); } if (auto textureType = as<TextureTypeBase>(type)) @@ -1492,9 +1492,9 @@ namespace Slang LayoutSize elementCount = LayoutSize::infinite(); if( auto arrayType = as<ArrayExpressionType>(arrayTypeLayout->type) ) { - if( auto elementCountVal = arrayType->arrayLength ) + if( !arrayType->isUnsized()) { - elementCount = LayoutSize::RawValue(getIntVal(elementCountVal)); + elementCount = LayoutSize::RawValue(getIntVal(arrayType->getElementCount())); } } addRangesRec(elementTypeLayout, path, multiplier * elementCount); diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 4e5db17c0..6076a41ca 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -513,6 +513,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt vecType->elementCount = ExtractGenericArgInteger(subst->getArgs()[1]); return vecType; } + else if (magicMod->magicName == "ArrayType") + { + SLANG_ASSERT(subst && subst->getArgs().getCount() == 2); + auto vecType = astBuilder->getOrCreate<ArrayExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1])); + vecType->declRef = declRef; + return vecType; + } else if (magicMod->magicName == "Matrix") { SLANG_ASSERT(subst && subst->getArgs().getCount() == 3); @@ -1097,19 +1104,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt Type* elementType, IntVal* elementCount) { - auto arrayType = astBuilder->create<ArrayExpressionType>(); - arrayType->baseType = elementType; - arrayType->arrayLength = elementCount; - return arrayType; + return astBuilder->getArrayType(elementType, elementCount); } ArrayExpressionType* getArrayType( ASTBuilder* astBuilder, Type* elementType) { - auto arrayType = astBuilder->create<ArrayExpressionType>(); - arrayType->baseType = elementType; - return arrayType; + return astBuilder->getArrayType(elementType, nullptr); } NamedExpressionType* getNamedType( diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index e36ee944c..dd119cf3a 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -326,6 +326,8 @@ namespace Slang All = 7 }; + const int kUnsizedArrayMagicLength = 0x7FFFFFFF; + /// Get the module dclaration that a declaration is associated with, if any. ModuleDecl* getModuleDecl(Decl* decl); diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 45bacaa44..3da0cc95a 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1467,6 +1467,8 @@ static LayoutSize GetElementCount(IntVal* val) if (auto constantVal = as<ConstantIntVal>(val)) { + if (constantVal->value == kUnsizedArrayMagicLength) + return LayoutSize::infinite(); return LayoutSize(LayoutSize::RawValue(constantVal->value)); } else if( auto varRefVal = as<GenericParamIntVal>(val) ) @@ -3669,7 +3671,7 @@ static TypeLayoutResult _createTypeLayout( } else if (auto arrayType = as<ArrayExpressionType>(type)) { - return createArrayLikeTypeLayout(context, arrayType, arrayType->baseType, arrayType->arrayLength); + return createArrayLikeTypeLayout(context, arrayType, arrayType->getElementType(), arrayType->getElementCount()); } else if (auto declRefType = as<DeclRefType>(type)) { |
