summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-04-13 23:49:00 +0800
committerGitHub <noreply@github.com>2023-04-13 08:49:00 -0700
commitc7e5601bb67d2a5ebadb7f84c6968b5912e7566d (patch)
tree311903d485d42617288198ea45f46c50e150f082 /source
parent6fbd892a0e015fd07d1c41983676713aa6f09333 (diff)
Matrix swizzle writes (#2713)
* Add a bunch of builder emit wrappers for constant indices To avoid cluttering any calling code with int instruction construction * Matrix swizzle stores Closes https://github.com/shader-slang/slang/issues/2512 * Matrix swizzle store tests * Squash vs warnings * Select scalar for singular swizzles * Test singular swizzle materialization * Use IRIntegerValue over UInt for IR wrappers * Correct size of swizzle vector type * Remove variable shadowing
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-insts.h13
-rw-r--r--source/slang/slang-ir.cpp26
-rw-r--r--source/slang/slang-lower-to-ir.cpp337
3 files changed, 334 insertions, 42 deletions
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index bf0a5d4cd..fe658566c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2810,6 +2810,10 @@ public:
IRType* elementType,
IRInst* elementCount);
+ IRVectorType* getVectorType(
+ IRType* elementType,
+ IRIntegerValue elementCount);
+
IRMatrixType* getMatrixType(
IRType* elementType,
IRInst* rowCount,
@@ -3324,6 +3328,10 @@ public:
IRInst* emitElementExtract(
IRInst* base,
+ IRIntegerValue index);
+
+ IRInst* emitElementExtract(
+ IRInst* base,
const ArrayView<IRInst*>& accessChain);
IRInst* emitElementAddress(
@@ -3337,9 +3345,14 @@ public:
IRInst* emitElementAddress(
IRInst* basePtr,
+ IRIntegerValue index);
+
+ IRInst* emitElementAddress(
+ IRInst* basePtr,
const ArrayView<IRInst*>& accessChain);
IRInst* emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement);
+ IRInst* emitUpdateElement(IRInst* base, IRIntegerValue index, IRInst* newElement);
IRInst* emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement);
IRInst* emitGetAddress(
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index bd2271953..e624ef1fd 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2809,6 +2809,13 @@ namespace Slang
operands);
}
+ IRVectorType* IRBuilder::getVectorType(
+ IRType* elementType,
+ IRIntegerValue elementCount)
+ {
+ return getVectorType(elementType, getIntValue(getIntType(), elementCount));
+ }
+
IRMatrixType* IRBuilder::getMatrixType(
IRType* elementType,
IRInst* rowCount,
@@ -4607,6 +4614,13 @@ namespace Slang
IRInst* IRBuilder::emitElementExtract(
IRInst* base,
+ IRIntegerValue index)
+ {
+ return emitElementExtract(base, getIntValue(getIntType(), index));
+ }
+
+ IRInst* IRBuilder::emitElementExtract(
+ IRInst* base,
const ArrayView<IRInst*>& accessChain)
{
for (auto access : accessChain)
@@ -4653,6 +4667,13 @@ namespace Slang
IRInst* IRBuilder::emitElementAddress(
IRInst* basePtr,
+ IRIntegerValue index)
+ {
+ return emitElementAddress(basePtr, getIntValue(getIntType(), index));
+ }
+
+ IRInst* IRBuilder::emitElementAddress(
+ IRInst* basePtr,
IRInst* index)
{
IRType* type = nullptr;
@@ -4726,6 +4747,11 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitUpdateElement(IRInst* base, IRIntegerValue index, IRInst* newElement)
+ {
+ return emitUpdateElement(base, getIntValue(getIntType(), index), newElement);
+ }
+
IRInst* IRBuilder::emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement)
{
List<IRInst*> args;
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ada2e043e..1fba0e2f8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -90,6 +90,7 @@ struct SubscriptInfo : ExtendedValueInfo
struct BoundStorageInfo;
struct BoundMemberInfo;
struct SwizzledLValueInfo;
+struct SwizzledMatrixLValueInfo;
struct CopiedValInfo;
struct ExtractedExistentialValInfo;
@@ -125,6 +126,9 @@ struct LoweredValInfo
// The result of applying swizzling to an l-value
SwizzledLValue,
+ // The result of applying swizzling to an l-value matrix
+ SwizzledMatrixLValue,
+
// The value extracted from an opened existential
ExtractedExistential,
};
@@ -194,12 +198,21 @@ struct LoweredValInfo
static LoweredValInfo swizzledLValue(
SwizzledLValueInfo* extInfo);
+ static LoweredValInfo swizzledMatrixLValue(
+ SwizzledMatrixLValueInfo* extInfo);
+
SwizzledLValueInfo* getSwizzledLValueInfo()
{
SLANG_ASSERT(flavor == Flavor::SwizzledLValue);
return (SwizzledLValueInfo*)ext;
}
+ SwizzledMatrixLValueInfo* getSwizzledMatrixLValueInfo()
+ {
+ SLANG_ASSERT(flavor == Flavor::SwizzledMatrixLValue);
+ return (SwizzledMatrixLValueInfo*)ext;
+ }
+
static LoweredValInfo extractedExistential(
ExtractedExistentialValInfo* extInfo);
@@ -295,6 +308,23 @@ struct SwizzledLValueInfo : ExtendedValueInfo
UInt elementIndices[4];
};
+// Represents the result of a matrix swizzle operation in an l-value context.
+// The same non-contiguous and no-duplicate rules as above apply.
+struct SwizzledMatrixLValueInfo : ExtendedValueInfo
+{
+ // The type of the expression.
+ IRType* type;
+
+ // The base expression (this should be an l-value)
+ LoweredValInfo base;
+
+ // The number of elements in the swizzle
+ UInt elementCount;
+
+ // The coords for the elements being swizzled, zero indexed
+ MatrixCoord elementCoords[4];
+};
+
// Represents the results of extractng a value of
// some (statically unknown) concrete type from
// an existential, in an l-value context.
@@ -352,6 +382,15 @@ LoweredValInfo LoweredValInfo::swizzledLValue(
return info;
}
+LoweredValInfo LoweredValInfo::swizzledMatrixLValue(
+ SwizzledMatrixLValueInfo* extInfo)
+{
+ LoweredValInfo info;
+ info.flavor = Flavor::SwizzledMatrixLValue;
+ info.ext = extInfo;
+ return info;
+}
+
LoweredValInfo LoweredValInfo::extractedExistential(
ExtractedExistentialValInfo* extInfo)
{
@@ -1036,6 +1075,32 @@ top:
swizzleInfo->elementIndices));
}
+ case LoweredValInfo::Flavor::SwizzledMatrixLValue:
+ {
+ auto swizzleInfo = lowered.getSwizzledMatrixLValueInfo();
+ auto base = getSimpleVal(context, swizzleInfo->base);
+ if(const auto type = as<IRMatrixType>(base->getDataType()))
+ {
+ IRInst* components[4];
+ for(UInt i = 0; i < swizzleInfo->elementCount; ++i)
+ {
+ components[i] = builder->emitElementExtract(
+ builder->emitElementExtract(base,swizzleInfo->elementCoords[i].row),
+ swizzleInfo->elementCoords[i].col);
+ }
+ return swizzleInfo->elementCount == 1
+ ? LoweredValInfo::simple(components[0])
+ : LoweredValInfo::simple(builder->emitMakeVector(
+ builder->getVectorType(type->getElementType(), swizzleInfo->elementCount),
+ swizzleInfo->elementCount,
+ components));
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Expected a matrix type in matrix swizzle");
+ }
+ }
+
case LoweredValInfo::Flavor::ExtractedExistential:
{
auto info = lowered.getExtractedExistentialValInfo();
@@ -2468,6 +2533,7 @@ void addInArg(
case LoweredValInfo::Flavor::Simple:
case LoweredValInfo::Flavor::Ptr:
case LoweredValInfo::Flavor::SwizzledLValue:
+ case LoweredValInfo::Flavor::SwizzledMatrixLValue:
case LoweredValInfo::Flavor::BoundStorage:
case LoweredValInfo::Flavor::BoundMember:
case LoweredValInfo::Flavor::ExtractedExistential:
@@ -4571,31 +4637,55 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
// When visiting a swizzle expression in an l-value context,
// we need to construct a "swizzled l-value."
- LoweredValInfo visitMatrixSwizzleExpr(MatrixSwizzleExpr*)
- {
- SLANG_UNIMPLEMENTED_X("matrix swizzle lvalue case");
- }
-
- // When visiting a swizzle expression in an l-value context,
- // we need to construct a "sizzled l-value."
- LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr)
+ LoweredValInfo visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr)
{
auto irType = lowerType(context, expr->type);
auto loweredBase = lowerRValueExpr(context, expr->base);
- RefPtr<SwizzledLValueInfo> swizzledLValue = new SwizzledLValueInfo();
+ RefPtr<SwizzledMatrixLValueInfo> swizzledLValue = new SwizzledMatrixLValueInfo();
swizzledLValue->type = irType;
UInt elementCount = (UInt)expr->elementCount;
swizzledLValue->elementCount = elementCount;
- // As a small optimization, we will detect if the base expression
- // has also lowered into a swizzle and only return a single
- // swizzle instead of nested swizzles.
+ // In the default case, we can just copy the indices being
+ // used for the swizzle over directly from the expression,
+ // and use the base as-is.
+ //
+ swizzledLValue->base = loweredBase;
+ for (UInt ii = 0; ii < elementCount; ++ii)
+ {
+ swizzledLValue->elementCoords[ii] = expr->elementCoords[ii];
+ }
+
+ context->shared->extValues.add(swizzledLValue);
+ return LoweredValInfo::swizzledMatrixLValue(swizzledLValue);
+ }
+
+ // When visiting a swizzle expression in an l-value context,
+ // we need to construct a "swizzled l-value."
+ LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr)
+ {
+ auto irType = lowerType(context, expr->type);
+ auto loweredBase = lowerLValueExpr(context, expr->base);
+ UInt elementCount = (UInt)expr->elementCount;
+
+ // Assign to 'bs' the elements from 'as' according to the first 'n' indices in 'is'
+ auto backpermute = [](UInt n, const auto* as, const int* is, auto* bs)
+ {
+ for(UInt i = 0; i < n; ++i)
+ {
+ bs[i] = as[is[i]];
+ }
+ };
+
+ // As required by the implementation of 'assign' and as a small
+ // optimization, we will detect if the base expression has also lowered
+ // into a swizzle and only return a single swizzle instead of nested
+ // swizzles.
//
// E.g., if we have input like `foo[i].zw.y` we should optimize it
// down to just `foo[i].w`.
- //
if(loweredBase.flavor == LoweredValInfo::Flavor::SwizzledLValue)
{
auto baseSwizzleInfo = loweredBase.getSwizzledLValueInfo();
@@ -4604,43 +4694,66 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
// `foo[i]` in our example above), but will need to remap
// the swizzle indices it uses.
//
+
+ RefPtr<SwizzledLValueInfo> swizzledLValue = new SwizzledLValueInfo;
+ swizzledLValue->type = irType;
swizzledLValue->base = baseSwizzleInfo->base;
- for (UInt ii = 0; ii < elementCount; ++ii)
- {
- // First we get the swizzle element of the "outer" swizzle,
- // as it was written by the user. In our running example of
- // `foo[i].zw.y` this is the `y` element reference.
- //
- UInt originalElementIndex = UInt(expr->elementIndices[ii]);
+ swizzledLValue->elementCount = elementCount;
- // Next we will use that original element index to figure
- // out which of the elements of the original swizzle this
- // should map to.
- //
- // In our example, `y` means index 1, and so we fetch
- // element 1 from the inner swizzle sequence `zw`, to get `w`.
- //
- SLANG_ASSERT(originalElementIndex < baseSwizzleInfo->elementCount);
- UInt remappedElementIndex = baseSwizzleInfo->elementIndices[originalElementIndex];
+ // Take the swizzle element of the "outer" swizzle, as it was
+ // written by the user. In our running example of `foo[i].zw.y`
+ // this is the `y` element reference.
+ //
+ // Use that original element index to figure out which of the
+ // elements of the original swizzle this should map to.
+ backpermute(
+ swizzledLValue->elementCount,
+ baseSwizzleInfo->elementIndices,
+ expr->elementIndices,
+ swizzledLValue->elementIndices);
- swizzledLValue->elementIndices[ii] = remappedElementIndex;
- }
+ context->shared->extValues.add(swizzledLValue);
+ return LoweredValInfo::swizzledLValue(swizzledLValue);
+ }
+ else if(loweredBase.flavor == LoweredValInfo::Flavor::SwizzledMatrixLValue)
+ {
+ auto baseSwizzleInfo = loweredBase.getSwizzledMatrixLValueInfo();
+
+ RefPtr<SwizzledMatrixLValueInfo> swizzledLValue = new SwizzledMatrixLValueInfo();
+ swizzledLValue->type = irType;
+ swizzledLValue->base = baseSwizzleInfo->base;
+ swizzledLValue->elementCount = elementCount;
+
+ // Use the index of our swizzle to permute the index of the base
+ // swizzle as above
+ backpermute(
+ swizzledLValue->elementCount,
+ baseSwizzleInfo->elementCoords,
+ expr->elementIndices,
+ swizzledLValue->elementCoords);
+
+ context->shared->extValues.add(swizzledLValue);
+ return LoweredValInfo::swizzledMatrixLValue(swizzledLValue);
}
else
{
+ RefPtr<SwizzledLValueInfo> swizzledLValue = new SwizzledLValueInfo;
+ swizzledLValue->type = irType;
+ swizzledLValue->base = loweredBase;
+ swizzledLValue->elementCount = elementCount;
+
// In the default case, we can just copy the indices being
// used for the swizzle over directly from the expression,
// and use the base as-is.
//
- swizzledLValue->base = loweredBase;
for (UInt ii = 0; ii < elementCount; ++ii)
{
swizzledLValue->elementIndices[ii] = (UInt) expr->elementIndices[ii];
}
- }
- context->shared->extValues.add(swizzledLValue);
- return LoweredValInfo::swizzledLValue(swizzledLValue);
+ context->shared->extValues.add(swizzledLValue);
+ return LoweredValInfo::swizzledLValue(swizzledLValue);
+ }
}
};
@@ -5845,6 +5958,30 @@ LoweredValInfo tryGetAddress(
}
break;
+ // TODO(Ellie): There's an uncomfortable level of duplication here...
+ case LoweredValInfo::Flavor::SwizzledMatrixLValue:
+ {
+ auto originalSwizzleInfo = val.getSwizzledMatrixLValueInfo();
+ auto originalBase = originalSwizzleInfo->base;
+
+ UInt elementCount = originalSwizzleInfo->elementCount;
+
+ auto newBase = tryGetAddress(context, originalBase, TryGetAddressMode::Aggressive);
+ RefPtr<SwizzledMatrixLValueInfo> newSwizzleInfo = new SwizzledMatrixLValueInfo();
+ context->shared->extValues.add(newSwizzleInfo);
+
+ newSwizzleInfo->base = newBase;
+ newSwizzleInfo->type = originalSwizzleInfo->type;
+ newSwizzleInfo->elementCount = elementCount;
+ for(UInt ee = 0; ee < elementCount; ++ee)
+ {
+ newSwizzleInfo->elementCoords[ee] = originalSwizzleInfo->elementCoords[ee];
+ }
+
+ return LoweredValInfo::swizzledMatrixLValue(newSwizzleInfo);
+ }
+ break;
+
// TODO: are there other cases we need to handled here?
default:
@@ -5889,6 +6026,26 @@ void assign(
auto builder = context->irBuilder;
+ // If there's a single element, just emit a regular store, otherwise
+ // proceed with a swizzle store
+ auto swizzledStore = [builder](
+ IRInst* dest,
+ IRInst* source,
+ UInt elementCount,
+ UInt const* elementIndices){
+ if(elementCount == 1)
+ {
+ return builder->emitStore(
+ builder->emitElementAddress(dest, elementIndices[0]),
+ source);
+ }
+ return builder->emitSwizzledStore(
+ dest,
+ source,
+ elementCount,
+ elementIndices);
+ };
+
top:
switch (left.flavor)
{
@@ -5967,14 +6124,8 @@ top:
// it around, in comparison to a simpler model where
// we simply form a pointer to each of the vector
// elements and write to them individually.
- //
- // TODO: we might also consider just special-casing
- // single-element swizzles so that the common case
- // can turn into a simple `store` instead of a
- // `swizzledStore`.
- //
IRInst* irRightVal = getSimpleVal(context, right);
- builder->emitSwizzledStore(
+ swizzledStore(
loweredBase.val,
irRightVal,
swizzleInfo->elementCount,
@@ -5985,6 +6136,108 @@ top:
}
break;
+ case LoweredValInfo::Flavor::SwizzledMatrixLValue:
+ {
+ // The `left` value is of the form `<base>.<swizzleElements>`.
+ // How we will handle this depends on what `base` looks like:
+ auto swizzleInfo = left.getSwizzledMatrixLValueInfo();
+ auto loweredBase = swizzleInfo->base;
+
+ IRInst* irRightVal = getSimpleVal(context, right);
+
+ const UInt maxRowIndex = 4;
+ const UInt maxCols = 4; // swizzleInfo->elementCount;
+
+ // Sort the swizzle elements according to the row to which they
+ // write.
+ // Using row-major terminology
+
+ // The number of element writes in each row
+ UInt rowSizes[maxRowIndex] = {};
+ // The columns being written to in each row
+ UInt rowWrites[maxRowIndex][maxCols];
+ // The RHS element indices being written in each row
+ UInt rowIndices[maxRowIndex][maxCols];
+ for(UInt i = 0; i < swizzleInfo->elementCount; ++i)
+ {
+ const auto& c = swizzleInfo->elementCoords[i];
+ auto& rowSize = rowSizes[c.row];
+ rowWrites[c.row][rowSize] = c.col;
+ rowIndices[c.row][rowSize] = i;
+ ++rowSize;
+ }
+
+ const auto rElemType =
+ composeGetters<IRType>(irRightVal, &IRInst::getDataType, &IRVectorType::getElementType);
+
+ switch( loweredBase.flavor )
+ {
+ case LoweredValInfo::Flavor::Ptr:
+ {
+ // Matrix swizzle writes are implemented as several vector swizzle writes
+ for(UInt r = 0; r < maxRowIndex; ++r)
+ {
+ // Skip if we have nothing in this row
+ if(rowSizes[r] == 0)
+ {
+ continue;
+ }
+ const auto rowAddr = builder->emitElementAddress(loweredBase.val, r);
+ // Only select the RHS elements if it's a vector
+ const auto rSwizzled = rElemType
+ ? builder->emitSwizzle(
+ builder->getVectorType(rElemType, rowSizes[r]),
+ irRightVal,
+ rowSizes[r],
+ rowIndices[r])
+ : irRightVal;
+ swizzledStore(
+ rowAddr,
+ rSwizzled,
+ rowSizes[r],
+ rowWrites[r]);
+ }
+ }
+ break;
+ default:
+ {
+ // As above, our fallback position is to lower via a
+ // temporary, e.g.:
+ //
+ // float4x3 tmp = <base>;
+ // tmp[0].xzy = float3(...);
+ // tmp[1].yxz = float3(...);
+ // tmp[4].yzx = float3(...);
+ // <base> = tmp;
+ //
+ // Create a variable, and use the ptr writing matrix
+ // swizzle assignment above to fill it, then write that back
+ // to the l value. This approach generates the neatest IR
+ const auto beforeLValue = getSimpleVal(context, loweredBase);
+ const auto type = beforeLValue->getDataType();
+
+ // Store our initial lvalue in tmp
+ const auto tmpVar = builder->emitVar(type);
+ builder->emitStore(tmpVar, beforeLValue);
+
+ // Make a new swizzle write to write into this pointer
+ auto nextSwizzleInfo = left.getSwizzledMatrixLValueInfo();
+ SwizzledMatrixLValueInfo nextInfo = *nextSwizzleInfo;
+ nextInfo.base = LoweredValInfo::ptr(tmpVar);
+
+ // Perform that swizzling assignment
+ assign(context, LoweredValInfo::swizzledMatrixLValue(&nextInfo), right);
+
+ // Write (non-swizzled) into the l value
+ left = loweredBase;
+ right = LoweredValInfo::ptr(tmpVar);
+ goto top;
+ }
+ break;
+ }
+ }
+ break;
+
case LoweredValInfo::Flavor::BoundStorage:
{
// The `left` value refers to a subscript operation on