summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-17 21:39:16 -0700
committerGitHub <noreply@github.com>2023-08-18 12:39:16 +0800
commit8a292cff17adbaaca92a8c0de9d41a77d2a13294 (patch)
tree1b14b2bd86a939ef3808c225f660f1e6701c478f /source
parent80c8f13e369b0bf0b86d2b19a4902594e6d67e5c (diff)
SPIRV: Fix matrix layout tests. (#3137)
* SPIRV: Fix matrix layout tests. * Remove spaces. * Disable debug output. * Fix. * Update expected-failure list. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang3
-rw-r--r--source/slang/slang-emit-spirv.cpp104
-rw-r--r--source/slang/slang-ir-inst-defs.h7
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp105
4 files changed, 166 insertions, 53 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index d721bec25..b690a5910 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -3022,6 +3022,7 @@ T mul(vector<T, N> x, vector<T, N> y)
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "($1 * $0)")
+__target_intrinsic(spirv, "OpMatrixTimesVector resultType resultId _1 _0")
[__readNone]
vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
{
@@ -3078,6 +3079,7 @@ vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "($1 * $0)")
+__target_intrinsic(spirv, "OpVectorTimesMatrix resultType resultId _1 _0")
[__readNone]
vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right)
{
@@ -3134,6 +3136,7 @@ vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right)
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "($1 * $0)")
+__target_intrinsic(spirv, "OpMatrixTimesMatrix resultType resultId _1 _0")
[__readNone]
matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left)
{
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 3ba06b55c..9a7f5ad31 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1240,9 +1240,9 @@ struct SPIRVEmitContext
auto matrixType = static_cast<IRMatrixType*>(inst);
auto vectorSpvType = ensureVectorType(
static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(),
- static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(),
+ static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(),
nullptr);
- const auto columnCount = static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue();
+ const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue();
auto matrixSPVType = emitOpTypeMatrix(
inst,
vectorSpvType,
@@ -2152,6 +2152,16 @@ struct SPIRVEmitContext
getSection(SpvLogicalSectionID::Annotations),
decoration,
dstID,
+ SpvDecorationBufferBlock
+ );
+ }
+ break;
+ case kIROp_SPIRVBlockDecoration:
+ {
+ emitOpDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ decoration,
+ dstID,
SpvDecorationBlock
);
}
@@ -2206,25 +2216,55 @@ struct SPIRVEmitContext
}
if (matrixType)
{
- IRSizeAndAlignment matrixSize;
- getSizeAndAlignment(IRTypeLayoutRules::get(layoutRuleName), matrixType, &matrixSize);
+ // SPIRV sepc on MatrixStride:
+ // Applies only to a member of a structure type.Only valid on a
+ // matrix or array whose most basic element is a matrix.Matrix
+ // Stride is an unsigned 32 - bit integer specifying the stride
+ // of the rows in a RowMajor - decorated matrix or columns in a
+ // ColMajor - decorated matrix.
+ IRIntegerValue matrixStride = 0;
+ auto rule = IRTypeLayoutRules::get(layoutRuleName);
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(rule, matrixType->getElementType(), &elementSizeAlignment);
+
// Reminder: the meaning of row/column major layout
// in our semantics is the *opposite* of what GLSL/SPIRV
// calls them, because what they call "columns"
// are what we call "rows."
//
- emitOpMemberDecorate(
- getSection(SpvLogicalSectionID::Annotations),
- nullptr,
- spvStructID,
- SpvLiteralInteger::from32(id),
- getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR? SpvDecorationRowMajor : SpvDecorationColMajor);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ emitOpMemberDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ spvStructID,
+ SpvLiteralInteger::from32(id),
+ SpvDecorationRowMajor);
+
+ auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getRowCount()));
+ vectorSize = rule->alignCompositeElement(vectorSize);
+ matrixStride = vectorSize.getStride();
+ }
+ else
+ {
+ emitOpMemberDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ spvStructID,
+ SpvLiteralInteger::from32(id),
+ SpvDecorationColMajor);
+
+ auto vectorSize = rule->getVectorSizeAndAlignment(elementSizeAlignment, getIntVal(matrixType->getColumnCount()));
+ vectorSize = rule->alignCompositeElement(vectorSize);
+ matrixStride = vectorSize.getStride();
+ }
+
emitOpMemberDecorateMatrixStride(
getSection(SpvLogicalSectionID::Annotations),
nullptr,
spvStructID,
SpvLiteralInteger::from32(id),
- SpvLiteralInteger::from32((int32_t)matrixSize.getStride()));
+ SpvLiteralInteger::from32((int32_t)matrixStride));
}
id++;
}
@@ -2835,6 +2875,7 @@ struct SPIRVEmitContext
SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst)
{
+ // Note: SPIRV only supports the case where `index` is constant.
auto base = inst->getBase();
const auto baseTy = base->getDataType();
SLANG_ASSERT(
@@ -2845,19 +2886,14 @@ struct SPIRVEmitContext
IRBuilder builder(m_irModule);
builder.setInsertBefore(inst);
+ auto index = getIntVal(inst->getIndex());
- auto ptr = emitOpAccessChain(
- parent,
- nullptr,
- builder.getPtrType(inst->getFullType()),
- inst->getBase(),
- makeArray(inst->getIndex())
- );
- return emitOpLoad(
+ return emitOpCompositeExtract(
parent,
inst,
inst->getFullType(),
- ptr
+ inst->getBase(),
+ makeArray(SpvLiteralInteger::from32((int32_t)index))
);
}
@@ -3291,17 +3327,31 @@ SlangResult emitSPIRVFromIR(
auto targetRequest = codeGenContext->getTargetReq();
auto sink = codeGenContext->getSink();
+#if 0
+ {
+ DiagnosticSinkWriter writer(codeGenContext->getSink());
+ dumpIR(
+ irModule,
+ { IRDumpOptions::Mode::Simplified, 0 },
+ "BEFORE SPIR-V LEGALIZE",
+ codeGenContext->getSourceManager(),
+ &writer);
+ }
+#endif
+
SPIRVEmitContext context(irModule, targetRequest, sink);
legalizeIRForSPIRV(&context, irModule, irEntryPoints, codeGenContext);
#if 0
- DiagnosticSinkWriter writer(codeGenContext->getSink());
- dumpIR(
- irModule,
- {IRDumpOptions::Mode::Simplified, 0},
- "BEFORE SPIR-V EMIT",
- codeGenContext->getSourceManager(),
- &writer);
+ {
+ DiagnosticSinkWriter writer(codeGenContext->getSink());
+ dumpIR(
+ irModule,
+ { IRDumpOptions::Mode::Simplified, 0 },
+ "BEFORE SPIR-V EMIT",
+ codeGenContext->getSourceManager(),
+ &writer);
+ }
#endif
context.emitFrontMatter();
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index a8fdd8202..64663120d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -904,11 +904,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// Overrides the floating mode for the target function
INST(FloatingPointModeOverrideDecoration, FloatingPointModeOverride, 1, 0)
- /// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0)
- INST_RANGE(Decoration, HighLevelDeclDecoration, SPIRVBufferBlockDecoration)
+ /// Recognized by SPIRV-emit pass so we can emit a SPIRV `Block` decoration.
+ INST(SPIRVBlockDecoration, spvBlock, 0, 0)
+
+
+ INST_RANGE(Decoration, HighLevelDeclDecoration, SPIRVBlockDecoration)
//
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 637592357..f6294e2ba 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -10,6 +10,7 @@
#include "slang-glsl-extension-tracker.h"
#include "slang-ir-lower-buffer-element-type.h"
#include "slang-ir-layout.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -52,6 +53,36 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
{
}
+ // Wraps the element type of a constant buffer or parameter block in a struct if it is not already a struct,
+ // returns the newly created struct type.
+ IRType* wrapConstantBufferElement(IRInst* cbParamInst)
+ {
+ auto innerType = as<IRParameterGroupType>(cbParamInst->getDataType())->getElementType();
+ IRBuilder builder(cbParamInst);
+ builder.setInsertBefore(cbParamInst);
+ auto structType = builder.createStructType();
+ StringBuilder sb;
+ sb << "cbuffer_";
+ getTypeNameHint(sb, innerType);
+ sb << "_t";
+ builder.addNameHintDecoration(structType, sb.produceString().getUnownedSlice());
+ auto key = builder.createStructKey();
+ builder.createStructField(structType, key, innerType);
+ builder.setInsertBefore(cbParamInst);
+ auto newCbType = builder.getType(cbParamInst->getDataType()->getOp(), structType);
+ cbParamInst->setFullType(newCbType);
+ auto rules = getTypeLayoutRuleForBuffer(m_sharedContext->m_targetRequest, cbParamInst->getDataType());
+ IRSizeAndAlignment sizeAlignment;
+ getSizeAndAlignment(rules, structType, &sizeAlignment);
+ traverseUses(cbParamInst, [&](IRUse* use)
+ {
+ builder.setInsertBefore(use->getUser());
+ auto addr = builder.emitFieldAddress(builder.getPtrType(kIROp_PtrType, innerType, SpvStorageClassUniform), cbParamInst, key);
+ use->set(addr);
+ });
+ return structType;
+ }
+
void processGlobalParam(IRGlobalParam* inst)
{
// If the global param is not a pointer type, make it so and insert explicit load insts.
@@ -86,34 +117,45 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
// Strip any HLSL wrappers
+ IRBuilder builder(m_sharedContext->m_irModule);
+ bool needLoad = true;
auto innerType = inst->getFullType();
- if(const auto constantBufferType = as<IRConstantBufferType>(innerType))
- {
- innerType = constantBufferType->getElementType();
- storageClass = SpvStorageClassUniform;
- }
- else if (auto paramBlockType = as<IRParameterBlockType>(innerType))
+ if (as<IRConstantBufferType>(innerType) || as<IRParameterBlockType>(innerType))
{
- innerType = paramBlockType->getElementType();
+ innerType = as<IRUniformParameterGroupType>(innerType)->getElementType();
storageClass = SpvStorageClassUniform;
+ // Constant buffer is already treated like a pointer type, and
+ // we are not adding another layer of indirection when replacing it
+ // with a pointer type. Therefore we don't need to insert a load at
+ // use sites.
+ needLoad = false;
+ // If inner element type is not a struct type, we need to wrap it with
+ // a struct.
+ if (!as<IRStructType>(innerType))
+ {
+ innerType = wrapConstantBufferElement(inst);
+ }
+ builder.addDecoration(innerType, kIROp_SPIRVBlockDecoration);
}
// Make a pointer type of storageClass.
- IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
ptrType = builder.getPtrType(kIROp_PtrType, innerType, storageClass);
inst->setFullType(ptrType);
- // Insert an explicit load at each use site.
- List<IRUse*> uses;
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- uses.add(use);
- }
- for (auto use : uses)
+ if (needLoad)
{
- builder.setInsertBefore(use->getUser());
- auto loadedValue = builder.emitLoad(inst);
- use->set(loadedValue);
+ // Insert an explicit load at each use site.
+ List<IRUse*> uses;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ uses.add(use);
+ }
+ for (auto use : uses)
+ {
+ builder.setInsertBefore(use->getUser());
+ auto loadedValue = builder.emitLoad(inst);
+ use->set(loadedValue);
+ }
}
}
processGlobalVar(inst);
@@ -206,22 +248,37 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
- // Replace getElement(x, i) with, y = store(x); p = getElementPtr(y, i); load(p)
- // SPIR-V has no support for dynamic indexing into values like we do.
+ Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar;
+
+ // Replace getElement(x, i) with, y = store(x); p = getElementPtr(y, i); load(p),
+ // when i is not a constant. SPIR-V has no support for dynamic indexing into values like we do.
// It may be advantageous however to do this further up the pipeline
void processGetElement(IRGetElement* inst)
{
- const auto x = inst->getBase();
+ IRInst* x = nullptr;
List<IRInst*> indices;
IRGetElement* c = inst;
do
{
+ if (as<IRIntLit>(c->getIndex()))
+ break;
+ x = c->getBase();
indices.add(c->getIndex());
} while(c = as<IRGetElement>(c->getBase()), c);
+
+ if (!x)
+ return;
+
IRBuilder builder(m_sharedContext->m_irModule);
+ IRInst* y = nullptr;
+ if (!m_mapArrayValueToVar.tryGetValue(x, y))
+ {
+ setInsertAfterOrdinaryInst(&builder, x);
+ y = builder.emitVar(x->getDataType(), SpvStorageClassFunction);
+ builder.emitStore(y, x);
+ m_mapArrayValueToVar.set(x, y);
+ }
builder.setInsertBefore(inst);
- IRInst* y = builder.emitVar(x->getDataType(), SpvStorageClassFunction);
- builder.emitStore(y, x);
for(Index i = indices.getCount() - 1; i >= 0; --i)
y = builder.emitElementAddress(y, indices[i]);
const auto newInst = builder.emitLoad(y);
@@ -367,7 +424,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
break;
}
builder.addNameHintDecoration(structType, nameSb.getUnownedSlice());
- builder.addDecoration(structType, kIROp_SPIRVBufferBlockDecoration);
+ builder.addDecoration(structType, kIROp_SPIRVBlockDecoration);
inst->replaceUsesWith(ptrType);
inst->removeAndDeallocate();
addUsersToWorkList(ptrType);