diff options
| author | Yong He <yonghe@outlook.com> | 2023-10-02 03:33:58 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-10-02 18:33:58 +0800 |
| commit | ccf2611c024ab12dcccd978f3f501d4ee9fc52bc (patch) | |
| tree | f4df843e3b46886005d6bfbae34dc3bcc6fb8321 /source/slang/slang-ir-spirv-legalize.cpp | |
| parent | 6138de5f084cafdc98381237c2d8bed7c8804f1c (diff) | |
Add SPIRV intrinsics for ShaderExecutionReordering and RW/Buffer. (#3252)
* Add SPIRV intrinsics for ShaderExecutionReordering.
* Add intrinsics for `Buffer` and `RWBuffer`.
* Various spirv fixes.
* Marshal bool vector type.
* Inline global constants + OpFOrdNotEqual->OpFUnordNotEqual.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-spirv-legalize.cpp')
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 224 |
1 files changed, 211 insertions, 13 deletions
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 121452533..5d4673981 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -694,6 +694,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase List<WriteBackPair> writeBacks; IRBuilder builder(inst); builder.setInsertBefore(inst); + auto funcType = as<IRFuncType>(funcValue->getFullType()); for (UInt i = 0; i < inst->getArgCount(); i++) { auto arg = inst->getArg(i); @@ -727,8 +728,17 @@ struct SPIRVLegalizationContext : public SourceEmitterBase switch (root->getOp()) { case kIROp_RWStructuredBufferGetElementPtr: - newArgs.add(arg); - continue; + if (funcType) + { + if (funcType->getParamCount() > i && as<IRRefType>(funcType->getParamType(i))) + { + // If we are passing an address from a structured buffer as a + // ref argument, pass the original pointer as is. + // This is to support stdlib atomic functions. + newArgs.add(arg); + continue; + } + } } } @@ -1181,31 +1191,165 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void legalizeSPIRVEntryPoint(IRFunc* func, IREntryPointDecoration* entryPointDecor) { - if (entryPointDecor->getProfile().getStage() == Stage::Geometry) + auto stage = entryPointDecor->getProfile().getStage(); + switch (stage) { + case Stage::Geometry: if (!func->findDecoration<IRInstanceDecoration>()) { IRBuilder builder(func); builder.addDecoration(func, kIROp_InstanceDecoration, builder.getIntValue(builder.getUIntType(), 1)); } + break; + case Stage::Compute: + if (!func->findDecoration<IRNumThreadsDecoration>()) + { + IRBuilder builder(func); + auto one = builder.getIntValue(builder.getUIntType(), 1); + IRInst* args[3] = { one, one, one }; + builder.addDecoration(func, kIROp_NumThreadsDecoration, args, 3); + } + break; } + } - void processModule() + // Opcodes that can exist in global scope, as long as the operands are. + bool isLegalGlobalInst(IRInst* inst) { - // Process global params before anything else, so we don't generate inefficient - // array marhalling code for array-typed global params. - for (auto globalInst : m_module->getGlobalInsts()) + switch (inst->getOp()) { - if (auto globalParam = as<IRGlobalParam>(globalInst)) - { - processGlobalParam(globalParam); - } - else + case kIROp_MakeStruct: + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + case kIROp_MakeVector: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MakeVectorFromScalar: + return true; + default: + return false; + } + } + + // Opcodes that can be inlined into function bodies. + bool isInlinableGlobalInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_And: + case kIROp_Or: + case kIROp_Not: + case kIROp_Neg: + case kIROp_FieldExtract: + case kIROp_FieldAddress: + case kIROp_GetElement: + case kIROp_GetElementPtr: + case kIROp_UpdateElement: + case kIROp_MakeTuple: + case kIROp_GetTupleElement: + case kIROp_MakeStruct: + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + case kIROp_MakeVector: + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MakeVectorFromScalar: + case kIROp_swizzle: + case kIROp_swizzleSet: + case kIROp_MatrixReshape: + case kIROp_MakeString: + case kIROp_MakeResultError: + case kIROp_MakeResultValue: + case kIROp_GetResultError: + case kIROp_GetResultValue: + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_CastIntToPtr: + case kIROp_CastPtrToBool: + case kIROp_CastPtrToInt: + case kIROp_BitAnd: + case kIROp_BitNot: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitCast: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + case kIROp_Neq: + case kIROp_Eql: + case kIROp_Call: + return true; + default: + return false; + } + } + + bool shouldInlineInst(IRInst* inst) + { + if (!isInlinableGlobalInst(inst)) + return false; + if (isLegalGlobalInst(inst)) + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + if (shouldInlineInst(inst->getOperand(i))) + return true; + return false; + } + return true; + } + + /// Inline `inst` in the local function body so they can be emitted as a local inst. + /// + IRInst* maybeInlineGlobalValue(IRBuilder& builder, IRInst* inst, IRCloneEnv& cloneEnv) + { + if (!shouldInlineInst(inst)) + { + switch (inst->getOp()) { - addToWorkList(globalInst); + case kIROp_Func: + case kIROp_Specialize: + case kIROp_Generic: + case kIROp_LookupWitness: + return inst; } + if (as<IRType>(inst)) + return inst; + + // If we encounter a global value that shouldn't be inlined, e.g. a const literal, + // we should insert a GlobalValueRef() inst to wrap around it, so all the dependent uses + // can be pinned to the function body. + auto result = builder.emitGlobalValueRef(inst); + cloneEnv.mapOldValToNew[inst] = result; + return result; + } + + // If the global value is inlinable, we make all its operands avaialble locally, and then copy it + // to the local scope. + ShortList<IRInst*> args; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + auto inlinedOperand = maybeInlineGlobalValue(builder, operand, cloneEnv); + args.add(inlinedOperand); } + auto result = cloneInst(&cloneEnv, &builder, inst); + cloneEnv.mapOldValToNew[inst] = result; + return result; + } + + void processWorkList() + { while (workList.getCount() != 0) { @@ -1284,6 +1428,39 @@ struct SPIRVLegalizationContext : public SourceEmitterBase break; } } + } + + void setInsertBeforeOutsideASM(IRBuilder& builder, IRInst* beforeInst) + { + auto parent = beforeInst->getParent(); + while (parent) + { + if (as<IRSPIRVAsm>(parent)) + { + builder.setInsertBefore(parent); + return; + } + parent = parent->getParent(); + } + builder.setInsertBefore(beforeInst); + } + + void processModule() + { + // Process global params before anything else, so we don't generate inefficient + // array marhalling code for array-typed global params. + for (auto globalInst : m_module->getGlobalInsts()) + { + if (auto globalParam = as<IRGlobalParam>(globalInst)) + { + processGlobalParam(globalParam); + } + else + { + addToWorkList(globalInst); + } + } + processWorkList(); // Translate types. List<IRHLSLStructuredBufferTypeBase*> instsToProcess; @@ -1302,6 +1479,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase t->replaceUsesWith(builder.getPtrType(kIROp_PtrType, lowered.structType, SpvStorageClassStorageBuffer)); } + List<IRUse*> globalInstUsesToInline; + for (auto globalInst : m_module->getGlobalInsts()) { if (auto func = as<IRFunc>(globalInst)) @@ -1314,8 +1493,27 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // After legalizing the control flow, we need to sort our blocks to ensure this is true. sortBlocksInFunc(func); } + + if (isInlinableGlobalInst(globalInst)) + { + for (auto use = globalInst->firstUse; use; use = use->nextUse) + { + if (getParentFunc(use->getUser()) != nullptr) + globalInstUsesToInline.add(use); + } + } } + for (auto use : globalInstUsesToInline) + { + auto user = use->getUser(); + IRBuilder builder(user); + setInsertBeforeOutsideASM(builder, user); + IRCloneEnv cloneEnv; + auto val = maybeInlineGlobalValue(builder, use->get(), cloneEnv); + if (val != use->get()) + builder.replaceOperand(use, val); + } } }; |
