diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-13 20:31:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-13 20:31:30 -0700 |
| commit | 09408e32d7c0ccebf38fe31b5d2ddf4b1cd128e4 (patch) | |
| tree | 65e0f711de39f2d095aed8f15798668975375e08 /source | |
| parent | 27d7961db15ed5890d2ad0eff1218e26dcdaf82c (diff) | |
Allow multi-level breaks to break out of `switch` statements. (#2451)
* Allow multi-level breaks to break out of `switch` statements.
* Rename loop->region.
* Add `[ForceInline]` attribute.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-eliminate-multilevel-break.cpp | 176 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 22 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 5 |
9 files changed, 156 insertions, 72 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index b54c70236..1711102da 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2711,6 +2711,9 @@ attribute_syntax [__extern] : ExternAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [__unsafeForceInlineEarly] : UnsafeForceInlineEarlyAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [ForceInline] : ForceInlineAttribute; + __attributeTarget(FuncDecl) attribute_syntax [DllImport(modulePath: String)] : DllImportAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 2d967445a..8868b7a1d 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -910,6 +910,14 @@ class UnsafeForceInlineEarlyAttribute : public Attribute SLANG_AST_CLASS(UnsafeForceInlineEarlyAttribute) }; +// A `[ForceInline]` attribute indicates that the callee should be inlined +// by the Slang compiler. +// +class ForceInlineAttribute : public Attribute +{ + SLANG_AST_CLASS(ForceInlineAttribute) +}; + /// An attribute that marks a type declaration as either allowing or /// disallowing the type to be inherited from in other modules. class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) }; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 6d85e1ce0..4666e80d8 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -386,10 +386,8 @@ Result linkAndOptimizeIR( #endif validateIRModuleIfEnabled(codeGenContext, irModule); - // Inline calls to any functions marked with [__unsafeInlineEarly] again, - // since we may be missing out cases prevented by the generic constructs - // that we just lowered out. - performMandatoryEarlyInlining(irModule); + // Inline calls to any functions marked with [__unsafeInlineEarly] or [ForceInline]. + performForceInlining(irModule); // Specialization can introduce dead code that could trip // up downstream passes like type legalization, so we diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp index 269b74aad..2330991c5 100644 --- a/source/slang/slang-ir-eliminate-multilevel-break.cpp +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -13,20 +13,32 @@ struct EliminateMultiLevelBreakContext { IRModule* irModule; - struct LoopInfo : RefObject + struct BreakableRegionInfo : RefObject { - LoopInfo* parent = nullptr; + BreakableRegionInfo* parent = nullptr; int level = 0; - IRLoop* loopInst; + IRInst* headerInst; List<IRBlock*> blocks; HashSet<IRBlock*> blockSet; - List<RefPtr<LoopInfo>> childLoops; - IRBlock* getBreakBlock() { return loopInst->getBreakBlock(); } + List<RefPtr<BreakableRegionInfo>> childRegions; + IRBlock* getBreakBlock() + { + switch (headerInst->getOp()) + { + case kIROp_loop: + return as<IRLoop>(headerInst)->getBreakBlock(); + case kIROp_Switch: + return as<IRSwitch>(headerInst)->getBreakLabel(); + default: + SLANG_UNREACHABLE("Unknown breakable inst"); + } + } + template<typename Func> void forEach(const Func& f) { f(this); - for (auto child : childLoops) + for (auto child : childRegions) child->forEach(f); } }; @@ -34,44 +46,56 @@ struct EliminateMultiLevelBreakContext struct MultiLevelBreakInfo { IRUnconditionalBranch* breakInst; - LoopInfo* currentLoop; - LoopInfo* breakTargetLoop; + BreakableRegionInfo* currentRegion; + BreakableRegionInfo* breakTargetRegion; }; struct FuncContext { - List<RefPtr<LoopInfo>> loops; + List<RefPtr<BreakableRegionInfo>> regions; HashSet<IRBlock*> breakBlocks; - Dictionary<IRBlock*, LoopInfo*> mapBreakBlockToLoop; - Dictionary<IRBlock*, LoopInfo*> mapBlockToLoop; + Dictionary<IRBlock*, BreakableRegionInfo*> mapBreakBlockToRegion; + Dictionary<IRBlock*, BreakableRegionInfo*> mapBlockToRegion; HashSet<IRBlock*> processedBlocks; List<MultiLevelBreakInfo> multiLevelBreaks; - void collectLoopBlocks(LoopInfo& info) + void collectBreakableRegionBlocks(BreakableRegionInfo& info) { - auto startBlock = info.loopInst->getTargetBlock(); - info.blockSet.Add(startBlock); - info.blocks.add(startBlock); - breakBlocks.Add(info.loopInst->getBreakBlock()); + auto successors = as<IRBlock>(info.headerInst->getParent())->getSuccessors(); + for (auto successor : successors) + { + if (info.blockSet.Add(successor)) + info.blocks.add(successor); + } + // Push break block to a stack so we can easily check if a block is a break block in its + // parent regions. + breakBlocks.Add(info.getBreakBlock()); for (Index i = 0; i < info.blocks.getCount(); i++) { auto block = info.blocks[i]; if (!processedBlocks.Add(block)) continue; - if (auto loopInst = as<IRLoop>(block->getTerminator())) + switch (block->getTerminator()->getOp()) { - RefPtr<LoopInfo> childLoop = new LoopInfo(); - childLoop->loopInst = loopInst; - childLoop->parent = &info; - childLoop->level = info.level + 1; - collectLoopBlocks(*childLoop); - info.childLoops.add(childLoop); - block = loopInst->getBreakBlock(); - if (info.blockSet.Add(block)) + case kIROp_loop: + case kIROp_Switch: { - info.blocks.add(block); + // Both region and switch insts mark the start a breakable region. + RefPtr<BreakableRegionInfo> childRegion = new BreakableRegionInfo(); + childRegion->headerInst = block->getTerminator(); + childRegion->parent = &info; + childRegion->level = info.level + 1; + collectBreakableRegionBlocks(*childRegion); + info.childRegions.add(childRegion); + block = childRegion->getBreakBlock(); + if (info.blockSet.Add(block)) + { + info.blocks.add(block); + } + continue; } - continue; + default: + break; } for (auto succ : block->getSuccessors()) { @@ -82,7 +106,9 @@ struct EliminateMultiLevelBreakContext } } } - breakBlocks.Remove(info.loopInst->getBreakBlock()); + + // Pop the break block from stack since we are no longer processing the region. + breakBlocks.Remove(info.getBreakBlock()); } void gatherInfo(IRGlobalValueWithCode* func) @@ -92,23 +118,30 @@ struct EliminateMultiLevelBreakContext if (processedBlocks.Contains(block)) continue; auto terminator = block->getTerminator(); - if (auto loop = as<IRLoop>(terminator)) + switch (terminator->getOp()) { - RefPtr<LoopInfo> loopInfo = new LoopInfo(); - loopInfo->loopInst = loop; - collectLoopBlocks(*loopInfo); - loops.add(loopInfo); + case kIROp_loop: + case kIROp_Switch: + { + RefPtr<BreakableRegionInfo> regionInfo = new BreakableRegionInfo(); + regionInfo->headerInst = terminator; + collectBreakableRegionBlocks(*regionInfo); + regions.add(regionInfo); + } + break; + default: + break; } } - for (auto& l : loops) + for (auto& l : regions) { l->forEach( - [&](LoopInfo* loop) + [&](BreakableRegionInfo* region) { - mapBreakBlockToLoop.Add(loop->loopInst->getBreakBlock(), loop); - for (auto block : loop->blocks) - mapBlockToLoop.Add(block, loop); + mapBreakBlockToRegion.Add(region->getBreakBlock(), region); + for (auto block : region->blocks) + mapBlockToRegion.Add(block, region); }); } @@ -119,18 +152,18 @@ struct EliminateMultiLevelBreakContext { if (as<IRLoop>(terminator)) continue; - LoopInfo* breakLoop = nullptr; - LoopInfo* currentLoop = nullptr; - if (!mapBreakBlockToLoop.TryGetValue(branch->getTargetBlock(), breakLoop)) + BreakableRegionInfo* breakTargetRegion = nullptr; + BreakableRegionInfo* currentRegion = nullptr; + if (!mapBreakBlockToRegion.TryGetValue(branch->getTargetBlock(), breakTargetRegion)) continue; - if (mapBlockToLoop.TryGetValue(block, currentLoop)) + if (mapBlockToRegion.TryGetValue(block, currentRegion)) { - if (currentLoop != breakLoop) + if (currentRegion != breakTargetRegion) { MultiLevelBreakInfo breakInfo; breakInfo.breakInst = branch; - breakInfo.breakTargetLoop = breakLoop; - breakInfo.currentLoop = currentLoop; + breakInfo.breakTargetRegion = breakTargetRegion; + breakInfo.currentRegion = currentRegion; multiLevelBreaks.add(breakInfo); } } @@ -165,7 +198,7 @@ struct EliminateMultiLevelBreakContext IRBuilder builder(&sharedBuilder); builder.setInsertInto(func); - OrderedHashSet<LoopInfo*> skippedOverLoops; + OrderedHashSet<BreakableRegionInfo*> skippedOverRegions; auto unreachableBlock = builder.emitBlock(); builder.setInsertInto(unreachableBlock); builder.emitUnreachable(); @@ -174,64 +207,65 @@ struct EliminateMultiLevelBreakContext // Rewrite multi-level breaks with single level break + target level argument. for (auto breakInfo : funcInfo.multiLevelBreaks) { - auto loop = breakInfo.currentLoop; - while (loop) + auto region = breakInfo.currentRegion; + while (region) { - skippedOverLoops.Add(loop); - loop = loop->parent; - if (loop == breakInfo.breakTargetLoop) + skippedOverRegions.Add(region); + region = region->parent; + if (region == breakInfo.breakTargetRegion) break; } builder.setInsertBefore(breakInfo.breakInst); - auto targetLevelInst = builder.getIntValue(builder.getIntType(), breakInfo.breakTargetLoop->level); - builder.emitBranch(breakInfo.currentLoop->getBreakBlock(), 1, &targetLevelInst); + auto targetLevelInst = builder.getIntValue(builder.getIntType(), breakInfo.breakTargetRegion->level); + builder.emitBranch(breakInfo.currentRegion->getBreakBlock(), 1, &targetLevelInst); breakInfo.breakInst->removeAndDeallocate(); } // Rewrite skipped-over break blocks to accept a target level argument. builder.setInsertInto(func); - OrderedDictionary<IRBlock*, int> mapNewBreakBlockToLoopLevel; - for (auto skippedLoop : skippedOverLoops) + OrderedDictionary<IRBlock*, int> mapNewBreakBlockToRegionLevel; + for (auto skippedRegion : skippedOverRegions) { - auto breakBlock = skippedLoop->getBreakBlock(); + auto breakBlock = skippedRegion->getBreakBlock(); // The existing break block cannot have parameters. We assume that PHI-elimination is // run before this pass. SLANG_RELEASE_ASSERT(breakBlock->getFirstParam() == nullptr); - // The new CFG structure will be: newBreakBlock --> newBreakBodyBlock { IfElse (-->oldBreakBlock, -->outerBreakBlock) } - // `newBreakBlock` defines the `IRParam` for the break target, then immediately jumps to `newBreakBodyBlock` for the actual branch. We need this - // separation to avoid introducing critical edge to the CFG (blocks cannot have more - // than 1 predecessors and more than 1 successors at the same time). + // The new CFG structure will be: newBreakBlock --> newBreakBodyBlock { IfElse + // (-->oldBreakBlock, -->outerBreakBlock) } `newBreakBlock` defines the `IRParam` for + // the break target, then immediately jumps to `newBreakBodyBlock` for the actual + // branch. We need this separation to avoid introducing critical edge to the CFG (blocks + // cannot have more than 1 predecessors and more than 1 successors at the same time). auto jumpToOuterBlock = builder.createBlock(); auto newBreakBodyBlock = builder.createBlock(); auto newBreakBlock = builder.createBlock(); newBreakBlock->insertBefore(breakBlock); newBreakBodyBlock->insertAfter(breakBlock); jumpToOuterBlock->insertAfter(newBreakBlock); - mapNewBreakBlockToLoopLevel[newBreakBlock] = skippedLoop->level; + mapNewBreakBlockToRegionLevel[newBreakBlock] = skippedRegion->level; breakBlock->replaceUsesWith(newBreakBlock); builder.setInsertInto(newBreakBlock); auto targetLevelParam = builder.emitParam(builder.getIntType()); builder.emitBranch(newBreakBodyBlock); builder.setInsertInto(newBreakBodyBlock); - auto levelNeq = builder.emitNeq(targetLevelParam, builder.getIntValue(builder.getIntType(), skippedLoop->level)); + auto levelNeq = builder.emitNeq(targetLevelParam, builder.getIntValue(builder.getIntType(), skippedRegion->level)); builder.emitIfElse(levelNeq, jumpToOuterBlock, breakBlock, unreachableBlock); builder.setInsertInto(jumpToOuterBlock); - if (skippedOverLoops.Contains(skippedLoop->parent)) + if (skippedOverRegions.Contains(skippedRegion->parent)) { - builder.emitBranch(skippedLoop->parent->getBreakBlock(), 1, (IRInst**)&targetLevelParam); + builder.emitBranch(skippedRegion->parent->getBreakBlock(), 1, (IRInst**)&targetLevelParam); } else { - builder.emitBranch(skippedLoop->parent->getBreakBlock()); + builder.emitBranch(skippedRegion->parent->getBreakBlock()); } } - // Once we have rewritten loops' break blocks with additional targetLevel parameter, all + // Once we have rewritten regions' break blocks with additional targetLevel parameter, all // original branches into that block without a parameter will now need to provide a default - // value equal to the level of its corresponding loop. - for (auto breakBlockKV : mapNewBreakBlockToLoopLevel) + // value equal to the level of its corresponding region. + for (auto breakBlockKV : mapNewBreakBlockToRegionLevel) { auto breakBlock = breakBlockKV.Key; auto level = breakBlockKV.Value; @@ -252,6 +286,12 @@ struct EliminateMultiLevelBreakContext // For complex branches, insert an intermediate block so we can specify the // target index argument. { + if (user->getOp() == kIROp_Switch && &(as<IRSwitch>(user)->breakLabel) == use) + { + // If this is the "breakLabel" operand of the original switch inst, don't do anything + // since it is not an actual branch. + continue; + } builder.setInsertInto(func); auto tmpBlock = builder.createBlock(); tmpBlock->insertAfter(user->getParent()); diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 70d98d10c..8f5412fa6 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -516,6 +516,28 @@ void performMandatoryEarlyInlining(IRModule* module) pass.considerAllCallSites(); } +struct ForceInliningPass : InliningPassBase +{ + typedef InliningPassBase Super; + + ForceInliningPass(IRModule* module) + : Super(module) + {} + + bool shouldInline(CallSiteInfo const& info) + { + if (info.callee->findDecoration<IRForceInlineDecoration>() || + info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>()) + return true; + return false; + } +}; + +void performForceInlining(IRModule* module) +{ + ForceInliningPass pass(module); + pass.considerAllCallSites(); +} // Defined in slang-ir-specialize-resource.cpp bool isResourceType(IRType* type); diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h index 8ac23f6b0..70c7c3321 100644 --- a/source/slang/slang-ir-inline.h +++ b/source/slang/slang-ir-inline.h @@ -9,6 +9,9 @@ namespace Slang /// Inline any call sites to functions marked `[unsafeForceInlineEarly]` void performMandatoryEarlyInlining(IRModule* module); + /// Inline any call sites to functions marked `[ForceInline]` + void performForceInlining(IRModule* module); + /// Inline calls to functions that returns a resource/sampler via either return value or output parameter. void performGLSLResourceReturnFunctionInlining(IRModule* module); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f63a093aa..8f8261af5 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -639,6 +639,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// An `[unsafeForceInlineEarly]` decoration specifies that calls to this function should be inline after initial codegen INST(UnsafeForceInlineEarlyDecoration, unsafeForceInlineEarly, 0, 0) + /// A `[ForceInline]` decoration indicates the callee should be inlined by the Slang compiler. + INST(ForceInlineDecoration, ForceInline, 0, 0) + /// A `[naturalSizeAndAlignment(s,a)]` decoration is attached to a type to indicate that is has natural size `s` and alignment `a` INST(NaturalSizeAndAlignmentDecoration, naturalSizeAndAlignment, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index e09d60c3e..98bc6a0a2 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -501,6 +501,8 @@ struct IRFormatDecoration : IRDecoration IR_SIMPLE_DECORATION(UnsafeForceInlineEarlyDecoration) +IR_SIMPLE_DECORATION(ForceInlineDecoration) + struct IRNaturalSizeAndAlignmentDecoration : IRDecoration { enum { kOp = kIROp_NaturalSizeAndAlignmentDecoration }; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index bd724aa9d..b03f3ae62 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8102,6 +8102,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); } + if (decl->findModifier<ForceInlineAttribute>()) + { + getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); + } + if (auto attr = decl->findModifier<CustomJVPAttribute>()) { auto loweredVal = lowerLValueExpr(this->context, attr->funcDeclRef); |
