From 453683bf44f2112719802eaac2b332d49eebd640 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 19 Aug 2024 15:03:56 -0700 Subject: Tuple swizzling, concat, comparison and `countof`. (#4856) * Tuple swizzling and element access. * Update proposal status. * Cleanup. * Fix merrge error. * Address review. --- source/slang/core.meta.slang | 79 +++++++- source/slang/slang-ast-builder.cpp | 19 +- source/slang/slang-ast-dump.cpp | 25 +++ source/slang/slang-ast-expr.h | 8 +- source/slang/slang-ast-support-types.h | 3 + source/slang/slang-ast-val.cpp | 75 +++++++ source/slang/slang-ast-val.h | 24 +++ source/slang/slang-check-expr.cpp | 237 ++++++++++++++++++---- source/slang/slang-check-impl.h | 6 + source/slang/slang-check-overload.cpp | 42 ++-- source/slang/slang-diagnostic-defs.h | 2 + source/slang/slang-ir-autodiff-fwd.cpp | 1 + source/slang/slang-ir-inline.cpp | 11 + source/slang/slang-ir-inst-defs.h | 6 + source/slang/slang-ir-insts.h | 20 ++ source/slang/slang-ir-lower-tuple-types.cpp | 174 +++++++++++++++- source/slang/slang-ir-peephole.cpp | 1 + source/slang/slang-ir-redundancy-removal.cpp | 16 +- source/slang/slang-ir-specialize.cpp | 210 ++++++++++++++++--- source/slang/slang-ir-ssa-simplification.cpp | 2 +- source/slang/slang-ir-ssa-simplification.h | 1 + source/slang/slang-ir.cpp | 33 +++ source/slang/slang-ir.h | 18 ++ source/slang/slang-language-server-ast-lookup.cpp | 26 ++- source/slang/slang-language-server-completion.cpp | 14 ++ source/slang/slang-language-server.cpp | 42 ++++ source/slang/slang-lookup.cpp | 105 ++++++---- source/slang/slang-lower-to-ir.cpp | 91 +++++---- source/slang/slang-parser.cpp | 22 +- source/slang/slang-serialize-type-info.h | 25 +++ source/slang/slang-serialize.h | 35 ++++ source/slang/slang-type-layout.cpp | 71 +++++++ 32 files changed, 1257 insertions(+), 187 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 6c51ccef0..84e1b8168 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -874,10 +874,85 @@ __generic __magic_type(TupleType) struct Tuple { - __intrinsic_op($(0)) + __intrinsic_op($(kIROp_MakeTuple)) __init(expand each T); } +__intrinsic_op($(kIROp_MakeTuple)) +Tuple makeTuple(T v); + +Tuple concat(Tuple t, Tuple u) +{ + return makeTuple(expand each t, expand each u); +} + + +[__unsafeForceInlineEarly] +bool __assign(inout bool v, bool newVal) +{ + v = newVal; + return newVal; +} + +[__unsafeForceInlineEarly] +void __tupleLessKernel(inout bool result, inout bool exit, T a, T b) +{ + if (!exit) + { + if (a.lessThan(b)) + { + result = true; + exit = true; + } + else if (!a.equals(b)) + { + exit = true; + } + } +} + +[__unsafeForceInlineEarly] +void __tupleGreaterKernel(inout bool result, inout bool exit, T a, T b) +{ + if (!exit) + { + if (!a.lessThanOrEquals(b)) + { + result = true; + exit = true; + } + else if (!a.equals(b)) + { + exit = true; + } + } +} + +__generic +extension Tuple : IComparable +{ + bool lessThan(Tuple other) + { + bool result = false; + bool exit = false; + expand __tupleLessKernel(result, exit, each this, each other); + return result; + } + bool lessThanOrEquals(Tuple other) + { + bool result = false; + bool exit = false; + expand __tupleGreaterKernel(result, exit, each this, each other); + return !result; + } + bool equals(Tuple other) + { + bool result = true; + expand result && __assign(result, result && (each this).equals(each other)); + return result; + } +} + __generic __magic_type(NativeRefType) __intrinsic_type($(kIROp_NativePtrType)) @@ -2181,7 +2256,7 @@ __generic [OverloadRank(-10)] bool operator >=(T v0, T v1) { - return v1.lessThan(v1); + return v1.lessThanOrEquals(v0); } __generic [__unsafeForceInlineEarly] diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 3a2b2933d..a13e13851 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -572,9 +572,26 @@ Type* ASTBuilder::getExpandType(Type* pattern, ArrayView capturedPacks) return getOrCreate(pattern, capturedPacks); } +void flattenTypeList(ShortList& flattenedList, Type* type) +{ + if (auto typePack = as(type)) + { + for (Index i = 0; i < typePack->getTypeCount(); i++) + flattenTypeList(flattenedList, typePack->getElementType(i)); + } + else + { + flattenedList.add(type); + } +} + ConcreteTypePack* ASTBuilder::getTypePack(ArrayView types) { - return getOrCreate(types); + // Flatten all type packs in the type list. + ShortList flattenedTypes; + for (auto type : types) + flattenTypeList(flattenedTypes, type); + return getOrCreate(flattenedTypes.getArrayView().arrayView); } TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness( diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index 8b2494310..a1ab7a5c8 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -203,6 +203,27 @@ struct ASTDumpContext m_writer->emit("}"); } + template + void dump(const ShortList& list) + { + m_writer->emit(" { \n"); + m_writer->indent(); + for (Index i = 0; i < list.getCount(); ++i) + { + dump(list[i]); + if (i < list.getCount() - 1) + { + m_writer->emit(",\n"); + } + else + { + m_writer->emit("\n"); + } + } + m_writer->dedent(); + m_writer->emit("}"); + } + void dump(SourceLoc sourceLoc) { if (m_dumpFlags & ASTDumpUtil::Flag::HideSourceLoc) @@ -285,6 +306,10 @@ struct ASTDumpContext { m_writer->emit(UInt(v)); } + void dump(UInt v) + { + m_writer->emit(v); + } void dump(int32_t v) { m_writer->emit(v); diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index e6edce8f9..c07f7f5b9 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -291,8 +291,7 @@ class SwizzleExpr: public Expr { SLANG_AST_CLASS(SwizzleExpr) Expr* base = nullptr; - int elementCount; - int elementIndices[4]; + ShortList elementIndices; SourceLoc memberOpLoc; }; @@ -425,6 +424,11 @@ class AlignOfExpr : public SizeOfLikeExpr SLANG_AST_CLASS(AlignOfExpr); }; +class CountOfExpr : public SizeOfLikeExpr +{ + SLANG_AST_CLASS(CountOfExpr); +}; + class MakeOptionalExpr : public Expr { SLANG_AST_CLASS(MakeOptionalExpr) diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index f74c169e7..21948dc04 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -147,6 +147,9 @@ namespace Slang // Cost of converting an integer to a half type kConversionCost_IntegerToHalfConversion = 500, + // Cost of using a concrete argument pack + kConversionCost_ParameterPack = 500, + // Default case (usable for user-defined conversions) kConversionCost_Default = 500, diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index e222e86e1..e8020aa04 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -1534,6 +1534,81 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! CountOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +void CountOfIntVal::_toTextOverride(StringBuilder& out) +{ + out << "countof("; + getTypeArg()->toText(out); + out << ")"; +} + +Val* CountOfIntVal::tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType) +{ + if (auto typePack = as(newType)) + { + bool anyAbstract = false; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + if (isAbstractTypePack(typePack->getElementType(i))) + { + anyAbstract = true; + break; + } + } + if (!anyAbstract) + { + auto result = astBuilder->getIntVal(intType, typePack->getTypeCount()); + return result; + } + } + else if (auto tupleType = as(newType)) + { + bool anyAbstract = false; + for (Index i = 0; i < tupleType->getMemberCount(); i++) + { + if (isAbstractTypePack(tupleType->getMember(i))) + { + anyAbstract = true; + break; + } + } + if (!anyAbstract) + { + auto result = astBuilder->getIntVal(intType, tupleType->getMemberCount()); + return result; + } + } + return nullptr; +} + +Val* CountOfIntVal::tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType) +{ + if (auto result = tryFoldOrNull(astBuilder, intType, newType)) + return result; + auto result = astBuilder->getOrCreate(intType, newType); + return result; +} + +Val* CountOfIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto newType = as(getTypeArg()->substituteImpl(astBuilder, subst, &diff)); + if (!diff) + return this; + + (*ioDiff)++; + return tryFold(astBuilder, getType(), newType); +} + +Val* CountOfIntVal::_resolveImplOverride() +{ + auto resolvedTypeArg = getTypeArg()->resolve(); + if (resolvedTypeArg == getTypeArg()) + return this; + return tryFold(getCurrentASTBuilder(), getType(), as(resolvedTypeArg)); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void WitnessLookupIntVal::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 74e0f27c1..2599ce46a 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -248,6 +248,30 @@ class FuncCallIntVal : public IntVal Val* _linkTimeResolveOverride(Dictionary& map); }; +class CountOfIntVal : public IntVal +{ + SLANG_AST_CLASS(CountOfIntVal) + + CountOfIntVal(Type* inType, Type* typeArg) + { + setOperands(inType, typeArg); + } + + Val* getTypeArg() { return getOperand(1); } + + void _toTextOverride(StringBuilder& out); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); + bool _isLinkTimeValOverride() + { + return false; + } + + static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType); + + static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType); +}; + class WitnessLookupIntVal : public IntVal { SLANG_AST_CLASS(WitnessLookupIntVal) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 965d25bd5..fe43a4f8f 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -16,6 +16,7 @@ #include "slang-lookup.h" #include "slang-lookup-spirv.h" #include "slang-ast-print.h" +#include "core/slang-char-util.h" namespace Slang { @@ -1866,6 +1867,13 @@ namespace Slang } } + if (auto countOfExpr = expr.as()) + { + auto type = as(countOfExpr.getExpr()->sizedType->substitute(m_astBuilder, expr.getSubsts())); + if (type) + return as(CountOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type)); + } + // it is possible that we are referring to a generic value param if (auto declRefExpr = expr.as()) { @@ -3205,6 +3213,31 @@ namespace Slang return false; } + static bool _isCountOfType(Type* type) + { + if (!type) + { + return false; + } + + if (isTypePack(type)) + { + return true; + } + + if (as(type)) + { + return true; + } + + if (as(type)) + { + return true; + } + + return false; + } + Expr* SemanticsExprVisitor::visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr) { auto valueExpr = dispatch(sizeOfLikeExpr->value); @@ -3229,12 +3262,25 @@ namespace Slang type = properType.type; } - if (!_isSizeOfType(type)) + if (as(sizeOfLikeExpr)) { - getSink()->diagnose(sizeOfLikeExpr, Diagnostics::sizeOfArgumentIsInvalid); + if (!_isCountOfType(type)) + { + getSink()->diagnose(sizeOfLikeExpr, Diagnostics::countOfArgumentIsInvalid); - sizeOfLikeExpr->type = m_astBuilder->getErrorType(); - return sizeOfLikeExpr; + sizeOfLikeExpr->type = m_astBuilder->getErrorType(); + return sizeOfLikeExpr; + } + } + else + { + if (!_isSizeOfType(type)) + { + getSink()->diagnose(sizeOfLikeExpr, Diagnostics::sizeOfArgumentIsInvalid); + + sizeOfLikeExpr->type = m_astBuilder->getErrorType(); + return sizeOfLikeExpr; + } } sizeOfLikeExpr->sizedType = type; @@ -3815,6 +3861,108 @@ namespace Slang return CreateErrorExpr(memberRefExpr); } + Expr* SemanticsVisitor::checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* baseTupleType) + { + UInt tupleElementCount = (UInt)baseTupleType->getMemberCount(); + if (tupleElementCount == 0) + return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); + + if (memberExpr->name == getSession()->getCompletionRequestTokenName()) + { + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; + suggestions.swizzleBaseType = + memberExpr->baseExpression ? memberExpr->baseExpression->type : nullptr; + suggestions.elementCount[0] = (Index)tupleElementCount; + suggestions.elementCount[1] = 0; + return memberExpr; + } + + String swizzleText = getText(memberExpr->name); + auto span = swizzleText.getUnownedSlice(); + Index pos = 0; + + ShortList elementCoords; + + bool anyDuplicates = false; + + // The contents of the string are 0-terminated + // Every update to cursor corresponds to a check against 0-termination + while (pos < span.getLength()) + { + UInt elementCoord; + + // Check for the preceding underscore + if (span[pos] != '_') + { + return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); + } + pos++; + + // Parse index. + if (pos >= span.getLength()) + { + // Unexpected end of swizzle string, fallback to + // member lookup. + return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); + } + + auto ch = span[pos]; + + if (!CharUtil::isDigit(ch)) + { + // An invalid character in the swizzle is an error, fallback to + // member lookup. + return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); + } + elementCoord = (UInt)StringUtil::parseIntAndAdvancePos(span, pos); + + if (elementCoord >= tupleElementCount) + { + getSink()->diagnose(memberExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseTupleType); + return CreateErrorExpr(memberExpr); + } + + // Check if we've seen this index before + for (int ee = 0; ee < elementCoords.getCount(); ee++) + { + if (elementCoords[ee] == elementCoord) + anyDuplicates = true; + } + + // add to our list... + elementCoords.add(elementCoord); + } + + SwizzleExpr* swizExpr = m_astBuilder->create(); + swizExpr->loc = memberExpr->loc; + swizExpr->base = memberExpr->baseExpression; + swizExpr->elementIndices = _Move(elementCoords); + swizExpr->memberOpLoc = memberExpr->memberOperatorLoc; + + if (swizExpr->elementIndices.getCount() == 1) + { + // single-component swizzle produces a scalar + // + swizExpr->type = QualType(baseTupleType->getMember(swizExpr->elementIndices[0])); + } + else + { + List types; + for (auto index : swizExpr->elementIndices) + { + types.add(baseTupleType->getMember(index)); + } + swizExpr->type = QualType(m_astBuilder->getTupleType(types)); + } + + // A swizzle can be used as an l-value as long as there + // were no duplicates in the list of components + swizExpr->type.isLeftValue = !anyDuplicates; + return swizExpr; + } + Expr* SemanticsVisitor::CheckSwizzleExpr( MemberExpr* memberRefExpr, Type* baseElementType, @@ -3823,15 +3971,10 @@ namespace Slang SwizzleExpr* swizExpr = m_astBuilder->create(); swizExpr->loc = memberRefExpr->loc; swizExpr->base = memberRefExpr->baseExpression; - swizExpr->elementIndices[0] = 0; - swizExpr->elementIndices[1] = 0; - swizExpr->elementIndices[2] = 0; - swizExpr->elementIndices[3] = 0; swizExpr->memberOpLoc = memberRefExpr->memberOperatorLoc; IntegerLiteralValue limitElement = baseElementCount; - int elementIndices[4]; - int elementCount = 0; + ShortList elementIndices; bool anyDuplicates = false; bool anyError = false; @@ -3875,35 +4018,31 @@ namespace Slang // If elementCount is already at 4 stop trying to assign a swizzle element and send an error, // we cannot have more valid swizzle elements than 4. - if (elementCount >= 4) + if (elementIndices.getCount() >= 4) { anyError = true; break; } // Check if we've seen this index before - for (int ee = 0; ee < elementCount; ee++) + for (int ee = 0; ee < elementIndices.getCount(); ee++) { - if (elementIndices[ee] == elementIndex) + if (elementIndices[ee] == (UInt)elementIndex) anyDuplicates = true; } // add to our list... - elementIndices[elementCount++] = elementIndex; + elementIndices.add(elementIndex); } - for (int ee = 0; ee < elementCount; ++ee) - { - swizExpr->elementIndices[ee] = elementIndices[ee]; - } - swizExpr->elementCount = elementCount; + swizExpr->elementIndices = _Move(elementIndices); if (anyError) { getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString()); return CreateErrorExpr(memberRefExpr); } - else if (elementCount == 1) + else if (swizExpr->elementIndices.getCount() == 1) { // single-component swizzle produces a scalar // @@ -3918,7 +4057,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount))); + m_astBuilder->getIntVal(m_astBuilder->getIntType(), swizExpr->elementIndices.getCount()))); } // A swizzle can be used as an l-value as long as there @@ -4255,6 +4394,32 @@ namespace Slang return baseExpr; } + Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType) + { + LookupResult lookupResult = lookUpMember( + m_astBuilder, + this, + expr->name, + baseType, + m_outerScope); + bool diagnosed = false; + lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); + if (!lookupResult.isValid()) + { + return lookupMemberResultFailure(expr, baseType, diagnosed); + } + if (expr->name == getSession()->getCompletionRequestTokenName()) + { + suggestCompletionItems(CompletionSuggestions::ScopeKind::Member, lookupResult); + } + return createLookupResultExpr( + expr->name, + lookupResult, + expr->baseExpression, + expr->loc, + expr); + } + Expr* SemanticsExprVisitor::visitMemberExpr(MemberExpr * expr) { bool needDeref = false; @@ -4279,10 +4444,9 @@ namespace Slang // because vectors are also declaration reference types... // // Also note: the way this is done right now means that the ability - // to swizzle vectors interferes with any chance o(baseType)) { return CheckMatrixSwizzleExpr( @@ -4322,34 +4486,17 @@ namespace Slang { return _lookupStaticMember(expr, expr->baseExpression); } + else if (auto baseTupleType = as(baseType)) + { + return checkTupleSwizzleExpr(expr, baseTupleType); + } else if (as(baseType)) { return CreateErrorExpr(expr); } else { - LookupResult lookupResult = lookUpMember( - m_astBuilder, - this, - expr->name, - baseType.Ptr(), - m_outerScope); - bool diagnosed = false; - lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); - if (!lookupResult.isValid()) - { - return lookupMemberResultFailure(expr, baseType, diagnosed); - } - if (expr->name == getSession()->getCompletionRequestTokenName()) - { - suggestCompletionItems(CompletionSuggestions::ScopeKind::Member, lookupResult); - } - return createLookupResultExpr( - expr->name, - lookupResult, - expr->baseExpression, - expr->loc, - expr); + return checkGeneralMemberLookupExpr(expr, baseType); } } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 7565b5472..fea81f68c 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2637,6 +2637,8 @@ namespace Slang IntVal* baseElementRowCount, IntVal* baseElementColCount); + Expr* checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* baseTupleType); + Expr* CheckSwizzleExpr( MemberExpr* memberRefExpr, Type* baseElementType, @@ -2647,6 +2649,10 @@ namespace Slang Type* baseElementType, IntVal* baseElementCount); + // Check a member expr as a general member lookup. + // This is the default/fallback behavior if the base type isn't swizzlable. + Expr* checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType); + /// Perform semantic checking of an assignment where the operands have already been checked. Expr* checkAssignWithCheckedOperands(AssignExpr* expr); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 16f6ed7da..29cb74c23 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -680,6 +680,10 @@ namespace Slang packArg->type = paramType; resultArgs.add(packArg); } + + // Always add a flat cost for using an argument pack, + // so that we prefer non-pack overloads when possible. + candidate.conversionCostSum += kConversionCost_ParameterPack; } else { @@ -1669,26 +1673,6 @@ namespace Slang // Try to match the variadic part. // Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack param. auto astBuilder = semantics->getASTBuilder(); - while (remainingArgCount > 0) - { - auto argType = getArgType(fixedParamCount); - if (auto typeType = as(argType)) - { - argType = typeType->getType(); - } - if (isAbstractTypePack(argType)) - { - MatchedArg arg; - arg.argExpr = getArg(fixedParamCount); - arg.argType = getArgType(fixedParamCount); - outMatchedArgs.add(arg); - fixedParamCount++; - remainingArgCount--; - typePackCount--; - continue; - } - break; - } if (remainingArgCount <= 0) return true; @@ -1704,6 +1688,24 @@ namespace Slang Index typePackSize = remainingArgCount / typePackCount; for (Index i = 0; i < typePackCount; ++i) { + // If type pack size is 1, we may not need to wrap things in a PackExpr, + // if the argument is already a pack. + if (typePackSize == 1) + { + auto argType = getArgType(fixedParamCount + i); + if (auto typeType = as(argType)) + { + argType = typeType->getType(); + } + if (isTypePack(argType)) + { + MatchedArg arg; + arg.argExpr = getArg(fixedParamCount + i); + arg.argType = getArgType(fixedParamCount + i); + outMatchedArgs.add(arg); + continue; + } + } PackExpr* packExpr = nullptr; if (mode == Mode::ForReal) { diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index a58b59c0c..7288befe8 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -359,6 +359,8 @@ DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is no DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-static function reference '$0' is not allowed here.") DIAGNOSTIC(30099, Error, sizeOfArgumentIsInvalid, "argument to sizeof is invalid") +DIAGNOSTIC(30083, Error, countOfArgumentIsInvalid, "argument to countof can only be a type pack or tuple") + DIAGNOSTIC(30101, Error, readingFromWriteOnly, "cannot read from writeonly, check modifiers.") DIAGNOSTIC(30102, Error, differentiableMemberShouldHaveCorrespondingFieldInDiffType, "differentiable member '$0' should have a corresponding field in '$1'. Use [DerivativeMember($1.)] or mark as no_diff") diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 0a9b2d691..9adbe42d5 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1656,6 +1656,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) { disableIRValidationAtInsert(); auto simplifyOptions = IRSimplificationOptions::getDefault(nullptr); + simplifyOptions.removeRedundancy = true; simplifyFunc(autoDiffSharedContext->targetProgram, func, simplifyOptions); enableIRValidationAtInsert(); } diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 0d5cb1c70..9b2b59cd9 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -277,6 +277,17 @@ struct InliningPassBase if(!isDefinition(calleeFunc)) return false; + // We cannot inline a call inside an `IRExpand`. + // Because this will make the cfg inside the `IRExpand` too complex, + // and our expand specialization logic isn't general enough to deal + // with that yet. + for (auto parent = call->getParent(); parent; parent = parent->getParent()) + { + if (as(parent)) + return false; + if (as(parent)) + break; + } return true; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index a4225d041..80c810620 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -245,6 +245,7 @@ INST(RTTIType, rtti_type, 0, HOISTABLE) INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE) INST(TupleType, tuple_type, 0, HOISTABLE) INST(TargetTupleType, TargetTuple, 0, HOISTABLE) +INST(TypePack, TypePack, 0, HOISTABLE) INST(ExpandTypeOrVal, ExpandTypeOrVal, 1, HOISTABLE) // A type that identifies it's contained type as being emittable as `spirv_literal. @@ -281,6 +282,8 @@ INST(StructKey, key, 0, GLOBAL) INST(GlobalGenericParam, global_generic_param, 0, GLOBAL) INST(WitnessTable, witness_table, 0, 0) +INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE) + // A placeholder witness that ThisType implements the enclosing interface. // Used only in interface definitions. INST(ThisTypeWitness, thisTypeWitness, 1, 0) @@ -343,6 +346,7 @@ INST(MakeArrayFromElement, makeArrayFromElement, 1, 0) INST(MakeStruct, makeStruct, 0, 0) INST(MakeTuple, makeTuple, 0, 0) INST(MakeTargetTuple, makeTuple, 0, 0) +INST(MakeValuePack, makeValuePack, 0, 0) INST(GetTargetTupleElement, getTargetTupleElement, 0, 0) INST(GetTupleElement, getTupleElement, 2, 0) INST(MakeWitnessPack, MakeWitnessPack, 0, HOISTABLE) @@ -1117,6 +1121,8 @@ INST(TreatAsDynamicUniform, TreatAsDynamicUniform, 1, 0) INST(SizeOf, sizeOf, 1, 0) INST(AlignOf, alignOf, 1, 0) +INST(CountOf, countOf, 1, 0) + INST(GetArrayLength, GetArrayLength, 1, 0) INST(IsType, IsType, 3, 0) INST(TypeEquals, TypeEquals, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index dc5fb2744..4111fb983 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2857,6 +2857,11 @@ struct IRMakeTuple : IRInst IR_LEAF_ISA(MakeTuple) }; +struct IRMakeValuePack : IRInst +{ + IR_LEAF_ISA(MakeValuePack) +}; + struct IRMakeWitnessPack : IRInst { IR_LEAF_ISA(MakeWitnessPack) @@ -3509,6 +3514,8 @@ public: IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2); IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2, IRType* type3); + IRTypePack* getTypePack(UInt count, IRType* const* types); + IRExpandType* getExpandTypeOrVal(IRType* type, IRInst* pattern, ArrayView capture); IRResultType* getResultType(IRType* valueType, IRType* errorType); @@ -3670,6 +3677,14 @@ public: return getAttributedType(baseType, attributes.getCount(), attributes.getBuffer()); } + IRInst* getIndexedFieldKey( + IRInst* baseType, + UInt fieldIndex) + { + IRInst* args[] = { baseType, getIntValue(getIntType(), fieldIndex) }; + return emitIntrinsicInst(getVoidType(), kIROp_IndexedFieldKey, 2, args); + } + IRMetalMeshGridPropertiesType* getMetalMeshGridPropertiesType() { return (IRMetalMeshGridPropertiesType*)getType(kIROp_MetalMeshGridPropertiesType); @@ -3872,6 +3887,9 @@ public: return emitMakeTuple(SLANG_COUNT_OF(args), args); } + IRInst* emitMakeValuePack(IRType* type, UInt count, IRInst* const* args); + IRInst* emitMakeValuePack(UInt count, IRInst* const* args); + IRInst* emitMakeWitnessPack(IRType* type, ArrayView args) { return emitIntrinsicInst(type, kIROp_MakeWitnessPack, (UInt)args.getCount(), args.getBuffer()); @@ -4364,6 +4382,8 @@ public: IRInst* emitAlignOf( IRInst* sizedType); + IRInst* emitCountOf(IRType* type, IRInst* sizedType); + IRInst* emitCastPtrToBool(IRInst* val); IRInst* emitCastPtrToInt(IRInst* val); IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val); diff --git a/source/slang/slang-ir-lower-tuple-types.cpp b/source/slang/slang-ir-lower-tuple-types.cpp index caa031d85..6177cfec2 100644 --- a/source/slang/slang-ir-lower-tuple-types.cpp +++ b/source/slang/slang-ir-lower-tuple-types.cpp @@ -85,7 +85,7 @@ namespace Slang workListSet.add(inst); } - void processMakeTuple(IRMakeTuple* inst) + void processMakeTuple(IRInst* inst) { IRBuilder builderStorage(module); auto builder = &builderStorage; @@ -121,6 +121,124 @@ namespace Slang inst->removeAndDeallocate(); } + void processGetElementPtr(IRGetElementPtr* inst) + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + + auto base = inst->getBase(); + auto baseValueType = tryGetPointedToType(&builder, base->getDataType()); + auto loweredTupleInfo = getLoweredTupleType(&builder, baseValueType); + if (!loweredTupleInfo) + return; + + auto elementIndex = getIntVal(inst->getIndex()); + SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount()); + + auto field = loweredTupleInfo->fields[(Index)elementIndex]; + auto getElement = builder.emitFieldAddress(builder.getPtrType(field->getFieldType()), base, field->getKey()); + inst->replaceUsesWith(getElement); + inst->removeAndDeallocate(); + } + + void processSwizzle(IRSwizzle* inst) + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + + auto base = inst->getBase(); + auto loweredTupleInfo = getLoweredTupleType(&builder, base->getDataType()); + + if (!loweredTupleInfo) + return; + + if (inst->getElementCount() == 1) + { + auto elementIndex = getIntVal(inst->getElementIndex(0)); + SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount()); + + auto field = loweredTupleInfo->fields[(Index)elementIndex]; + auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey()); + inst->replaceUsesWith(getElement); + inst->removeAndDeallocate(); + } + else + { + List elements; + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto elementIndex = getIntVal(inst->getElementIndex(i)); + SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount()); + + auto field = loweredTupleInfo->fields[(Index)elementIndex]; + auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey()); + elements.add(getElement); + } + auto resultTypeInfo = getLoweredTupleType(&builder, inst->getDataType()); + auto makeStruct = builder.emitMakeStruct(resultTypeInfo->structType, elements); + inst->replaceUsesWith(makeStruct); + inst->removeAndDeallocate(); + } + } + + void processSwizzleSet(IRSwizzleSet* inst) + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + + auto base = inst->getBase(); + auto loweredTupleInfo = getLoweredTupleType(&builder, base->getDataType()); + auto sourceTupleInfo = getLoweredTupleType(&builder, inst->getSource()->getDataType()); + if (!loweredTupleInfo) + return; + + List elements; + for (Index i = 0; i < loweredTupleInfo->fields.getCount(); i++) + { + auto field = loweredTupleInfo->fields[i]; + auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey()); + elements.add(getElement); + } + + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto baseIndex = getIntVal(inst->getElementIndex(i)); + auto sourceElement = sourceTupleInfo + ? builder.emitFieldExtract(sourceTupleInfo->fields[i]->getFieldType(), inst->getSource(), sourceTupleInfo->fields[i]->getKey()) + : inst->getSource(); + elements[baseIndex] = sourceElement; + } + auto resultTypeInfo = getLoweredTupleType(&builder, inst->getDataType()); + auto makeStruct = builder.emitMakeStruct(resultTypeInfo->structType, elements); + inst->replaceUsesWith(makeStruct); + inst->removeAndDeallocate(); + } + + void processSwizzledStore(IRSwizzledStore* inst) + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + + auto dest = inst->getDest(); + auto destValueType = tryGetPointedToType(&builder, dest->getDataType()); + auto loweredTupleInfo = getLoweredTupleType(&builder, destValueType); + auto sourceTupleInfo = getLoweredTupleType(&builder, inst->getSource()->getDataType()); + if (!loweredTupleInfo) + return; + + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto baseIndex = getIntVal(inst->getElementIndex(i)); + auto destField = loweredTupleInfo->fields[baseIndex]; + auto destFieldPtr = builder.emitFieldAddress(builder.getPtrType(destField->getFieldType()), dest, destField->getKey()); + auto sourceElement = sourceTupleInfo + ? builder.emitFieldExtract(sourceTupleInfo->fields[i]->getFieldType(), inst->getSource(), sourceTupleInfo->fields[i]->getKey()) + : inst->getSource(); + builder.emitStore(destFieldPtr, sourceElement); + } + inst->removeAndDeallocate(); + } + void processTupleType(IRTupleType* inst) { IRBuilder builderStorage(module); @@ -132,19 +250,47 @@ namespace Slang SLANG_UNUSED(loweredTupleInfo); } + void processIndexedFieldKey(IRIndexedFieldKey* inst) + { + IRBuilder builder(module); + auto loweredTupleInfo = getLoweredTupleType(&builder, inst->getBaseType()); + if (!loweredTupleInfo) + return; + auto fieldIndex = getIntVal(inst->getIndex()); + SLANG_ASSERT(fieldIndex >= 0 && (Index)fieldIndex < loweredTupleInfo->fields.getCount()); + inst->replaceUsesWith(loweredTupleInfo->fields[fieldIndex]->getKey()); + inst->removeAndDeallocate(); + } + void processInst(IRInst* inst) { switch (inst->getOp()) { case kIROp_MakeTuple: + case kIROp_MakeValuePack: processMakeTuple((IRMakeTuple*)inst); break; case kIROp_GetTupleElement: processGetTupleElement((IRGetTupleElement*)inst); break; + case kIROp_GetElementPtr: + processGetElementPtr((IRGetElementPtr*)inst); + break; + case kIROp_swizzle: + processSwizzle((IRSwizzle*)inst); + break; + case kIROp_swizzleSet: + processSwizzleSet((IRSwizzleSet*)inst); + break; + case kIROp_SwizzledStore: + processSwizzledStore((IRSwizzledStore*)inst); + break; case kIROp_TupleType: processTupleType((IRTupleType*)inst); break; + case kIROp_IndexedFieldKey: + processIndexedFieldKey((IRIndexedFieldKey*)inst); + break; default: break; } @@ -152,6 +298,32 @@ namespace Slang void processModule() { + // First, we want to replace all TypePack with TupleType. + + List typePacks; + for (auto inst : module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_TypePack) + { + typePacks.add(inst); + } + } + IRBuilder builder(module); + for (auto inst : typePacks) + { + builder.setInsertBefore(inst); + ShortList types; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + types.add((IRType*)inst->getOperand(i)); + } + auto tupleType = builder.getTupleType((UInt)types.getCount(), types.getArrayView().getBuffer()); + inst->replaceUsesWith(tupleType); + inst->removeAndDeallocate(); + } + + // Next, lower all tuples to structs. + addToWorkList(module->getModuleInst()); while (workList.getCount() != 0) diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index aa8dfddab..b5f5edb05 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -327,6 +327,7 @@ struct PeepholeContext : InstPassBase switch (inst->getOperand(0)->getOp()) { case kIROp_MakeTuple: + case kIROp_MakeValuePack: case kIROp_MakeWitnessPack: { auto element = inst->getOperand(1); diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index b679d8438..fa3cd444a 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -8,6 +8,19 @@ namespace Slang struct RedundancyRemovalContext { RefPtr dom; + bool isSingleIterationLoop(IRLoop* loop) + { + int useCount = 0; + for (auto use = loop->getBreakBlock()->firstUse; use; use = use->nextUse) + { + if (use->getUser() == loop) + continue; + useCount++; + if (useCount > 1) + return false; + } + return true; + } bool tryHoistInstToOuterMostLoop(IRGlobalValueWithCode* func, IRInst* inst) { @@ -17,7 +30,8 @@ struct RedundancyRemovalContext parentBlock = dom->getImmediateDominator(parentBlock)) { auto terminatorInst = parentBlock->getTerminator(); - if (terminatorInst->getOp() == kIROp_loop) + if (terminatorInst->getOp() == kIROp_loop && + !isSingleIterationLoop(as(terminatorInst))) { // Consider hoisting the inst into this block. // This is only possible if all operands of the inst are dominating `parentBlock`. diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 2eb16112f..c9e94352e 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -594,7 +594,165 @@ struct SpecializationContext case kIROp_GetTupleElement: return maybeSpecializeFoldableInst(inst); + + case kIROp_TypePack: + case kIROp_TupleType: + return maybeSpecializeTypePackOrTupleType(inst); + + case kIROp_MakeValuePack: + case kIROp_MakeTuple: + return maybeSpecializeMakeValuePackOrTuple(inst); + + case kIROp_CountOf: + return maybeSpecializeCountOf(inst); + } + } + + + void flattenPackOperand(ShortList& flattenedList, IRInst* inst) + { + if (auto makeValuePack = as(inst)) + { + for (UInt i = 0; i < makeValuePack->getOperandCount(); i++) + { + flattenPackOperand(flattenedList, makeValuePack->getOperand(i)); + } + } + else if (auto typePack = as(inst)) + { + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + flattenPackOperand(flattenedList, typePack->getOperand(i)); + } + } + else + { + SLANG_ASSERT(inst); + flattenedList.add(inst); + } + } + + bool maybeSpecializeTypePackOrTupleType(IRInst* inst) + { + // If any element of the type pack or tuple is a TypePack, we want to + // flatten that type pack into the current type pack or tuple. + + bool needProcess = false; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (as(inst->getOperand(i))) + { + needProcess = true; + break; + } + } + // If none of the operands are MakeValuePack, there is no need to flatten anything. + if (!needProcess) + return false; + + // We will recursively flatten all MakeValuePack operands. + ShortList flattendOperands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + flattenPackOperand(flattendOperands, operand); + } + + IRBuilder builder(module); + builder.setInsertBefore(inst); + IRInst* newInst; + if (inst->getOp() == kIROp_TypePack) + newInst = builder.getTypePack(flattendOperands.getCount(), (IRType* const*)flattendOperands.getArrayView().getBuffer()); + else + newInst = builder.getTupleType(flattendOperands.getCount(), (IRType* const*)flattendOperands.getArrayView().getBuffer()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + addUsersToWorkList(newInst); + return true; + } + + bool maybeSpecializeMakeValuePackOrTuple(IRInst* inst) + { + // If any element of the value pack or tuple is a ValuePack, we want to + // flatten that value pack into the current value pack or tuple. + + bool needProcess = false; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (as(inst->getOperand(i))) + { + needProcess = true; + break; + } } + // If none of the operands are MakeValuePack, there is no need to flatten anything. + if (!needProcess) + return false; + + // We will recursively flatten all MakeValuePack operands. + ShortList flattendOperands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + flattenPackOperand(flattendOperands, operand); + } + + IRBuilder builder(module); + builder.setInsertBefore(inst); + IRInst* newInst = nullptr; + if (inst->getOp() == kIROp_MakeValuePack) + newInst = builder.emitMakeValuePack(inst->getFullType(), flattendOperands.getCount(), flattendOperands.getArrayView().getBuffer()); + else + newInst = builder.emitMakeTuple(inst->getFullType(), flattendOperands.getCount(), flattendOperands.getArrayView().getBuffer()); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + addUsersToWorkList(newInst); + return true; + } + + bool maybeSpecializeCountOf(IRInst* inst) + { + auto operand = inst->getOperand(0); + + // If operand is a value, make sure we are working on its type. + + switch (operand->getOp()) + { + case kIROp_MakeValuePack: + case kIROp_MakeTuple: + operand = operand->getDataType(); + break; + } + + // We can only figure out the count of a type pack or tuple type. + switch (operand->getOp()) + { + case kIROp_TypePack: + case kIROp_TupleType: + break; + default: + return false; + } + + // If none of the element type is a TypePack, we can just return the count. + for (UInt i = 0; i < operand->getOperandCount(); i++) + { + switch (operand->getOperand(i)->getOp()) + { + case kIROp_Param: + case kIROp_TypePack: + case kIROp_ExpandTypeOrVal: + return false; + } + } + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto newInst = builder.getIntValue(inst->getDataType(), operand->getOperandCount()); + addUsersToWorkList(inst); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; } // Specializing lookup on witness tables is a general @@ -606,7 +764,7 @@ struct SpecializationContext { // Note: While we currently have named the instruction // `lookup_witness_method`, the `method` part is a misnomer - // and the same instruction can look up *any* interfacemay + // and the same instruction can look up *any* interface // requirement based on the witness table that provides // a conformance, and the "key" that indicates the interface // requirement. @@ -2268,7 +2426,7 @@ struct SpecializationContext for (UInt i = 0; i < expandInst->getCaptureCount(); i++) { - if (!as(expandInst->getCapture(i))) + if (!as(expandInst->getCapture(i))) return false; } @@ -2276,16 +2434,16 @@ struct SpecializationContext builder.setInsertBefore(expandInst); List elements; UInt elementCount = 0; - if (auto firstTupleType = as(expandInst->getCapture(0))) + if (auto firstTypePack = as(expandInst->getCapture(0))) { - elementCount = firstTupleType->getOperandCount(); + elementCount = firstTypePack->getOperandCount(); } if (elementCount == 0) { - auto resultTuple = builder.emitMakeTuple(0, (IRInst*const*)nullptr); - expandInst->replaceUsesWith(resultTuple); + auto resultValuePack = builder.emitMakeValuePack(0, (IRInst*const*)nullptr); + expandInst->replaceUsesWith(resultValuePack); expandInst->removeAndDeallocate(); - addUsersToWorkList(resultTuple); + addUsersToWorkList(resultValuePack); return true; } @@ -2328,7 +2486,7 @@ struct SpecializationContext } } - auto resultTuple = builder.emitMakeTuple(elements); + auto resultValuePack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer()); auto currentBlock = builder.getBlock(); for (auto nextInst = expandInst->next; nextInst;) { @@ -2337,7 +2495,7 @@ struct SpecializationContext nextInst = next; } addUsersToWorkList(expandInst); - expandInst->replaceUsesWith(resultTuple); + expandInst->replaceUsesWith(resultValuePack); expandInst->removeAndDeallocate(); return true; } @@ -2355,15 +2513,15 @@ struct SpecializationContext { auto eachInst = as(val); auto packInst = eachInst->getElement(); - if (auto tuple = as(packInst)) + if (auto typePack = as(packInst)) { - SLANG_RELEASE_ASSERT(indexInPack < tuple->getOperandCount()); - return tuple->getOperand(indexInPack); + SLANG_RELEASE_ASSERT(indexInPack < typePack->getOperandCount()); + return typePack->getOperand(indexInPack); } - else if (auto makeTuple = as(packInst)) + else if (auto makeValuePack = as(packInst)) { - SLANG_RELEASE_ASSERT(indexInPack < makeTuple->getOperandCount()); - return makeTuple->getOperand(indexInPack); + SLANG_RELEASE_ASSERT(indexInPack < makeValuePack->getOperandCount()); + return makeValuePack->getOperand(indexInPack); } else if (!as(packInst->getDataType())) { @@ -2413,24 +2571,18 @@ struct SpecializationContext if (expandInst->getCaptureCount() == 0) return false; - bool anyAbstractPack = false; for (UInt i = 0; i < expandInst->getCaptureCount(); i++) { - if (!as(expandInst->getCaptureType(i))) - { - anyAbstractPack = true; - break; - } + if (!as(expandInst->getCaptureType(i))) + return false; } - if (anyAbstractPack) - return false; IRBuilder builder(expandInst); builder.setInsertBefore(expandInst); List elements; UInt elementCount = 0; - if (auto firstTupleType = as(expandInst->getCaptureType(0))) + if (auto firstTypePack = as(expandInst->getCaptureType(0))) { - elementCount = firstTupleType->getOperandCount(); + elementCount = firstTypePack->getOperandCount(); } for (UInt i = 0; i < elementCount; i++) { @@ -2444,16 +2596,16 @@ struct SpecializationContext List types; for (auto element : elements) types.add(element->getDataType()); - auto newTupleType = builder.getTupleType(types); - auto result = builder.emitMakeWitnessPack(newTupleType, elements.getArrayView()); + auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer()); + auto result = builder.emitMakeWitnessPack(newTypePack, elements.getArrayView()); expandInst->replaceUsesWith(result); expandInst->removeAndDeallocate(); return true; } else { - auto newTupleType = builder.getTupleType(elements.getCount(), (IRType*const*)elements.getBuffer()); - expandInst->replaceUsesWith(newTupleType); + auto newTypePack = builder.getTypePack(elements.getCount(), (IRType*const*)elements.getBuffer()); + expandInst->replaceUsesWith(newTypePack); expandInst->removeAndDeallocate(); return true; } diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index b8d6360ad..cd0f67186 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -77,7 +77,7 @@ namespace Slang funcChanged = false; funcChanged |= applySparseConditionalConstantPropagation(func, sink); funcChanged |= peepholeOptimize(target, func); - if (!options.minimalOptimization) + if (options.removeRedundancy) funcChanged |= removeRedundancyInFunc(func); funcChanged |= simplifyCFG(func, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h index d524241ae..446b2d5b2 100644 --- a/source/slang/slang-ir-ssa-simplification.h +++ b/source/slang/slang-ir-ssa-simplification.h @@ -19,6 +19,7 @@ namespace Slang IRDeadCodeEliminationOptions deadCodeElimOptions; bool minimalOptimization = false; + bool removeRedundancy = false; static IRSimplificationOptions getDefault(TargetProgram* targetProgram); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index c97c04f88..1441b0567 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2802,6 +2802,11 @@ namespace Slang return getTupleType(SLANG_COUNT_OF(operands), operands); } + IRTypePack* IRBuilder::getTypePack(UInt count, IRType* const* types) + { + return (IRTypePack*)getType(kIROp_TypePack, count, (IRInst* const*)types); + } + IRExpandType* IRBuilder::getExpandTypeOrVal(IRType* type, IRInst* pattern, ArrayView capture) { ShortList args; @@ -4046,6 +4051,21 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitMakeValuePack(IRType* type, UInt count, IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_MakeValuePack, count, args); + } + + IRInst* IRBuilder::emitMakeValuePack(UInt count, IRInst* const* args) + { + ShortList types; + for (UInt i = 0; i < count; ++i) + types.add(args[i]->getFullType()); + + auto type = getTypePack((UInt)types.getCount(), types.getArrayView().getBuffer()); + return emitIntrinsicInst(type, kIROp_MakeValuePack, count, args); + } + IRInst* IRBuilder::emitMakeTuple(IRType* type, UInt count, IRInst* const* args) { return emitIntrinsicInst(type, kIROp_MakeTuple, count, args); @@ -5778,6 +5798,19 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitCountOf( + IRType* type, + IRInst* sizedType) + { + auto inst = createInst( + this, + kIROp_CountOf, + type, + sizedType); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBitCast( IRType* type, IRInst* val) diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index ececdad43..b1c2b001e 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1940,6 +1940,24 @@ struct IRTupleType : IRType IR_LEAF_ISA(TupleType) }; + +/// Represents a type pack. Type packs behave like tuples, but they have a +/// "flattening" semantics, so that MakeTypePack(MakeTypePack(T1,T2), T3) is +/// MakeTypePack(T1,T2,T3). +struct IRTypePack : IRType +{ + IR_LEAF_ISA(TypePack) +}; + +// A placeholder struct key for tuple type layouts that will be replaced with +// the actual struct key when the tuple type is materialized into a struct type. +struct IRIndexedFieldKey : IRInst +{ + IR_LEAF_ISA(IndexedFieldKey) + IRInst* getBaseType() { return getOperand(0); } + IRInst* getIndex() { return getOperand(1); } +}; + /// Represents a tuple in target language. TargetTupleType will not be lowered to structs. struct IRTargetTupleType : IRType { diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 5c73d2ab9..2d6ed2568 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -118,6 +118,27 @@ public: return dispatchIfNotNull(subscriptExpr->baseExpression); } + bool visitSizeOfLikeExpr(SizeOfLikeExpr* expr) + { + int tokenLength = 0; + if (as(expr)) + tokenLength = 7; // strlen("countof"); + else if (as(expr)) + tokenLength = 6; // strlen("sizeof"); + else if (as(expr)) + tokenLength = 7; // strlen("alignof"); + + if (_isLocInRange(context, expr->loc, tokenLength)) + { + ASTLookupResult result; + result.path = context->nodePath; + result.path.add(expr); + context->results.add(result); + return true; + } + return dispatchIfNotNull(expr->value); + } + bool visitParenExpr(ParenExpr* expr) { return dispatchIfNotNull(expr->base); @@ -225,7 +246,10 @@ public: } bool visitSwizzleExpr(SwizzleExpr* expr) { - if (_isLocInRange(context, expr->memberOpLoc, 0)) + Index tokenLength = expr->elementIndices.getCount(); + if (expr->base && as(expr->base->type)) + tokenLength *= 2; + if (_isLocInRange(context, expr->loc, tokenLength)) { ASTLookupResult result; result.path = context->nodePath; diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index b723e14b8..bee8f088a 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -674,6 +674,20 @@ List CompletionContext::createSwizzleCan } } } + else if (auto tupleType = as(type)) + { + auto count = Math::Min((int)elementCount[0], 4); + for (int i = 0; i < count; i++) + { + LanguageServerProtocol::CompletionItem item; + item.data = 0; + if (tupleType->getMember(i)) + item.detail = tupleType->getMember(i)->toString(); + item.kind = LanguageServerProtocol::kCompletionItemKindVariable; + item.label = String("_") + String(i); + result.add(item); + } + } for (auto& item : result) { for (auto ch : getCommitChars()) diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index 81c7d24cf..9588e7284 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -701,6 +701,40 @@ SlangResult LanguageServer::hover( } fillLoc(expr->loc); } + else if (auto swizzleExpr = as(expr)) + { + if (expr->type && swizzleExpr->base && swizzleExpr->base->type) + { + bool isTupleType = as(swizzleExpr->base->type) != nullptr; + sb << "```\n"; + swizzleExpr->type->toText(sb); + sb << " "; + swizzleExpr->base->type->toText(sb); + sb << "."; + for (auto index : swizzleExpr->elementIndices) + { + if (isTupleType || index > 4) + sb << "_" << index; + else + sb << "xyzw"[index]; + } + sb << "\n```\n"; + fillLoc(expr->loc); + } + } + else if (auto countOfExpr = as(expr)) + { + if (countOfExpr->sizedType) + { + if (auto foldedVal = as(CountOfIntVal::tryFoldOrNull(version->linkage->getASTBuilder(), expr->type.type, countOfExpr->sizedType))) + { + sb << "```\n" << "countof("; + countOfExpr->sizedType->toText(sb); + sb << ") = " << foldedVal->getValue() << "\n```\n"; + fillLoc(expr->loc); + } + } + } if (const auto higherOrderExpr = as(expr)) { String documentation; @@ -740,6 +774,14 @@ SlangResult LanguageServer::hover( { fillExprHoverInfo(thisExprExpr); } + else if (auto countOfExpr = as(leafNode)) + { + fillExprHoverInfo(countOfExpr); + } + else if (auto swizzleExpr = as(leafNode)) + { + fillExprHoverInfo(swizzleExpr); + } else if (auto importDecl = as(leafNode)) { auto moduleLoc = getModuleLoc(version->linkage->getSourceManager(), importDecl->importedModuleDecl); diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index c1895d754..7acaa030b 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -387,53 +387,16 @@ static void _lookUpMembersInSuperType( _lookUpMembersInSuperTypeImpl(astBuilder, name, leafType, superType, leafIsSuperWitness, request, ioResult, &breadcrumb); } -static void _lookUpMembersInSuperTypeDeclImpl( - ASTBuilder* astBuilder, +static void _lookupMembersInSuperTypeFacets(ASTBuilder* astBuilder, Name* name, - DeclRef declRef, + Type* selfType, + InheritanceInfo const& inheritanceInfo, LookupRequest const& request, LookupResult& ioResult, BreadcrumbInfo* inBreadcrumbs) { - auto semantics = request.semantics; - if (!as(declRef.getDecl()) && getText(name) == "This") - { - // If we are looking for `This` in anything other than an InterfaceDecl, - // we just need to return the declRef itself. - AddToLookupResult(ioResult, CreateLookupResultItem(declRef, inBreadcrumbs)); - return; - } - - // If the semantics context hasn't been established yet (e.g. when looking up during parsing), - // we simply do a direct lookup without considering subtypes or extensions. - // - if (!semantics) - { - // In this case we can only lookup in an aggregate type. - if (auto aggTypeDeclBaseRef = declRef.as()) - { - _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef.getDecl(), aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs); - } - return; - } - - ensureDecl(semantics, declRef.getDecl(), DeclCheckState::ReadyForLookup); - - // With semantics context, we can do a comprehensive lookup by scanning through - // the linearized inheritance list. + - auto selfType = DeclRefType::create(astBuilder, declRef); - InheritanceInfo inheritanceInfo; - if (auto extDeclRef = declRef.as()) - { - inheritanceInfo = semantics->getShared()->getInheritanceInfo(extDeclRef); - } - else - { - selfType = selfType->getCanonicalType(); - inheritanceInfo = semantics->getShared()->getInheritanceInfo(selfType); - } - for (auto facet : inheritanceInfo.facets) { auto containerDeclRef = facet->getDeclRef().as(); @@ -457,12 +420,12 @@ static void _lookUpMembersInSuperTypeDeclImpl( continue; } // If we are looking up only immediate members, ignore non "Self" facets or extension to "Self" - else if (int(request.options) & int(LookupOptions::IgnoreInheritance) + else if (int(request.options) & int(LookupOptions::IgnoreInheritance) && (facet.getImpl()->directness != Facet::Directness::Self && (!extensionFacet || !extensionFacet->targetType.type->equals(selfType)) )) { - continue; + continue; } // Some things that are syntactically `InheritanceDecl`s don't actually @@ -527,6 +490,56 @@ static void _lookUpMembersInSuperTypeDeclImpl( } } +static void _lookUpMembersInSuperTypeDeclImpl( + ASTBuilder* astBuilder, + Name* name, + DeclRef declRef, + LookupRequest const& request, + LookupResult& ioResult, + BreadcrumbInfo* inBreadcrumbs) +{ + auto semantics = request.semantics; + if (!as(declRef.getDecl()) && getText(name) == "This") + { + // If we are looking for `This` in anything other than an InterfaceDecl, + // we just need to return the declRef itself. + AddToLookupResult(ioResult, CreateLookupResultItem(declRef, inBreadcrumbs)); + return; + } + + // If the semantics context hasn't been established yet (e.g. when looking up during parsing), + // we simply do a direct lookup without considering subtypes or extensions. + // + if (!semantics) + { + // In this case we can only lookup in an aggregate type. + if (auto aggTypeDeclBaseRef = declRef.as()) + { + _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef.getDecl(), aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs); + } + return; + } + + ensureDecl(semantics, declRef.getDecl(), DeclCheckState::ReadyForLookup); + + // With semantics context, we can do a comprehensive lookup by scanning through + // the linearized inheritance list. + + auto selfType = DeclRefType::create(astBuilder, declRef); + InheritanceInfo inheritanceInfo; + if (auto extDeclRef = declRef.as()) + { + inheritanceInfo = semantics->getShared()->getInheritanceInfo(extDeclRef); + } + else + { + selfType = selfType->getCanonicalType(); + inheritanceInfo = semantics->getShared()->getInheritanceInfo(selfType); + } + + _lookupMembersInSuperTypeFacets(astBuilder, name, selfType, inheritanceInfo, request, ioResult, inBreadcrumbs); +} + static void _lookUpMembersInSuperTypeImpl( ASTBuilder* astBuilder, Name* name, @@ -565,6 +578,12 @@ static void _lookUpMembersInSuperTypeImpl( _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, declRef, request, ioResult, inBreadcrumbs); } + else if (auto eachType = as(superType)) + { + auto canEachType = eachType->getCanonicalType(); + InheritanceInfo inheritanceInfo = request.semantics->getShared()->getInheritanceInfo(canEachType); + _lookupMembersInSuperTypeFacets(astBuilder, name, canEachType, inheritanceInfo, request, ioResult, inBreadcrumbs); + } else if (auto extractExistentialType = as(superType)) { // We want lookup to be performed on the underlying interface type of the existential, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 9ceb3074a..44c62e858 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -336,11 +336,8 @@ struct SwizzledLValueInfo : ExtendedValueInfo // The base expression (this should be an l-value) LoweredValInfo base; - // The number of elements in the swizzle - UInt elementCount; - // THe indices for the elements being swizzled - UInt elementIndices[4]; + ShortList elementIndices; }; // Represents the result of a matrix swizzle operation in an l-value context. @@ -1183,8 +1180,8 @@ top: return LoweredValInfo::simple(builder->emitSwizzle( swizzleInfo->type, getSimpleVal(context, swizzleInfo->base), - swizzleInfo->elementCount, - swizzleInfo->elementIndices)); + swizzleInfo->elementIndices.getCount(), + swizzleInfo->elementIndices.getArrayView().getBuffer())); } case LoweredValInfo::Flavor::SwizzledMatrixLValue: @@ -1656,6 +1653,15 @@ struct ValLoweringVisitor : ValVisitorgetType()); + auto typeArg = lowerType(context, as(val->getTypeArg())); + auto count = irBuilder->emitCountOf(type, typeArg); + return LoweredValInfo::simple(count); + } + LoweredValInfo visitConcreteTypePack(ConcreteTypePack* typePack) { ShortList types; @@ -1665,7 +1671,7 @@ struct ValLoweringVisitor : ValVisitorgetTupleType((UInt)types.getCount(), types.getArrayView().getBuffer()); + IRType* irTypePack = irBuilder->getTypePack((UInt)types.getCount(), types.getArrayView().getBuffer()); return LoweredValInfo::simple(irTypePack); } @@ -4038,6 +4044,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor const auto size = naturalLayoutContext.calcSize(sizeOfLikeExpr->sizedType); auto builder = getBuilder(); + auto resultType = lowerType(context, sizeOfLikeExpr->type); if (!size) { @@ -4051,10 +4058,15 @@ struct ExprLoweringVisitorBase : public ExprVisitor { inst = builder->emitAlignOf(sizedType); } - else + else if (as(sizeOfLikeExpr)) { inst = builder->emitSizeOf(sizedType); } + else + { + + inst = builder->emitCountOf(resultType, sizedType); + } return LoweredValInfo::simple(inst); } @@ -4064,7 +4076,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor size.size : size.alignment; - return LoweredValInfo::simple(getBuilder()->getIntValue(builder->getUIntType(), value)); + return LoweredValInfo::simple(getBuilder()->getIntValue(resultType, value)); } LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/) @@ -4422,7 +4434,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor { irArgs.add(getSimpleVal(context, lowerSubExpr(arg))); } - auto irMakeTuple = getBuilder()->emitMakeTuple(irArgs); + auto irMakeTuple = getBuilder()->emitMakeValuePack((UInt)irArgs.getCount(), irArgs.getBuffer()); return LoweredValInfo::simple(irMakeTuple); } @@ -5367,10 +5379,10 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBasetype); auto loweredBase = lowerLValueExpr(context, expr->base); - UInt elementCount = (UInt)expr->elementCount; + UInt elementCount = (UInt)expr->elementIndices.getCount(); // Assign to 'bs' the elements from 'as' according to the first 'n' indices in 'is' - auto backpermute = [](UInt n, const auto* as, const int* is, auto* bs) + auto backpermute = [](UInt n, const auto as, const auto is, auto bs) { for(UInt i = 0; i < n; ++i) { @@ -5397,7 +5409,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase swizzledLValue = new SwizzledLValueInfo; swizzledLValue->type = irType; swizzledLValue->base = baseSwizzleInfo->base; - swizzledLValue->elementCount = elementCount; + swizzledLValue->elementIndices = elementCount; // Take the swizzle element of the "outer" swizzle, as it was // written by the user. In our running example of `foo[i].zw.y` @@ -5406,7 +5418,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBaseelementCount, + swizzledLValue->elementIndices.getCount(), baseSwizzleInfo->elementIndices, expr->elementIndices, swizzledLValue->elementIndices); @@ -5439,17 +5451,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase swizzledLValue = new SwizzledLValueInfo; swizzledLValue->type = irType; swizzledLValue->base = loweredBase; - swizzledLValue->elementCount = elementCount; - - // In the default case, we can just copy the indices being - // used for the swizzle over directly from the expression, - // and use the base as-is. - // - for (UInt ii = 0; ii < elementCount; ++ii) - { - swizzledLValue->elementIndices[ii] = (UInt) expr->elementIndices[ii]; - } - + swizzledLValue->elementIndices = expr->elementIndices; context->shared->extValues.add(swizzledLValue); return LoweredValInfo::swizzledLValue(swizzledLValue); } @@ -5460,8 +5462,6 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBasetype); @@ -5517,9 +5517,9 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBaseelementCount; - IRInst* irElementIndices[4]; - for (UInt ii = 0; ii < elementCount; ++ii) + ShortList irElementIndices; + irElementIndices.setCount(expr->elementIndices.getCount()); + for (UInt ii = 0; ii < (UInt)expr->elementIndices.getCount(); ++ii) { irElementIndices[ii] = builder->getIntValue( irIntType, @@ -5529,7 +5529,7 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBaseemitSwizzle( irType, irBase, - elementCount, + (UInt)irElementIndices.getCount(), &irElementIndices[0]); return LoweredValInfo::simple(irSwizzle); @@ -6926,7 +6926,7 @@ LoweredValInfo tryGetAddress( auto originalSwizzleInfo = val.getSwizzledLValueInfo(); auto originalBase = originalSwizzleInfo->base; - UInt elementCount = originalSwizzleInfo->elementCount; + UInt elementCount = (UInt)originalSwizzleInfo->elementIndices.getCount(); auto newBase = tryGetAddress(context, originalBase, TryGetAddressMode::Aggressive); if (newBase.flavor == LoweredValInfo::Flavor::Ptr && elementCount == 1) @@ -6942,7 +6942,7 @@ LoweredValInfo tryGetAddress( newSwizzleInfo->base = newBase; newSwizzleInfo->type = originalSwizzleInfo->type; - newSwizzleInfo->elementCount = elementCount; + newSwizzleInfo->elementIndices.setCount(elementCount); for(UInt ee = 0; ee < elementCount; ++ee) newSwizzleInfo->elementIndices[ee] = originalSwizzleInfo->elementIndices[ee]; @@ -7124,8 +7124,8 @@ top: irLeftVal->getDataType(), irLeftVal, irRightVal, - swizzleInfo->elementCount, - swizzleInfo->elementIndices); + (UInt)swizzleInfo->elementIndices.getCount(), + swizzleInfo->elementIndices.getArrayView().getBuffer()); // And finally, store the value back where we got it. // @@ -7158,8 +7158,8 @@ top: swizzledStore( loweredBase.val, irRightVal, - swizzleInfo->elementCount, - swizzleInfo->elementIndices); + (UInt)swizzleInfo->elementIndices.getCount(), + swizzleInfo->elementIndices.getArrayView().getBuffer()); } break; } @@ -11525,7 +11525,7 @@ IRTypeLayout* lowerTypeLayout( else if( auto structTypeLayout = as(typeLayout) ) { IRStructTypeLayout::Builder builder(context->irBuilder); - + int fieldIndex = 0; for( auto fieldLayout : structTypeLayout->fields ) { auto fieldDecl = fieldLayout->varDecl; @@ -11573,11 +11573,24 @@ IRTypeLayout* lowerTypeLayout( context->mapEntryPointParamToKey.add(paramDecl.getDecl(), irFieldKey); } } - else + else if (fieldDecl.getDecl()) { irFieldKey = getSimpleVal(context, ensureDecl(context, fieldDecl.getDecl())); } + else + { + // If we don't have a concrete field decl for the field in the layout, + // it could be that the field in the layout is for a member of a tuple + // type that hasn't been materialized into a struct decl yet. + // We will use a `IndexFieldKey(type, memberIndex)` inst as a placeholder + // for the field key. + // This placeholder can be replaced with the actual field key when the + // tuple type is materialized into a struct type. + auto irType = lowerType(context, typeLayout->getType()); + irFieldKey = context->irBuilder->getIndexedFieldKey(irType, fieldIndex); + } + fieldIndex++; SLANG_ASSERT(irFieldKey); auto irFieldLayout = lowerVarLayout(context, fieldLayout); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index c1f0f91c2..5881e5796 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -6281,6 +6281,23 @@ namespace Slang return alignOfExpr; } + static NodeBase* parseCountOfExpr(Parser* parser, void* /*userData*/) + { + // We could have a type or a variable or an expression + CountOfExpr* countOfExpr = parser->astBuilder->create(); + + parser->ReadMatchingToken(TokenType::LParent); + + // The return type is always an Int + countOfExpr->type = parser->astBuilder->getIntType(); + + countOfExpr->value = parser->ParseExpression(); + + parser->ReadMatchingToken(TokenType::RParent); + + return countOfExpr; + } + static NodeBase* parseTryExpr(Parser* parser, void* /*userData*/) { auto tryExpr = parser->astBuilder->create(); @@ -7549,7 +7566,7 @@ namespace Slang { ExpandExpr* expandExpr = parser->astBuilder->create(); expandExpr->loc = loc; - expandExpr->baseExpr = parser->ParseExpression(); + expandExpr->baseExpr = parser->ParseArgExpr(); return expandExpr; } @@ -7557,7 +7574,7 @@ namespace Slang { EachExpr* eachExpr = parser->astBuilder->create(); eachExpr->loc = loc; - eachExpr->baseExpr = parser->ParseExpression(); + eachExpr->baseExpr = parser->ParseLeafExpression(); return eachExpr; } @@ -8515,6 +8532,7 @@ namespace Slang _makeParseExpr("__dispatch_kernel", parseDispatchKernel), _makeParseExpr("sizeof", parseSizeOfExpr), _makeParseExpr("alignof", parseAlignOfExpr), + _makeParseExpr("countof", parseCountOfExpr), }; ConstArrayView getSyntaxParseInfos() diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h index 40129b083..6d06d2400 100644 --- a/source/slang/slang-serialize-type-info.h +++ b/source/slang/slang-serialize-type-info.h @@ -242,6 +242,31 @@ struct SerialTypeInfo> } }; +// ShortList +template +struct SerialTypeInfo> +{ + typedef ShortList NativeType; + typedef SerialIndex SerialType; + + enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) }; + + static void toSerial(SerialWriter* writer, const void* native, void* serial) + { + auto& src = *(const NativeType*)native; + auto& dst = *(SerialType*)serial; + + dst = writer->addArray(src.getArrayView().getBuffer(), src.getCount()); + } + static void toNative(SerialReader* reader, const void* serial, void* native) + { + auto& dst = *(NativeType*)native; + auto& src = *(const SerialType*)serial; + + reader->getArray(src, dst); + } +}; + // String template <> struct SerialTypeInfo diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index a91ff21e9..e8786d561 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -241,6 +241,9 @@ public: template void getArray(SerialIndex index, List& out); + template + void getArray(SerialIndex index, ShortList& out); + const void* getArray(SerialIndex index, Index& outCount); SerialPointer getPointer(SerialIndex index); @@ -334,6 +337,38 @@ void SerialReader::getArray(SerialIndex index, List& out) } } +template +void SerialReader::getArray(SerialIndex index, ShortList& out) +{ + typedef SerialTypeInfo ElementTypeInfo; + typedef typename ElementTypeInfo::SerialType ElementSerialType; + + Index count; + auto serialElements = (const ElementSerialType*)getArray(index, count); + + if (count == 0) + { + out.clear(); + return; + } + + if (std::is_same::value) + { + // If they are the same we can just write out + out.clear(); + out.addRange((const T*)serialElements, count); + } + else + { + // Else we need to convert + out.setCount(count); + for (Index i = 0; i < count; ++i) + { + ElementTypeInfo::toNative(this, (const void*)&serialElements[i], (void*)&out[i]); + } + } +} + /* This is a class used tby toSerial implementations to turn native type into the serial type */ class SerialWriter : public RefObject { diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index f85fb2c5f..f5fbfafdf 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -4396,6 +4396,77 @@ static TypeLayoutResult _createTypeLayout( type, rules); } + else if (auto tupleType = as(type)) + { + // A `Tuple` type is laid out exactly the same way as a `struct` type, + // except that we want have a declref to the field. + + StructTypeLayoutBuilder typeLayoutBuilder; + StructTypeLayoutBuilder pendingDataTypeLayoutBuilder; + + typeLayoutBuilder.beginLayout(type, rules); + auto typeLayout = typeLayoutBuilder.getTypeLayout(); + + _addLayout(context, type, typeLayout); + for (Index i = 0; i < tupleType->getMemberCount(); i++) + { + // The members of a `Tuple` type may include existential (interface) + // types (including as nested sub-fields), and any types present + // in those fields will need to be specialized based on the + // input arguments being passed to `_createTypeLayout`. + // + // We won't know how many type slots each field consumes until + // we process it, but we can figure out the starting index for + // the slots its will consume by looking at the layout we've + // computed so far. + // + Int baseExistentialSlotIndex = 0; + if (auto resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam)) + baseExistentialSlotIndex = Int(resInfo->count.getFiniteValue()); + // + // When computing the layout for the field, we will give it access + // to all the incoming specialized type slots that haven't already + // been consumed/claimed by preceding fields. + // + auto fieldLayoutContext = context.withSpecializationArgsOffsetBy(baseExistentialSlotIndex); + + auto elementType = tupleType->getMember(i); + TypeLayoutResult fieldResult = _createTypeLayout( + fieldLayoutContext, + elementType, + nullptr); + auto fieldTypeLayout = fieldResult.layout; + + auto fieldVarLayout = typeLayoutBuilder.addField(DeclRef(), fieldResult); + + // If any of the members of the `Tuple` type had existential/interface + // type, then we need to compute a second `StructTypeLayout` that + // represents the layout and resource using for the "pending data" + // that this type needs to have stored somewhere, but which can't + // be laid out in the layout of the type itself. + // + if (auto fieldPendingDataTypeLayout = fieldTypeLayout->pendingDataTypeLayout) + { + // We only create this secondary layout on-demand, so that + // we don't end up with a bunch of empty structure type layouts + // created for no reason. + // + pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(type, rules); + auto fieldPendingVarLayout = pendingDataTypeLayoutBuilder.addField(DeclRef(), fieldPendingDataTypeLayout); + fieldVarLayout->pendingVarLayout = fieldPendingVarLayout; + } + } + + typeLayoutBuilder.endLayout(); + pendingDataTypeLayoutBuilder.endLayout(); + + if (auto pendingDataTypeLayout = pendingDataTypeLayoutBuilder.getTypeLayout()) + { + typeLayout->pendingDataTypeLayout = pendingDataTypeLayout; + } + + return _updateLayout(context, type, typeLayoutBuilder.getTypeLayoutResult()); + } else if (auto declRefType = as(type)) { auto declRef = declRefType->getDeclRef(); -- cgit v1.2.3