diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-27 23:42:06 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-27 23:42:06 -0500 |
| commit | 10e2d9c7c532c204f26bb2c9f383f21b121b2ff2 (patch) | |
| tree | 9ae0dd84b505a7ecd3fb45de9dbde74f8dd1ebe9 | |
| parent | a3ba22b51c371d5a20d61aa4e35233ba4f4f68db (diff) | |
More fixes for reverse-mode on complicated loops (#2675)
* Multiple fixes to get various loop tests to pass.
* Create reverse-nested-loop.slang
* Fix for variables becoming inaccessible during cfg normalization
* Removed comments and moved break-branch-normalization to eliminateMultiLevelBreaks
* Fix.
* Override liveness tests
17 files changed, 295 insertions, 33 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 9116f67e9..2199b0771 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -43,6 +43,7 @@ struct BreakableRegionInfo { IRVar* breakVar; IRBlock* breakBlock; + IRBlock* headerBlock; }; struct CFGNormalizationContext @@ -57,13 +58,39 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst) // For now, we're going to naively assume the next block is the condition block. // Add in more support for more cases as necessary. // - + auto firstBlock = loopInst->getTargetBlock(); - auto ifElse = as<IRIfElse>(firstBlock->getTerminator()); - SLANG_RELEASE_ASSERT(ifElse); + if (as<IRIfElse>(firstBlock->getTerminator())) + { + return firstBlock; + } + else + { + // If there isn't a condition we need to make one with a dummy condition that + // always evaluates to true + // - return firstBlock; + IRBuilder condBuilder(loopInst->getModule()); + + auto condBlock = condBuilder.emitBlock(); + condBlock->insertAfter(as<IRBlock>(loopInst->getParent())); + + // Make loop go into the condition block + firstBlock->replaceUsesWith(condBlock); + + // Emit a condition: true side goes to the loop body, and + // false side goes into the break block. + // + condBuilder.setInsertInto(condBlock); + condBuilder.emitIfElse( + condBuilder.getBoolValue(true), + firstBlock, + loopInst->getBreakBlock(), + firstBlock); + + return condBlock; + } } struct CFGNormalizationPass @@ -133,6 +160,20 @@ struct CFGNormalizationPass return false; } + void _moveVarsToRegionHeader(BreakableRegionInfo* region, IRBlock* block) + { + for (auto child = block->getFirstChild(); child;) + { + auto nextChild = child->getNextInst(); + + if (as<IRVar>(child)) + { + child->insertBefore(region->headerBlock->getTerminator()); + } + + child = nextChild; + } + } RegionEndpoint getNormalizedRegionEndpoint( BreakableRegionInfo* parentRegion, @@ -140,6 +181,7 @@ struct CFGNormalizationPass List<IRBlock*> afterBlocks) { IRBlock* currentBlock = entryBlock; + _moveVarsToRegionHeader(parentRegion, currentBlock); // By default a region starts off with the 'base' control flow // and not in the 'break' control flow @@ -343,6 +385,8 @@ struct CFGNormalizationPass SLANG_UNEXPECTED("Unhandled control flow inst"); break; } + + _moveVarsToRegionHeader(parentRegion, currentBlock); } // Resolve all intermediate after-blocks @@ -399,6 +443,7 @@ struct CFGNormalizationPass { BreakableRegionInfo info; info.breakBlock = as<IRLoop>(branchInst)->getBreakBlock(); + info.headerBlock = as<IRBlock>(branchInst->getParent()); // Emit var into parent block. builder.setInsertBefore( @@ -426,7 +471,7 @@ struct CFGNormalizationPass &info, firstLoopBlock, List<IRBlock*>(info.breakBlock)); - + // Should not be empty.. but check anyway SLANG_RELEASE_ASSERT(!preBreakEndPoint.isRegionEmpty); @@ -495,7 +540,7 @@ struct CFGNormalizationPass // Add a test for the break variable into the condition. auto cond = ifElse->getCondition(); - builder.setInsertAfter(cond); + builder.setInsertBefore(ifElse); auto breakFlagVal = builder.emitLoad(info.breakVar); // Need to invert the break flag if the loop is diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 640f516ed..709968f77 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -14,6 +14,7 @@ #include "slang-ir-init-local-var.h" #include "slang-ir-redundancy-removal.h" #include "slang-ir-dominators.h" +#include "slang-ir-loop-unroll.h" namespace Slang { @@ -583,6 +584,9 @@ namespace Slang { convertFuncToSingleReturnForm(func->getModule(), func); } + + eliminateContinueBlocksInFunc(func->getModule(), func); + eliminateMultiLevelBreakForFunc(func->getModule(), func); IRCFGNormalizationPass cfgPass = {this->getSink()}; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index a30826370..3678bd4b3 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -409,16 +409,14 @@ struct DiffUnzipPass for (auto region : indexRegions) { // Grab first primal block. - IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]); - builder.setInsertBefore(firstPrimalBlock->getTerminator()); + IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]); + builder.setInsertBefore(primalInitBlock->getTerminator()); // Make variable in the top-most block (so it's visible to diff blocks) region->primalCountLastVar = builder.emitVar(builder.getIntType()); builder.addNameHintDecoration(region->primalCountLastVar, UnownedStringSlice("_pc_last_var")); - { - IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]); - + { auto primalCondBlock = as<IRUnconditionalBranch>( primalInitBlock->getTerminator())->getTargetBlock(); builder.setInsertBefore(primalCondBlock->getTerminator()); @@ -664,8 +662,8 @@ struct DiffUnzipPass storageVar); // 4. Store current value into the array and replace uses with a load. - // TODO: If an index is missing, use the 'last' value of the primal index. - + // If an index is missing, use the 'last' value of the primal index. + { if (!isIntermediateContext) setInsertAfterOrdinaryInst(&builder, valueToStore); diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp index 3618e1326..e73fae982 100644 --- a/source/slang/slang-ir-eliminate-multilevel-break.cpp +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -175,8 +175,62 @@ struct EliminateMultiLevelBreakContext } }; + + void insertBlockBetween(IRBlock* block, IRBlock* successor) + { + IRBuilder builder(block->getModule()); + + List<IRUse*> relevantUses; + for (auto use = successor->firstUse; use; use = use->nextUse) + { + if (auto terminator = as<IRTerminatorInst>(use->getUser())) + { + if (as<IRBlock>(terminator->getParent()) == block) + { + relevantUses.add(use); + } + } + } + + SLANG_RELEASE_ASSERT(relevantUses.getCount() == 1); + + builder.insertBlockAlongEdge(block->getModule(), IREdge(relevantUses[0])); + } + + bool normalizeBranchesIntoBreakBlocks(IRGlobalValueWithCode* func) + { + bool changed = false; + + List<IRBlock*> workList; + + for (auto block : func->getBlocks()) + workList.add(block); + + for (auto block : workList) + { + if (auto loop = as<IRLoop>(block->getTerminator())) + { + auto breakBlock = loop->getBreakBlock(); + + for (auto predecessor : breakBlock->getPredecessors()) + { + if (!as<IRUnconditionalBranch>(predecessor->getTerminator())) + { + insertBlockBetween(predecessor, breakBlock); + changed = true; + } + } + } + } + + return changed; + } + void processFunc(IRGlobalValueWithCode* func) { + + normalizeBranchesIntoBreakBlocks(func); + // If func does not have any multi-level breaks, return. { FuncContext funcInfo; diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 2f689ebde..4f9b8d272 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -603,6 +603,16 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) builder.setInsertInto(innerBreakableRegionBreakBlock); _moveParams(innerBreakableRegionBreakBlock, continueBlock); builder.emitBranch(continueBlock); + + // If the original loop can be executed up to N times, the new loop may be executed + // upto N+1 times (although most insts are skipped in the last traversal) + // + if (auto maxItersDecoration = loopInst->findDecoration<IRLoopMaxItersDecoration>()) + { + auto maxIters = maxItersDecoration->getMaxIters(); + maxItersDecoration->removeAndDeallocate(); + builder.addLoopMaxItersDecoration(loopInst, maxIters + 1); + } } void eliminateContinueBlocksInFunc(IRModule* module, IRGlobalValueWithCode* func) diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang new file mode 100644 index 000000000..0f9502673 --- /dev/null +++ b/tests/autodiff/reverse-continue-loop.slang @@ -0,0 +1,43 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_loop_with_continue(float y) +{ + float t = y; + + for (int i = 0; i < 3; i++) + { + if (t > 4.0) + continue; + + t = t * t; + } + + return t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 0.0); + + __bwd_diff(test_loop_with_continue)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 32.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_loop_with_continue)(dpa, 1.0f); + outputBuffer[1] = dpa.d; // Expect: 0.0131072 + } +} diff --git a/tests/autodiff/reverse-continue-loop.slang.expected.txt b/tests/autodiff/reverse-continue-loop.slang.expected.txt new file mode 100644 index 000000000..17dbb061d --- /dev/null +++ b/tests/autodiff/reverse-continue-loop.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +32.000000 +0.013107 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/reverse-hybrid-control-flow.slang b/tests/autodiff/reverse-hybrid-control-flow.slang new file mode 100644 index 000000000..9379df4cf --- /dev/null +++ b/tests/autodiff/reverse-hybrid-control-flow.slang @@ -0,0 +1,47 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_simple_loop(float y) +{ + float t = y; + + if (y > 0.5) + { + for (int i = 0; i < 3; i++) + { + t = t * t; + } + } + else + { + t = t * 10.f; + } + + return t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_loop)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 8.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_simple_loop)(dpa, 1.0f); + outputBuffer[1] = dpa.d; // Expect: 10.0 + } +} diff --git a/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt b/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt new file mode 100644 index 000000000..3cb76c394 --- /dev/null +++ b/tests/autodiff/reverse-hybrid-control-flow.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +8.000000 +10.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/reverse-nested-loop.slang b/tests/autodiff/reverse-nested-loop.slang new file mode 100644 index 000000000..08cde5230 --- /dev/null +++ b/tests/autodiff/reverse-nested-loop.slang @@ -0,0 +1,43 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_simple_nested_loop(float y) +{ + float t = y; + + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 2; j++) + { + t = t * (i + j + 1); + } + } + + return t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_nested_loop)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 12.0 * 1 + } + + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_nested_loop)(dpa, 0.4f); + outputBuffer[1] = dpa.d; // Expect: 12 * 0.4 = 4.8 + } +} diff --git a/tests/autodiff/reverse-nested-loop.slang.expected.txt b/tests/autodiff/reverse-nested-loop.slang.expected.txt new file mode 100644 index 000000000..59c14cf1d --- /dev/null +++ b/tests/autodiff/reverse-nested-loop.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +4.800000 +0.000000 +0.000000 +0.000000 diff --git a/tests/experimental/liveness/liveness-2.slang.expected b/tests/experimental/liveness/liveness-2.slang.expected index 16883c1fd..e0486f8da 100644 --- a/tests/experimental/liveness/liveness-2.slang.expected +++ b/tests/experimental/liveness/liveness-2.slang.expected @@ -32,12 +32,12 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } idx_0[i_0] = offset_0 + i_0; i_0 = i_0 + 1; } - livenessEnd_0(i_0, 0); int _S1 = idx_0[0] + idx_0[1]; int _S2 = idx_0[2]; livenessEnd_1(idx_0, 0); diff --git a/tests/experimental/liveness/liveness-3.slang.expected b/tests/experimental/liveness/liveness-3.slang.expected index cb093a640..d4b417082 100644 --- a/tests/experimental/liveness/liveness-3.slang.expected +++ b/tests/experimental/liveness/liveness-3.slang.expected @@ -12,12 +12,12 @@ void livenessStart_0(spirv_by_reference int _0[2], spirv_literal int _1); spirv_instruction(id = 256) void livenessStart_1(spirv_by_reference int _0, spirv_literal int _1); -spirv_instruction(id = 256) -void livenessStart_2(spirv_by_reference int _0[3], spirv_literal int _1); - spirv_instruction(id = 257) void livenessEnd_0(spirv_by_reference int _0, spirv_literal int _1); +spirv_instruction(id = 256) +void livenessStart_2(spirv_by_reference int _0[3], spirv_literal int _1); + spirv_instruction(id = 257) void livenessEnd_1(spirv_by_reference int _0[3], spirv_literal int _1); @@ -46,6 +46,8 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(_S1, 0); + livenessEnd_0(k_0, 0); break; } int idx_0[3]; @@ -69,6 +71,7 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } int modRange_0 = i_0 % 3; @@ -97,7 +100,6 @@ int calcThing_0(int offset_0) livenessEnd_0(_S6, 0); _S4 = _S10; } - livenessEnd_0(i_0, 0); livenessEnd_0(_S1, 0); livenessEnd_0(k_0, 0); if(_S3) @@ -123,8 +125,6 @@ int calcThing_0(int offset_0) livenessStart_1(total_0, 0); total_0 = total_1; } - livenessEnd_0(_S1, 0); - livenessEnd_0(k_0, 0); livenessEnd_2(another_0, 0); int _S16 = total_0; livenessEnd_0(total_0, 0); diff --git a/tests/experimental/liveness/liveness-4.slang.expected b/tests/experimental/liveness/liveness-4.slang.expected index efc2e3846..483247ecd 100644 --- a/tests/experimental/liveness/liveness-4.slang.expected +++ b/tests/experimental/liveness/liveness-4.slang.expected @@ -34,6 +34,7 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(k_0, 0); break; } int _S1 = (k_0 + 7) % 5; @@ -49,12 +50,12 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); i_0 = i_0 + 1; } - livenessEnd_0(i_0, 0); livenessEnd_0(k_0, 0); if(_S2) { @@ -64,7 +65,6 @@ int calcThing_0(int offset_0) livenessStart_1(k_0, 0); k_0 = k_1; } - livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); return -2; } diff --git a/tests/experimental/liveness/liveness-5.slang.expected b/tests/experimental/liveness/liveness-5.slang.expected index e9fe9d652..5e144a095 100644 --- a/tests/experimental/liveness/liveness-5.slang.expected +++ b/tests/experimental/liveness/liveness-5.slang.expected @@ -37,6 +37,7 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(k_0, 0); break; } int _S1 = (k_0 + 7) % 5; @@ -52,12 +53,12 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); i_0 = i_0 + 1; } - livenessEnd_0(i_0, 0); livenessEnd_0(k_0, 0); int _S3 = total_0; livenessEnd_0(total_0, 0); @@ -72,7 +73,6 @@ int calcThing_0(int offset_0) livenessStart_1(total_0, 0); total_0 = total_1; } - livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); if(total_0 > 4) { diff --git a/tests/experimental/liveness/liveness-6.slang.expected b/tests/experimental/liveness/liveness-6.slang.expected index b661c09bf..0d2e997b2 100644 --- a/tests/experimental/liveness/liveness-6.slang.expected +++ b/tests/experimental/liveness/liveness-6.slang.expected @@ -37,6 +37,7 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(k_0, 0); break; } int arr_0[2]; @@ -57,13 +58,13 @@ int calcThing_0(int offset_0) } else { + livenessEnd_0(i_0, 0); break; } another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); arr_0[_S1] = arr_0[_S1] + i_0; i_0 = i_0 + 1; } - livenessEnd_0(i_0, 0); livenessEnd_0(k_0, 0); int _S4 = total_0; livenessEnd_0(total_0, 0); @@ -81,7 +82,6 @@ int calcThing_0(int offset_0) livenessStart_1(total_0, 0); total_0 = total_2; } - livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); if(total_0 > 4) { diff --git a/tests/experimental/liveness/liveness.slang.expected b/tests/experimental/liveness/liveness.slang.expected index 06809ffc3..b0017ea9d 100644 --- a/tests/experimental/liveness/liveness.slang.expected +++ b/tests/experimental/liveness/liveness.slang.expected @@ -13,10 +13,10 @@ spirv_instruction(id = 256) void livenessStart_1(spirv_by_reference int _0, spirv_literal int _1); spirv_instruction(id = 257) -void livenessEnd_0(spirv_by_reference uint _0, spirv_literal int _1); +void livenessEnd_0(spirv_by_reference int _0, spirv_literal int _1); spirv_instruction(id = 257) -void livenessEnd_1(spirv_by_reference int _0, spirv_literal int _1); +void livenessEnd_1(spirv_by_reference uint _0, spirv_literal int _1); int someSlowFunc_0(int a_0) { @@ -35,18 +35,18 @@ int someSlowFunc_0(int a_0) } else { + livenessEnd_0(i_0, 0); break; } uint _S3 = v_0 >> 1; uint _S4 = v_0; - livenessEnd_0(v_0, 0); + livenessEnd_1(v_0, 0); uint _S5 = (_S3 | _S4 << 31) * uint(i_0); int i_1 = i_0 + 1; livenessStart_0(v_0, 0); v_0 = _S5; i_0 = i_1; } - livenessEnd_1(i_0, 0); return int(v_0); } @@ -111,6 +111,7 @@ void main() } else { + livenessEnd_0(i_2, 0); break; } SomeStruct_0 s_3; @@ -153,15 +154,14 @@ void main() livenessEnd_2(s_3, 0); int _S22 = _S20 + _S21; int _S23 = res_0; - livenessEnd_1(res_0, 0); + livenessEnd_0(res_0, 0); int res_1 = _S23 + _S22; i_2 = i_2 + 1; livenessStart_1(res_0, 0); res_0 = res_1; } - livenessEnd_1(i_2, 0); int _S24 = res_0; - livenessEnd_1(res_0, 0); + livenessEnd_0(res_0, 0); ((outputBuffer_0)._data[(uint(index_0))]) = _S24; return; } |
