diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 104 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 105 |
4 files changed, 166 insertions, 53 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index d721bec25..b690a5910 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -3022,6 +3022,7 @@ T mul(vector<T, N> x, vector<T, N> y) __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") +__target_intrinsic(spirv, "OpMatrixTimesVector resultType resultId _1 _0") [__readNone] vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right) { @@ -3078,6 +3079,7 @@ vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right) __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") +__target_intrinsic(spirv, "OpVectorTimesMatrix resultType resultId _1 _0") [__readNone] vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right) { @@ -3134,6 +3136,7 @@ vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right) __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") +__target_intrinsic(spirv, "OpMatrixTimesMatrix resultType resultId _1 _0") [__readNone] matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left) { diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 3ba06b55c..9a7f5ad31 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1240,9 +1240,9 @@ struct SPIRVEmitContext auto matrixType = static_cast<IRMatrixType*>(inst); auto vectorSpvType = ensureVectorType( static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), - static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(), + static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); - const auto columnCount = static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(); + const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); auto matrixSPVType = emitOpTypeMatrix( inst, vectorSpvType, @@ -2152,6 +2152,16 @@ struct SPIRVEmitContext getSection(SpvLogicalSectionID::Annotations), decoration, dstID, + SpvDecorationBufferBlock + ); + } + break; + case kIROp_SPIRVBlockDecoration: + { + emitOpDecorate( + getSection(SpvLogicalSectionID::Annotations), + decoration, + dstID, SpvDecorationBlock ); } @@ -2206,25 +2216,55 @@ struct SPIRVEmitContext } if (matrixType) { - IRSizeAndAlignment matrixSize; - getSizeAndAlignment(IRTypeLayoutRules::get(layoutRuleName), matrixType, &matrixSize); + // SPIRV sepc on MatrixStride: + // Applies only to a member of a structure type.Only valid on a + // matrix or array whose most basic element is a matrix.Matrix + // Stride is an unsigned 32 - bit integer specifying the stride + // of the rows in a RowMajor - decorated matrix or columns in a + // ColMajor - decorated matrix. + IRIntegerValue matrixStride = 0; + auto rule = IRTypeLayoutRules::get(layoutRuleName); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment(rule, matrixType->getElementType(), &elementSizeAlignment); + // Reminder: the meaning of row/column major layout // in our semantics is the *opposite* of what GLSL/SPIRV // calls them, because what they call "columns" // are what we call "rows." // - emitOpMemberDecorate( - getSection(SpvLogicalSectionID::Annotations), - nullptr, - spvStructID, - SpvLiteralInteger::from32(id), - getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR? SpvDecorationRowMajor : SpvDecorationColMajor); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + emitOpMemberDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + spvStructID, + SpvLiteralInteger::from32(id), + SpvDecorationRowMajor); + + auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getRowCount())); + vectorSize = rule->alignCompositeElement(vectorSize); + matrixStride = vectorSize.getStride(); + } + else + { + emitOpMemberDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + spvStructID, + SpvLiteralInteger::from32(id), + SpvDecorationColMajor); + + auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getColumnCount())); + vectorSize = rule->alignCompositeElement(vectorSize); + matrixStride = vectorSize.getStride(); + } + emitOpMemberDecorateMatrixStride( getSection(SpvLogicalSectionID::Annotations), nullptr, spvStructID, SpvLiteralInteger::from32(id), - SpvLiteralInteger::from32((int32_t)matrixSize.getStride())); + SpvLiteralInteger::from32((int32_t)matrixStride)); } id++; } @@ -2835,6 +2875,7 @@ struct SPIRVEmitContext SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst) { + // Note: SPIRV only supports the case where `index` is constant. auto base = inst->getBase(); const auto baseTy = base->getDataType(); SLANG_ASSERT( @@ -2845,19 +2886,14 @@ struct SPIRVEmitContext IRBuilder builder(m_irModule); builder.setInsertBefore(inst); + auto index = getIntVal(inst->getIndex()); - auto ptr = emitOpAccessChain( - parent, - nullptr, - builder.getPtrType(inst->getFullType()), - inst->getBase(), - makeArray(inst->getIndex()) - ); - return emitOpLoad( + return emitOpCompositeExtract( parent, inst, inst->getFullType(), - ptr + inst->getBase(), + makeArray(SpvLiteralInteger::from32((int32_t)index)) ); } @@ -3291,17 +3327,31 @@ SlangResult emitSPIRVFromIR( auto targetRequest = codeGenContext->getTargetReq(); auto sink = codeGenContext->getSink(); +#if 0 + { + DiagnosticSinkWriter writer(codeGenContext->getSink()); + dumpIR( + irModule, + { IRDumpOptions::Mode::Simplified, 0 }, + "BEFORE SPIR-V LEGALIZE", + codeGenContext->getSourceManager(), + &writer); + } +#endif + SPIRVEmitContext context(irModule, targetRequest, sink); legalizeIRForSPIRV(&context, irModule, irEntryPoints, codeGenContext); #if 0 - DiagnosticSinkWriter writer(codeGenContext->getSink()); - dumpIR( - irModule, - {IRDumpOptions::Mode::Simplified, 0}, - "BEFORE SPIR-V EMIT", - codeGenContext->getSourceManager(), - &writer); + { + DiagnosticSinkWriter writer(codeGenContext->getSink()); + dumpIR( + irModule, + { IRDumpOptions::Mode::Simplified, 0 }, + "BEFORE SPIR-V EMIT", + codeGenContext->getSourceManager(), + &writer); + } #endif context.emitFrontMatter(); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index a8fdd8202..64663120d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -904,11 +904,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Overrides the floating mode for the target function INST(FloatingPointModeOverrideDecoration, FloatingPointModeOverride, 1, 0) - /// Marks a struct type as being used as a structured buffer block. /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0) - INST_RANGE(Decoration, HighLevelDeclDecoration, SPIRVBufferBlockDecoration) + /// Recognized by SPIRV-emit pass so we can emit a SPIRV `Block` decoration. + INST(SPIRVBlockDecoration, spvBlock, 0, 0) + + + INST_RANGE(Decoration, HighLevelDeclDecoration, SPIRVBlockDecoration) // diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 637592357..f6294e2ba 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -10,6 +10,7 @@ #include "slang-glsl-extension-tracker.h" #include "slang-ir-lower-buffer-element-type.h" #include "slang-ir-layout.h" +#include "slang-ir-util.h" namespace Slang { @@ -52,6 +53,36 @@ struct SPIRVLegalizationContext : public SourceEmitterBase { } + // Wraps the element type of a constant buffer or parameter block in a struct if it is not already a struct, + // returns the newly created struct type. + IRType* wrapConstantBufferElement(IRInst* cbParamInst) + { + auto innerType = as<IRParameterGroupType>(cbParamInst->getDataType())->getElementType(); + IRBuilder builder(cbParamInst); + builder.setInsertBefore(cbParamInst); + auto structType = builder.createStructType(); + StringBuilder sb; + sb << "cbuffer_"; + getTypeNameHint(sb, innerType); + sb << "_t"; + builder.addNameHintDecoration(structType, sb.produceString().getUnownedSlice()); + auto key = builder.createStructKey(); + builder.createStructField(structType, key, innerType); + builder.setInsertBefore(cbParamInst); + auto newCbType = builder.getType(cbParamInst->getDataType()->getOp(), structType); + cbParamInst->setFullType(newCbType); + auto rules = getTypeLayoutRuleForBuffer(m_sharedContext->m_targetRequest, cbParamInst->getDataType()); + IRSizeAndAlignment sizeAlignment; + getSizeAndAlignment(rules, structType, &sizeAlignment); + traverseUses(cbParamInst, [&](IRUse* use) + { + builder.setInsertBefore(use->getUser()); + auto addr = builder.emitFieldAddress(builder.getPtrType(kIROp_PtrType, innerType, SpvStorageClassUniform), cbParamInst, key); + use->set(addr); + }); + return structType; + } + void processGlobalParam(IRGlobalParam* inst) { // If the global param is not a pointer type, make it so and insert explicit load insts. @@ -86,34 +117,45 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } // Strip any HLSL wrappers + IRBuilder builder(m_sharedContext->m_irModule); + bool needLoad = true; auto innerType = inst->getFullType(); - if(const auto constantBufferType = as<IRConstantBufferType>(innerType)) - { - innerType = constantBufferType->getElementType(); - storageClass = SpvStorageClassUniform; - } - else if (auto paramBlockType = as<IRParameterBlockType>(innerType)) + if (as<IRConstantBufferType>(innerType) || as<IRParameterBlockType>(innerType)) { - innerType = paramBlockType->getElementType(); + innerType = as<IRUniformParameterGroupType>(innerType)->getElementType(); storageClass = SpvStorageClassUniform; + // Constant buffer is already treated like a pointer type, and + // we are not adding another layer of indirection when replacing it + // with a pointer type. Therefore we don't need to insert a load at + // use sites. + needLoad = false; + // If inner element type is not a struct type, we need to wrap it with + // a struct. + if (!as<IRStructType>(innerType)) + { + innerType = wrapConstantBufferElement(inst); + } + builder.addDecoration(innerType, kIROp_SPIRVBlockDecoration); } // Make a pointer type of storageClass. - IRBuilder builder(m_sharedContext->m_irModule); builder.setInsertBefore(inst); ptrType = builder.getPtrType(kIROp_PtrType, innerType, storageClass); inst->setFullType(ptrType); - // Insert an explicit load at each use site. - List<IRUse*> uses; - for (auto use = inst->firstUse; use; use = use->nextUse) - { - uses.add(use); - } - for (auto use : uses) + if (needLoad) { - builder.setInsertBefore(use->getUser()); - auto loadedValue = builder.emitLoad(inst); - use->set(loadedValue); + // Insert an explicit load at each use site. + List<IRUse*> uses; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + uses.add(use); + } + for (auto use : uses) + { + builder.setInsertBefore(use->getUser()); + auto loadedValue = builder.emitLoad(inst); + use->set(loadedValue); + } } } processGlobalVar(inst); @@ -206,22 +248,37 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } - // Replace getElement(x, i) with, y = store(x); p = getElementPtr(y, i); load(p) - // SPIR-V has no support for dynamic indexing into values like we do. + Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar; + + // Replace getElement(x, i) with, y = store(x); p = getElementPtr(y, i); load(p), + // when i is not a constant. SPIR-V has no support for dynamic indexing into values like we do. // It may be advantageous however to do this further up the pipeline void processGetElement(IRGetElement* inst) { - const auto x = inst->getBase(); + IRInst* x = nullptr; List<IRInst*> indices; IRGetElement* c = inst; do { + if (as<IRIntLit>(c->getIndex())) + break; + x = c->getBase(); indices.add(c->getIndex()); } while(c = as<IRGetElement>(c->getBase()), c); + + if (!x) + return; + IRBuilder builder(m_sharedContext->m_irModule); + IRInst* y = nullptr; + if (!m_mapArrayValueToVar.tryGetValue(x, y)) + { + setInsertAfterOrdinaryInst(&builder, x); + y = builder.emitVar(x->getDataType(), SpvStorageClassFunction); + builder.emitStore(y, x); + m_mapArrayValueToVar.set(x, y); + } builder.setInsertBefore(inst); - IRInst* y = builder.emitVar(x->getDataType(), SpvStorageClassFunction); - builder.emitStore(y, x); for(Index i = indices.getCount() - 1; i >= 0; --i) y = builder.emitElementAddress(y, indices[i]); const auto newInst = builder.emitLoad(y); @@ -367,7 +424,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase break; } builder.addNameHintDecoration(structType, nameSb.getUnownedSlice()); - builder.addDecoration(structType, kIROp_SPIRVBufferBlockDecoration); + builder.addDecoration(structType, kIROp_SPIRVBlockDecoration); inst->replaceUsesWith(ptrType); inst->removeAndDeallocate(); addUsersToWorkList(ptrType); |
