summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-17 21:39:16 -0700
committerGitHub <noreply@github.com>2023-08-18 12:39:16 +0800
commit8a292cff17adbaaca92a8c0de9d41a77d2a13294 (patch)
tree1b14b2bd86a939ef3808c225f660f1e6701c478f /source/slang/slang-emit-spirv.cpp
parent80c8f13e369b0bf0b86d2b19a4902594e6d67e5c (diff)
SPIRV: Fix matrix layout tests. (#3137)
* SPIRV: Fix matrix layout tests. * Remove spaces. * Disable debug output. * Fix. * Update expected-failure list. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp104
1 files changed, 77 insertions, 27 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 3ba06b55c..9a7f5ad31 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1240,9 +1240,9 @@ struct SPIRVEmitContext
auto matrixType = static_cast<IRMatrixType*>(inst);
auto vectorSpvType = ensureVectorType(
static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(),
- static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(),
+ static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(),
nullptr);
- const auto columnCount = static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue();
+ const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue();
auto matrixSPVType = emitOpTypeMatrix(
inst,
vectorSpvType,
@@ -2152,6 +2152,16 @@ struct SPIRVEmitContext
getSection(SpvLogicalSectionID::Annotations),
decoration,
dstID,
+ SpvDecorationBufferBlock
+ );
+ }
+ break;
+ case kIROp_SPIRVBlockDecoration:
+ {
+ emitOpDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ decoration,
+ dstID,
SpvDecorationBlock
);
}
@@ -2206,25 +2216,55 @@ struct SPIRVEmitContext
}
if (matrixType)
{
- IRSizeAndAlignment matrixSize;
- getSizeAndAlignment(IRTypeLayoutRules::get(layoutRuleName), matrixType, &matrixSize);
+ // SPIRV sepc on MatrixStride:
+ // Applies only to a member of a structure type.Only valid on a
+ // matrix or array whose most basic element is a matrix.Matrix
+ // Stride is an unsigned 32 - bit integer specifying the stride
+ // of the rows in a RowMajor - decorated matrix or columns in a
+ // ColMajor - decorated matrix.
+ IRIntegerValue matrixStride = 0;
+ auto rule = IRTypeLayoutRules::get(layoutRuleName);
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(rule, matrixType->getElementType(), &elementSizeAlignment);
+
// Reminder: the meaning of row/column major layout
// in our semantics is the *opposite* of what GLSL/SPIRV
// calls them, because what they call "columns"
// are what we call "rows."
//
- emitOpMemberDecorate(
- getSection(SpvLogicalSectionID::Annotations),
- nullptr,
- spvStructID,
- SpvLiteralInteger::from32(id),
- getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR? SpvDecorationRowMajor : SpvDecorationColMajor);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ emitOpMemberDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ spvStructID,
+ SpvLiteralInteger::from32(id),
+ SpvDecorationRowMajor);
+
+ auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getRowCount()));
+ vectorSize = rule->alignCompositeElement(vectorSize);
+ matrixStride = vectorSize.getStride();
+ }
+ else
+ {
+ emitOpMemberDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ spvStructID,
+ SpvLiteralInteger::from32(id),
+ SpvDecorationColMajor);
+
+ auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getColumnCount()));
+ vectorSize = rule->alignCompositeElement(vectorSize);
+ matrixStride = vectorSize.getStride();
+ }
+
emitOpMemberDecorateMatrixStride(
getSection(SpvLogicalSectionID::Annotations),
nullptr,
spvStructID,
SpvLiteralInteger::from32(id),
- SpvLiteralInteger::from32((int32_t)matrixSize.getStride()));
+ SpvLiteralInteger::from32((int32_t)matrixStride));
}
id++;
}
@@ -2835,6 +2875,7 @@ struct SPIRVEmitContext
SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst)
{
+ // Note: SPIRV only supports the case where `index` is constant.
auto base = inst->getBase();
const auto baseTy = base->getDataType();
SLANG_ASSERT(
@@ -2845,19 +2886,14 @@ struct SPIRVEmitContext
IRBuilder builder(m_irModule);
builder.setInsertBefore(inst);
+ auto index = getIntVal(inst->getIndex());
- auto ptr = emitOpAccessChain(
- parent,
- nullptr,
- builder.getPtrType(inst->getFullType()),
- inst->getBase(),
- makeArray(inst->getIndex())
- );
- return emitOpLoad(
+ return emitOpCompositeExtract(
parent,
inst,
inst->getFullType(),
- ptr
+ inst->getBase(),
+ makeArray(SpvLiteralInteger::from32((int32_t)index))
);
}
@@ -3291,17 +3327,31 @@ SlangResult emitSPIRVFromIR(
auto targetRequest = codeGenContext->getTargetReq();
auto sink = codeGenContext->getSink();
+#if 0
+ {
+ DiagnosticSinkWriter writer(codeGenContext->getSink());
+ dumpIR(
+ irModule,
+ { IRDumpOptions::Mode::Simplified, 0 },
+ "BEFORE SPIR-V LEGALIZE",
+ codeGenContext->getSourceManager(),
+ &writer);
+ }
+#endif
+
SPIRVEmitContext context(irModule, targetRequest, sink);
legalizeIRForSPIRV(&context, irModule, irEntryPoints, codeGenContext);
#if 0
- DiagnosticSinkWriter writer(codeGenContext->getSink());
- dumpIR(
- irModule,
- {IRDumpOptions::Mode::Simplified, 0},
- "BEFORE SPIR-V EMIT",
- codeGenContext->getSourceManager(),
- &writer);
+ {
+ DiagnosticSinkWriter writer(codeGenContext->getSink());
+ dumpIR(
+ irModule,
+ { IRDumpOptions::Mode::Simplified, 0 },
+ "BEFORE SPIR-V EMIT",
+ codeGenContext->getSourceManager(),
+ &writer);
+ }
#endif
context.emitFrontMatter();