summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-19 15:03:56 -0700
committerGitHub <noreply@github.com>2024-08-19 15:03:56 -0700
commit453683bf44f2112719802eaac2b332d49eebd640 (patch)
treed399db4c9cba90c11980186d3df1ffcc4d423b5a /source/slang
parentecf85df6eee3da76ef54b14e4ab083f22da89e46 (diff)
Tuple swizzling, concat, comparison and `countof`. (#4856)
* Tuple swizzling and element access. * Update proposal status. * Cleanup. * Fix merrge error. * Address review.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/core.meta.slang79
-rw-r--r--source/slang/slang-ast-builder.cpp19
-rw-r--r--source/slang/slang-ast-dump.cpp25
-rw-r--r--source/slang/slang-ast-expr.h8
-rw-r--r--source/slang/slang-ast-support-types.h3
-rw-r--r--source/slang/slang-ast-val.cpp75
-rw-r--r--source/slang/slang-ast-val.h24
-rw-r--r--source/slang/slang-check-expr.cpp237
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--source/slang/slang-check-overload.cpp42
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-inline.cpp11
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h20
-rw-r--r--source/slang/slang-ir-lower-tuple-types.cpp174
-rw-r--r--source/slang/slang-ir-peephole.cpp1
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp16
-rw-r--r--source/slang/slang-ir-specialize.cpp210
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp2
-rw-r--r--source/slang/slang-ir-ssa-simplification.h1
-rw-r--r--source/slang/slang-ir.cpp33
-rw-r--r--source/slang/slang-ir.h18
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp26
-rw-r--r--source/slang/slang-language-server-completion.cpp14
-rw-r--r--source/slang/slang-language-server.cpp42
-rw-r--r--source/slang/slang-lookup.cpp105
-rw-r--r--source/slang/slang-lower-to-ir.cpp91
-rw-r--r--source/slang/slang-parser.cpp22
-rw-r--r--source/slang/slang-serialize-type-info.h25
-rw-r--r--source/slang/slang-serialize.h35
-rw-r--r--source/slang/slang-type-layout.cpp71
32 files changed, 1257 insertions, 187 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 6c51ccef0..84e1b8168 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -874,10 +874,85 @@ __generic<each T>
__magic_type(TupleType)
struct Tuple
{
- __intrinsic_op($(0))
+ __intrinsic_op($(kIROp_MakeTuple))
__init(expand each T);
}
+__intrinsic_op($(kIROp_MakeTuple))
+Tuple<T> makeTuple<each T>(T v);
+
+Tuple<T, U> concat<each T, each U>(Tuple<T> t, Tuple<U> u)
+{
+ return makeTuple(expand each t, expand each u);
+}
+
+
+[__unsafeForceInlineEarly]
+bool __assign(inout bool v, bool newVal)
+{
+ v = newVal;
+ return newVal;
+}
+
+[__unsafeForceInlineEarly]
+void __tupleLessKernel<T : IComparable>(inout bool result, inout bool exit, T a, T b)
+{
+ if (!exit)
+ {
+ if (a.lessThan(b))
+ {
+ result = true;
+ exit = true;
+ }
+ else if (!a.equals(b))
+ {
+ exit = true;
+ }
+ }
+}
+
+[__unsafeForceInlineEarly]
+void __tupleGreaterKernel<T : IComparable>(inout bool result, inout bool exit, T a, T b)
+{
+ if (!exit)
+ {
+ if (!a.lessThanOrEquals(b))
+ {
+ result = true;
+ exit = true;
+ }
+ else if (!a.equals(b))
+ {
+ exit = true;
+ }
+ }
+}
+
+__generic<each T : IComparable>
+extension Tuple<T> : IComparable
+{
+ bool lessThan(Tuple<T> other)
+ {
+ bool result = false;
+ bool exit = false;
+ expand __tupleLessKernel(result, exit, each this, each other);
+ return result;
+ }
+ bool lessThanOrEquals(Tuple<T> other)
+ {
+ bool result = false;
+ bool exit = false;
+ expand __tupleGreaterKernel(result, exit, each this, each other);
+ return !result;
+ }
+ bool equals(Tuple<T> other)
+ {
+ bool result = true;
+ expand result && __assign(result, result && (each this).equals(each other));
+ return result;
+ }
+}
+
__generic<T>
__magic_type(NativeRefType)
__intrinsic_type($(kIROp_NativePtrType))
@@ -2181,7 +2256,7 @@ __generic<T : IComparable>
[OverloadRank(-10)]
bool operator >=(T v0, T v1)
{
- return v1.lessThan(v1);
+ return v1.lessThanOrEquals(v0);
}
__generic<T : IComparable>
[__unsafeForceInlineEarly]
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 3a2b2933d..a13e13851 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -572,9 +572,26 @@ Type* ASTBuilder::getExpandType(Type* pattern, ArrayView<Type*> capturedPacks)
return getOrCreate<ExpandType>(pattern, capturedPacks);
}
+void flattenTypeList(ShortList<Type*>& flattenedList, Type* type)
+{
+ if (auto typePack = as<ConcreteTypePack>(type))
+ {
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ flattenTypeList(flattenedList, typePack->getElementType(i));
+ }
+ else
+ {
+ flattenedList.add(type);
+ }
+}
+
ConcreteTypePack* ASTBuilder::getTypePack(ArrayView<Type*> types)
{
- return getOrCreate<ConcreteTypePack>(types);
+ // Flatten all type packs in the type list.
+ ShortList<Type*> flattenedTypes;
+ for (auto type : types)
+ flattenTypeList(flattenedTypes, type);
+ return getOrCreate<ConcreteTypePack>(flattenedTypes.getArrayView().arrayView);
}
TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness(
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index 8b2494310..a1ab7a5c8 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -203,6 +203,27 @@ struct ASTDumpContext
m_writer->emit("}");
}
+ template <typename T, int n>
+ void dump(const ShortList<T, n>& list)
+ {
+ m_writer->emit(" { \n");
+ m_writer->indent();
+ for (Index i = 0; i < list.getCount(); ++i)
+ {
+ dump(list[i]);
+ if (i < list.getCount() - 1)
+ {
+ m_writer->emit(",\n");
+ }
+ else
+ {
+ m_writer->emit("\n");
+ }
+ }
+ m_writer->dedent();
+ m_writer->emit("}");
+ }
+
void dump(SourceLoc sourceLoc)
{
if (m_dumpFlags & ASTDumpUtil::Flag::HideSourceLoc)
@@ -285,6 +306,10 @@ struct ASTDumpContext
{
m_writer->emit(UInt(v));
}
+ void dump(UInt v)
+ {
+ m_writer->emit(v);
+ }
void dump(int32_t v)
{
m_writer->emit(v);
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index e6edce8f9..c07f7f5b9 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -291,8 +291,7 @@ class SwizzleExpr: public Expr
{
SLANG_AST_CLASS(SwizzleExpr)
Expr* base = nullptr;
- int elementCount;
- int elementIndices[4];
+ ShortList<UInt, 4> elementIndices;
SourceLoc memberOpLoc;
};
@@ -425,6 +424,11 @@ class AlignOfExpr : public SizeOfLikeExpr
SLANG_AST_CLASS(AlignOfExpr);
};
+class CountOfExpr : public SizeOfLikeExpr
+{
+ SLANG_AST_CLASS(CountOfExpr);
+};
+
class MakeOptionalExpr : public Expr
{
SLANG_AST_CLASS(MakeOptionalExpr)
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index f74c169e7..21948dc04 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -147,6 +147,9 @@ namespace Slang
// Cost of converting an integer to a half type
kConversionCost_IntegerToHalfConversion = 500,
+ // Cost of using a concrete argument pack
+ kConversionCost_ParameterPack = 500,
+
// Default case (usable for user-defined conversions)
kConversionCost_Default = 500,
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index e222e86e1..e8020aa04 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -1534,6 +1534,81 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio
return this;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! CountOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+void CountOfIntVal::_toTextOverride(StringBuilder& out)
+{
+ out << "countof(";
+ getTypeArg()->toText(out);
+ out << ")";
+}
+
+Val* CountOfIntVal::tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType)
+{
+ if (auto typePack = as<ConcreteTypePack>(newType))
+ {
+ bool anyAbstract = false;
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ if (isAbstractTypePack(typePack->getElementType(i)))
+ {
+ anyAbstract = true;
+ break;
+ }
+ }
+ if (!anyAbstract)
+ {
+ auto result = astBuilder->getIntVal(intType, typePack->getTypeCount());
+ return result;
+ }
+ }
+ else if (auto tupleType = as<TupleType>(newType))
+ {
+ bool anyAbstract = false;
+ for (Index i = 0; i < tupleType->getMemberCount(); i++)
+ {
+ if (isAbstractTypePack(tupleType->getMember(i)))
+ {
+ anyAbstract = true;
+ break;
+ }
+ }
+ if (!anyAbstract)
+ {
+ auto result = astBuilder->getIntVal(intType, tupleType->getMemberCount());
+ return result;
+ }
+ }
+ return nullptr;
+}
+
+Val* CountOfIntVal::tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType)
+{
+ if (auto result = tryFoldOrNull(astBuilder, intType, newType))
+ return result;
+ auto result = astBuilder->getOrCreate<CountOfIntVal>(intType, newType);
+ return result;
+}
+
+Val* CountOfIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ auto newType = as<Type>(getTypeArg()->substituteImpl(astBuilder, subst, &diff));
+ if (!diff)
+ return this;
+
+ (*ioDiff)++;
+ return tryFold(astBuilder, getType(), newType);
+}
+
+Val* CountOfIntVal::_resolveImplOverride()
+{
+ auto resolvedTypeArg = getTypeArg()->resolve();
+ if (resolvedTypeArg == getTypeArg())
+ return this;
+ return tryFold(getCurrentASTBuilder(), getType(), as<Type>(resolvedTypeArg));
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! WitnessLookupIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void WitnessLookupIntVal::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 74e0f27c1..2599ce46a 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -248,6 +248,30 @@ class FuncCallIntVal : public IntVal
Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map);
};
+class CountOfIntVal : public IntVal
+{
+ SLANG_AST_CLASS(CountOfIntVal)
+
+ CountOfIntVal(Type* inType, Type* typeArg)
+ {
+ setOperands(inType, typeArg);
+ }
+
+ Val* getTypeArg() { return getOperand(1); }
+
+ void _toTextOverride(StringBuilder& out);
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
+ bool _isLinkTimeValOverride()
+ {
+ return false;
+ }
+
+ static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType);
+
+ static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType);
+};
+
class WitnessLookupIntVal : public IntVal
{
SLANG_AST_CLASS(WitnessLookupIntVal)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 965d25bd5..fe43a4f8f 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -16,6 +16,7 @@
#include "slang-lookup.h"
#include "slang-lookup-spirv.h"
#include "slang-ast-print.h"
+#include "core/slang-char-util.h"
namespace Slang
{
@@ -1866,6 +1867,13 @@ namespace Slang
}
}
+ if (auto countOfExpr = expr.as<CountOfExpr>())
+ {
+ auto type = as<Type>(countOfExpr.getExpr()->sizedType->substitute(m_astBuilder, expr.getSubsts()));
+ if (type)
+ return as<IntVal>(CountOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type));
+ }
+
// it is possible that we are referring to a generic value param
if (auto declRefExpr = expr.as<DeclRefExpr>())
{
@@ -3205,6 +3213,31 @@ namespace Slang
return false;
}
+ static bool _isCountOfType(Type* type)
+ {
+ if (!type)
+ {
+ return false;
+ }
+
+ if (isTypePack(type))
+ {
+ return true;
+ }
+
+ if (as<TupleType>(type))
+ {
+ return true;
+ }
+
+ if (as<ArrayExpressionType>(type))
+ {
+ return true;
+ }
+
+ return false;
+ }
+
Expr* SemanticsExprVisitor::visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr)
{
auto valueExpr = dispatch(sizeOfLikeExpr->value);
@@ -3229,12 +3262,25 @@ namespace Slang
type = properType.type;
}
- if (!_isSizeOfType(type))
+ if (as<CountOfExpr>(sizeOfLikeExpr))
{
- getSink()->diagnose(sizeOfLikeExpr, Diagnostics::sizeOfArgumentIsInvalid);
+ if (!_isCountOfType(type))
+ {
+ getSink()->diagnose(sizeOfLikeExpr, Diagnostics::countOfArgumentIsInvalid);
- sizeOfLikeExpr->type = m_astBuilder->getErrorType();
- return sizeOfLikeExpr;
+ sizeOfLikeExpr->type = m_astBuilder->getErrorType();
+ return sizeOfLikeExpr;
+ }
+ }
+ else
+ {
+ if (!_isSizeOfType(type))
+ {
+ getSink()->diagnose(sizeOfLikeExpr, Diagnostics::sizeOfArgumentIsInvalid);
+
+ sizeOfLikeExpr->type = m_astBuilder->getErrorType();
+ return sizeOfLikeExpr;
+ }
}
sizeOfLikeExpr->sizedType = type;
@@ -3815,6 +3861,108 @@ namespace Slang
return CreateErrorExpr(memberRefExpr);
}
+ Expr* SemanticsVisitor::checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* baseTupleType)
+ {
+ UInt tupleElementCount = (UInt)baseTupleType->getMemberCount();
+ if (tupleElementCount == 0)
+ return checkGeneralMemberLookupExpr(memberExpr, baseTupleType);
+
+ if (memberExpr->name == getSession()->getCompletionRequestTokenName())
+ {
+ auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions;
+ suggestions.clear();
+ suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle;
+ suggestions.swizzleBaseType =
+ memberExpr->baseExpression ? memberExpr->baseExpression->type : nullptr;
+ suggestions.elementCount[0] = (Index)tupleElementCount;
+ suggestions.elementCount[1] = 0;
+ return memberExpr;
+ }
+
+ String swizzleText = getText(memberExpr->name);
+ auto span = swizzleText.getUnownedSlice();
+ Index pos = 0;
+
+ ShortList<UInt> elementCoords;
+
+ bool anyDuplicates = false;
+
+ // The contents of the string are 0-terminated
+ // Every update to cursor corresponds to a check against 0-termination
+ while (pos < span.getLength())
+ {
+ UInt elementCoord;
+
+ // Check for the preceding underscore
+ if (span[pos] != '_')
+ {
+ return checkGeneralMemberLookupExpr(memberExpr, baseTupleType);
+ }
+ pos++;
+
+ // Parse index.
+ if (pos >= span.getLength())
+ {
+ // Unexpected end of swizzle string, fallback to
+ // member lookup.
+ return checkGeneralMemberLookupExpr(memberExpr, baseTupleType);
+ }
+
+ auto ch = span[pos];
+
+ if (!CharUtil::isDigit(ch))
+ {
+ // An invalid character in the swizzle is an error, fallback to
+ // member lookup.
+ return checkGeneralMemberLookupExpr(memberExpr, baseTupleType);
+ }
+ elementCoord = (UInt)StringUtil::parseIntAndAdvancePos(span, pos);
+
+ if (elementCoord >= tupleElementCount)
+ {
+ getSink()->diagnose(memberExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseTupleType);
+ return CreateErrorExpr(memberExpr);
+ }
+
+ // Check if we've seen this index before
+ for (int ee = 0; ee < elementCoords.getCount(); ee++)
+ {
+ if (elementCoords[ee] == elementCoord)
+ anyDuplicates = true;
+ }
+
+ // add to our list...
+ elementCoords.add(elementCoord);
+ }
+
+ SwizzleExpr* swizExpr = m_astBuilder->create<SwizzleExpr>();
+ swizExpr->loc = memberExpr->loc;
+ swizExpr->base = memberExpr->baseExpression;
+ swizExpr->elementIndices = _Move(elementCoords);
+ swizExpr->memberOpLoc = memberExpr->memberOperatorLoc;
+
+ if (swizExpr->elementIndices.getCount() == 1)
+ {
+ // single-component swizzle produces a scalar
+ //
+ swizExpr->type = QualType(baseTupleType->getMember(swizExpr->elementIndices[0]));
+ }
+ else
+ {
+ List<Type*> types;
+ for (auto index : swizExpr->elementIndices)
+ {
+ types.add(baseTupleType->getMember(index));
+ }
+ swizExpr->type = QualType(m_astBuilder->getTupleType(types));
+ }
+
+ // A swizzle can be used as an l-value as long as there
+ // were no duplicates in the list of components
+ swizExpr->type.isLeftValue = !anyDuplicates;
+ return swizExpr;
+ }
+
Expr* SemanticsVisitor::CheckSwizzleExpr(
MemberExpr* memberRefExpr,
Type* baseElementType,
@@ -3823,15 +3971,10 @@ namespace Slang
SwizzleExpr* swizExpr = m_astBuilder->create<SwizzleExpr>();
swizExpr->loc = memberRefExpr->loc;
swizExpr->base = memberRefExpr->baseExpression;
- swizExpr->elementIndices[0] = 0;
- swizExpr->elementIndices[1] = 0;
- swizExpr->elementIndices[2] = 0;
- swizExpr->elementIndices[3] = 0;
swizExpr->memberOpLoc = memberRefExpr->memberOperatorLoc;
IntegerLiteralValue limitElement = baseElementCount;
- int elementIndices[4];
- int elementCount = 0;
+ ShortList<UInt,4> elementIndices;
bool anyDuplicates = false;
bool anyError = false;
@@ -3875,35 +4018,31 @@ namespace Slang
// If elementCount is already at 4 stop trying to assign a swizzle element and send an error,
// we cannot have more valid swizzle elements than 4.
- if (elementCount >= 4)
+ if (elementIndices.getCount() >= 4)
{
anyError = true;
break;
}
// Check if we've seen this index before
- for (int ee = 0; ee < elementCount; ee++)
+ for (int ee = 0; ee < elementIndices.getCount(); ee++)
{
- if (elementIndices[ee] == elementIndex)
+ if (elementIndices[ee] == (UInt)elementIndex)
anyDuplicates = true;
}
// add to our list...
- elementIndices[elementCount++] = elementIndex;
+ elementIndices.add(elementIndex);
}
- for (int ee = 0; ee < elementCount; ++ee)
- {
- swizExpr->elementIndices[ee] = elementIndices[ee];
- }
- swizExpr->elementCount = elementCount;
+ swizExpr->elementIndices = _Move(elementIndices);
if (anyError)
{
getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString());
return CreateErrorExpr(memberRefExpr);
}
- else if (elementCount == 1)
+ else if (swizExpr->elementIndices.getCount() == 1)
{
// single-component swizzle produces a scalar
//
@@ -3918,7 +4057,7 @@ namespace Slang
// here if the input type had a sugared name...
swizExpr->type = QualType(createVectorType(
baseElementType,
- m_astBuilder->getIntVal(m_astBuilder->getIntType(), elementCount)));
+ m_astBuilder->getIntVal(m_astBuilder->getIntType(), swizExpr->elementIndices.getCount())));
}
// A swizzle can be used as an l-value as long as there
@@ -4255,6 +4394,32 @@ namespace Slang
return baseExpr;
}
+ Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)
+ {
+ LookupResult lookupResult = lookUpMember(
+ m_astBuilder,
+ this,
+ expr->name,
+ baseType,
+ m_outerScope);
+ bool diagnosed = false;
+ lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed);
+ if (!lookupResult.isValid())
+ {
+ return lookupMemberResultFailure(expr, baseType, diagnosed);
+ }
+ if (expr->name == getSession()->getCompletionRequestTokenName())
+ {
+ suggestCompletionItems(CompletionSuggestions::ScopeKind::Member, lookupResult);
+ }
+ return createLookupResultExpr(
+ expr->name,
+ lookupResult,
+ expr->baseExpression,
+ expr->loc,
+ expr);
+ }
+
Expr* SemanticsExprVisitor::visitMemberExpr(MemberExpr * expr)
{
bool needDeref = false;
@@ -4279,10 +4444,9 @@ namespace Slang
// because vectors are also declaration reference types...
//
// Also note: the way this is done right now means that the ability
- // to swizzle vectors interferes with any chance o<f looking up
+ // to swizzle vectors interferes with any chance of looking up
// members via extension, for vector or scalar types.
//
- // TODO: Matrix swizzles probably need to be handled at some point.
if (auto baseMatrixType = as<MatrixExpressionType>(baseType))
{
return CheckMatrixSwizzleExpr(
@@ -4322,34 +4486,17 @@ namespace Slang
{
return _lookupStaticMember(expr, expr->baseExpression);
}
+ else if (auto baseTupleType = as<TupleType>(baseType))
+ {
+ return checkTupleSwizzleExpr(expr, baseTupleType);
+ }
else if (as<ErrorType>(baseType))
{
return CreateErrorExpr(expr);
}
else
{
- LookupResult lookupResult = lookUpMember(
- m_astBuilder,
- this,
- expr->name,
- baseType.Ptr(),
- m_outerScope);
- bool diagnosed = false;
- lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed);
- if (!lookupResult.isValid())
- {
- return lookupMemberResultFailure(expr, baseType, diagnosed);
- }
- if (expr->name == getSession()->getCompletionRequestTokenName())
- {
- suggestCompletionItems(CompletionSuggestions::ScopeKind::Member, lookupResult);
- }
- return createLookupResultExpr(
- expr->name,
- lookupResult,
- expr->baseExpression,
- expr->loc,
- expr);
+ return checkGeneralMemberLookupExpr(expr, baseType);
}
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 7565b5472..fea81f68c 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2637,6 +2637,8 @@ namespace Slang
IntVal* baseElementRowCount,
IntVal* baseElementColCount);
+ Expr* checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* baseTupleType);
+
Expr* CheckSwizzleExpr(
MemberExpr* memberRefExpr,
Type* baseElementType,
@@ -2647,6 +2649,10 @@ namespace Slang
Type* baseElementType,
IntVal* baseElementCount);
+ // Check a member expr as a general member lookup.
+ // This is the default/fallback behavior if the base type isn't swizzlable.
+ Expr* checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType);
+
/// Perform semantic checking of an assignment where the operands have already been checked.
Expr* checkAssignWithCheckedOperands(AssignExpr* expr);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 16f6ed7da..29cb74c23 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -680,6 +680,10 @@ namespace Slang
packArg->type = paramType;
resultArgs.add(packArg);
}
+
+ // Always add a flat cost for using an argument pack,
+ // so that we prefer non-pack overloads when possible.
+ candidate.conversionCostSum += kConversionCost_ParameterPack;
}
else
{
@@ -1669,26 +1673,6 @@ namespace Slang
// Try to match the variadic part.
// Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack param.
auto astBuilder = semantics->getASTBuilder();
- while (remainingArgCount > 0)
- {
- auto argType = getArgType(fixedParamCount);
- if (auto typeType = as<TypeType>(argType))
- {
- argType = typeType->getType();
- }
- if (isAbstractTypePack(argType))
- {
- MatchedArg arg;
- arg.argExpr = getArg(fixedParamCount);
- arg.argType = getArgType(fixedParamCount);
- outMatchedArgs.add(arg);
- fixedParamCount++;
- remainingArgCount--;
- typePackCount--;
- continue;
- }
- break;
- }
if (remainingArgCount <= 0)
return true;
@@ -1704,6 +1688,24 @@ namespace Slang
Index typePackSize = remainingArgCount / typePackCount;
for (Index i = 0; i < typePackCount; ++i)
{
+ // If type pack size is 1, we may not need to wrap things in a PackExpr,
+ // if the argument is already a pack.
+ if (typePackSize == 1)
+ {
+ auto argType = getArgType(fixedParamCount + i);
+ if (auto typeType = as<TypeType>(argType))
+ {
+ argType = typeType->getType();
+ }
+ if (isTypePack(argType))
+ {
+ MatchedArg arg;
+ arg.argExpr = getArg(fixedParamCount + i);
+ arg.argType = getArgType(fixedParamCount + i);
+ outMatchedArgs.add(arg);
+ continue;
+ }
+ }
PackExpr* packExpr = nullptr;
if (mode == Mode::ForReal)
{
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index a58b59c0c..7288befe8 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -359,6 +359,8 @@ DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is no
DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-static function reference '$0' is not allowed here.")
DIAGNOSTIC(30099, Error, sizeOfArgumentIsInvalid, "argument to sizeof is invalid")
+DIAGNOSTIC(30083, Error, countOfArgumentIsInvalid, "argument to countof can only be a type pack or tuple")
+
DIAGNOSTIC(30101, Error, readingFromWriteOnly, "cannot read from writeonly, check modifiers.")
DIAGNOSTIC(30102, Error, differentiableMemberShouldHaveCorrespondingFieldInDiffType, "differentiable member '$0' should have a corresponding field in '$1'. Use [DerivativeMember($1.<field-name>)] or mark as no_diff")
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 0a9b2d691..9adbe42d5 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1656,6 +1656,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
{
disableIRValidationAtInsert();
auto simplifyOptions = IRSimplificationOptions::getDefault(nullptr);
+ simplifyOptions.removeRedundancy = true;
simplifyFunc(autoDiffSharedContext->targetProgram, func, simplifyOptions);
enableIRValidationAtInsert();
}
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 0d5cb1c70..9b2b59cd9 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -277,6 +277,17 @@ struct InliningPassBase
if(!isDefinition(calleeFunc))
return false;
+ // We cannot inline a call inside an `IRExpand`.
+ // Because this will make the cfg inside the `IRExpand` too complex,
+ // and our expand specialization logic isn't general enough to deal
+ // with that yet.
+ for (auto parent = call->getParent(); parent; parent = parent->getParent())
+ {
+ if (as<IRExpand>(parent))
+ return false;
+ if (as<IRGlobalValueWithCode>(parent))
+ break;
+ }
return true;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index a4225d041..80c810620 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -245,6 +245,7 @@ INST(RTTIType, rtti_type, 0, HOISTABLE)
INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE)
INST(TupleType, tuple_type, 0, HOISTABLE)
INST(TargetTupleType, TargetTuple, 0, HOISTABLE)
+INST(TypePack, TypePack, 0, HOISTABLE)
INST(ExpandTypeOrVal, ExpandTypeOrVal, 1, HOISTABLE)
// A type that identifies it's contained type as being emittable as `spirv_literal.
@@ -281,6 +282,8 @@ INST(StructKey, key, 0, GLOBAL)
INST(GlobalGenericParam, global_generic_param, 0, GLOBAL)
INST(WitnessTable, witness_table, 0, 0)
+INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE)
+
// A placeholder witness that ThisType implements the enclosing interface.
// Used only in interface definitions.
INST(ThisTypeWitness, thisTypeWitness, 1, 0)
@@ -343,6 +346,7 @@ INST(MakeArrayFromElement, makeArrayFromElement, 1, 0)
INST(MakeStruct, makeStruct, 0, 0)
INST(MakeTuple, makeTuple, 0, 0)
INST(MakeTargetTuple, makeTuple, 0, 0)
+INST(MakeValuePack, makeValuePack, 0, 0)
INST(GetTargetTupleElement, getTargetTupleElement, 0, 0)
INST(GetTupleElement, getTupleElement, 2, 0)
INST(MakeWitnessPack, MakeWitnessPack, 0, HOISTABLE)
@@ -1117,6 +1121,8 @@ INST(TreatAsDynamicUniform, TreatAsDynamicUniform, 1, 0)
INST(SizeOf, sizeOf, 1, 0)
INST(AlignOf, alignOf, 1, 0)
+INST(CountOf, countOf, 1, 0)
+
INST(GetArrayLength, GetArrayLength, 1, 0)
INST(IsType, IsType, 3, 0)
INST(TypeEquals, TypeEquals, 2, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index dc5fb2744..4111fb983 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2857,6 +2857,11 @@ struct IRMakeTuple : IRInst
IR_LEAF_ISA(MakeTuple)
};
+struct IRMakeValuePack : IRInst
+{
+ IR_LEAF_ISA(MakeValuePack)
+};
+
struct IRMakeWitnessPack : IRInst
{
IR_LEAF_ISA(MakeWitnessPack)
@@ -3509,6 +3514,8 @@ public:
IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2);
IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2, IRType* type3);
+ IRTypePack* getTypePack(UInt count, IRType* const* types);
+
IRExpandType* getExpandTypeOrVal(IRType* type, IRInst* pattern, ArrayView<IRInst*> capture);
IRResultType* getResultType(IRType* valueType, IRType* errorType);
@@ -3670,6 +3677,14 @@ public:
return getAttributedType(baseType, attributes.getCount(), attributes.getBuffer());
}
+ IRInst* getIndexedFieldKey(
+ IRInst* baseType,
+ UInt fieldIndex)
+ {
+ IRInst* args[] = { baseType, getIntValue(getIntType(), fieldIndex) };
+ return emitIntrinsicInst(getVoidType(), kIROp_IndexedFieldKey, 2, args);
+ }
+
IRMetalMeshGridPropertiesType* getMetalMeshGridPropertiesType()
{
return (IRMetalMeshGridPropertiesType*)getType(kIROp_MetalMeshGridPropertiesType);
@@ -3872,6 +3887,9 @@ public:
return emitMakeTuple(SLANG_COUNT_OF(args), args);
}
+ IRInst* emitMakeValuePack(IRType* type, UInt count, IRInst* const* args);
+ IRInst* emitMakeValuePack(UInt count, IRInst* const* args);
+
IRInst* emitMakeWitnessPack(IRType* type, ArrayView<IRInst*> args)
{
return emitIntrinsicInst(type, kIROp_MakeWitnessPack, (UInt)args.getCount(), args.getBuffer());
@@ -4364,6 +4382,8 @@ public:
IRInst* emitAlignOf(
IRInst* sizedType);
+ IRInst* emitCountOf(IRType* type, IRInst* sizedType);
+
IRInst* emitCastPtrToBool(IRInst* val);
IRInst* emitCastPtrToInt(IRInst* val);
IRInst* emitCastIntToPtr(IRType* ptrType, IRInst* val);
diff --git a/source/slang/slang-ir-lower-tuple-types.cpp b/source/slang/slang-ir-lower-tuple-types.cpp
index caa031d85..6177cfec2 100644
--- a/source/slang/slang-ir-lower-tuple-types.cpp
+++ b/source/slang/slang-ir-lower-tuple-types.cpp
@@ -85,7 +85,7 @@ namespace Slang
workListSet.add(inst);
}
- void processMakeTuple(IRMakeTuple* inst)
+ void processMakeTuple(IRInst* inst)
{
IRBuilder builderStorage(module);
auto builder = &builderStorage;
@@ -121,6 +121,124 @@ namespace Slang
inst->removeAndDeallocate();
}
+ void processGetElementPtr(IRGetElementPtr* inst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+
+ auto base = inst->getBase();
+ auto baseValueType = tryGetPointedToType(&builder, base->getDataType());
+ auto loweredTupleInfo = getLoweredTupleType(&builder, baseValueType);
+ if (!loweredTupleInfo)
+ return;
+
+ auto elementIndex = getIntVal(inst->getIndex());
+ SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount());
+
+ auto field = loweredTupleInfo->fields[(Index)elementIndex];
+ auto getElement = builder.emitFieldAddress(builder.getPtrType(field->getFieldType()), base, field->getKey());
+ inst->replaceUsesWith(getElement);
+ inst->removeAndDeallocate();
+ }
+
+ void processSwizzle(IRSwizzle* inst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+
+ auto base = inst->getBase();
+ auto loweredTupleInfo = getLoweredTupleType(&builder, base->getDataType());
+
+ if (!loweredTupleInfo)
+ return;
+
+ if (inst->getElementCount() == 1)
+ {
+ auto elementIndex = getIntVal(inst->getElementIndex(0));
+ SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount());
+
+ auto field = loweredTupleInfo->fields[(Index)elementIndex];
+ auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey());
+ inst->replaceUsesWith(getElement);
+ inst->removeAndDeallocate();
+ }
+ else
+ {
+ List<IRInst*> elements;
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto elementIndex = getIntVal(inst->getElementIndex(i));
+ SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount());
+
+ auto field = loweredTupleInfo->fields[(Index)elementIndex];
+ auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey());
+ elements.add(getElement);
+ }
+ auto resultTypeInfo = getLoweredTupleType(&builder, inst->getDataType());
+ auto makeStruct = builder.emitMakeStruct(resultTypeInfo->structType, elements);
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ }
+ }
+
+ void processSwizzleSet(IRSwizzleSet* inst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+
+ auto base = inst->getBase();
+ auto loweredTupleInfo = getLoweredTupleType(&builder, base->getDataType());
+ auto sourceTupleInfo = getLoweredTupleType(&builder, inst->getSource()->getDataType());
+ if (!loweredTupleInfo)
+ return;
+
+ List<IRInst*> elements;
+ for (Index i = 0; i < loweredTupleInfo->fields.getCount(); i++)
+ {
+ auto field = loweredTupleInfo->fields[i];
+ auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey());
+ elements.add(getElement);
+ }
+
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto baseIndex = getIntVal(inst->getElementIndex(i));
+ auto sourceElement = sourceTupleInfo
+ ? builder.emitFieldExtract(sourceTupleInfo->fields[i]->getFieldType(), inst->getSource(), sourceTupleInfo->fields[i]->getKey())
+ : inst->getSource();
+ elements[baseIndex] = sourceElement;
+ }
+ auto resultTypeInfo = getLoweredTupleType(&builder, inst->getDataType());
+ auto makeStruct = builder.emitMakeStruct(resultTypeInfo->structType, elements);
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ }
+
+ void processSwizzledStore(IRSwizzledStore* inst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+
+ auto dest = inst->getDest();
+ auto destValueType = tryGetPointedToType(&builder, dest->getDataType());
+ auto loweredTupleInfo = getLoweredTupleType(&builder, destValueType);
+ auto sourceTupleInfo = getLoweredTupleType(&builder, inst->getSource()->getDataType());
+ if (!loweredTupleInfo)
+ return;
+
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto baseIndex = getIntVal(inst->getElementIndex(i));
+ auto destField = loweredTupleInfo->fields[baseIndex];
+ auto destFieldPtr = builder.emitFieldAddress(builder.getPtrType(destField->getFieldType()), dest, destField->getKey());
+ auto sourceElement = sourceTupleInfo
+ ? builder.emitFieldExtract(sourceTupleInfo->fields[i]->getFieldType(), inst->getSource(), sourceTupleInfo->fields[i]->getKey())
+ : inst->getSource();
+ builder.emitStore(destFieldPtr, sourceElement);
+ }
+ inst->removeAndDeallocate();
+ }
+
void processTupleType(IRTupleType* inst)
{
IRBuilder builderStorage(module);
@@ -132,19 +250,47 @@ namespace Slang
SLANG_UNUSED(loweredTupleInfo);
}
+ void processIndexedFieldKey(IRIndexedFieldKey* inst)
+ {
+ IRBuilder builder(module);
+ auto loweredTupleInfo = getLoweredTupleType(&builder, inst->getBaseType());
+ if (!loweredTupleInfo)
+ return;
+ auto fieldIndex = getIntVal(inst->getIndex());
+ SLANG_ASSERT(fieldIndex >= 0 && (Index)fieldIndex < loweredTupleInfo->fields.getCount());
+ inst->replaceUsesWith(loweredTupleInfo->fields[fieldIndex]->getKey());
+ inst->removeAndDeallocate();
+ }
+
void processInst(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_MakeTuple:
+ case kIROp_MakeValuePack:
processMakeTuple((IRMakeTuple*)inst);
break;
case kIROp_GetTupleElement:
processGetTupleElement((IRGetTupleElement*)inst);
break;
+ case kIROp_GetElementPtr:
+ processGetElementPtr((IRGetElementPtr*)inst);
+ break;
+ case kIROp_swizzle:
+ processSwizzle((IRSwizzle*)inst);
+ break;
+ case kIROp_swizzleSet:
+ processSwizzleSet((IRSwizzleSet*)inst);
+ break;
+ case kIROp_SwizzledStore:
+ processSwizzledStore((IRSwizzledStore*)inst);
+ break;
case kIROp_TupleType:
processTupleType((IRTupleType*)inst);
break;
+ case kIROp_IndexedFieldKey:
+ processIndexedFieldKey((IRIndexedFieldKey*)inst);
+ break;
default:
break;
}
@@ -152,6 +298,32 @@ namespace Slang
void processModule()
{
+ // First, we want to replace all TypePack with TupleType.
+
+ List<IRInst*> typePacks;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (inst->getOp() == kIROp_TypePack)
+ {
+ typePacks.add(inst);
+ }
+ }
+ IRBuilder builder(module);
+ for (auto inst : typePacks)
+ {
+ builder.setInsertBefore(inst);
+ ShortList<IRType*> types;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ types.add((IRType*)inst->getOperand(i));
+ }
+ auto tupleType = builder.getTupleType((UInt)types.getCount(), types.getArrayView().getBuffer());
+ inst->replaceUsesWith(tupleType);
+ inst->removeAndDeallocate();
+ }
+
+ // Next, lower all tuples to structs.
+
addToWorkList(module->getModuleInst());
while (workList.getCount() != 0)
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index aa8dfddab..b5f5edb05 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -327,6 +327,7 @@ struct PeepholeContext : InstPassBase
switch (inst->getOperand(0)->getOp())
{
case kIROp_MakeTuple:
+ case kIROp_MakeValuePack:
case kIROp_MakeWitnessPack:
{
auto element = inst->getOperand(1);
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index b679d8438..fa3cd444a 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -8,6 +8,19 @@ namespace Slang
struct RedundancyRemovalContext
{
RefPtr<IRDominatorTree> dom;
+ bool isSingleIterationLoop(IRLoop* loop)
+ {
+ int useCount = 0;
+ for (auto use = loop->getBreakBlock()->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser() == loop)
+ continue;
+ useCount++;
+ if (useCount > 1)
+ return false;
+ }
+ return true;
+ }
bool tryHoistInstToOuterMostLoop(IRGlobalValueWithCode* func, IRInst* inst)
{
@@ -17,7 +30,8 @@ struct RedundancyRemovalContext
parentBlock = dom->getImmediateDominator(parentBlock))
{
auto terminatorInst = parentBlock->getTerminator();
- if (terminatorInst->getOp() == kIROp_loop)
+ if (terminatorInst->getOp() == kIROp_loop &&
+ !isSingleIterationLoop(as<IRLoop>(terminatorInst)))
{
// Consider hoisting the inst into this block.
// This is only possible if all operands of the inst are dominating `parentBlock`.
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 2eb16112f..c9e94352e 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -594,7 +594,165 @@ struct SpecializationContext
case kIROp_GetTupleElement:
return maybeSpecializeFoldableInst(inst);
+
+ case kIROp_TypePack:
+ case kIROp_TupleType:
+ return maybeSpecializeTypePackOrTupleType(inst);
+
+ case kIROp_MakeValuePack:
+ case kIROp_MakeTuple:
+ return maybeSpecializeMakeValuePackOrTuple(inst);
+
+ case kIROp_CountOf:
+ return maybeSpecializeCountOf(inst);
+ }
+ }
+
+
+ void flattenPackOperand(ShortList<IRInst*>& flattenedList, IRInst* inst)
+ {
+ if (auto makeValuePack = as<IRMakeValuePack>(inst))
+ {
+ for (UInt i = 0; i < makeValuePack->getOperandCount(); i++)
+ {
+ flattenPackOperand(flattenedList, makeValuePack->getOperand(i));
+ }
+ }
+ else if (auto typePack = as<IRTypePack>(inst))
+ {
+ for (UInt i = 0; i < typePack->getOperandCount(); i++)
+ {
+ flattenPackOperand(flattenedList, typePack->getOperand(i));
+ }
+ }
+ else
+ {
+ SLANG_ASSERT(inst);
+ flattenedList.add(inst);
+ }
+ }
+
+ bool maybeSpecializeTypePackOrTupleType(IRInst* inst)
+ {
+ // If any element of the type pack or tuple is a TypePack, we want to
+ // flatten that type pack into the current type pack or tuple.
+
+ bool needProcess = false;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (as<IRTypePack>(inst->getOperand(i)))
+ {
+ needProcess = true;
+ break;
+ }
+ }
+ // If none of the operands are MakeValuePack, there is no need to flatten anything.
+ if (!needProcess)
+ return false;
+
+ // We will recursively flatten all MakeValuePack operands.
+ ShortList<IRInst*> flattendOperands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ flattenPackOperand(flattendOperands, operand);
+ }
+
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ IRInst* newInst;
+ if (inst->getOp() == kIROp_TypePack)
+ newInst = builder.getTypePack(flattendOperands.getCount(), (IRType* const*)flattendOperands.getArrayView().getBuffer());
+ else
+ newInst = builder.getTupleType(flattendOperands.getCount(), (IRType* const*)flattendOperands.getArrayView().getBuffer());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newInst);
+ return true;
+ }
+
+ bool maybeSpecializeMakeValuePackOrTuple(IRInst* inst)
+ {
+ // If any element of the value pack or tuple is a ValuePack, we want to
+ // flatten that value pack into the current value pack or tuple.
+
+ bool needProcess = false;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (as<IRMakeValuePack>(inst->getOperand(i)))
+ {
+ needProcess = true;
+ break;
+ }
}
+ // If none of the operands are MakeValuePack, there is no need to flatten anything.
+ if (!needProcess)
+ return false;
+
+ // We will recursively flatten all MakeValuePack operands.
+ ShortList<IRInst*> flattendOperands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ flattenPackOperand(flattendOperands, operand);
+ }
+
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ IRInst* newInst = nullptr;
+ if (inst->getOp() == kIROp_MakeValuePack)
+ newInst = builder.emitMakeValuePack(inst->getFullType(), flattendOperands.getCount(), flattendOperands.getArrayView().getBuffer());
+ else
+ newInst = builder.emitMakeTuple(inst->getFullType(), flattendOperands.getCount(), flattendOperands.getArrayView().getBuffer());
+
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newInst);
+ return true;
+ }
+
+ bool maybeSpecializeCountOf(IRInst* inst)
+ {
+ auto operand = inst->getOperand(0);
+
+ // If operand is a value, make sure we are working on its type.
+
+ switch (operand->getOp())
+ {
+ case kIROp_MakeValuePack:
+ case kIROp_MakeTuple:
+ operand = operand->getDataType();
+ break;
+ }
+
+ // We can only figure out the count of a type pack or tuple type.
+ switch (operand->getOp())
+ {
+ case kIROp_TypePack:
+ case kIROp_TupleType:
+ break;
+ default:
+ return false;
+ }
+
+ // If none of the element type is a TypePack, we can just return the count.
+ for (UInt i = 0; i < operand->getOperandCount(); i++)
+ {
+ switch (operand->getOperand(i)->getOp())
+ {
+ case kIROp_Param:
+ case kIROp_TypePack:
+ case kIROp_ExpandTypeOrVal:
+ return false;
+ }
+ }
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ auto newInst = builder.getIntValue(inst->getDataType(), operand->getOperandCount());
+ addUsersToWorkList(inst);
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ return true;
}
// Specializing lookup on witness tables is a general
@@ -606,7 +764,7 @@ struct SpecializationContext
{
// Note: While we currently have named the instruction
// `lookup_witness_method`, the `method` part is a misnomer
- // and the same instruction can look up *any* interfacemay
+ // and the same instruction can look up *any* interface
// requirement based on the witness table that provides
// a conformance, and the "key" that indicates the interface
// requirement.
@@ -2268,7 +2426,7 @@ struct SpecializationContext
for (UInt i = 0; i < expandInst->getCaptureCount(); i++)
{
- if (!as<IRTupleType>(expandInst->getCapture(i)))
+ if (!as<IRTypePack>(expandInst->getCapture(i)))
return false;
}
@@ -2276,16 +2434,16 @@ struct SpecializationContext
builder.setInsertBefore(expandInst);
List<IRInst*> elements;
UInt elementCount = 0;
- if (auto firstTupleType = as<IRTupleType>(expandInst->getCapture(0)))
+ if (auto firstTypePack = as<IRTypePack>(expandInst->getCapture(0)))
{
- elementCount = firstTupleType->getOperandCount();
+ elementCount = firstTypePack->getOperandCount();
}
if (elementCount == 0)
{
- auto resultTuple = builder.emitMakeTuple(0, (IRInst*const*)nullptr);
- expandInst->replaceUsesWith(resultTuple);
+ auto resultValuePack = builder.emitMakeValuePack(0, (IRInst*const*)nullptr);
+ expandInst->replaceUsesWith(resultValuePack);
expandInst->removeAndDeallocate();
- addUsersToWorkList(resultTuple);
+ addUsersToWorkList(resultValuePack);
return true;
}
@@ -2328,7 +2486,7 @@ struct SpecializationContext
}
}
- auto resultTuple = builder.emitMakeTuple(elements);
+ auto resultValuePack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer());
auto currentBlock = builder.getBlock();
for (auto nextInst = expandInst->next; nextInst;)
{
@@ -2337,7 +2495,7 @@ struct SpecializationContext
nextInst = next;
}
addUsersToWorkList(expandInst);
- expandInst->replaceUsesWith(resultTuple);
+ expandInst->replaceUsesWith(resultValuePack);
expandInst->removeAndDeallocate();
return true;
}
@@ -2355,15 +2513,15 @@ struct SpecializationContext
{
auto eachInst = as<IREach>(val);
auto packInst = eachInst->getElement();
- if (auto tuple = as<IRTupleType>(packInst))
+ if (auto typePack = as<IRTypePack>(packInst))
{
- SLANG_RELEASE_ASSERT(indexInPack < tuple->getOperandCount());
- return tuple->getOperand(indexInPack);
+ SLANG_RELEASE_ASSERT(indexInPack < typePack->getOperandCount());
+ return typePack->getOperand(indexInPack);
}
- else if (auto makeTuple = as<IRMakeTuple>(packInst))
+ else if (auto makeValuePack = as<IRMakeValuePack>(packInst))
{
- SLANG_RELEASE_ASSERT(indexInPack < makeTuple->getOperandCount());
- return makeTuple->getOperand(indexInPack);
+ SLANG_RELEASE_ASSERT(indexInPack < makeValuePack->getOperandCount());
+ return makeValuePack->getOperand(indexInPack);
}
else if (!as<IRTypeKind>(packInst->getDataType()))
{
@@ -2413,24 +2571,18 @@ struct SpecializationContext
if (expandInst->getCaptureCount() == 0)
return false;
- bool anyAbstractPack = false;
for (UInt i = 0; i < expandInst->getCaptureCount(); i++)
{
- if (!as<IRTupleType>(expandInst->getCaptureType(i)))
- {
- anyAbstractPack = true;
- break;
- }
+ if (!as<IRTypePack>(expandInst->getCaptureType(i)))
+ return false;
}
- if (anyAbstractPack)
- return false;
IRBuilder builder(expandInst);
builder.setInsertBefore(expandInst);
List<IRInst*> elements;
UInt elementCount = 0;
- if (auto firstTupleType = as<IRTupleType>(expandInst->getCaptureType(0)))
+ if (auto firstTypePack = as<IRTypePack>(expandInst->getCaptureType(0)))
{
- elementCount = firstTupleType->getOperandCount();
+ elementCount = firstTypePack->getOperandCount();
}
for (UInt i = 0; i < elementCount; i++)
{
@@ -2444,16 +2596,16 @@ struct SpecializationContext
List<IRType*> types;
for (auto element : elements)
types.add(element->getDataType());
- auto newTupleType = builder.getTupleType(types);
- auto result = builder.emitMakeWitnessPack(newTupleType, elements.getArrayView());
+ auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer());
+ auto result = builder.emitMakeWitnessPack(newTypePack, elements.getArrayView());
expandInst->replaceUsesWith(result);
expandInst->removeAndDeallocate();
return true;
}
else
{
- auto newTupleType = builder.getTupleType(elements.getCount(), (IRType*const*)elements.getBuffer());
- expandInst->replaceUsesWith(newTupleType);
+ auto newTypePack = builder.getTypePack(elements.getCount(), (IRType*const*)elements.getBuffer());
+ expandInst->replaceUsesWith(newTypePack);
expandInst->removeAndDeallocate();
return true;
}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index b8d6360ad..cd0f67186 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -77,7 +77,7 @@ namespace Slang
funcChanged = false;
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
funcChanged |= peepholeOptimize(target, func);
- if (!options.minimalOptimization)
+ if (options.removeRedundancy)
funcChanged |= removeRedundancyInFunc(func);
funcChanged |= simplifyCFG(func, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h
index d524241ae..446b2d5b2 100644
--- a/source/slang/slang-ir-ssa-simplification.h
+++ b/source/slang/slang-ir-ssa-simplification.h
@@ -19,6 +19,7 @@ namespace Slang
IRDeadCodeEliminationOptions deadCodeElimOptions;
bool minimalOptimization = false;
+ bool removeRedundancy = false;
static IRSimplificationOptions getDefault(TargetProgram* targetProgram);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index c97c04f88..1441b0567 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2802,6 +2802,11 @@ namespace Slang
return getTupleType(SLANG_COUNT_OF(operands), operands);
}
+ IRTypePack* IRBuilder::getTypePack(UInt count, IRType* const* types)
+ {
+ return (IRTypePack*)getType(kIROp_TypePack, count, (IRInst* const*)types);
+ }
+
IRExpandType* IRBuilder::getExpandTypeOrVal(IRType* type, IRInst* pattern, ArrayView<IRInst*> capture)
{
ShortList<IRInst*> args;
@@ -4046,6 +4051,21 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitMakeValuePack(IRType* type, UInt count, IRInst* const* args)
+ {
+ return emitIntrinsicInst(type, kIROp_MakeValuePack, count, args);
+ }
+
+ IRInst* IRBuilder::emitMakeValuePack(UInt count, IRInst* const* args)
+ {
+ ShortList<IRType*> types;
+ for (UInt i = 0; i < count; ++i)
+ types.add(args[i]->getFullType());
+
+ auto type = getTypePack((UInt)types.getCount(), types.getArrayView().getBuffer());
+ return emitIntrinsicInst(type, kIROp_MakeValuePack, count, args);
+ }
+
IRInst* IRBuilder::emitMakeTuple(IRType* type, UInt count, IRInst* const* args)
{
return emitIntrinsicInst(type, kIROp_MakeTuple, count, args);
@@ -5778,6 +5798,19 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitCountOf(
+ IRType* type,
+ IRInst* sizedType)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_CountOf,
+ type,
+ sizedType);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitBitCast(
IRType* type,
IRInst* val)
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index ececdad43..b1c2b001e 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1940,6 +1940,24 @@ struct IRTupleType : IRType
IR_LEAF_ISA(TupleType)
};
+
+/// Represents a type pack. Type packs behave like tuples, but they have a
+/// "flattening" semantics, so that MakeTypePack(MakeTypePack(T1,T2), T3) is
+/// MakeTypePack(T1,T2,T3).
+struct IRTypePack : IRType
+{
+ IR_LEAF_ISA(TypePack)
+};
+
+// A placeholder struct key for tuple type layouts that will be replaced with
+// the actual struct key when the tuple type is materialized into a struct type.
+struct IRIndexedFieldKey : IRInst
+{
+ IR_LEAF_ISA(IndexedFieldKey)
+ IRInst* getBaseType() { return getOperand(0); }
+ IRInst* getIndex() { return getOperand(1); }
+};
+
/// Represents a tuple in target language. TargetTupleType will not be lowered to structs.
struct IRTargetTupleType : IRType
{
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index 5c73d2ab9..2d6ed2568 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -118,6 +118,27 @@ public:
return dispatchIfNotNull(subscriptExpr->baseExpression);
}
+ bool visitSizeOfLikeExpr(SizeOfLikeExpr* expr)
+ {
+ int tokenLength = 0;
+ if (as<CountOfExpr>(expr))
+ tokenLength = 7; // strlen("countof");
+ else if (as<SizeOfExpr>(expr))
+ tokenLength = 6; // strlen("sizeof");
+ else if (as<AlignOfExpr>(expr))
+ tokenLength = 7; // strlen("alignof");
+
+ if (_isLocInRange(context, expr->loc, tokenLength))
+ {
+ ASTLookupResult result;
+ result.path = context->nodePath;
+ result.path.add(expr);
+ context->results.add(result);
+ return true;
+ }
+ return dispatchIfNotNull(expr->value);
+ }
+
bool visitParenExpr(ParenExpr* expr)
{
return dispatchIfNotNull(expr->base);
@@ -225,7 +246,10 @@ public:
}
bool visitSwizzleExpr(SwizzleExpr* expr)
{
- if (_isLocInRange(context, expr->memberOpLoc, 0))
+ Index tokenLength = expr->elementIndices.getCount();
+ if (expr->base && as<TupleType>(expr->base->type))
+ tokenLength *= 2;
+ if (_isLocInRange(context, expr->loc, tokenLength))
{
ASTLookupResult result;
result.path = context->nodePath;
diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp
index b723e14b8..bee8f088a 100644
--- a/source/slang/slang-language-server-completion.cpp
+++ b/source/slang/slang-language-server-completion.cpp
@@ -674,6 +674,20 @@ List<LanguageServerProtocol::CompletionItem> CompletionContext::createSwizzleCan
}
}
}
+ else if (auto tupleType = as<TupleType>(type))
+ {
+ auto count = Math::Min((int)elementCount[0], 4);
+ for (int i = 0; i < count; i++)
+ {
+ LanguageServerProtocol::CompletionItem item;
+ item.data = 0;
+ if (tupleType->getMember(i))
+ item.detail = tupleType->getMember(i)->toString();
+ item.kind = LanguageServerProtocol::kCompletionItemKindVariable;
+ item.label = String("_") + String(i);
+ result.add(item);
+ }
+ }
for (auto& item : result)
{
for (auto ch : getCommitChars())
diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp
index 81c7d24cf..9588e7284 100644
--- a/source/slang/slang-language-server.cpp
+++ b/source/slang/slang-language-server.cpp
@@ -701,6 +701,40 @@ SlangResult LanguageServer::hover(
}
fillLoc(expr->loc);
}
+ else if (auto swizzleExpr = as<SwizzleExpr>(expr))
+ {
+ if (expr->type && swizzleExpr->base && swizzleExpr->base->type)
+ {
+ bool isTupleType = as<TupleType>(swizzleExpr->base->type) != nullptr;
+ sb << "```\n";
+ swizzleExpr->type->toText(sb);
+ sb << " ";
+ swizzleExpr->base->type->toText(sb);
+ sb << ".";
+ for (auto index : swizzleExpr->elementIndices)
+ {
+ if (isTupleType || index > 4)
+ sb << "_" << index;
+ else
+ sb << "xyzw"[index];
+ }
+ sb << "\n```\n";
+ fillLoc(expr->loc);
+ }
+ }
+ else if (auto countOfExpr = as<CountOfExpr>(expr))
+ {
+ if (countOfExpr->sizedType)
+ {
+ if (auto foldedVal = as<ConstantIntVal>(CountOfIntVal::tryFoldOrNull(version->linkage->getASTBuilder(), expr->type.type, countOfExpr->sizedType)))
+ {
+ sb << "```\n" << "countof(";
+ countOfExpr->sizedType->toText(sb);
+ sb << ") = " << foldedVal->getValue() << "\n```\n";
+ fillLoc(expr->loc);
+ }
+ }
+ }
if (const auto higherOrderExpr = as<HigherOrderInvokeExpr>(expr))
{
String documentation;
@@ -740,6 +774,14 @@ SlangResult LanguageServer::hover(
{
fillExprHoverInfo(thisExprExpr);
}
+ else if (auto countOfExpr = as<CountOfExpr>(leafNode))
+ {
+ fillExprHoverInfo(countOfExpr);
+ }
+ else if (auto swizzleExpr = as<SwizzleExpr>(leafNode))
+ {
+ fillExprHoverInfo(swizzleExpr);
+ }
else if (auto importDecl = as<ImportDecl>(leafNode))
{
auto moduleLoc = getModuleLoc(version->linkage->getSourceManager(), importDecl->importedModuleDecl);
diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp
index c1895d754..7acaa030b 100644
--- a/source/slang/slang-lookup.cpp
+++ b/source/slang/slang-lookup.cpp
@@ -387,53 +387,16 @@ static void _lookUpMembersInSuperType(
_lookUpMembersInSuperTypeImpl(astBuilder, name, leafType, superType, leafIsSuperWitness, request, ioResult, &breadcrumb);
}
-static void _lookUpMembersInSuperTypeDeclImpl(
- ASTBuilder* astBuilder,
+static void _lookupMembersInSuperTypeFacets(ASTBuilder* astBuilder,
Name* name,
- DeclRef<Decl> declRef,
+ Type* selfType,
+ InheritanceInfo const& inheritanceInfo,
LookupRequest const& request,
LookupResult& ioResult,
BreadcrumbInfo* inBreadcrumbs)
{
- auto semantics = request.semantics;
- if (!as<InterfaceDecl>(declRef.getDecl()) && getText(name) == "This")
- {
- // If we are looking for `This` in anything other than an InterfaceDecl,
- // we just need to return the declRef itself.
- AddToLookupResult(ioResult, CreateLookupResultItem(declRef, inBreadcrumbs));
- return;
- }
-
- // If the semantics context hasn't been established yet (e.g. when looking up during parsing),
- // we simply do a direct lookup without considering subtypes or extensions.
- //
- if (!semantics)
- {
- // In this case we can only lookup in an aggregate type.
- if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>())
- {
- _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef.getDecl(), aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs);
- }
- return;
- }
-
- ensureDecl(semantics, declRef.getDecl(), DeclCheckState::ReadyForLookup);
-
- // With semantics context, we can do a comprehensive lookup by scanning through
- // the linearized inheritance list.
+
- auto selfType = DeclRefType::create(astBuilder, declRef);
- InheritanceInfo inheritanceInfo;
- if (auto extDeclRef = declRef.as<ExtensionDecl>())
- {
- inheritanceInfo = semantics->getShared()->getInheritanceInfo(extDeclRef);
- }
- else
- {
- selfType = selfType->getCanonicalType();
- inheritanceInfo = semantics->getShared()->getInheritanceInfo(selfType);
- }
-
for (auto facet : inheritanceInfo.facets)
{
auto containerDeclRef = facet->getDeclRef().as<ContainerDecl>();
@@ -457,12 +420,12 @@ static void _lookUpMembersInSuperTypeDeclImpl(
continue;
}
// If we are looking up only immediate members, ignore non "Self" facets or extension to "Self"
- else if (int(request.options) & int(LookupOptions::IgnoreInheritance)
+ else if (int(request.options) & int(LookupOptions::IgnoreInheritance)
&& (facet.getImpl()->directness != Facet::Directness::Self
&& (!extensionFacet || !extensionFacet->targetType.type->equals(selfType))
))
{
- continue;
+ continue;
}
// Some things that are syntactically `InheritanceDecl`s don't actually
@@ -527,6 +490,56 @@ static void _lookUpMembersInSuperTypeDeclImpl(
}
}
+static void _lookUpMembersInSuperTypeDeclImpl(
+ ASTBuilder* astBuilder,
+ Name* name,
+ DeclRef<Decl> declRef,
+ LookupRequest const& request,
+ LookupResult& ioResult,
+ BreadcrumbInfo* inBreadcrumbs)
+{
+ auto semantics = request.semantics;
+ if (!as<InterfaceDecl>(declRef.getDecl()) && getText(name) == "This")
+ {
+ // If we are looking for `This` in anything other than an InterfaceDecl,
+ // we just need to return the declRef itself.
+ AddToLookupResult(ioResult, CreateLookupResultItem(declRef, inBreadcrumbs));
+ return;
+ }
+
+ // If the semantics context hasn't been established yet (e.g. when looking up during parsing),
+ // we simply do a direct lookup without considering subtypes or extensions.
+ //
+ if (!semantics)
+ {
+ // In this case we can only lookup in an aggregate type.
+ if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>())
+ {
+ _lookUpDirectAndTransparentMembers(astBuilder, name, aggTypeDeclBaseRef.getDecl(), aggTypeDeclBaseRef, request, ioResult, inBreadcrumbs);
+ }
+ return;
+ }
+
+ ensureDecl(semantics, declRef.getDecl(), DeclCheckState::ReadyForLookup);
+
+ // With semantics context, we can do a comprehensive lookup by scanning through
+ // the linearized inheritance list.
+
+ auto selfType = DeclRefType::create(astBuilder, declRef);
+ InheritanceInfo inheritanceInfo;
+ if (auto extDeclRef = declRef.as<ExtensionDecl>())
+ {
+ inheritanceInfo = semantics->getShared()->getInheritanceInfo(extDeclRef);
+ }
+ else
+ {
+ selfType = selfType->getCanonicalType();
+ inheritanceInfo = semantics->getShared()->getInheritanceInfo(selfType);
+ }
+
+ _lookupMembersInSuperTypeFacets(astBuilder, name, selfType, inheritanceInfo, request, ioResult, inBreadcrumbs);
+}
+
static void _lookUpMembersInSuperTypeImpl(
ASTBuilder* astBuilder,
Name* name,
@@ -565,6 +578,12 @@ static void _lookUpMembersInSuperTypeImpl(
_lookUpMembersInSuperTypeDeclImpl(astBuilder, name, declRef, request, ioResult, inBreadcrumbs);
}
+ else if (auto eachType = as<EachType>(superType))
+ {
+ auto canEachType = eachType->getCanonicalType();
+ InheritanceInfo inheritanceInfo = request.semantics->getShared()->getInheritanceInfo(canEachType);
+ _lookupMembersInSuperTypeFacets(astBuilder, name, canEachType, inheritanceInfo, request, ioResult, inBreadcrumbs);
+ }
else if (auto extractExistentialType = as<ExtractExistentialType>(superType))
{
// We want lookup to be performed on the underlying interface type of the existential,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 9ceb3074a..44c62e858 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -336,11 +336,8 @@ struct SwizzledLValueInfo : ExtendedValueInfo
// The base expression (this should be an l-value)
LoweredValInfo base;
- // The number of elements in the swizzle
- UInt elementCount;
-
// THe indices for the elements being swizzled
- UInt elementIndices[4];
+ ShortList<UInt, 4> elementIndices;
};
// Represents the result of a matrix swizzle operation in an l-value context.
@@ -1183,8 +1180,8 @@ top:
return LoweredValInfo::simple(builder->emitSwizzle(
swizzleInfo->type,
getSimpleVal(context, swizzleInfo->base),
- swizzleInfo->elementCount,
- swizzleInfo->elementIndices));
+ swizzleInfo->elementIndices.getCount(),
+ swizzleInfo->elementIndices.getArrayView().getBuffer()));
}
case LoweredValInfo::Flavor::SwizzledMatrixLValue:
@@ -1656,6 +1653,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(resultVal);
}
+ LoweredValInfo visitCountOfIntVal(CountOfIntVal* val)
+ {
+ auto irBuilder = getBuilder();
+ auto type = lowerType(context, val->getType());
+ auto typeArg = lowerType(context, as<Type>(val->getTypeArg()));
+ auto count = irBuilder->emitCountOf(type, typeArg);
+ return LoweredValInfo::simple(count);
+ }
+
LoweredValInfo visitConcreteTypePack(ConcreteTypePack* typePack)
{
ShortList<IRType*> types;
@@ -1665,7 +1671,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
types.add(loweredType);
}
auto irBuilder = getBuilder();
- IRType* irTypePack = irBuilder->getTupleType((UInt)types.getCount(), types.getArrayView().getBuffer());
+ IRType* irTypePack = irBuilder->getTypePack((UInt)types.getCount(), types.getArrayView().getBuffer());
return LoweredValInfo::simple(irTypePack);
}
@@ -4038,6 +4044,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
const auto size = naturalLayoutContext.calcSize(sizeOfLikeExpr->sizedType);
auto builder = getBuilder();
+ auto resultType = lowerType(context, sizeOfLikeExpr->type);
if (!size)
{
@@ -4051,10 +4058,15 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
{
inst = builder->emitAlignOf(sizedType);
}
- else
+ else if (as<SizeOfExpr>(sizeOfLikeExpr))
{
inst = builder->emitSizeOf(sizedType);
}
+ else
+ {
+
+ inst = builder->emitCountOf(resultType, sizedType);
+ }
return LoweredValInfo::simple(inst);
}
@@ -4064,7 +4076,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
size.size :
size.alignment;
- return LoweredValInfo::simple(getBuilder()->getIntValue(builder->getUIntType(), value));
+ return LoweredValInfo::simple(getBuilder()->getIntValue(resultType, value));
}
LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/)
@@ -4422,7 +4434,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
{
irArgs.add(getSimpleVal(context, lowerSubExpr(arg)));
}
- auto irMakeTuple = getBuilder()->emitMakeTuple(irArgs);
+ auto irMakeTuple = getBuilder()->emitMakeValuePack((UInt)irArgs.getCount(), irArgs.getBuffer());
return LoweredValInfo::simple(irMakeTuple);
}
@@ -5367,10 +5379,10 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
{
auto irType = lowerType(context, expr->type);
auto loweredBase = lowerLValueExpr(context, expr->base);
- UInt elementCount = (UInt)expr->elementCount;
+ UInt elementCount = (UInt)expr->elementIndices.getCount();
// Assign to 'bs' the elements from 'as' according to the first 'n' indices in 'is'
- auto backpermute = [](UInt n, const auto* as, const int* is, auto* bs)
+ auto backpermute = [](UInt n, const auto as, const auto is, auto bs)
{
for(UInt i = 0; i < n; ++i)
{
@@ -5397,7 +5409,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
RefPtr<SwizzledLValueInfo> swizzledLValue = new SwizzledLValueInfo;
swizzledLValue->type = irType;
swizzledLValue->base = baseSwizzleInfo->base;
- swizzledLValue->elementCount = elementCount;
+ swizzledLValue->elementIndices = elementCount;
// Take the swizzle element of the "outer" swizzle, as it was
// written by the user. In our running example of `foo[i].zw.y`
@@ -5406,7 +5418,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
// Use that original element index to figure out which of the
// elements of the original swizzle this should map to.
backpermute(
- swizzledLValue->elementCount,
+ swizzledLValue->elementIndices.getCount(),
baseSwizzleInfo->elementIndices,
expr->elementIndices,
swizzledLValue->elementIndices);
@@ -5439,17 +5451,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
RefPtr<SwizzledLValueInfo> swizzledLValue = new SwizzledLValueInfo;
swizzledLValue->type = irType;
swizzledLValue->base = loweredBase;
- swizzledLValue->elementCount = elementCount;
-
- // In the default case, we can just copy the indices being
- // used for the swizzle over directly from the expression,
- // and use the base as-is.
- //
- for (UInt ii = 0; ii < elementCount; ++ii)
- {
- swizzledLValue->elementIndices[ii] = (UInt) expr->elementIndices[ii];
- }
-
+ swizzledLValue->elementIndices = expr->elementIndices;
context->shared->extValues.add(swizzledLValue);
return LoweredValInfo::swizzledLValue(swizzledLValue);
}
@@ -5460,8 +5462,6 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBase<RValueExprLowe
{
static bool _isLValueContext() { return false; }
- // A matrix swizzle in an r-value context can save time by just
- // emitting the matrix swizzle instructions directly.
LoweredValInfo visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr)
{
auto resultType = lowerType(context, expr->type);
@@ -5517,9 +5517,9 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBase<RValueExprLowe
auto irIntType = getIntType(context);
- UInt elementCount = (UInt)expr->elementCount;
- IRInst* irElementIndices[4];
- for (UInt ii = 0; ii < elementCount; ++ii)
+ ShortList<IRInst*, 4> irElementIndices;
+ irElementIndices.setCount(expr->elementIndices.getCount());
+ for (UInt ii = 0; ii < (UInt)expr->elementIndices.getCount(); ++ii)
{
irElementIndices[ii] = builder->getIntValue(
irIntType,
@@ -5529,7 +5529,7 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBase<RValueExprLowe
auto irSwizzle = builder->emitSwizzle(
irType,
irBase,
- elementCount,
+ (UInt)irElementIndices.getCount(),
&irElementIndices[0]);
return LoweredValInfo::simple(irSwizzle);
@@ -6926,7 +6926,7 @@ LoweredValInfo tryGetAddress(
auto originalSwizzleInfo = val.getSwizzledLValueInfo();
auto originalBase = originalSwizzleInfo->base;
- UInt elementCount = originalSwizzleInfo->elementCount;
+ UInt elementCount = (UInt)originalSwizzleInfo->elementIndices.getCount();
auto newBase = tryGetAddress(context, originalBase, TryGetAddressMode::Aggressive);
if (newBase.flavor == LoweredValInfo::Flavor::Ptr && elementCount == 1)
@@ -6942,7 +6942,7 @@ LoweredValInfo tryGetAddress(
newSwizzleInfo->base = newBase;
newSwizzleInfo->type = originalSwizzleInfo->type;
- newSwizzleInfo->elementCount = elementCount;
+ newSwizzleInfo->elementIndices.setCount(elementCount);
for(UInt ee = 0; ee < elementCount; ++ee)
newSwizzleInfo->elementIndices[ee] = originalSwizzleInfo->elementIndices[ee];
@@ -7124,8 +7124,8 @@ top:
irLeftVal->getDataType(),
irLeftVal,
irRightVal,
- swizzleInfo->elementCount,
- swizzleInfo->elementIndices);
+ (UInt)swizzleInfo->elementIndices.getCount(),
+ swizzleInfo->elementIndices.getArrayView().getBuffer());
// And finally, store the value back where we got it.
//
@@ -7158,8 +7158,8 @@ top:
swizzledStore(
loweredBase.val,
irRightVal,
- swizzleInfo->elementCount,
- swizzleInfo->elementIndices);
+ (UInt)swizzleInfo->elementIndices.getCount(),
+ swizzleInfo->elementIndices.getArrayView().getBuffer());
}
break;
}
@@ -11525,7 +11525,7 @@ IRTypeLayout* lowerTypeLayout(
else if( auto structTypeLayout = as<StructTypeLayout>(typeLayout) )
{
IRStructTypeLayout::Builder builder(context->irBuilder);
-
+ int fieldIndex = 0;
for( auto fieldLayout : structTypeLayout->fields )
{
auto fieldDecl = fieldLayout->varDecl;
@@ -11573,11 +11573,24 @@ IRTypeLayout* lowerTypeLayout(
context->mapEntryPointParamToKey.add(paramDecl.getDecl(), irFieldKey);
}
}
- else
+ else if (fieldDecl.getDecl())
{
irFieldKey = getSimpleVal(context,
ensureDecl(context, fieldDecl.getDecl()));
}
+ else
+ {
+ // If we don't have a concrete field decl for the field in the layout,
+ // it could be that the field in the layout is for a member of a tuple
+ // type that hasn't been materialized into a struct decl yet.
+ // We will use a `IndexFieldKey(type, memberIndex)` inst as a placeholder
+ // for the field key.
+ // This placeholder can be replaced with the actual field key when the
+ // tuple type is materialized into a struct type.
+ auto irType = lowerType(context, typeLayout->getType());
+ irFieldKey = context->irBuilder->getIndexedFieldKey(irType, fieldIndex);
+ }
+ fieldIndex++;
SLANG_ASSERT(irFieldKey);
auto irFieldLayout = lowerVarLayout(context, fieldLayout);
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index c1f0f91c2..5881e5796 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -6281,6 +6281,23 @@ namespace Slang
return alignOfExpr;
}
+ static NodeBase* parseCountOfExpr(Parser* parser, void* /*userData*/)
+ {
+ // We could have a type or a variable or an expression
+ CountOfExpr* countOfExpr = parser->astBuilder->create<CountOfExpr>();
+
+ parser->ReadMatchingToken(TokenType::LParent);
+
+ // The return type is always an Int
+ countOfExpr->type = parser->astBuilder->getIntType();
+
+ countOfExpr->value = parser->ParseExpression();
+
+ parser->ReadMatchingToken(TokenType::RParent);
+
+ return countOfExpr;
+ }
+
static NodeBase* parseTryExpr(Parser* parser, void* /*userData*/)
{
auto tryExpr = parser->astBuilder->create<TryExpr>();
@@ -7549,7 +7566,7 @@ namespace Slang
{
ExpandExpr* expandExpr = parser->astBuilder->create<ExpandExpr>();
expandExpr->loc = loc;
- expandExpr->baseExpr = parser->ParseExpression();
+ expandExpr->baseExpr = parser->ParseArgExpr();
return expandExpr;
}
@@ -7557,7 +7574,7 @@ namespace Slang
{
EachExpr* eachExpr = parser->astBuilder->create<EachExpr>();
eachExpr->loc = loc;
- eachExpr->baseExpr = parser->ParseExpression();
+ eachExpr->baseExpr = parser->ParseLeafExpression();
return eachExpr;
}
@@ -8515,6 +8532,7 @@ namespace Slang
_makeParseExpr("__dispatch_kernel", parseDispatchKernel),
_makeParseExpr("sizeof", parseSizeOfExpr),
_makeParseExpr("alignof", parseAlignOfExpr),
+ _makeParseExpr("countof", parseCountOfExpr),
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()
diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h
index 40129b083..6d06d2400 100644
--- a/source/slang/slang-serialize-type-info.h
+++ b/source/slang/slang-serialize-type-info.h
@@ -242,6 +242,31 @@ struct SerialTypeInfo<List<T, ALLOCATOR>>
}
};
+// ShortList
+template <typename T, int n, typename ALLOCATOR>
+struct SerialTypeInfo<ShortList<T, n, ALLOCATOR>>
+{
+ typedef ShortList<T, n, ALLOCATOR> NativeType;
+ typedef SerialIndex SerialType;
+
+ enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) };
+
+ static void toSerial(SerialWriter* writer, const void* native, void* serial)
+ {
+ auto& src = *(const NativeType*)native;
+ auto& dst = *(SerialType*)serial;
+
+ dst = writer->addArray(src.getArrayView().getBuffer(), src.getCount());
+ }
+ static void toNative(SerialReader* reader, const void* serial, void* native)
+ {
+ auto& dst = *(NativeType*)native;
+ auto& src = *(const SerialType*)serial;
+
+ reader->getArray(src, dst);
+ }
+};
+
// String
template <>
struct SerialTypeInfo<String>
diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h
index a91ff21e9..e8786d561 100644
--- a/source/slang/slang-serialize.h
+++ b/source/slang/slang-serialize.h
@@ -241,6 +241,9 @@ public:
template <typename T>
void getArray(SerialIndex index, List<T>& out);
+ template <typename T, int n>
+ void getArray(SerialIndex index, ShortList<T, n>& out);
+
const void* getArray(SerialIndex index, Index& outCount);
SerialPointer getPointer(SerialIndex index);
@@ -334,6 +337,38 @@ void SerialReader::getArray(SerialIndex index, List<T>& out)
}
}
+template <typename T, int n>
+void SerialReader::getArray(SerialIndex index, ShortList<T, n>& out)
+{
+ typedef SerialTypeInfo<T> ElementTypeInfo;
+ typedef typename ElementTypeInfo::SerialType ElementSerialType;
+
+ Index count;
+ auto serialElements = (const ElementSerialType*)getArray(index, count);
+
+ if (count == 0)
+ {
+ out.clear();
+ return;
+ }
+
+ if (std::is_same<T, ElementSerialType>::value)
+ {
+ // If they are the same we can just write out
+ out.clear();
+ out.addRange((const T*)serialElements, count);
+ }
+ else
+ {
+ // Else we need to convert
+ out.setCount(count);
+ for (Index i = 0; i < count; ++i)
+ {
+ ElementTypeInfo::toNative(this, (const void*)&serialElements[i], (void*)&out[i]);
+ }
+ }
+}
+
/* This is a class used tby toSerial implementations to turn native type into the serial type */
class SerialWriter : public RefObject
{
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index f85fb2c5f..f5fbfafdf 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -4396,6 +4396,77 @@ static TypeLayoutResult _createTypeLayout(
type,
rules);
}
+ else if (auto tupleType = as<TupleType>(type))
+ {
+ // A `Tuple` type is laid out exactly the same way as a `struct` type,
+ // except that we want have a declref to the field.
+
+ StructTypeLayoutBuilder typeLayoutBuilder;
+ StructTypeLayoutBuilder pendingDataTypeLayoutBuilder;
+
+ typeLayoutBuilder.beginLayout(type, rules);
+ auto typeLayout = typeLayoutBuilder.getTypeLayout();
+
+ _addLayout(context, type, typeLayout);
+ for (Index i = 0; i < tupleType->getMemberCount(); i++)
+ {
+ // The members of a `Tuple` type may include existential (interface)
+ // types (including as nested sub-fields), and any types present
+ // in those fields will need to be specialized based on the
+ // input arguments being passed to `_createTypeLayout`.
+ //
+ // We won't know how many type slots each field consumes until
+ // we process it, but we can figure out the starting index for
+ // the slots its will consume by looking at the layout we've
+ // computed so far.
+ //
+ Int baseExistentialSlotIndex = 0;
+ if (auto resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::ExistentialTypeParam))
+ baseExistentialSlotIndex = Int(resInfo->count.getFiniteValue());
+ //
+ // When computing the layout for the field, we will give it access
+ // to all the incoming specialized type slots that haven't already
+ // been consumed/claimed by preceding fields.
+ //
+ auto fieldLayoutContext = context.withSpecializationArgsOffsetBy(baseExistentialSlotIndex);
+
+ auto elementType = tupleType->getMember(i);
+ TypeLayoutResult fieldResult = _createTypeLayout(
+ fieldLayoutContext,
+ elementType,
+ nullptr);
+ auto fieldTypeLayout = fieldResult.layout;
+
+ auto fieldVarLayout = typeLayoutBuilder.addField(DeclRef<VarDeclBase>(), fieldResult);
+
+ // If any of the members of the `Tuple` type had existential/interface
+ // type, then we need to compute a second `StructTypeLayout` that
+ // represents the layout and resource using for the "pending data"
+ // that this type needs to have stored somewhere, but which can't
+ // be laid out in the layout of the type itself.
+ //
+ if (auto fieldPendingDataTypeLayout = fieldTypeLayout->pendingDataTypeLayout)
+ {
+ // We only create this secondary layout on-demand, so that
+ // we don't end up with a bunch of empty structure type layouts
+ // created for no reason.
+ //
+ pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(type, rules);
+ auto fieldPendingVarLayout = pendingDataTypeLayoutBuilder.addField(DeclRef<VarDeclBase>(), fieldPendingDataTypeLayout);
+ fieldVarLayout->pendingVarLayout = fieldPendingVarLayout;
+ }
+ }
+
+ typeLayoutBuilder.endLayout();
+ pendingDataTypeLayoutBuilder.endLayout();
+
+ if (auto pendingDataTypeLayout = pendingDataTypeLayoutBuilder.getTypeLayout())
+ {
+ typeLayout->pendingDataTypeLayout = pendingDataTypeLayout;
+ }
+
+ return _updateLayout(context, type, typeLayoutBuilder.getTypeLayoutResult());
+ }
else if (auto declRefType = as<DeclRefType>(type))
{
auto declRef = declRefType->getDeclRef();