From e312d5c7dfde80941d96e522079a5d70f7d00649 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 31 Jan 2023 03:26:59 -0500 Subject: Patched support for multi-return and fallthrough if-else with break stmts (#2617) --- source/slang/slang-ir-autodiff-cfg-norm.cpp | 103 ++++++++++----------- source/slang/slang-ir-autodiff-cfg-norm.h | 1 + source/slang/slang-ir-autodiff-fwd.cpp | 16 +++- source/slang/slang-ir-autodiff-rev.cpp | 4 +- .../slang/slang-ir-autodiff-transcriber-base.cpp | 6 +- source/slang/slang-ir-autodiff-transpose.h | 11 ++- source/slang/slang-ir-autodiff-unzip.cpp | 5 +- tests/autodiff/reverse-multi-return.slang | 50 ++++++++++ .../reverse-multi-return.slang.expected.txt | 7 ++ tests/autodiff/reverse-single-iter-loop.slang | 8 +- 10 files changed, 139 insertions(+), 72 deletions(-) create mode 100644 tests/autodiff/reverse-multi-return.slang create mode 100644 tests/autodiff/reverse-multi-return.slang.expected.txt diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 4e0a413db..2b201466b 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -97,14 +97,6 @@ struct CFGNormalizationPass builder->setInsertInto(afterBlock); unreachInst->removeAndDeallocate(); - /* - HashSet predecessorSet; - for (auto predecessor : parentAfterBlock->getPredecessors()) - predecessorSet.Add(predecessor); - - SLANG_ASSERT(predecessorSet.Count() <= 1); - */ - builder->emitBranch(parentAfterBlock); } } @@ -169,6 +161,45 @@ struct CFGNormalizationPass IRBlock* parentAfterBlock = afterBlocks[0]; + auto addBreakBypassBranch = [&](IRBlock* block) + { + // We could arrive at the after-block before or + // after encountering a break statement. + // To handle this, we'll split the flow by checking the break flag + // + builder.setInsertAfter(block); + + auto preAfterSplitBlock = builder.emitBlock(); + preAfterSplitBlock->insertBefore(block); + + auto afterSplitBlock = builder.emitBlock(); + afterSplitBlock->insertBefore(block); + + block->replaceUsesWith(preAfterSplitBlock); + + builder.setInsertInto(preAfterSplitBlock); + builder.emitBranch(afterSplitBlock); + + // Converging block for the split that we're making. + auto afterSplitAfterBlock = builder.emitBlock(); + + builder.setInsertInto(afterSplitBlock); + auto breakFlagValue = builder.emitLoad(parentRegion->breakVar); + + builder.emitIfElse( + breakFlagValue, + block, + afterSplitAfterBlock, + afterSplitAfterBlock); + + // At this point, we need to place afterSplitAfterBlock between + // at the _end_ of this region, but we aren't there yet (and + // don't know which block is the end of this region) + // Therefore, we'll defer this step and add it to a list for later. + // + pendingAfterBlocks.add(afterSplitAfterBlock); + }; + // Follow this thread of execution till we hit an // acceptable after block. // @@ -210,12 +241,15 @@ struct CFGNormalizationPass auto afterBlock = ifElse->getAfterBlock(); // Trivial case, both end-points branch into the after block - if (trueTargetBlock == afterBlock && + /*if (trueTargetBlock == afterBlock && falseTargetBlock == afterBlock) { + if () + addBreakBypassBranch(afterBlock); currentBlock = afterBlock; + // TODO: Need to split block. break; - } + }*/ auto afterBreakRegion = false; auto afterBaseRegion = false; @@ -281,41 +315,7 @@ struct CFGNormalizationPass // Do we need to split the after region? if (afterBaseRegion && afterBreakRegion) { - // We could arrive at the after-block before or - // after encountering a break statement. - // To handle this, we'll split the flow by checking the break flag - // - builder.setInsertAfter(afterBlock); - - auto preAfterSplitBlock = builder.emitBlock(); - preAfterSplitBlock->insertBefore(afterBlock); - - auto afterSplitBlock = builder.emitBlock(); - afterSplitBlock->insertBefore(afterBlock); - - afterBlock->replaceUsesWith(preAfterSplitBlock); - - builder.setInsertInto(preAfterSplitBlock); - builder.emitBranch(afterSplitBlock); - - // Converging block for the split that we're making. - auto afterSplitAfterBlock = builder.emitBlock(); - - builder.setInsertInto(afterSplitBlock); - auto breakFlagValue = builder.emitLoad(parentRegion->breakVar); - - builder.emitIfElse( - breakFlagValue, - afterBlock, - afterSplitAfterBlock, - afterSplitAfterBlock); - - // At this point, we need to place afterSplitAfterBlock between - // at the _end_ of this region, but we aren't there yet (and - // don't know which block is the end of this region) - // Therefore, we'll defer this step and add it to a list for later. - // - pendingAfterBlocks.add(afterSplitAfterBlock); + addBreakBypassBranch(afterBlock); // Update current block. currentBlock = afterBlock; @@ -419,12 +419,6 @@ struct CFGNormalizationPass if (isLoopTrivial(as(branchInst))) { auto firstLoopBlock = as(branchInst)->getTargetBlock(); - auto terminator = firstLoopBlock->getTerminator(); - - // We really shouldn't see a conditional branch on a trivial loop - // but if we hit this assert, handle this case. - // - SLANG_RELEASE_ASSERT(as(terminator)); // Normalize the region from the first loop block till break. auto preBreakEndPoint = getNormalizedRegionEndpoint( @@ -583,6 +577,7 @@ struct CFGNormalizationPass }; void normalizeCFG( + SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func, IRCFGNormalizationPass const& options) { @@ -591,9 +586,7 @@ void normalizeCFG( // eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func); - SharedIRBuilder sharedBuilder(func->getModule()); - sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); - CFGNormalizationContext context = {&sharedBuilder, options.sink}; + CFGNormalizationContext context = {sharedBuilder, options.sink}; CFGNormalizationPass cfgPass(context); List workList; @@ -622,7 +615,7 @@ void normalizeCFG( } disableIRValidationAtInsert(); - constructSSA(&sharedBuilder, func); + constructSSA(sharedBuilder, func); enableIRValidationAtInsert(); } diff --git a/source/slang/slang-ir-autodiff-cfg-norm.h b/source/slang/slang-ir-autodiff-cfg-norm.h index 2a39f7695..f256d8ce8 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.h +++ b/source/slang/slang-ir-autodiff-cfg-norm.h @@ -19,6 +19,7 @@ namespace Slang /// "after" block. /// void normalizeCFG( + SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func, IRCFGNormalizationPass const& options = IRCFGNormalizationPass()); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index abe3f718c..f60412efb 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -527,6 +527,9 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns auto diffArg = lookupDiffInst(origArg, nullptr); if (diffArg) newArgs.add(diffArg); + else + newArgs.add( + getDifferentialZeroOfType(builder, origArg->getDataType())); } } @@ -576,16 +579,15 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns return InstPair(nullptr, nullptr); } -InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder*, IRInst* origInst) { switch(origInst->getOp()) { case kIROp_FloatLit: - return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f)); + case kIROp_IntLit: + return InstPair(origInst, nullptr); case kIROp_VoidLit: return InstPair(origInst, origInst); - case kIROp_IntLit: - return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0)); } getSink()->diagnose( @@ -943,9 +945,15 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build SLANG_ASSERT(primalVal); auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue()); SLANG_ASSERT(diffPrimalVal); + auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue()); + if (!primalDiffVal) + primalDiffVal = getDifferentialZeroOfType(builder, origInst->getPrimalValue()->getDataType()); SLANG_ASSERT(primalDiffVal); + auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); + if (!diffDiffVal) + diffDiffVal = getDifferentialZeroOfType(builder, origInst->getDifferentialValue()->getDataType()); SLANG_ASSERT(diffDiffVal); auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType()); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 8f218293d..0f2ceceb4 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -510,7 +510,7 @@ namespace Slang eliminateMultiLevelBreakForFunc(func->getModule(), func); IRCFGNormalizationPass cfgPass = {this->getSink()}; - normalizeCFG(func); + normalizeCFG(autoDiffSharedContext->sharedBuilder, func); AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &diffTypeContext; @@ -544,6 +544,8 @@ namespace Slang // reversible. if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc))) return diffPropagateFunc; + + autoDiffSharedContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); // Forward transcribe the clone of the original func. ForwardDiffTranscriber& fwdTranscriber = *static_cast( diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 91374e006..520c6d276 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -911,12 +911,14 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst // Tag the differential inst using a decoration (if it doesn't have one) if (!pair.differential->findDecoration() && - !pair.differential->findDecoration()) + !pair.differential->findDecoration() && + !as(pair.differential)) { // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential // instead. // - builder->markInstAsDifferential(pair.differential, as(pair.primal->getDataType())); + auto primalType = as(pair.primal->getDataType()); + builder->markInstAsDifferential(pair.differential, primalType); } break; diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index f87aa7751..5aad6e3a3 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1694,9 +1694,6 @@ struct DiffTransposePass IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer()); - if (isDifferentialInst(inst)) - builder->markInstAsDifferential(newInst); - return newInst; } @@ -1725,6 +1722,11 @@ struct DiffTransposePass builder->setInsertAfter(operand); IRInst* newOperand = promoteToType(builder, targetType, operand); + + if (isDifferentialInst(operand)) + builder->markInstAsDifferential( + newOperand, tryGetPrimalTypeFromDiffInst(fwdInst)); + newOperands.add(newOperand); needNewInst = true; @@ -1747,7 +1749,8 @@ struct DiffTransposePass builder->setInsertLoc(oldLoc); if (isDifferentialInst(fwdInst)) - builder->markInstAsDifferential(newInst); + builder->markInstAsDifferential( + newInst, tryGetPrimalTypeFromDiffInst(fwdInst)); return newInst; } diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 44cb2aa09..daf6e44d4 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -322,7 +322,10 @@ struct ExtractPrimalFuncContext { if (shouldStoreInst(inst)) { - builder.setInsertAfter(inst); + if (as(inst)) + builder.setInsertBefore(block->getFirstOrdinaryInst()); + else + builder.setInsertAfter(inst); storeInst(builder, inst, outIntermediary); } else if (inst->getOp() == kIROp_Var) diff --git a/tests/autodiff/reverse-multi-return.slang b/tests/autodiff/reverse-multi-return.slang new file mode 100644 index 000000000..ee8bb9a4c --- /dev/null +++ b/tests/autodiff/reverse-multi-return.slang @@ -0,0 +1,50 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_multi_return(float y) +{ + if (y > 0.6) + { + if (y > 0.8) + { + return y * 10.0f; + } + else + { + return y * 4.0f; + } + } + return y * 6.0f; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_multi_return)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 10.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_multi_return)(dpa, 1.0f); + outputBuffer[1] = dpa.d; // Expect: 6.0 + } + + { + dpfloat dpa = dpfloat(0.7, 0.0); + + __bwd_diff(test_multi_return)(dpa, 1.0f); + outputBuffer[2] = dpa.d; // Expect: 4.0 + } +} diff --git a/tests/autodiff/reverse-multi-return.slang.expected.txt b/tests/autodiff/reverse-multi-return.slang.expected.txt new file mode 100644 index 000000000..115191b13 --- /dev/null +++ b/tests/autodiff/reverse-multi-return.slang.expected.txt @@ -0,0 +1,7 @@ +type: float +10.000000 +6.000000 +4.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/reverse-single-iter-loop.slang b/tests/autodiff/reverse-single-iter-loop.slang index 47232147a..20c26e000 100644 --- a/tests/autodiff/reverse-single-iter-loop.slang +++ b/tests/autodiff/reverse-single-iter-loop.slang @@ -49,11 +49,9 @@ float test_nested_if_else_single_iter_loop(float y) break; } } - else - { - x = y * 6.0f; - break; - } + + x = y * 6.0f; + break; } return x; -- cgit v1.2.3