summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-06-11 15:49:41 -0700
committerGitHub <noreply@github.com>2024-06-11 15:49:41 -0700
commit5da06d43bb0997455211ca56597c4302b09909ab (patch)
tree66897108a3cff9175ed025bedd05b705706a7606 /source
parent7e796692065060dea34b9e5b7eb224be444f5dee (diff)
Fix global value inlining for spirv_asm blocks. (#4339)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-insts.h5
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp294
2 files changed, 177 insertions, 122 deletions
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index a4e9906f8..d63c6878e 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3190,7 +3190,10 @@ struct IRSPIRVAsmInst : IRInst
IRSPIRVAsmOperand* getOpcodeOperand()
{
- const auto opcodeOperand = cast<IRSPIRVAsmOperand>(getOperand(0));
+ auto operand = getOperand(0);
+ if (auto globalRef = as<IRGlobalValueRef>(operand))
+ operand = globalRef->getValue();
+ const auto opcodeOperand = cast<IRSPIRVAsmOperand>(operand);
// This must be either:
// - An enum, such as 'OpNop'
// - The __truncate pseudo-instruction
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index fb002122a..44c5bae7b 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -1690,143 +1690,194 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
- // Opcodes that can exist in global scope, as long as the operands are.
- bool isLegalGlobalInst(IRInst* inst)
+ struct GlobalInstInliningContext
{
- switch (inst->getOp())
+ Dictionary<IRInst*, bool> m_mapGlobalInstToShouldInline;
+
+ // Opcodes that can exist in global scope, as long as the operands are.
+ bool isLegalGlobalInst(IRInst* inst)
{
- case kIROp_MakeStruct:
- case kIROp_MakeArray:
- case kIROp_MakeArrayFromElement:
- case kIROp_MakeVector:
- case kIROp_MakeMatrix:
- case kIROp_MakeMatrixFromScalar:
- case kIROp_MakeVectorFromScalar:
- return true;
- default:
- return false;
+ switch (inst->getOp())
+ {
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MakeVectorFromScalar:
+ return true;
+ default:
+ if (as<IRConstant>(inst))
+ return true;
+ if (as<IRSPIRVAsmOperand>(inst))
+ return true;
+ return false;
+ }
}
- }
- // Opcodes that can be inlined into function bodies.
- bool isInlinableGlobalInst(IRInst* inst)
- {
- switch (inst->getOp())
+ // Opcodes that can be inlined into function bodies.
+ bool isInlinableGlobalInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_FRem:
+ case kIROp_IRem:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Not:
+ case kIROp_Neg:
+ case kIROp_Div:
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
+ case kIROp_GetElement:
+ case kIROp_GetElementPtr:
+ case kIROp_GetOffsetPtr:
+ case kIROp_UpdateElement:
+ case kIROp_MakeTuple:
+ case kIROp_GetTupleElement:
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_swizzle:
+ case kIROp_swizzleSet:
+ case kIROp_MatrixReshape:
+ case kIROp_MakeString:
+ case kIROp_MakeResultError:
+ case kIROp_MakeResultValue:
+ case kIROp_GetResultError:
+ case kIROp_GetResultValue:
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastIntToPtr:
+ case kIROp_PtrCast:
+ case kIROp_CastPtrToBool:
+ case kIROp_CastPtrToInt:
+ case kIROp_BitAnd:
+ case kIROp_BitNot:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_BitCast:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ case kIROp_Neq:
+ case kIROp_Eql:
+ case kIROp_Call:
+ case kIROp_SPIRVAsm:
+ return true;
+ default:
+ if (as<IRSPIRVAsmInst>(inst))
+ return true;
+ if (as<IRSPIRVAsmOperand>(inst))
+ return true;
+ return false;
+ }
+ }
+
+ bool shouldInlineInstImpl(IRInst* inst)
{
- case kIROp_Add:
- case kIROp_Sub:
- case kIROp_Mul:
- case kIROp_FRem:
- case kIROp_IRem:
- case kIROp_Lsh:
- case kIROp_Rsh:
- case kIROp_And:
- case kIROp_Or:
- case kIROp_Not:
- case kIROp_Neg:
- case kIROp_Div:
- case kIROp_FieldExtract:
- case kIROp_FieldAddress:
- case kIROp_GetElement:
- case kIROp_GetElementPtr:
- case kIROp_GetOffsetPtr:
- case kIROp_UpdateElement:
- case kIROp_MakeTuple:
- case kIROp_GetTupleElement:
- case kIROp_MakeStruct:
- case kIROp_MakeArray:
- case kIROp_MakeArrayFromElement:
- case kIROp_MakeVector:
- case kIROp_MakeMatrix:
- case kIROp_MakeMatrixFromScalar:
- case kIROp_MakeVectorFromScalar:
- case kIROp_swizzle:
- case kIROp_swizzleSet:
- case kIROp_MatrixReshape:
- case kIROp_MakeString:
- case kIROp_MakeResultError:
- case kIROp_MakeResultValue:
- case kIROp_GetResultError:
- case kIROp_GetResultValue:
- case kIROp_CastFloatToInt:
- case kIROp_CastIntToFloat:
- case kIROp_CastIntToPtr:
- case kIROp_PtrCast:
- case kIROp_CastPtrToBool:
- case kIROp_CastPtrToInt:
- case kIROp_BitAnd:
- case kIROp_BitNot:
- case kIROp_BitOr:
- case kIROp_BitXor:
- case kIROp_BitCast:
- case kIROp_IntCast:
- case kIROp_FloatCast:
- case kIROp_Greater:
- case kIROp_Less:
- case kIROp_Geq:
- case kIROp_Leq:
- case kIROp_Neq:
- case kIROp_Eql:
- case kIROp_Call:
- case kIROp_SPIRVAsm:
+ if (!isInlinableGlobalInst(inst))
+ return false;
+ if (isLegalGlobalInst(inst))
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ if (shouldInlineInst(inst->getOperand(i)))
+ return true;
+ return false;
+ }
return true;
- default:
- return false;
}
- }
- bool shouldInlineInst(IRInst* inst)
- {
- if (!isInlinableGlobalInst(inst))
- return false;
- if (isLegalGlobalInst(inst))
+ bool shouldInlineInst(IRInst* inst)
{
- for (UInt i = 0; i < inst->getOperandCount(); i++)
- if (shouldInlineInst(inst->getOperand(i)))
- return true;
- return false;
+ bool result = false;
+ if (m_mapGlobalInstToShouldInline.tryGetValue(inst, result))
+ return result;
+ result = shouldInlineInstImpl(inst);
+ m_mapGlobalInstToShouldInline[inst] = result;
+ return result;
}
- return true;
- }
- /// Inline `inst` in the local function body so they can be emitted as a local inst.
- ///
- IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* inst, IRCloneEnv& cloneEnv)
- {
- if (!shouldInlineInst(inst))
+ IRInst* inlineInst(IRBuilder& builder, IRCloneEnv& cloneEnv, IRInst* inst)
{
- switch (inst->getOp())
+ IRInst* result;
+ if (cloneEnv.mapOldValToNew.tryGetValue(inst, result))
+ return result;
+
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
{
- case kIROp_Func:
- case kIROp_Specialize:
- case kIROp_Generic:
- case kIROp_LookupWitness:
- return inst;
- }
- if (as<IRType>(inst))
- return inst;
-
- // If we encounter a global value that shouldn't be inlined, e.g. a const literal,
- // we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses
- // can be pinned to the function body.
- auto result = builder.emitGlobalValueRef(inst);
+ auto operand = inst->getOperand(i);
+ IRBuilder operandBuilder(builder);
+ setInsertBeforeOutsideASM(operandBuilder, builder.getInsertLoc().getInst());
+ maybeInlineGlobalValue(operandBuilder, inst, operand, cloneEnv);
+ }
+ result = cloneInstAndOperands(&cloneEnv, &builder, inst);
cloneEnv.mapOldValToNew[inst] = result;
+ IRBuilder subBuilder(builder);
+ subBuilder.setInsertInto(result);
+ for (auto child : inst->getDecorations())
+ {
+ cloneInst(&cloneEnv, &subBuilder, child);
+ }
+ for (auto child : inst->getChildren())
+ {
+ inlineInst(subBuilder, cloneEnv, child);
+ }
return result;
}
- // If the global value is inlinable, we make all its operands avaialble locally, and then copy it
- // to the local scope.
- ShortList<IRInst*> args;
- for (UInt i = 0; i < inst->getOperandCount(); i++)
+ /// Inline `inst` in the local function body so they can be emitted as a local inst.
+ ///
+ IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* user, IRInst* inst, IRCloneEnv& cloneEnv)
{
- auto operand = inst->getOperand(i);
- auto inlinedOperand = maybeInlineGlobalValue(builder, operand, cloneEnv);
- args.add(inlinedOperand);
+ if (!shouldInlineInst(inst))
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Specialize:
+ case kIROp_Generic:
+ case kIROp_LookupWitness:
+ return inst;
+ }
+ if (as<IRType>(inst))
+ return inst;
+
+ // If we encounter a global value that shouldn't be inlined, e.g. a const literal,
+ // we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses
+ // can be pinned to the function body.
+ auto result = inst;
+ bool shouldWrapGlobalRef = true;
+ if (!isLegalGlobalInst(user) && !getIROpInfo(user->getOp()).isHoistable())
+ shouldWrapGlobalRef = false;
+ else if (as<IRSPIRVAsmOperand>(user) && as<IRSPIRVAsmOperandInst>(user))
+ shouldWrapGlobalRef = false;
+ else if (as<IRSPIRVAsmInst>(user))
+ shouldWrapGlobalRef = false;
+ if (shouldWrapGlobalRef)
+ result = builder.emitGlobalValueRef(inst);
+ cloneEnv.mapOldValToNew[inst] = result;
+ return result;
+ }
+
+ // If the global value is inlinable, we make all its operands avaialble locally, and then copy it
+ // to the local scope.
+ return inlineInst(builder, cloneEnv, inst);
}
- auto result = cloneInst(&cloneEnv, &builder, inst);
- cloneEnv.mapOldValToNew[inst] = result;
- return result;
- }
+ };
void processBranch(IRInst* branch)
{
@@ -2079,7 +2130,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
- void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst)
+ static void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst)
{
auto parent = beforeInst->getParent();
while (parent)
@@ -2234,6 +2285,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Inline global values that can't represented by SPIRV constant inst
// to their use sites.
List<IRUse*> globalInstUsesToInline;
+ GlobalInstInliningContext globalInstInliningContext;
for (auto globalInst : m_module->getGlobalInsts())
{
@@ -2248,7 +2300,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
sortBlocksInFunc(func);
}
- if (isInlinableGlobalInst(globalInst))
+ if (globalInstInliningContext.isInlinableGlobalInst(globalInst))
{
for (auto use = globalInst->firstUse; use; use = use->nextUse)
{
@@ -2264,7 +2316,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(user);
setInsertBeforeOutsideASM(builder, user);
IRCloneEnv cloneEnv;
- auto val = maybeInlineGlobalValue(builder, use->get(), cloneEnv);
+ auto val = globalInstInliningContext.maybeInlineGlobalValue(builder, use->getUser(), use->get(), cloneEnv);
if (val != use->get())
builder.replaceOperand(use, val);
}