diff options
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 109 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 15 | ||||
| -rw-r--r-- | tests/bugs/inlining/global-const-inline.slang | 20 | ||||
| -rw-r--r-- | tests/bugs/inlining/global-const-inline.slang.expected.txt | 4 | ||||
| -rw-r--r-- | tests/experimental/liveness/liveness-3.slang.expected | 37 | ||||
| -rw-r--r-- | tests/experimental/liveness/liveness-4.slang.expected | 15 | ||||
| -rw-r--r-- | tests/experimental/liveness/liveness-5.slang.expected | 31 | ||||
| -rw-r--r-- | tests/experimental/liveness/liveness-6.slang.expected | 35 | ||||
| -rw-r--r-- | tests/experimental/liveness/liveness.slang.expected | 33 |
10 files changed, 239 insertions, 80 deletions
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index ef01de47e..7fc977170 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -393,6 +393,75 @@ struct InliningPassBase return clonedInst; } + /// Inline the body of the callee for `callSite`, for a callee that has only + /// a single basic block. + /// + void inlineSingleBlockFuncBody( + CallSiteInfo const& callSite, IRCloneEnv* env, IRBuilder* builder) + { + auto call = callSite.call; + auto callee = callSite.callee; + + // The callee had better have only a single basic block. + // + auto firstBlock = callee->getFirstBlock(); + SLANG_ASSERT(!firstBlock->getNextBlock()); + + // We will loop over the instructions in the block and clone + // them into the same basic block as the `call`. + // + builder->setInsertBefore(call); + + // Along the way, we will detect any `return` instruction, + // and remember the (clone of the) returned value. + // + IRInst* returnVal = nullptr; + + for (auto inst : firstBlock->getChildren()) + { + switch (inst->getOp()) + { + default: + // In the common case we just clone the instruction as-is + _cloneInstWithSourceLoc(callSite, env, builder, inst); + break; + + case kIROp_Param: + // Parameters of the first block are the parameters of + // the function itself, so we skip them rather than + // clone them. + // + break; + + case kIROp_Return: + // We expect to see only a single `return` instruction, + // and when we see it we note the value being returned. + // + SLANG_ASSERT(!returnVal); + returnVal = findCloneForOperand(env, inst->getOperand(0)); + break; + } + } + + // We are going to remove the original `call` now that the callee + // has been inlined, but before we do that we need to replace + // all uses of the `call` with whatever value was produced by the + // inlined body of the callee. + // + if (returnVal) + { + call->replaceUsesWith(returnVal); + } + else + { + call->replaceUsesWith(builder->getVoidValue()); + } + + // Once the `call` has no uses, we can safely remove it. + // + call->removeAndDeallocate(); + } + /// Inline the body of the callee for `callSite`. void inlineFuncBody( CallSiteInfo const& callSite, IRCloneEnv* env, IRBuilder* builder) @@ -400,11 +469,45 @@ struct InliningPassBase auto call = callSite.call; auto callee = callSite.callee; - // Break the basic block containing the call inst into two basic blocks. + // If the callee consists of a single basic block *and* that block + // ends with a `return` instruction, then we can apply a simple approach + // to inlining that is compatible with any call site (including those + // at the global scope). + // + auto firstBlock = callee->getFirstBlock(); + SLANG_ASSERT(firstBlock); + if(!firstBlock->getNextBlock() && as<IRReturn>(firstBlock->getTerminator())) + { + inlineSingleBlockFuncBody(callSite, env, builder); + return; + } + + // If the callee has any non-trivial control flow (multiple basic blocks + // and terminators other than `return`), we will need to split the control + // flow of the caller at the block that contains `call`. + // + // For any of this to work, we have to assume that the `call` appears + // in a basic block inside of a function (not, e.g., at the global scope). + // auto callerBlock = callSite.call->getParent(); - builder->setInsertInto(callerBlock->getParent()); + SLANG_ASSERT(as<IRBlock>(callerBlock)); + auto callerFunc = callerBlock->getParent(); + SLANG_ASSERT(callerFunc); + + // As a fail-safe for release builds, if the above expectations are somehow + // *not* met, we will fall back to not inlining the call at all. + // + if (!callerFunc) + { + return; + } + + // We will create a new basic block block in the parent function that + // will contain all the instructions that come *after* the `call`. + // + builder->setInsertInto(callerFunc); auto afterBlock = builder->createBlock(); - + // Many operations (e.g. `cloneInst`) has define-before-use assumptions on the IR. // It is important to make sure we keep the ordering of blocks by inserting the // second half of the basic block right after `callerBlock`. diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 2415f1388..0bd5c6e9f 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -1221,6 +1221,26 @@ bool constructSSA(IRModule* module, IRInst* globalVal) case kIROp_GlobalVar: return constructSSA(module, (IRGlobalValueWithCode*)globalVal); + case kIROp_Generic: + { + // The above cases handle the actual code-bearing declarations + // that can contian basic blocks with local variables, but + // we would also like to perform SSA simplifications on + // *generic* functions, and so we will also process any + // instruction that is produced by an `IRGeneric`. + // + // TODO: At some point we may simply want to apply this pass + // recursively to *all* instructions, in order to make it + // robust to the presence of nested functions in general. + + auto generic = cast<IRGeneric>(globalVal); + auto returnVal = findInnerMostGenericReturnVal(generic); + if(!returnVal) + return false; + + return constructSSA(module, returnVal); + } + default: break; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 9378a69e8..383067363 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9102,9 +9102,24 @@ RefPtr<IRModule> generateIRForTranslationUnit( // normal `call` + `ifElse`, etc. lowerErrorHandling(module, compileRequest->getSink()); + // Next, attempt to promote local variables to SSA + // temporaries and do basic simplifications. + // + constructSSA(module); + simplifyCFG(module); + applySparseConditionalConstantPropagation(module); + // Next, inline calls to any functions that have been // marked for mandatory "early" inlining. // + // Note: We performed certain critical simplifications + // above, before this step, so that the body of functions + // subject to mandatory inlining can be simplified ahead + // of time. By simplifying the body before inlining it, + // we can make sure that things like superfluous temporaries + // are eliminated from the callee, and not copied into + // call sites. + // performMandatoryEarlyInlining(module); // Next, attempt to promote local variables to SSA diff --git a/tests/bugs/inlining/global-const-inline.slang b/tests/bugs/inlining/global-const-inline.slang new file mode 100644 index 000000000..629031e87 --- /dev/null +++ b/tests/bugs/inlining/global-const-inline.slang @@ -0,0 +1,20 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj +//TEST(compute,vulkan):COMPARE_COMPUTE_EX:-vk -slang -compute -shaderobj + +static const float3 CONSTANT = float3(16); + +int test(int value) +{ + return value*int(CONSTANT.x) + value; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + let index = int(dispatchThreadID.x); + let value = test(index); + outputBuffer[index] = value; +}
\ No newline at end of file diff --git a/tests/bugs/inlining/global-const-inline.slang.expected.txt b/tests/bugs/inlining/global-const-inline.slang.expected.txt new file mode 100644 index 000000000..d4cb1cc00 --- /dev/null +++ b/tests/bugs/inlining/global-const-inline.slang.expected.txt @@ -0,0 +1,4 @@ +0 +11 +22 +33 diff --git a/tests/experimental/liveness/liveness-3.slang.expected b/tests/experimental/liveness/liveness-3.slang.expected index 3c124ed92..58f562d86 100644 --- a/tests/experimental/liveness/liveness-3.slang.expected +++ b/tests/experimental/liveness/liveness-3.slang.expected @@ -87,18 +87,17 @@ int calcThing_0(int offset_0) livenessStart_1(_S7, 0); _S7 = _S10; } - int _S11 = _S7 + i_0; - idx_0[modRange_0] = idx_0[modRange_0] + _S11; + idx_0[modRange_0] = idx_0[modRange_0] + (_S7 + i_0); i_0 = i_0 + 1; livenessStart_1(_S5, 0); - int _S12 = _S7; + int _S11 = _S7; livenessEnd_0(_S7, 0); - _S5 = _S12; + _S5 = _S11; } livenessEnd_0(i_0, 0); livenessEnd_0(_S2, 0); - int _S13 = (k_0 + 7) % 5; - if(_S13 == 4) + int _S12 = (k_0 + 7) % 5; + if(_S12 == 4) { livenessEnd_0(_S5, 0); livenessEnd_1(idx_0, 0); @@ -106,39 +105,39 @@ int calcThing_0(int offset_0) livenessEnd_2(another_0, 0); return total_0; } - int _S14 = idx_0[0] + idx_0[1]; - int _S15 = idx_0[2]; + int _S13 = idx_0[0] + idx_0[1]; + int _S14 = idx_0[2]; livenessEnd_1(idx_0, 0); - int _S16 = _S14 + _S15; - int _S17 = total_0; + int _S15 = _S13 + _S14; + int _S16 = total_0; livenessEnd_0(total_0, 0); - int total_1 = _S17 + _S16; + int total_1 = _S16 + _S15; k_0 = k_0 + 1; livenessStart_1(_S2, 0); - int _S18 = _S5; + int _S17 = _S5; livenessEnd_0(_S5, 0); - _S2 = _S18; + _S2 = _S17; livenessStart_1(total_0, 0); total_0 = total_1; } livenessEnd_0(_S2, 0); livenessEnd_0(k_0, 0); livenessEnd_2(another_0, 0); - int _S19 = total_0; + int _S18 = total_0; livenessEnd_0(total_0, 0); - return - _S19; + return - _S18; } -layout(std430, binding = 0) buffer _S20 { +layout(std430, binding = 0) buffer _S19 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S21 = uint(index_0); - int _S22 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S21)]) = _S22; + uint _S20 = uint(index_0); + int _S21 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S20)]) = _S21; return; } diff --git a/tests/experimental/liveness/liveness-4.slang.expected b/tests/experimental/liveness/liveness-4.slang.expected index 38f42c02a..52c6ebb32 100644 --- a/tests/experimental/liveness/liveness-4.slang.expected +++ b/tests/experimental/liveness/liveness-4.slang.expected @@ -48,13 +48,12 @@ int calcThing_0(int offset_0) { break; } - int _S2 = k_0 + i_0; - another_0[i_0 & 1] = another_0[i_0 & 1] + _S2; + another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); i_0 = i_0 + 1; } livenessEnd_0(i_0, 0); - int _S3 = (k_0 + 7) % 5; - if(_S3 == 4) + int _S2 = (k_0 + 7) % 5; + if(_S2 == 4) { livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); @@ -67,16 +66,16 @@ int calcThing_0(int offset_0) return -2; } -layout(std430, binding = 0) buffer _S4 { +layout(std430, binding = 0) buffer _S3 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S5 = uint(index_0); - int _S6 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S5)]) = _S6; + uint _S4 = uint(index_0); + int _S5 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S4)]) = _S5; return; } diff --git a/tests/experimental/liveness/liveness-5.slang.expected b/tests/experimental/liveness/liveness-5.slang.expected index 920e05b59..ea6e37036 100644 --- a/tests/experimental/liveness/liveness-5.slang.expected +++ b/tests/experimental/liveness/liveness-5.slang.expected @@ -51,17 +51,16 @@ int calcThing_0(int offset_0) { break; } - int _S2 = k_0 + i_0; - another_0[i_0 & 1] = another_0[i_0 & 1] + _S2; + another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); i_0 = i_0 + 1; } livenessEnd_0(i_0, 0); - int _S3 = another_0[k_0 & 1]; - int _S4 = total_0; + int _S2 = another_0[k_0 & 1]; + int _S3 = total_0; livenessEnd_0(total_0, 0); - int total_1 = _S4 + _S3; - int _S5 = (k_0 + 7) % 5; - if(_S5 == 4) + int total_1 = _S3 + _S2; + int _S4 = (k_0 + 7) % 5; + if(_S4 == 4) { livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); @@ -76,32 +75,32 @@ int calcThing_0(int offset_0) int total_2; if(total_0 > 4) { - int _S6 = total_0; + int _S5 = total_0; livenessEnd_0(total_0, 0); - int _S7 = - _S6; + int _S6 = - _S5; livenessStart_1(total_2, 0); - total_2 = _S7; + total_2 = _S6; } else { - int _S8 = total_0; + int _S7 = total_0; livenessEnd_0(total_0, 0); livenessStart_1(total_2, 0); - total_2 = _S8; + total_2 = _S7; } return total_2; } -layout(std430, binding = 0) buffer _S9 { +layout(std430, binding = 0) buffer _S8 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S10 = uint(index_0); - int _S11 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S10)]) = _S11; + uint _S9 = uint(index_0); + int _S10 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S9)]) = _S10; return; } diff --git a/tests/experimental/liveness/liveness-6.slang.expected b/tests/experimental/liveness/liveness-6.slang.expected index 91ee98f8e..ac1894f95 100644 --- a/tests/experimental/liveness/liveness-6.slang.expected +++ b/tests/experimental/liveness/liveness-6.slang.expected @@ -55,21 +55,20 @@ int calcThing_0(int offset_0) { break; } - int _S3 = k_0 + i_0; - another_0[i_0 & 1] = another_0[i_0 & 1] + _S3; + another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); arr_0[k_0 & 1] = arr_0[k_0 & 1] + i_0; i_0 = i_0 + 1; } livenessEnd_0(i_0, 0); - int _S4 = another_0[k_0 & 1]; - int _S5 = total_0; + int _S3 = another_0[k_0 & 1]; + int _S4 = total_0; livenessEnd_0(total_0, 0); - int total_1 = _S5 + _S4; - int _S6 = arr_0[k_0 & 1]; + int total_1 = _S4 + _S3; + int _S5 = arr_0[k_0 & 1]; livenessEnd_1(arr_0, 0); - int total_2 = total_1 + _S6; - int _S7 = (k_0 + 7) % 5; - if(_S7 == 4) + int total_2 = total_1 + _S5; + int _S6 = (k_0 + 7) % 5; + if(_S6 == 4) { livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); @@ -84,32 +83,32 @@ int calcThing_0(int offset_0) int total_3; if(total_0 > 4) { - int _S8 = total_0; + int _S7 = total_0; livenessEnd_0(total_0, 0); - int _S9 = - _S8; + int _S8 = - _S7; livenessStart_1(total_3, 0); - total_3 = _S9; + total_3 = _S8; } else { - int _S10 = total_0; + int _S9 = total_0; livenessEnd_0(total_0, 0); livenessStart_1(total_3, 0); - total_3 = _S10; + total_3 = _S9; } return total_3; } -layout(std430, binding = 0) buffer _S11 { +layout(std430, binding = 0) buffer _S10 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S12 = uint(index_0); - int _S13 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S12)]) = _S13; + uint _S11 = uint(index_0); + int _S12 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S11)]) = _S12; return; } diff --git a/tests/experimental/liveness/liveness.slang.expected b/tests/experimental/liveness/liveness.slang.expected index 0a39f0225..e50ecac5a 100644 --- a/tests/experimental/liveness/liveness.slang.expected +++ b/tests/experimental/liveness/liveness.slang.expected @@ -7,10 +7,10 @@ standard output = { layout(row_major) uniform; layout(row_major) buffer; spirv_instruction(id = 256) -void livenessStart_0(spirv_by_reference int _0, spirv_literal int _1); +void livenessStart_0(spirv_by_reference uint _0, spirv_literal int _1); spirv_instruction(id = 256) -void livenessStart_1(spirv_by_reference uint _0, spirv_literal int _1); +void livenessStart_1(spirv_by_reference int _0, spirv_literal int _1); spirv_instruction(id = 257) void livenessEnd_0(spirv_by_reference uint _0, spirv_literal int _1); @@ -21,12 +21,12 @@ void livenessEnd_1(spirv_by_reference int _0, spirv_literal int _1); int someSlowFunc_0(int a_0) { uint _S1 = uint(a_0); - int i_0; uint v_0; - livenessStart_0(i_0, 0); - i_0 = 0; - livenessStart_1(v_0, 0); + int i_0; + livenessStart_0(v_0, 0); v_0 = _S1; + livenessStart_1(i_0, 0); + i_0 = 0; for(;;) { if(i_0 < a_0 * 20) @@ -40,9 +40,10 @@ int someSlowFunc_0(int a_0) uint _S3 = v_0; livenessEnd_0(v_0, 0); uint _S4 = (_S2 | _S3 << 31) * uint(i_0); - i_0 = i_0 + 1; - livenessStart_1(v_0, 0); + int i_1 = i_0 + 1; + livenessStart_0(v_0, 0); v_0 = _S4; + i_0 = i_1; } livenessEnd_1(i_0, 0); return int(v_0); @@ -89,15 +90,15 @@ layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - int i_1; + int i_2; int res_0; - livenessStart_0(i_1, 0); - i_1 = 0; - livenessStart_0(res_0, 0); + livenessStart_1(i_2, 0); + i_2 = 0; + livenessStart_1(res_0, 0); res_0 = index_0; for(;;) { - if(i_1 < index_0) + if(i_2 < index_0) { } else @@ -148,11 +149,11 @@ void main() int _S20 = res_0; livenessEnd_1(res_0, 0); int res_1 = _S20 + _S19; - i_1 = i_1 + 1; - livenessStart_0(res_0, 0); + i_2 = i_2 + 1; + livenessStart_1(res_0, 0); res_0 = res_1; } - livenessEnd_1(i_1, 0); + livenessEnd_1(i_2, 0); int _S21 = res_0; livenessEnd_1(res_0, 0); ((outputBuffer_0)._data[(uint(index_0))]) = _S21; |
