summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDietrich Geisler <dag368@cornell.edu>2020-06-02 12:12:35 -0400
committerGitHub <noreply@github.com>2020-06-02 09:12:35 -0700
commit926a0c51071f6cf5718c77958cc801030ce9d404 (patch)
treec02e84cd402afc6383db2e169c08d05c2a12fbc6
parent8acb704ecabc10c31e664de3814c544572e3945f (diff)
Working matrix swizzle (#1354)
* Working matrix swizzle. Supports one and zero indexing and multiple elements. Performs semantic checking of the swizzle. Matrix swizzles are transformed into a vector of indexing operations during lowering to the IR. This change does not handle matrix swizzle as lvalues. * Renaming * Added missing semicolon * Initialize variable for gcc * Added the expect file for diagnostics * Matrix swizzle updated per PR feedback * Stylistic fix * Formatting fixes * Fix compiling with AST change. Change indentation. Co-authored-by: jsmall-nvidia <jsmall@nvidia.com>
-rw-r--r--source/slang/slang-ast-dump.cpp9
-rw-r--r--source/slang/slang-ast-expr.h17
-rw-r--r--source/slang/slang-check-expr.cpp166
-rw-r--r--source/slang/slang-check-impl.h13
-rw-r--r--source/slang/slang-lower-to-ir.cpp53
-rw-r--r--tests/diagnostics/matrix-swizzle.slang24
-rw-r--r--tests/diagnostics/matrix-swizzle.slang.expected18
-rw-r--r--tests/language-feature/swizzles/matrix-swizzles.slang34
-rw-r--r--tests/language-feature/swizzles/matrix-swizzles.slang.expected.txt4
9 files changed, 338 insertions, 0 deletions
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index 2157b7c19..d0b2497bb 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -450,6 +450,15 @@ struct Context
m_writer->emit("}");
}
+ void dump(const MatrixCoord& coord)
+ {
+ m_writer->emit("(");
+ m_writer->emit(coord.row);
+ m_writer->emit(", ");
+ m_writer->emit(coord.col);
+ m_writer->emit(")\n");
+ }
+
void dump(const LookupResult& result)
{
auto& nonConstResult = const_cast<LookupResult&>(result);
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 13c3527de..e7111e631 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -176,6 +176,23 @@ class StaticMemberExpr: public DeclRefExpr
RefPtr<Expr> baseExpression;
};
+struct MatrixCoord
+{
+ bool operator==(const MatrixCoord& rhs) const { return row == rhs.row && col == rhs.col; };
+ bool operator!=(const MatrixCoord& rhs) const { return !(*this == rhs); };
+ // Rows and columns are zero indexed
+ int row;
+ int col;
+};
+
+class MatrixSwizzleExpr : public Expr
+{
+ SLANG_CLASS(MatrixSwizzleExpr)
+ RefPtr<Expr> base;
+ int elementCount;
+ MatrixCoord elementCoords[4];
+};
+
class SwizzleExpr: public Expr
{
SLANG_CLASS(SwizzleExpr)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 2be64777b..d8eef571e 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1356,6 +1356,164 @@ namespace Slang
}
}
+ RefPtr<Expr> SemanticsVisitor::CheckMatrixSwizzleExpr(
+ MemberExpr* memberRefExpr,
+ RefPtr<Type> baseElementType,
+ IntegerLiteralValue baseElementRowCount,
+ IntegerLiteralValue baseElementColCount)
+ {
+ RefPtr<MatrixSwizzleExpr> swizExpr = m_astBuilder->create<MatrixSwizzleExpr>();
+ swizExpr->loc = memberRefExpr->loc;
+ swizExpr->base = memberRefExpr->baseExpression;
+
+ // We can have up to 4 swizzles of two elements each
+ MatrixCoord elementCoords[4];
+ int elementCount = 0;
+
+ bool anyDuplicates = false;
+ int zeroIndexOffset = -1;
+
+ String swizzleText = getText(memberRefExpr->name);
+ auto cursor = swizzleText.begin();
+
+ // The contents of the string are 0-terminated
+ // Every update to cursor corresponds to a check against 0-termination
+ while (*cursor)
+ {
+ // Throw out swizzling with more than 4 output elements
+ if (elementCount >= 4)
+ {
+ getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString());
+ return CreateErrorExpr(memberRefExpr);
+ }
+ MatrixCoord elementCoord = { 0, 0 };
+
+ // Check for the preceding underscore
+ if (*cursor++ != '_')
+ {
+ getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString());
+ return CreateErrorExpr(memberRefExpr);
+ }
+
+ // Check for one or zero indexing
+ if (*cursor == 'm')
+ {
+ // Can't mix one and zero indexing
+ if (zeroIndexOffset == 1)
+ {
+ getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString());
+ return CreateErrorExpr(memberRefExpr);
+ }
+ zeroIndexOffset = 0;
+ // Increment the index since we saw 'm'
+ cursor++;
+ }
+ else
+ {
+ // Can't mix one and zero indexing
+ if (zeroIndexOffset == 0)
+ {
+ getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString());
+ return CreateErrorExpr(memberRefExpr);
+ }
+ zeroIndexOffset = 1;
+ }
+
+ // Check for the ij components
+ for (Index j = 0; j < 2; j++)
+ {
+ auto ch = *cursor++;
+
+ if (ch < '0' || ch > '4')
+ {
+ // An invalid character in the swizzle is an error
+ getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString());
+ return CreateErrorExpr(memberRefExpr);
+ }
+ const int subIndex = ch - '0' - zeroIndexOffset;
+
+ // Check the limit for either the row or column, depending on the step
+ IntegerLiteralValue elementLimit;
+ if (j == 0)
+ {
+ elementLimit = baseElementRowCount;
+ elementCoord.row = subIndex;
+ }
+ else
+ {
+ elementLimit = baseElementColCount;
+ elementCoord.col = subIndex;
+ }
+ // Make sure the index is in range for the source type
+ // Account for off-by-one and reject 0 if oneIndexed
+ if (subIndex >= elementLimit || subIndex < 0)
+ {
+ getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->toString());
+ return CreateErrorExpr(memberRefExpr);
+ }
+ }
+ // Check if we've seen this index before
+ for (int ee = 0; ee < elementCount; ee++)
+ {
+ if (elementCoords[ee] == elementCoord)
+ anyDuplicates = true;
+ }
+
+ // add to our list...
+ elementCoords[elementCount] = elementCoord;
+ elementCount++;
+ }
+
+ // Store our list in the actual AST node
+ for (int ee = 0; ee < elementCount; ++ee)
+ {
+ swizExpr->elementCoords[ee] = elementCoords[ee];
+ }
+ swizExpr->elementCount = elementCount;
+
+ if (elementCount == 1)
+ {
+ // single-component swizzle produces a scalar
+ //
+ // Note(tfoley): the official HLSL rules seem to be that it produces
+ // a one-component vector, which is then implicitly convertible to
+ // a scalar, but that seems like it just adds complexity.
+ swizExpr->type = QualType(baseElementType);
+ }
+ else
+ {
+ // TODO(tfoley): would be nice to "re-sugar" type
+ // here if the input type had a sugared name...
+ swizExpr->type = QualType(createVectorType(
+ baseElementType,
+ m_astBuilder->create<ConstantIntVal>(elementCount)));
+ }
+
+ // A swizzle can be used as an l-value as long as there
+ // were no duplicates in the list of components
+ swizExpr->type.isLeftValue = !anyDuplicates;
+
+ return swizExpr;
+ }
+
+ RefPtr<Expr> SemanticsVisitor::CheckMatrixSwizzleExpr(
+ MemberExpr* memberRefExpr,
+ RefPtr<Type> baseElementType,
+ RefPtr<IntVal> baseRowCount,
+ RefPtr<IntVal> baseColCount)
+ {
+ if (auto constantRowCount = as<ConstantIntVal>(baseRowCount))
+ {
+ if (auto constantColCount = as<ConstantIntVal>(baseColCount))
+ {
+ return CheckMatrixSwizzleExpr(memberRefExpr, baseElementType,
+ constantRowCount->value, constantColCount->value);
+ }
+ }
+ getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on matrix of unknown size");
+ return CreateErrorExpr(memberRefExpr);
+ }
+
RefPtr<Expr> SemanticsVisitor::CheckSwizzleExpr(
MemberExpr* memberRefExpr,
RefPtr<Type> baseElementType,
@@ -1674,6 +1832,14 @@ namespace Slang
// members via extension, for vector or scalar types.
//
// TODO: Matrix swizzles probably need to be handled at some point.
+ if (auto baseMatrixType = as<MatrixExpressionType>(baseType))
+ {
+ return CheckMatrixSwizzleExpr(
+ expr,
+ baseMatrixType->getElementType(),
+ baseMatrixType->getRowCount(),
+ baseMatrixType->getColumnCount());
+ }
if (auto baseVecType = as<VectorExpressionType>(baseType))
{
return CheckSwizzleExpr(
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 52f460488..42edb5df7 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1288,6 +1288,18 @@ namespace Slang
RefPtr<Expr> MaybeDereference(RefPtr<Expr> inExpr);
+ RefPtr<Expr> CheckMatrixSwizzleExpr(
+ MemberExpr* memberRefExpr,
+ RefPtr<Type> baseElementType,
+ IntegerLiteralValue baseElementRowCount,
+ IntegerLiteralValue baseElementColCount);
+
+ RefPtr<Expr> CheckMatrixSwizzleExpr(
+ MemberExpr* memberRefExpr,
+ RefPtr<Type> baseElementType,
+ RefPtr<IntVal> baseElementRowCount,
+ RefPtr<IntVal> baseElementColCount);
+
RefPtr<Expr> CheckSwizzleExpr(
MemberExpr* memberRefExpr,
RefPtr<Type> baseElementType,
@@ -1364,6 +1376,7 @@ namespace Slang
}
CASE(DerefExpr)
+ CASE(MatrixSwizzleExpr)
CASE(SwizzleExpr)
CASE(OverloadedExpr)
CASE(OverloadedExpr2)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index bd05ae69a..f369729d2 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2895,6 +2895,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVisitor>
{
// When visiting a swizzle expression in an l-value context,
+ // we need to construct a "swizzled l-value."
+ LoweredValInfo visitMatrixSwizzleExpr(MatrixSwizzleExpr*)
+ {
+ SLANG_UNIMPLEMENTED_X("matrix swizzle lvalue case");
+ }
+
+ // When visiting a swizzle expression in an l-value context,
// we need to construct a "sizzled l-value."
LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr)
{
@@ -2964,6 +2971,52 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVisitor>
{
+ // A matrix swizzle in an r-value context can save time by just
+ // emitting the matrix swizzle instructions directly.
+ LoweredValInfo visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr)
+ {
+ auto resultType = lowerType(context, expr->type);
+ auto base = lowerSubExpr(expr->base);
+ auto matType = as<MatrixExpressionType>(expr->base->type.type);
+ if (!matType)
+ SLANG_UNEXPECTED("Expected a matrix type in matrix swizzle");
+ auto subscript2 = lowerType(context, matType->getElementType());
+ auto subscript1 = lowerType(context, matType->getRowType());
+
+ auto builder = getBuilder();
+
+ auto irIntType = getIntType(context);
+
+ UInt elementCount = (UInt)expr->elementCount;
+ IRInst* irExtracts[4];
+ for (UInt ii = 0; ii < elementCount; ++ii)
+ {
+ auto index1 = builder->getIntValue(
+ irIntType,
+ (IRIntegerValue)expr->elementCoords[ii].row);
+ auto index2 = builder->getIntValue(
+ irIntType,
+ (IRIntegerValue)expr->elementCoords[ii].col);
+ // First index expression
+ auto irExtract1 = subscriptValue(
+ subscript1,
+ base,
+ index1);
+ // Second index expression
+ irExtracts[ii] = getSimpleVal(context, subscriptValue(
+ subscript2,
+ irExtract1,
+ index2));
+ }
+ auto irVector = builder->emitMakeVector(
+ resultType,
+ elementCount,
+ irExtracts
+ );
+
+ return LoweredValInfo::simple(irVector);
+ }
+
// A swizzle in an r-value context can save time by just
// emitting the swizzle instructions directly.
LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr)
diff --git a/tests/diagnostics/matrix-swizzle.slang b/tests/diagnostics/matrix-swizzle.slang
new file mode 100644
index 000000000..d9331d89a
--- /dev/null
+++ b/tests/diagnostics/matrix-swizzle.slang
@@ -0,0 +1,24 @@
+//DIAGNOSTIC_TEST:SIMPLE:
+
+int doSomething(int a)
+{
+ int2x3 m1 = int2x3(0, 1, 2, 3, 4, 5);
+ int3x2 m2 = int3x2(0, 1, 2, 3, 4, 5);
+
+ int c = m1._14; // Out of bounds
+ c = m1._32;
+ c = m2._m22;
+ c = m2._; // unfinished
+ c = m2._m;
+ c = m2._1;
+ c = m2._m1;
+ c = m2._m12_;
+ int2 c2 = m1._m11_11; // Mixing of 1 and 0-indexing
+ c = m1._11_11_11_11_11; // More than 4 elements
+ c = m1.x; // Invalid character
+ c = m1._x;
+ c = m1.x123;
+
+ return m1._11;
+}
+
diff --git a/tests/diagnostics/matrix-swizzle.slang.expected b/tests/diagnostics/matrix-swizzle.slang.expected
new file mode 100644
index 000000000..8d349a2ed
--- /dev/null
+++ b/tests/diagnostics/matrix-swizzle.slang.expected
@@ -0,0 +1,18 @@
+result code = -1
+standard error = {
+tests/diagnostics/matrix-swizzle.slang(8): error 30052: invalid swizzle pattern '_14' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(9): error 30052: invalid swizzle pattern '_32' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(10): error 30052: invalid swizzle pattern '_m22' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(11): error 30052: invalid swizzle pattern '_' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(12): error 30052: invalid swizzle pattern '_m' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(13): error 30052: invalid swizzle pattern '_1' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(14): error 30052: invalid swizzle pattern '_m1' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(15): error 30052: invalid swizzle pattern '_m12_' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(16): error 30052: invalid swizzle pattern '_m11_11' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(17): error 30052: invalid swizzle pattern '_11_11_11_11_11' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(18): error 30052: invalid swizzle pattern 'x' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(19): error 30052: invalid swizzle pattern '_x' on type 'int'
+tests/diagnostics/matrix-swizzle.slang(20): error 30052: invalid swizzle pattern 'x123' on type 'int'
+}
+standard output = {
+}
diff --git a/tests/language-feature/swizzles/matrix-swizzles.slang b/tests/language-feature/swizzles/matrix-swizzles.slang
new file mode 100644
index 000000000..e1a2d7473
--- /dev/null
+++ b/tests/language-feature/swizzles/matrix-swizzles.slang
@@ -0,0 +1,34 @@
+// matrix-swizzle.slang
+
+//TEST(compute):COMPARE_COMPUTE:
+
+// Test that matrix swizzle works correctly
+// Matrix swizzles can either be one or zero indexed
+// Reference: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math
+
+int test(int val)
+{
+ float2x2 worldMatrix = float2x2(val + 0, val + 1, val + 2, val + 3);
+ float2 tempVector1;
+ float2 tempVector2;
+
+ // TODO: make left-hand side matrix swizzles work
+ tempVector1 = worldMatrix._m00_m11;
+ tempVector2 = worldMatrix._12_21;
+
+ // return tempMatrix[0][0] + tempMatrix[0][1] = val + 0 + val + 1
+ return tempVector1.x + tempVector2.x;
+}
+
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+}
diff --git a/tests/language-feature/swizzles/matrix-swizzles.slang.expected.txt b/tests/language-feature/swizzles/matrix-swizzles.slang.expected.txt
new file mode 100644
index 000000000..9b4237ab1
--- /dev/null
+++ b/tests/language-feature/swizzles/matrix-swizzles.slang.expected.txt
@@ -0,0 +1,4 @@
+1
+3
+5
+7