diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-23 06:58:50 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-23 21:58:50 +0800 |
| commit | c515bf9edf0ceefa9a0c9b36626ea7c8f72ce36f (patch) | |
| tree | 670a3a80f0f60b7be7fd50e40d9d088f5e7607a7 /source/slang | |
| parent | 6437c38e0a3c2c1daf36cb5e543dc0b467fa4b15 (diff) | |
Misc. SPIRV Fixes. (#3146)
* Lower all ByteAddressBuffer uses for SPIRV.
* Misc. SPIRV Fixes.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 16 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 157 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 166 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-snippet.cpp | 1 |
7 files changed, 338 insertions, 18 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index ec7809b90..73c20f2d0 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2934,7 +2934,7 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_max($0, $1)") __target_intrinsic(cpp, "$P_max($0, $1)") -__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMax, UMax, SMax) _0") +__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMax, UMax, SMax) _0 _1") [__readNone] T max(T x, T y); // Note: a stdlib implementation of `max` (or `min`) will require splitting @@ -2945,7 +2945,7 @@ T max(T x, T y); __generic<T : __BuiltinIntegerType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) -__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMax, UMax, SMax) _0") +__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMax, UMax, SMax) _0 _1") [__readNone] vector<T, N> max(vector<T, N> x, vector<T, N> y) { @@ -2965,14 +2965,14 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_max($0, $1)") __target_intrinsic(cpp, "$P_max($0, $1)") -__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMax, UMax, SMax) _0") +__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMax, UMax, SMax) _0 _1") [__readNone] T max(T x, T y); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) -__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMax, UMax, SMax) _0") +__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMax, UMax, SMax) _0 _1") [__readNone] vector<T, N> max(vector<T, N> x, vector<T, N> y) { @@ -2993,14 +2993,14 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_min($0, $1)") __target_intrinsic(cpp, "$P_min($0, $1)") -__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMin, UMin, SMin) _0") +__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMin, UMin, SMin) _0 _1") [__readNone] T min(T x, T y); __generic<T : __BuiltinIntegerType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) -__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMin, UMin, SMin) _0") +__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMin, UMin, SMin) _0 _1") [__readNone] vector<T,N> min(vector<T,N> x, vector<T,N> y) { @@ -3020,14 +3020,14 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_min($0, $1)") __target_intrinsic(cpp, "$P_min($0, $1)") -__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMin, UMin, SMin) _0") +__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMin, UMin, SMin) _0 _1") [__readNone] T min(T x, T y); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) -__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMin, UMin, SMin) _0") +__target_intrinsic(spirv, "OpExtInst resultType resultId glsl450 fus(FMin, UMin, SMin) _0 _1") [__readNone] vector<T,N> min(vector<T,N> x, vector<T,N> y) { diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 253f90c22..60431b76c 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -1058,6 +1058,19 @@ SpvInst* emitOpCompositeConstruct( return emitInst(parent, inst, SpvOpCompositeConstruct, idResultType, kResultID, constituent1, constituent2); } +template<typename T, typename Ts> +SpvInst* emitOpConstantComposite( + SpvInstParent* parent, + IRInst* inst, + const T& idResultType, + const Ts& constituents +) +{ + static_assert(isSingular<T>); + static_assert(isPlural<Ts>); + return emitInst(parent, inst, SpvOpConstantComposite, idResultType, kResultID, constituents); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCompositeExtract template<typename T1, typename T2, Index N> SpvInst* emitOpCompositeExtract( diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 9d8c4d89a..3febbd210 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1192,6 +1192,10 @@ struct SPIRVEmitContext case kIROp_DoubleType: { const FloatInfo i = getFloatingTypeInfo(as<IRType>(inst)); + if (inst->getOp() == kIROp_DoubleType) + requireSPIRVCapability(SpvCapabilityFloat64); + else if (inst->getOp() == kIROp_HalfType) + requireSPIRVCapability(SpvCapabilityFloat16); return emitOpTypeFloat(inst, SpvLiteralInteger::from32(int32_t(i.width))); } case kIROp_PtrType: @@ -1359,12 +1363,34 @@ struct SPIRVEmitContext // return emitFunc(as<IRFunc>(inst)); - case kIROp_BoolLit: - case kIROp_IntLit: - case kIROp_FloatLit: - case kIROp_StringLit: - return emitLit(inst); - + case kIROp_BoolLit: + case kIROp_IntLit: + case kIROp_FloatLit: + case kIROp_StringLit: + return emitLit(inst); + case kIROp_MakeVectorFromScalar: + { + const auto scalar = inst->getOperand(0); + const auto vecTy = as<IRVectorType>(inst->getDataType()); + SLANG_ASSERT(vecTy); + const auto numElems = as<IRIntLit>(vecTy->getElementCount()); + SLANG_ASSERT(numElems); + return emitSplat( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + scalar, + numElems->getValue()); + } + case kIROp_MakeVector: + case kIROp_MakeArray: + case kIROp_MakeStruct: + return emitCompositeConstruct(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); + case kIROp_MakeArrayFromElement: + return emitMakeArrayFromElement(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); + case kIROp_MakeMatrix: + return emitMakeMatrix(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); + case kIROp_MakeMatrixFromScalar: + return emitMakeMatrixFromScalar(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); case kIROp_GlobalParam: return emitGlobalParam(as<IRGlobalParam>(inst)); case kIROp_GlobalVar: @@ -1816,6 +1842,12 @@ struct SPIRVEmitContext return emitGetElement(parent, as<IRGetElement>(inst)); case kIROp_MakeStruct: return emitCompositeConstruct(parent, inst); + case kIROp_MakeArrayFromElement: + return emitMakeArrayFromElement(parent, inst); + case kIROp_MakeMatrixFromScalar: + return emitMakeMatrixFromScalar(parent, inst); + case kIROp_MakeMatrix: + return emitMakeMatrix(parent, inst); case kIROp_Load: return emitLoad(parent, as<IRLoad>(inst)); case kIROp_Store: @@ -1955,7 +1987,8 @@ struct SPIRVEmitContext } case kIROp_MakeArray: return emitConstruct(parent, inst); - + case kIROp_Select: + return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst)); case kIROp_DebugLine: return emitDebugLine(parent, as<IRDebugLine>(inst)); } @@ -2415,6 +2448,33 @@ struct SPIRVEmitContext void emitLoopHeaderBlock(IRLoop* loopInst, SpvInst* loopHeaderBlock) { + bool hasBackJump = false; + for (auto use = loopInst->getTargetBlock()->firstUse; use; use = use->nextUse) + { + if (use->getUser() == loopInst) + continue; + hasBackJump = true; + break; + } + if (!hasBackJump) + { + // If the loop does not have a back jump, it is used as a breakable region. + // SPIRV does not allow loops without a back jump, so we are going to emit + // a switch instead. + IRBuilder builder(loopInst); + builder.setInsertBefore(loopInst); + emitOpSelectionMerge( + loopHeaderBlock, + nullptr, + getIRInstSpvID(loopInst->getBreakBlock()), + SpvSelectionControlMaskNone + ); + emitInst(loopHeaderBlock, nullptr, SpvOpSwitch, + emitIntConstant(0, builder.getIntType()), + getIRInstSpvID(loopInst->getTargetBlock())); + return; + } + SpvLoopControlMask loopControl = SpvLoopControlMaskNone; if (auto loopControlDecoration = loopInst->findDecoration<IRLoopControlDecoration>()) { @@ -3049,11 +3109,91 @@ struct SPIRVEmitContext : emitOpConvertFToU(parent, inst, toTypeV, inst->getOperand(0)); } + template<typename T, typename Ts> + SpvInst* emitCompositeConstruct( + SpvInstParent* parent, + IRInst* inst, + const T& idResultType, + const Ts& constituents) + { + if (parent == getSection(SpvLogicalSectionID::ConstantsAndTypes)) + return emitOpConstantComposite(parent, inst, idResultType, constituents); + return emitOpCompositeConstruct(parent, inst, idResultType, constituents); + } + SpvInst* emitCompositeConstruct(SpvInstParent* parent, IRInst* inst) { + if (parent == getSection(SpvLogicalSectionID::ConstantsAndTypes)) + return emitOpConstantComposite(parent, inst, inst->getDataType(), OperandsOf(inst)); return emitOpCompositeConstruct(parent, inst, inst->getDataType(), OperandsOf(inst)); } + SpvInst* emitMakeArrayFromElement(SpvInstParent* parent, IRInst* inst) + { + List<IRInst*> elements; + auto arrayType = as<IRArrayType>(inst->getDataType()); + auto elementCount = getIntVal(arrayType->getElementCount()); + for (IRIntegerValue i = 0; i < elementCount; i++) + { + elements.add(inst->getOperand(0)); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), elements); + } + + SpvInst* emitMakeMatrixFromScalar(SpvInstParent* parent, IRInst* inst) + { + List<SpvInst*> rowVectors; + auto matrixType = as<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<IRInst*> colElements; + for (IRIntegerValue i = 0; i < colCount; i++) + { + colElements.add(inst->getOperand(0)); + } + auto rowVector = emitCompositeConstruct(parent, nullptr, rowVectorType, colElements); + for (IRIntegerValue i = 0; i < rowCount; i++) + { + rowVectors.add(rowVector); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rowVectors); + } + + SpvInst* emitMakeMatrix(SpvInstParent* parent, IRInst* inst) + { + // If operands are already row vectors, use CompositeConstruct directly. + if (as<IRVectorType>(inst->getOperand(0)->getDataType())) + { + return emitCompositeConstruct(parent, inst); + } + // Otherwise, operands are raw elements, we need to construct row vectors first, + // then construct matrix from row vectors. + List<SpvInst*> rowVectors; + auto matrixType = as<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<IRInst*> colElements; + UInt index = 0; + for (IRIntegerValue j = 0; j < rowCount; j++) + { + colElements.clear(); + for (IRIntegerValue i = 0; i < colCount; i++) + { + colElements.add(inst->getOperand(index)); + index++; + } + auto rowVector = emitCompositeConstruct(parent, nullptr, rowVectorType, colElements); + rowVectors.add(rowVector); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rowVectors); + } + SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst) { if (as<IRBasicType>(inst->getDataType())) @@ -3093,7 +3233,7 @@ struct SPIRVEmitContext scalarTy->getBaseType(), numElems, nullptr); - return emitOpCompositeConstruct( + return emitCompositeConstruct( parent, inst, spvVecTy, @@ -3154,6 +3294,7 @@ struct SPIRVEmitContext { case BaseType::Float: case BaseType::Double: + case BaseType::Half: isFloatingPoint = true; break; case BaseType::Bool: diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 2857424f9..cb743c06a 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -83,7 +83,7 @@ struct ExtractPrimalFuncContext SLANG_RELEASE_ASSERT(originalFuncType); List<IRType*> paramTypes; - for (Index i = 0; i < ((Count) originalFuncType->getParamCount()) - 1; i++) + for (UInt i = 0; i < originalFuncType->getParamCount(); i++) paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i))); paramTypes.add(builder.getOutType((IRType*)outIntermediateType)); auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType()); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4b0cac182..22d8c5538 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2337,6 +2337,7 @@ struct IRSwitch : IRTerminatorInst UInt getCaseCount() { return (getOperandCount() - 3) / 2; } IRInst* getCaseValue(UInt index) { return getOperand(3 + index*2 + 0); } IRBlock* getCaseLabel(UInt index) { return (IRBlock*) getOperand(3 + index*2 + 1); } + IRUse* getCaseLabelUse(UInt index) { return getOperands() + 3 + index * 2 + 1; } }; struct IRThrow : IRTerminatorInst diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index f6294e2ba..36fdbd56a 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -11,6 +11,7 @@ #include "slang-ir-lower-buffer-element-type.h" #include "slang-ir-layout.h" #include "slang-ir-util.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -245,7 +246,81 @@ struct SPIRVLegalizationContext : public SourceEmitterBase inst->removeAndDeallocate(); addUsersToWorkList(newCall); } + return; + } + + // According to SPIRV spec, the if the operands of a call has pointer + // type, then it can only be a memory-object. This means that if the + // pointer is a result of `getElementPtr`, we cannot use it as an + // argument. In this case, we have to allocate a temp var to pass the + // value, and write them back to the original pointer after the call. + // + // > SPIRV Spec section 2.16.1: + // > - Any pointer operand to an OpFunctionCall must be a memory object + // > declaration, or + // > - a pointer to an element in an array that is a memory object + // > declaration, where the element type is OpTypeSampler or OpTypeImage. + // + List<IRInst*> newArgs; + struct WriteBackPair { IRInst* originalAddrArg; IRInst* tempVar; }; + List<WriteBackPair> writeBacks; + IRBuilder builder(inst); + builder.setInsertBefore(inst); + for (UInt i = 0; i < inst->getArgCount(); i++) + { + auto arg = inst->getArg(i); + auto ptrType = as<IRPtrTypeBase>(arg->getDataType()); + if (!as<IRPtrTypeBase>(arg->getDataType())) + { + newArgs.add(arg); + continue; + } + // Is the arg already a memory-object by SPIRV definition? + // If so we don't need to allocate a temp var. + switch (arg->getOp()) + { + case kIROp_Var: + case kIROp_GlobalVar: + newArgs.add(arg); + continue; + case kIROp_Param: + if (arg->getParent() == getParentFunc(arg)->getFirstBlock()) + { + newArgs.add(arg); + continue; + } + break; + default: + break; + } + auto root = getRootAddr(arg); + if (root) + { + switch (root->getOp()) + { + case kIROp_RWStructuredBufferGetElementPtr: + newArgs.add(arg); + continue; + } + } + + // If we reach here, we need to allocate a temp var. + auto tempVar = builder.emitVar(ptrType->getValueType()); + auto load = builder.emitLoad(arg); + builder.emitStore(tempVar, load); + newArgs.add(tempVar); + writeBacks.add(WriteBackPair{ arg, tempVar }); + } + SLANG_ASSERT((UInt)newArgs.getCount() == inst->getArgCount()); + auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs); + for (auto wb : writeBacks) + { + auto newVal = builder.emitLoad(wb.tempVar); + builder.emitStore(wb.originalAddrArg, newVal); } + inst->replaceUsesWith(newCall); + inst->removeAndDeallocate(); + addUsersToWorkList(newCall); } Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar; @@ -430,6 +505,33 @@ struct SPIRVLegalizationContext : public SourceEmitterBase addUsersToWorkList(ptrType); } + void duplicateMergeBlockIfNeeded(IRUse* breakBlockUse) + { + auto breakBlock = as<IRBlock>(breakBlockUse->get()); + if (breakBlock->getFirstInst()->getOp() != kIROp_Unreachable) + { + return; + } + bool hasMoreThanOneUser = false; + for (auto use = breakBlock->firstUse; use; use = use->nextUse) + { + if (use->getUser() != breakBlockUse->getUser()) + { + hasMoreThanOneUser = true; + break; + } + } + if (!hasMoreThanOneUser) + return; + + // Create a duplicate block for this use. + IRBuilder builder(breakBlock); + builder.setInsertBefore(breakBlock); + auto block = builder.emitBlock(); + builder.emitUnreachable(); + breakBlockUse->set(block); + } + void processLoop(IRLoop* loop) { @@ -545,6 +647,56 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.emitBranch(t, ps.getCount(), ps.getBuffer()); } } + duplicateMergeBlockIfNeeded(&loop->breakBlock); + } + + void processIfElse(IRIfElse* inst) + { + duplicateMergeBlockIfNeeded(&inst->afterBlock); + + // SPIRV does not allow using merge block directly as true/false block, + // so we need to create an intermediate block if this is the case. + IRBuilder builder(inst); + if (inst->getTrueBlock() == inst->getAfterBlock()) + { + builder.setInsertBefore(inst->getAfterBlock()); + auto newBlock = builder.emitBlock(); + builder.emitBranch(inst->getAfterBlock()); + inst->trueBlock.set(newBlock); + } + if (inst->getFalseBlock() == inst->getAfterBlock()) + { + builder.setInsertBefore(inst->getAfterBlock()); + auto newBlock = builder.emitBlock(); + builder.emitBranch(inst->getAfterBlock()); + inst->falseBlock.set(newBlock); + } + } + + void processSwitch(IRSwitch* inst) + { + duplicateMergeBlockIfNeeded(&inst->breakLabel); + + // SPIRV does not allow using merge block directly as case block, + // so we need to create an intermediate block if this is the case. + IRBuilder builder(inst); + if (inst->getDefaultLabel() == inst->getBreakLabel()) + { + builder.setInsertBefore(inst->getBreakLabel()); + auto newBlock = builder.emitBlock(); + builder.emitBranch(inst->getBreakLabel()); + inst->defaultLabel.set(newBlock); + } + for (UInt i = 0; i < inst->getCaseCount(); i++) + { + if (inst->getCaseLabel(i) == inst->getBreakLabel()) + { + builder.setInsertBefore(inst->getBreakLabel()); + auto newBlock = builder.emitBlock(); + builder.emitBranch(inst->getBreakLabel()); + inst->getCaseLabelUse(i)->set(newBlock); + } + } } void processModule() @@ -593,6 +745,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_loop: processLoop(as<IRLoop>(inst)); break; + case kIROp_ifElse: + processIfElse(as<IRIfElse>(inst)); + break; + case kIROp_Switch: + processSwitch(as<IRSwitch>(inst)); + break; default: for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) { @@ -601,6 +759,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase break; } } + + // SPIRV requires a dominator block to appear before dominated blocks. + // After legalizing the control flow, we need to sort our blocks to ensure this is true. + for (auto globalInst : m_module->getGlobalInsts()) + { + if (auto func = as<IRGlobalValueWithCode>(globalInst)) + sortBlocksInFunc(func); + } } }; diff --git a/source/slang/slang-ir-spirv-snippet.cpp b/source/slang/slang-ir-spirv-snippet.cpp index ecb8d1b5e..c19d50ece 100644 --- a/source/slang/slang-ir-spirv-snippet.cpp +++ b/source/slang/slang-ir-spirv-snippet.cpp @@ -147,7 +147,6 @@ RefPtr<SpvSnippet> SpvSnippet::parse(UnownedStringSlice definition) throw Misc::TextFormatException( "Text parsing error: Unrecognized SPIR-V GLSLstd450 opcode: " + opName); } - printf("BNBBB: %d\n", glslOpcode); return (SpvWord)glslOpcode; } } |
