diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-01-31 03:26:59 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-31 00:26:59 -0800 |
| commit | e312d5c7dfde80941d96e522079a5d70f7d00649 (patch) | |
| tree | cf600a7f49117a77336ad55e59816f5c323cd705 /source | |
| parent | 77cdbb2101f4e27bf1800d4bc1077c0510668c25 (diff) | |
Patched support for multi-return and fallthrough if-else with break stmts (#2617)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 103 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 5 |
7 files changed, 79 insertions, 67 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 4e0a413db..2b201466b 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -97,14 +97,6 @@ struct CFGNormalizationPass builder->setInsertInto(afterBlock); unreachInst->removeAndDeallocate(); - /* - HashSet<IRBlock*> predecessorSet; - for (auto predecessor : parentAfterBlock->getPredecessors()) - predecessorSet.Add(predecessor); - - SLANG_ASSERT(predecessorSet.Count() <= 1); - */ - builder->emitBranch(parentAfterBlock); } } @@ -169,6 +161,45 @@ struct CFGNormalizationPass IRBlock* parentAfterBlock = afterBlocks[0]; + auto addBreakBypassBranch = [&](IRBlock* block) + { + // We could arrive at the after-block before or + // after encountering a break statement. + // To handle this, we'll split the flow by checking the break flag + // + builder.setInsertAfter(block); + + auto preAfterSplitBlock = builder.emitBlock(); + preAfterSplitBlock->insertBefore(block); + + auto afterSplitBlock = builder.emitBlock(); + afterSplitBlock->insertBefore(block); + + block->replaceUsesWith(preAfterSplitBlock); + + builder.setInsertInto(preAfterSplitBlock); + builder.emitBranch(afterSplitBlock); + + // Converging block for the split that we're making. + auto afterSplitAfterBlock = builder.emitBlock(); + + builder.setInsertInto(afterSplitBlock); + auto breakFlagValue = builder.emitLoad(parentRegion->breakVar); + + builder.emitIfElse( + breakFlagValue, + block, + afterSplitAfterBlock, + afterSplitAfterBlock); + + // At this point, we need to place afterSplitAfterBlock between + // at the _end_ of this region, but we aren't there yet (and + // don't know which block is the end of this region) + // Therefore, we'll defer this step and add it to a list for later. + // + pendingAfterBlocks.add(afterSplitAfterBlock); + }; + // Follow this thread of execution till we hit an // acceptable after block. // @@ -210,12 +241,15 @@ struct CFGNormalizationPass auto afterBlock = ifElse->getAfterBlock(); // Trivial case, both end-points branch into the after block - if (trueTargetBlock == afterBlock && + /*if (trueTargetBlock == afterBlock && falseTargetBlock == afterBlock) { + if () + addBreakBypassBranch(afterBlock); currentBlock = afterBlock; + // TODO: Need to split block. break; - } + }*/ auto afterBreakRegion = false; auto afterBaseRegion = false; @@ -281,41 +315,7 @@ struct CFGNormalizationPass // Do we need to split the after region? if (afterBaseRegion && afterBreakRegion) { - // We could arrive at the after-block before or - // after encountering a break statement. - // To handle this, we'll split the flow by checking the break flag - // - builder.setInsertAfter(afterBlock); - - auto preAfterSplitBlock = builder.emitBlock(); - preAfterSplitBlock->insertBefore(afterBlock); - - auto afterSplitBlock = builder.emitBlock(); - afterSplitBlock->insertBefore(afterBlock); - - afterBlock->replaceUsesWith(preAfterSplitBlock); - - builder.setInsertInto(preAfterSplitBlock); - builder.emitBranch(afterSplitBlock); - - // Converging block for the split that we're making. - auto afterSplitAfterBlock = builder.emitBlock(); - - builder.setInsertInto(afterSplitBlock); - auto breakFlagValue = builder.emitLoad(parentRegion->breakVar); - - builder.emitIfElse( - breakFlagValue, - afterBlock, - afterSplitAfterBlock, - afterSplitAfterBlock); - - // At this point, we need to place afterSplitAfterBlock between - // at the _end_ of this region, but we aren't there yet (and - // don't know which block is the end of this region) - // Therefore, we'll defer this step and add it to a list for later. - // - pendingAfterBlocks.add(afterSplitAfterBlock); + addBreakBypassBranch(afterBlock); // Update current block. currentBlock = afterBlock; @@ -419,12 +419,6 @@ struct CFGNormalizationPass if (isLoopTrivial(as<IRLoop>(branchInst))) { auto firstLoopBlock = as<IRLoop>(branchInst)->getTargetBlock(); - auto terminator = firstLoopBlock->getTerminator(); - - // We really shouldn't see a conditional branch on a trivial loop - // but if we hit this assert, handle this case. - // - SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(terminator)); // Normalize the region from the first loop block till break. auto preBreakEndPoint = getNormalizedRegionEndpoint( @@ -583,6 +577,7 @@ struct CFGNormalizationPass }; void normalizeCFG( + SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func, IRCFGNormalizationPass const& options) { @@ -591,9 +586,7 @@ void normalizeCFG( // eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func); - SharedIRBuilder sharedBuilder(func->getModule()); - sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); - CFGNormalizationContext context = {&sharedBuilder, options.sink}; + CFGNormalizationContext context = {sharedBuilder, options.sink}; CFGNormalizationPass cfgPass(context); List<IRBlock*> workList; @@ -622,7 +615,7 @@ void normalizeCFG( } disableIRValidationAtInsert(); - constructSSA(&sharedBuilder, func); + constructSSA(sharedBuilder, func); enableIRValidationAtInsert(); } diff --git a/source/slang/slang-ir-autodiff-cfg-norm.h b/source/slang/slang-ir-autodiff-cfg-norm.h index 2a39f7695..f256d8ce8 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.h +++ b/source/slang/slang-ir-autodiff-cfg-norm.h @@ -19,6 +19,7 @@ namespace Slang /// "after" block. /// void normalizeCFG( + SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func, IRCFGNormalizationPass const& options = IRCFGNormalizationPass()); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index abe3f718c..f60412efb 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -527,6 +527,9 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns auto diffArg = lookupDiffInst(origArg, nullptr); if (diffArg) newArgs.add(diffArg); + else + newArgs.add( + getDifferentialZeroOfType(builder, origArg->getDataType())); } } @@ -576,16 +579,15 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns return InstPair(nullptr, nullptr); } -InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder*, IRInst* origInst) { switch(origInst->getOp()) { case kIROp_FloatLit: - return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f)); + case kIROp_IntLit: + return InstPair(origInst, nullptr); case kIROp_VoidLit: return InstPair(origInst, origInst); - case kIROp_IntLit: - return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0)); } getSink()->diagnose( @@ -943,9 +945,15 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build SLANG_ASSERT(primalVal); auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue()); SLANG_ASSERT(diffPrimalVal); + auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue()); + if (!primalDiffVal) + primalDiffVal = getDifferentialZeroOfType(builder, origInst->getPrimalValue()->getDataType()); SLANG_ASSERT(primalDiffVal); + auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); + if (!diffDiffVal) + diffDiffVal = getDifferentialZeroOfType(builder, origInst->getDifferentialValue()->getDataType()); SLANG_ASSERT(diffDiffVal); auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType()); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 8f218293d..0f2ceceb4 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -510,7 +510,7 @@ namespace Slang eliminateMultiLevelBreakForFunc(func->getModule(), func); IRCFGNormalizationPass cfgPass = {this->getSink()}; - normalizeCFG(func); + normalizeCFG(autoDiffSharedContext->sharedBuilder, func); AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &diffTypeContext; @@ -544,6 +544,8 @@ namespace Slang // reversible. if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc))) return diffPropagateFunc; + + autoDiffSharedContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); // Forward transcribe the clone of the original func. ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>( diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 91374e006..520c6d276 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -911,12 +911,14 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst // Tag the differential inst using a decoration (if it doesn't have one) if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() && - !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>()) + !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() && + !as<IRConstant>(pair.differential)) { // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential // instead. // - builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType())); + auto primalType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsDifferential(pair.differential, primalType); } break; diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index f87aa7751..5aad6e3a3 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1694,9 +1694,6 @@ struct DiffTransposePass IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer()); - if (isDifferentialInst(inst)) - builder->markInstAsDifferential(newInst); - return newInst; } @@ -1725,6 +1722,11 @@ struct DiffTransposePass builder->setInsertAfter(operand); IRInst* newOperand = promoteToType(builder, targetType, operand); + + if (isDifferentialInst(operand)) + builder->markInstAsDifferential( + newOperand, tryGetPrimalTypeFromDiffInst(fwdInst)); + newOperands.add(newOperand); needNewInst = true; @@ -1747,7 +1749,8 @@ struct DiffTransposePass builder->setInsertLoc(oldLoc); if (isDifferentialInst(fwdInst)) - builder->markInstAsDifferential(newInst); + builder->markInstAsDifferential( + newInst, tryGetPrimalTypeFromDiffInst(fwdInst)); return newInst; } diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 44cb2aa09..daf6e44d4 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -322,7 +322,10 @@ struct ExtractPrimalFuncContext { if (shouldStoreInst(inst)) { - builder.setInsertAfter(inst); + if (as<IRParam>(inst)) + builder.setInsertBefore(block->getFirstOrdinaryInst()); + else + builder.setInsertAfter(inst); storeInst(builder, inst, outIntermediary); } else if (inst->getOp() == kIROp_Var) |
