From 8a292cff17adbaaca92a8c0de9d41a77d2a13294 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 17 Aug 2023 21:39:16 -0700 Subject: SPIRV: Fix matrix layout tests. (#3137) * SPIRV: Fix matrix layout tests. * Remove spaces. * Disable debug output. * Fix. * Update expected-failure list. --------- Co-authored-by: Yong He --- source/slang/slang-emit-spirv.cpp | 104 ++++++++++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 27 deletions(-) (limited to 'source/slang/slang-emit-spirv.cpp') diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 3ba06b55c..9a7f5ad31 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1240,9 +1240,9 @@ struct SPIRVEmitContext auto matrixType = static_cast(inst); auto vectorSpvType = ensureVectorType( static_cast(matrixType->getElementType())->getBaseType(), - static_cast(matrixType->getRowCount())->getValue(), + static_cast(matrixType->getColumnCount())->getValue(), nullptr); - const auto columnCount = static_cast(matrixType->getColumnCount())->getValue(); + const auto columnCount = static_cast(matrixType->getRowCount())->getValue(); auto matrixSPVType = emitOpTypeMatrix( inst, vectorSpvType, @@ -2147,6 +2147,16 @@ struct SPIRVEmitContext break; case kIROp_SPIRVBufferBlockDecoration: + { + emitOpDecorate( + getSection(SpvLogicalSectionID::Annotations), + decoration, + dstID, + SpvDecorationBufferBlock + ); + } + break; + case kIROp_SPIRVBlockDecoration: { emitOpDecorate( getSection(SpvLogicalSectionID::Annotations), @@ -2206,25 +2216,55 @@ struct SPIRVEmitContext } if (matrixType) { - IRSizeAndAlignment matrixSize; - getSizeAndAlignment(IRTypeLayoutRules::get(layoutRuleName), matrixType, &matrixSize); + // SPIRV sepc on MatrixStride: + // Applies only to a member of a structure type.Only valid on a + // matrix or array whose most basic element is a matrix.Matrix + // Stride is an unsigned 32 - bit integer specifying the stride + // of the rows in a RowMajor - decorated matrix or columns in a + // ColMajor - decorated matrix. + IRIntegerValue matrixStride = 0; + auto rule = IRTypeLayoutRules::get(layoutRuleName); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment(rule, matrixType->getElementType(), &elementSizeAlignment); + // Reminder: the meaning of row/column major layout // in our semantics is the *opposite* of what GLSL/SPIRV // calls them, because what they call "columns" // are what we call "rows." // - emitOpMemberDecorate( - getSection(SpvLogicalSectionID::Annotations), - nullptr, - spvStructID, - SpvLiteralInteger::from32(id), - getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR? SpvDecorationRowMajor : SpvDecorationColMajor); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + emitOpMemberDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + spvStructID, + SpvLiteralInteger::from32(id), + SpvDecorationRowMajor); + + auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getRowCount())); + vectorSize = rule->alignCompositeElement(vectorSize); + matrixStride = vectorSize.getStride(); + } + else + { + emitOpMemberDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + spvStructID, + SpvLiteralInteger::from32(id), + SpvDecorationColMajor); + + auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getColumnCount())); + vectorSize = rule->alignCompositeElement(vectorSize); + matrixStride = vectorSize.getStride(); + } + emitOpMemberDecorateMatrixStride( getSection(SpvLogicalSectionID::Annotations), nullptr, spvStructID, SpvLiteralInteger::from32(id), - SpvLiteralInteger::from32((int32_t)matrixSize.getStride())); + SpvLiteralInteger::from32((int32_t)matrixStride)); } id++; } @@ -2835,6 +2875,7 @@ struct SPIRVEmitContext SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst) { + // Note: SPIRV only supports the case where `index` is constant. auto base = inst->getBase(); const auto baseTy = base->getDataType(); SLANG_ASSERT( @@ -2845,19 +2886,14 @@ struct SPIRVEmitContext IRBuilder builder(m_irModule); builder.setInsertBefore(inst); + auto index = getIntVal(inst->getIndex()); - auto ptr = emitOpAccessChain( - parent, - nullptr, - builder.getPtrType(inst->getFullType()), - inst->getBase(), - makeArray(inst->getIndex()) - ); - return emitOpLoad( + return emitOpCompositeExtract( parent, inst, inst->getFullType(), - ptr + inst->getBase(), + makeArray(SpvLiteralInteger::from32((int32_t)index)) ); } @@ -3291,17 +3327,31 @@ SlangResult emitSPIRVFromIR( auto targetRequest = codeGenContext->getTargetReq(); auto sink = codeGenContext->getSink(); +#if 0 + { + DiagnosticSinkWriter writer(codeGenContext->getSink()); + dumpIR( + irModule, + { IRDumpOptions::Mode::Simplified, 0 }, + "BEFORE SPIR-V LEGALIZE", + codeGenContext->getSourceManager(), + &writer); + } +#endif + SPIRVEmitContext context(irModule, targetRequest, sink); legalizeIRForSPIRV(&context, irModule, irEntryPoints, codeGenContext); #if 0 - DiagnosticSinkWriter writer(codeGenContext->getSink()); - dumpIR( - irModule, - {IRDumpOptions::Mode::Simplified, 0}, - "BEFORE SPIR-V EMIT", - codeGenContext->getSourceManager(), - &writer); + { + DiagnosticSinkWriter writer(codeGenContext->getSink()); + dumpIR( + irModule, + { IRDumpOptions::Mode::Simplified, 0 }, + "BEFORE SPIR-V EMIT", + codeGenContext->getSourceManager(), + &writer); + } #endif context.emitFrontMatter(); -- cgit v1.2.3