diff options
| author | Yong He <yonghe@outlook.com> | 2024-06-11 15:49:41 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-06-11 15:49:41 -0700 |
| commit | 5da06d43bb0997455211ca56597c4302b09909ab (patch) | |
| tree | 66897108a3cff9175ed025bedd05b705706a7606 /source/slang | |
| parent | 7e796692065060dea34b9e5b7eb224be444f5dee (diff) | |
Fix global value inlining for spirv_asm blocks. (#4339)
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ir-insts.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 294 |
2 files changed, 177 insertions, 122 deletions
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a4e9906f8..d63c6878e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3190,7 +3190,10 @@ struct IRSPIRVAsmInst : IRInst IRSPIRVAsmOperand* getOpcodeOperand() { - const auto opcodeOperand = cast<IRSPIRVAsmOperand>(getOperand(0)); + auto operand = getOperand(0); + if (auto globalRef = as<IRGlobalValueRef>(operand)) + operand = globalRef->getValue(); + const auto opcodeOperand = cast<IRSPIRVAsmOperand>(operand); // This must be either: // - An enum, such as 'OpNop' // - The __truncate pseudo-instruction diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index fb002122a..44c5bae7b 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1690,143 +1690,194 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } - // Opcodes that can exist in global scope, as long as the operands are. - bool isLegalGlobalInst(IRInst* inst) + struct GlobalInstInliningContext { - switch (inst->getOp()) + Dictionary<IRInst*, bool> m_mapGlobalInstToShouldInline; + + // Opcodes that can exist in global scope, as long as the operands are. + bool isLegalGlobalInst(IRInst* inst) { - case kIROp_MakeStruct: - case kIROp_MakeArray: - case kIROp_MakeArrayFromElement: - case kIROp_MakeVector: - case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: - case kIROp_MakeVectorFromScalar: - return true; - default: - return false; + switch (inst->getOp()) + { + case kIROp_MakeStruct: + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + case kIROp_MakeVector: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MakeVectorFromScalar: + return true; + default: + if (as<IRConstant>(inst)) + return true; + if (as<IRSPIRVAsmOperand>(inst)) + return true; + return false; + } } - } - // Opcodes that can be inlined into function bodies. - bool isInlinableGlobalInst(IRInst* inst) - { - switch (inst->getOp()) + // Opcodes that can be inlined into function bodies. + bool isInlinableGlobalInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_And: + case kIROp_Or: + case kIROp_Not: + case kIROp_Neg: + case kIROp_Div: + case kIROp_FieldExtract: + case kIROp_FieldAddress: + case kIROp_GetElement: + case kIROp_GetElementPtr: + case kIROp_GetOffsetPtr: + case kIROp_UpdateElement: + case kIROp_MakeTuple: + case kIROp_GetTupleElement: + case kIROp_MakeStruct: + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + case kIROp_MakeVector: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MakeVectorFromScalar: + case kIROp_swizzle: + case kIROp_swizzleSet: + case kIROp_MatrixReshape: + case kIROp_MakeString: + case kIROp_MakeResultError: + case kIROp_MakeResultValue: + case kIROp_GetResultError: + case kIROp_GetResultValue: + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_CastIntToPtr: + case kIROp_PtrCast: + case kIROp_CastPtrToBool: + case kIROp_CastPtrToInt: + case kIROp_BitAnd: + case kIROp_BitNot: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitCast: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + case kIROp_Neq: + case kIROp_Eql: + case kIROp_Call: + case kIROp_SPIRVAsm: + return true; + default: + if (as<IRSPIRVAsmInst>(inst)) + return true; + if (as<IRSPIRVAsmOperand>(inst)) + return true; + return false; + } + } + + bool shouldInlineInstImpl(IRInst* inst) { - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - case kIROp_FRem: - case kIROp_IRem: - case kIROp_Lsh: - case kIROp_Rsh: - case kIROp_And: - case kIROp_Or: - case kIROp_Not: - case kIROp_Neg: - case kIROp_Div: - case kIROp_FieldExtract: - case kIROp_FieldAddress: - case kIROp_GetElement: - case kIROp_GetElementPtr: - case kIROp_GetOffsetPtr: - case kIROp_UpdateElement: - case kIROp_MakeTuple: - case kIROp_GetTupleElement: - case kIROp_MakeStruct: - case kIROp_MakeArray: - case kIROp_MakeArrayFromElement: - case kIROp_MakeVector: - case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: - case kIROp_MakeVectorFromScalar: - case kIROp_swizzle: - case kIROp_swizzleSet: - case kIROp_MatrixReshape: - case kIROp_MakeString: - case kIROp_MakeResultError: - case kIROp_MakeResultValue: - case kIROp_GetResultError: - case kIROp_GetResultValue: - case kIROp_CastFloatToInt: - case kIROp_CastIntToFloat: - case kIROp_CastIntToPtr: - case kIROp_PtrCast: - case kIROp_CastPtrToBool: - case kIROp_CastPtrToInt: - case kIROp_BitAnd: - case kIROp_BitNot: - case kIROp_BitOr: - case kIROp_BitXor: - case kIROp_BitCast: - case kIROp_IntCast: - case kIROp_FloatCast: - case kIROp_Greater: - case kIROp_Less: - case kIROp_Geq: - case kIROp_Leq: - case kIROp_Neq: - case kIROp_Eql: - case kIROp_Call: - case kIROp_SPIRVAsm: + if (!isInlinableGlobalInst(inst)) + return false; + if (isLegalGlobalInst(inst)) + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + if (shouldInlineInst(inst->getOperand(i))) + return true; + return false; + } return true; - default: - return false; } - } - bool shouldInlineInst(IRInst* inst) - { - if (!isInlinableGlobalInst(inst)) - return false; - if (isLegalGlobalInst(inst)) + bool shouldInlineInst(IRInst* inst) { - for (UInt i = 0; i < inst->getOperandCount(); i++) - if (shouldInlineInst(inst->getOperand(i))) - return true; - return false; + bool result = false; + if (m_mapGlobalInstToShouldInline.tryGetValue(inst, result)) + return result; + result = shouldInlineInstImpl(inst); + m_mapGlobalInstToShouldInline[inst] = result; + return result; } - return true; - } - /// Inline `inst` in the local function body so they can be emitted as a local inst. - /// - IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* inst, IRCloneEnv& cloneEnv) - { - if (!shouldInlineInst(inst)) + IRInst* inlineInst(IRBuilder& builder, IRCloneEnv& cloneEnv, IRInst* inst) { - switch (inst->getOp()) + IRInst* result; + if (cloneEnv.mapOldValToNew.tryGetValue(inst, result)) + return result; + + for (UInt i = 0; i < inst->getOperandCount(); i++) { - case kIROp_Func: - case kIROp_Specialize: - case kIROp_Generic: - case kIROp_LookupWitness: - return inst; - } - if (as<IRType>(inst)) - return inst; - - // If we encounter a global value that shouldn't be inlined, e.g. a const literal, - // we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses - // can be pinned to the function body. - auto result = builder.emitGlobalValueRef(inst); + auto operand = inst->getOperand(i); + IRBuilder operandBuilder(builder); + setInsertBeforeOutsideASM(operandBuilder, builder.getInsertLoc().getInst()); + maybeInlineGlobalValue(operandBuilder, inst, operand, cloneEnv); + } + result = cloneInstAndOperands(&cloneEnv, &builder, inst); cloneEnv.mapOldValToNew[inst] = result; + IRBuilder subBuilder(builder); + subBuilder.setInsertInto(result); + for (auto child : inst->getDecorations()) + { + cloneInst(&cloneEnv, &subBuilder, child); + } + for (auto child : inst->getChildren()) + { + inlineInst(subBuilder, cloneEnv, child); + } return result; } - // If the global value is inlinable, we make all its operands avaialble locally, and then copy it - // to the local scope. - ShortList<IRInst*> args; - for (UInt i = 0; i < inst->getOperandCount(); i++) + /// Inline `inst` in the local function body so they can be emitted as a local inst. + /// + IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* user, IRInst* inst, IRCloneEnv& cloneEnv) { - auto operand = inst->getOperand(i); - auto inlinedOperand = maybeInlineGlobalValue(builder, operand, cloneEnv); - args.add(inlinedOperand); + if (!shouldInlineInst(inst)) + { + switch (inst->getOp()) + { + case kIROp_Func: + case kIROp_Specialize: + case kIROp_Generic: + case kIROp_LookupWitness: + return inst; + } + if (as<IRType>(inst)) + return inst; + + // If we encounter a global value that shouldn't be inlined, e.g. a const literal, + // we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses + // can be pinned to the function body. + auto result = inst; + bool shouldWrapGlobalRef = true; + if (!isLegalGlobalInst(user) && !getIROpInfo(user->getOp()).isHoistable()) + shouldWrapGlobalRef = false; + else if (as<IRSPIRVAsmOperand>(user) && as<IRSPIRVAsmOperandInst>(user)) + shouldWrapGlobalRef = false; + else if (as<IRSPIRVAsmInst>(user)) + shouldWrapGlobalRef = false; + if (shouldWrapGlobalRef) + result = builder.emitGlobalValueRef(inst); + cloneEnv.mapOldValToNew[inst] = result; + return result; + } + + // If the global value is inlinable, we make all its operands avaialble locally, and then copy it + // to the local scope. + return inlineInst(builder, cloneEnv, inst); } - auto result = cloneInst(&cloneEnv, &builder, inst); - cloneEnv.mapOldValToNew[inst] = result; - return result; - } + }; void processBranch(IRInst* branch) { @@ -2079,7 +2130,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } - void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst) + static void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst) { auto parent = beforeInst->getParent(); while (parent) @@ -2234,6 +2285,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // Inline global values that can't represented by SPIRV constant inst // to their use sites. List<IRUse*> globalInstUsesToInline; + GlobalInstInliningContext globalInstInliningContext; for (auto globalInst : m_module->getGlobalInsts()) { @@ -2248,7 +2300,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase sortBlocksInFunc(func); } - if (isInlinableGlobalInst(globalInst)) + if (globalInstInliningContext.isInlinableGlobalInst(globalInst)) { for (auto use = globalInst->firstUse; use; use = use->nextUse) { @@ -2264,7 +2316,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(user); setInsertBeforeOutsideASM(builder, user); IRCloneEnv cloneEnv; - auto val = maybeInlineGlobalValue(builder, use->get(), cloneEnv); + auto val = globalInstInliningContext.maybeInlineGlobalValue(builder, use->getUser(), use->get(), cloneEnv); if (val != use->get()) builder.replaceOperand(use, val); } |
