diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-17 21:39:16 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-18 12:39:16 +0800 |
| commit | 8a292cff17adbaaca92a8c0de9d41a77d2a13294 (patch) | |
| tree | 1b14b2bd86a939ef3808c225f660f1e6701c478f /source/slang/slang-emit-spirv.cpp | |
| parent | 80c8f13e369b0bf0b86d2b19a4902594e6d67e5c (diff) | |
SPIRV: Fix matrix layout tests. (#3137)
* SPIRV: Fix matrix layout tests.
* Remove spaces.
* Disable debug output.
* Fix.
* Update expected-failure list.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 104 |
1 files changed, 77 insertions, 27 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 3ba06b55c..9a7f5ad31 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1240,9 +1240,9 @@ struct SPIRVEmitContext auto matrixType = static_cast<IRMatrixType*>(inst); auto vectorSpvType = ensureVectorType( static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), - static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(), + static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); - const auto columnCount = static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(); + const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); auto matrixSPVType = emitOpTypeMatrix( inst, vectorSpvType, @@ -2152,6 +2152,16 @@ struct SPIRVEmitContext getSection(SpvLogicalSectionID::Annotations), decoration, dstID, + SpvDecorationBufferBlock + ); + } + break; + case kIROp_SPIRVBlockDecoration: + { + emitOpDecorate( + getSection(SpvLogicalSectionID::Annotations), + decoration, + dstID, SpvDecorationBlock ); } @@ -2206,25 +2216,55 @@ struct SPIRVEmitContext } if (matrixType) { - IRSizeAndAlignment matrixSize; - getSizeAndAlignment(IRTypeLayoutRules::get(layoutRuleName), matrixType, &matrixSize); + // SPIRV sepc on MatrixStride: + // Applies only to a member of a structure type.Only valid on a + // matrix or array whose most basic element is a matrix.Matrix + // Stride is an unsigned 32 - bit integer specifying the stride + // of the rows in a RowMajor - decorated matrix or columns in a + // ColMajor - decorated matrix. + IRIntegerValue matrixStride = 0; + auto rule = IRTypeLayoutRules::get(layoutRuleName); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment(rule, matrixType->getElementType(), &elementSizeAlignment); + // Reminder: the meaning of row/column major layout // in our semantics is the *opposite* of what GLSL/SPIRV // calls them, because what they call "columns" // are what we call "rows." // - emitOpMemberDecorate( - getSection(SpvLogicalSectionID::Annotations), - nullptr, - spvStructID, - SpvLiteralInteger::from32(id), - getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR? SpvDecorationRowMajor : SpvDecorationColMajor); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + emitOpMemberDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + spvStructID, + SpvLiteralInteger::from32(id), + SpvDecorationRowMajor); + + auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getRowCount())); + vectorSize = rule->alignCompositeElement(vectorSize); + matrixStride = vectorSize.getStride(); + } + else + { + emitOpMemberDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + spvStructID, + SpvLiteralInteger::from32(id), + SpvDecorationColMajor); + + auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getColumnCount())); + vectorSize = rule->alignCompositeElement(vectorSize); + matrixStride = vectorSize.getStride(); + } + emitOpMemberDecorateMatrixStride( getSection(SpvLogicalSectionID::Annotations), nullptr, spvStructID, SpvLiteralInteger::from32(id), - SpvLiteralInteger::from32((int32_t)matrixSize.getStride())); + SpvLiteralInteger::from32((int32_t)matrixStride)); } id++; } @@ -2835,6 +2875,7 @@ struct SPIRVEmitContext SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst) { + // Note: SPIRV only supports the case where `index` is constant. auto base = inst->getBase(); const auto baseTy = base->getDataType(); SLANG_ASSERT( @@ -2845,19 +2886,14 @@ struct SPIRVEmitContext IRBuilder builder(m_irModule); builder.setInsertBefore(inst); + auto index = getIntVal(inst->getIndex()); - auto ptr = emitOpAccessChain( - parent, - nullptr, - builder.getPtrType(inst->getFullType()), - inst->getBase(), - makeArray(inst->getIndex()) - ); - return emitOpLoad( + return emitOpCompositeExtract( parent, inst, inst->getFullType(), - ptr + inst->getBase(), + makeArray(SpvLiteralInteger::from32((int32_t)index)) ); } @@ -3291,17 +3327,31 @@ SlangResult emitSPIRVFromIR( auto targetRequest = codeGenContext->getTargetReq(); auto sink = codeGenContext->getSink(); +#if 0 + { + DiagnosticSinkWriter writer(codeGenContext->getSink()); + dumpIR( + irModule, + { IRDumpOptions::Mode::Simplified, 0 }, + "BEFORE SPIR-V LEGALIZE", + codeGenContext->getSourceManager(), + &writer); + } +#endif + SPIRVEmitContext context(irModule, targetRequest, sink); legalizeIRForSPIRV(&context, irModule, irEntryPoints, codeGenContext); #if 0 - DiagnosticSinkWriter writer(codeGenContext->getSink()); - dumpIR( - irModule, - {IRDumpOptions::Mode::Simplified, 0}, - "BEFORE SPIR-V EMIT", - codeGenContext->getSourceManager(), - &writer); + { + DiagnosticSinkWriter writer(codeGenContext->getSink()); + dumpIR( + irModule, + { IRDumpOptions::Mode::Simplified, 0 }, + "BEFORE SPIR-V EMIT", + codeGenContext->getSourceManager(), + &writer); + } #endif context.emitFrontMatter(); |
