summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-spirv-legalize.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-10-02 03:33:58 -0700
committerGitHub <noreply@github.com>2023-10-02 18:33:58 +0800
commitccf2611c024ab12dcccd978f3f501d4ee9fc52bc (patch)
treef4df843e3b46886005d6bfbae34dc3bcc6fb8321 /source/slang/slang-ir-spirv-legalize.cpp
parent6138de5f084cafdc98381237c2d8bed7c8804f1c (diff)
Add SPIRV intrinsics for ShaderExecutionReordering and RW/Buffer. (#3252)
* Add SPIRV intrinsics for ShaderExecutionReordering. * Add intrinsics for `Buffer` and `RWBuffer`. * Various spirv fixes. * Marshal bool vector type. * Inline global constants + OpFOrdNotEqual->OpFUnordNotEqual. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-spirv-legalize.cpp')
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp224
1 files changed, 211 insertions, 13 deletions
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 121452533..5d4673981 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -694,6 +694,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
List<WriteBackPair> writeBacks;
IRBuilder builder(inst);
builder.setInsertBefore(inst);
+ auto funcType = as<IRFuncType>(funcValue->getFullType());
for (UInt i = 0; i < inst->getArgCount(); i++)
{
auto arg = inst->getArg(i);
@@ -727,8 +728,17 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
switch (root->getOp())
{
case kIROp_RWStructuredBufferGetElementPtr:
- newArgs.add(arg);
- continue;
+ if (funcType)
+ {
+ if (funcType->getParamCount() > i && as<IRRefType>(funcType->getParamType(i)))
+ {
+ // If we are passing an address from a structured buffer as a
+ // ref argument, pass the original pointer as is.
+ // This is to support stdlib atomic functions.
+ newArgs.add(arg);
+ continue;
+ }
+ }
}
}
@@ -1181,31 +1191,165 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
void legalizeSPIRVEntryPoint(IRFunc* func, IREntryPointDecoration* entryPointDecor)
{
- if (entryPointDecor->getProfile().getStage() == Stage::Geometry)
+ auto stage = entryPointDecor->getProfile().getStage();
+ switch (stage)
{
+ case Stage::Geometry:
if (!func->findDecoration<IRInstanceDecoration>())
{
IRBuilder builder(func);
builder.addDecoration(func, kIROp_InstanceDecoration, builder.getIntValue(builder.getUIntType(), 1));
}
+ break;
+ case Stage::Compute:
+ if (!func->findDecoration<IRNumThreadsDecoration>())
+ {
+ IRBuilder builder(func);
+ auto one = builder.getIntValue(builder.getUIntType(), 1);
+ IRInst* args[3] = { one, one, one };
+ builder.addDecoration(func, kIROp_NumThreadsDecoration, args, 3);
+ }
+ break;
}
+
}
- void processModule()
+ // Opcodes that can exist in global scope, as long as the operands are.
+ bool isLegalGlobalInst(IRInst* inst)
{
- // Process global params before anything else, so we don't generate inefficient
- // array marhalling code for array-typed global params.
- for (auto globalInst : m_module->getGlobalInsts())
+ switch (inst->getOp())
{
- if (auto globalParam = as<IRGlobalParam>(globalInst))
- {
- processGlobalParam(globalParam);
- }
- else
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MakeVectorFromScalar:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ // Opcodes that can be inlined into function bodies.
+ bool isInlinableGlobalInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_FRem:
+ case kIROp_IRem:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Not:
+ case kIROp_Neg:
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
+ case kIROp_GetElement:
+ case kIROp_GetElementPtr:
+ case kIROp_UpdateElement:
+ case kIROp_MakeTuple:
+ case kIROp_GetTupleElement:
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_swizzle:
+ case kIROp_swizzleSet:
+ case kIROp_MatrixReshape:
+ case kIROp_MakeString:
+ case kIROp_MakeResultError:
+ case kIROp_MakeResultValue:
+ case kIROp_GetResultError:
+ case kIROp_GetResultValue:
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastIntToPtr:
+ case kIROp_CastPtrToBool:
+ case kIROp_CastPtrToInt:
+ case kIROp_BitAnd:
+ case kIROp_BitNot:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_BitCast:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ case kIROp_Neq:
+ case kIROp_Eql:
+ case kIROp_Call:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ bool shouldInlineInst(IRInst* inst)
+ {
+ if (!isInlinableGlobalInst(inst))
+ return false;
+ if (isLegalGlobalInst(inst))
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ if (shouldInlineInst(inst->getOperand(i)))
+ return true;
+ return false;
+ }
+ return true;
+ }
+
+ /// Inline `inst` in the local function body so they can be emitted as a local inst.
+ ///
+ IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* inst, IRCloneEnv& cloneEnv)
+ {
+ if (!shouldInlineInst(inst))
+ {
+ switch (inst->getOp())
{
- addToWorkList(globalInst);
+ case kIROp_Func:
+ case kIROp_Specialize:
+ case kIROp_Generic:
+ case kIROp_LookupWitness:
+ return inst;
}
+ if (as<IRType>(inst))
+ return inst;
+
+ // If we encounter a global value that shouldn't be inlined, e.g. a const literal,
+ // we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses
+ // can be pinned to the function body.
+ auto result = builder.emitGlobalValueRef(inst);
+ cloneEnv.mapOldValToNew[inst] = result;
+ return result;
+ }
+
+ // If the global value is inlinable, we make all its operands avaialble locally, and then copy it
+ // to the local scope.
+ ShortList<IRInst*> args;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ auto inlinedOperand = maybeInlineGlobalValue(builder, operand, cloneEnv);
+ args.add(inlinedOperand);
}
+ auto result = cloneInst(&cloneEnv, &builder, inst);
+ cloneEnv.mapOldValToNew[inst] = result;
+ return result;
+ }
+
+ void processWorkList()
+ {
while (workList.getCount() != 0)
{
@@ -1284,6 +1428,39 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
break;
}
}
+ }
+
+ void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst)
+ {
+ auto parent = beforeInst->getParent();
+ while (parent)
+ {
+ if (as<IRSPIRVAsm>(parent))
+ {
+ builder.setInsertBefore(parent);
+ return;
+ }
+ parent = parent->getParent();
+ }
+ builder.setInsertBefore(beforeInst);
+ }
+
+ void processModule()
+ {
+ // Process global params before anything else, so we don't generate inefficient
+ // array marhalling code for array-typed global params.
+ for (auto globalInst : m_module->getGlobalInsts())
+ {
+ if (auto globalParam = as<IRGlobalParam>(globalInst))
+ {
+ processGlobalParam(globalParam);
+ }
+ else
+ {
+ addToWorkList(globalInst);
+ }
+ }
+ processWorkList();
// Translate types.
List<IRHLSLStructuredBufferTypeBase*> instsToProcess;
@@ -1302,6 +1479,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
t->replaceUsesWith(builder.getPtrType(kIROp_PtrType, lowered.structType, SpvStorageClassStorageBuffer));
}
+ List<IRUse*> globalInstUsesToInline;
+
for (auto globalInst : m_module->getGlobalInsts())
{
if (auto func = as<IRFunc>(globalInst))
@@ -1314,8 +1493,27 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// After legalizing the control flow, we need to sort our blocks to ensure this is true.
sortBlocksInFunc(func);
}
+
+ if (isInlinableGlobalInst(globalInst))
+ {
+ for (auto use = globalInst->firstUse; use; use = use->nextUse)
+ {
+ if (getParentFunc(use->getUser()) != nullptr)
+ globalInstUsesToInline.add(use);
+ }
+ }
}
+ for (auto use : globalInstUsesToInline)
+ {
+ auto user = use->getUser();
+ IRBuilder builder(user);
+ setInsertBeforeOutsideASM(builder, user);
+ IRCloneEnv cloneEnv;
+ auto val = maybeInlineGlobalValue(builder, use->get(), cloneEnv);
+ if (val != use->get())
+ builder.replaceOperand(use, val);
+ }
}
};