summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h8
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp176
-rw-r--r--source/slang/slang-ir-inline.cpp22
-rw-r--r--source/slang/slang-ir-inline.h3
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp5
9 files changed, 156 insertions, 72 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index b54c70236..1711102da 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2711,6 +2711,9 @@ attribute_syntax [__extern] : ExternAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [__unsafeForceInlineEarly] : UnsafeForceInlineEarlyAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [ForceInline] : ForceInlineAttribute;
+
__attributeTarget(FuncDecl)
attribute_syntax [DllImport(modulePath: String)] : DllImportAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 2d967445a..8868b7a1d 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -910,6 +910,14 @@ class UnsafeForceInlineEarlyAttribute : public Attribute
SLANG_AST_CLASS(UnsafeForceInlineEarlyAttribute)
};
+// A `[ForceInline]` attribute indicates that the callee should be inlined
+// by the Slang compiler.
+//
+class ForceInlineAttribute : public Attribute
+{
+ SLANG_AST_CLASS(ForceInlineAttribute)
+};
+
/// An attribute that marks a type declaration as either allowing or
/// disallowing the type to be inherited from in other modules.
class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) };
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 6d85e1ce0..4666e80d8 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -386,10 +386,8 @@ Result linkAndOptimizeIR(
#endif
validateIRModuleIfEnabled(codeGenContext, irModule);
- // Inline calls to any functions marked with [__unsafeInlineEarly] again,
- // since we may be missing out cases prevented by the generic constructs
- // that we just lowered out.
- performMandatoryEarlyInlining(irModule);
+ // Inline calls to any functions marked with [__unsafeInlineEarly] or [ForceInline].
+ performForceInlining(irModule);
// Specialization can introduce dead code that could trip
// up downstream passes like type legalization, so we
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
index 269b74aad..2330991c5 100644
--- a/source/slang/slang-ir-eliminate-multilevel-break.cpp
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -13,20 +13,32 @@ struct EliminateMultiLevelBreakContext
{
IRModule* irModule;
- struct LoopInfo : RefObject
+ struct BreakableRegionInfo : RefObject
{
- LoopInfo* parent = nullptr;
+ BreakableRegionInfo* parent = nullptr;
int level = 0;
- IRLoop* loopInst;
+ IRInst* headerInst;
List<IRBlock*> blocks;
HashSet<IRBlock*> blockSet;
- List<RefPtr<LoopInfo>> childLoops;
- IRBlock* getBreakBlock() { return loopInst->getBreakBlock(); }
+ List<RefPtr<BreakableRegionInfo>> childRegions;
+ IRBlock* getBreakBlock()
+ {
+ switch (headerInst->getOp())
+ {
+ case kIROp_loop:
+ return as<IRLoop>(headerInst)->getBreakBlock();
+ case kIROp_Switch:
+ return as<IRSwitch>(headerInst)->getBreakLabel();
+ default:
+ SLANG_UNREACHABLE("Unknown breakable inst");
+ }
+ }
+
template<typename Func>
void forEach(const Func& f)
{
f(this);
- for (auto child : childLoops)
+ for (auto child : childRegions)
child->forEach(f);
}
};
@@ -34,44 +46,56 @@ struct EliminateMultiLevelBreakContext
struct MultiLevelBreakInfo
{
IRUnconditionalBranch* breakInst;
- LoopInfo* currentLoop;
- LoopInfo* breakTargetLoop;
+ BreakableRegionInfo* currentRegion;
+ BreakableRegionInfo* breakTargetRegion;
};
struct FuncContext
{
- List<RefPtr<LoopInfo>> loops;
+ List<RefPtr<BreakableRegionInfo>> regions;
HashSet<IRBlock*> breakBlocks;
- Dictionary<IRBlock*, LoopInfo*> mapBreakBlockToLoop;
- Dictionary<IRBlock*, LoopInfo*> mapBlockToLoop;
+ Dictionary<IRBlock*, BreakableRegionInfo*> mapBreakBlockToRegion;
+ Dictionary<IRBlock*, BreakableRegionInfo*> mapBlockToRegion;
HashSet<IRBlock*> processedBlocks;
List<MultiLevelBreakInfo> multiLevelBreaks;
- void collectLoopBlocks(LoopInfo& info)
+ void collectBreakableRegionBlocks(BreakableRegionInfo& info)
{
- auto startBlock = info.loopInst->getTargetBlock();
- info.blockSet.Add(startBlock);
- info.blocks.add(startBlock);
- breakBlocks.Add(info.loopInst->getBreakBlock());
+ auto successors = as<IRBlock>(info.headerInst->getParent())->getSuccessors();
+ for (auto successor : successors)
+ {
+ if (info.blockSet.Add(successor))
+ info.blocks.add(successor);
+ }
+ // Push break block to a stack so we can easily check if a block is a break block in its
+ // parent regions.
+ breakBlocks.Add(info.getBreakBlock());
for (Index i = 0; i < info.blocks.getCount(); i++)
{
auto block = info.blocks[i];
if (!processedBlocks.Add(block))
continue;
- if (auto loopInst = as<IRLoop>(block->getTerminator()))
+ switch (block->getTerminator()->getOp())
{
- RefPtr<LoopInfo> childLoop = new LoopInfo();
- childLoop->loopInst = loopInst;
- childLoop->parent = &info;
- childLoop->level = info.level + 1;
- collectLoopBlocks(*childLoop);
- info.childLoops.add(childLoop);
- block = loopInst->getBreakBlock();
- if (info.blockSet.Add(block))
+ case kIROp_loop:
+ case kIROp_Switch:
{
- info.blocks.add(block);
+ // Both region and switch insts mark the start a breakable region.
+ RefPtr<BreakableRegionInfo> childRegion = new BreakableRegionInfo();
+ childRegion->headerInst = block->getTerminator();
+ childRegion->parent = &info;
+ childRegion->level = info.level + 1;
+ collectBreakableRegionBlocks(*childRegion);
+ info.childRegions.add(childRegion);
+ block = childRegion->getBreakBlock();
+ if (info.blockSet.Add(block))
+ {
+ info.blocks.add(block);
+ }
+ continue;
}
- continue;
+ default:
+ break;
}
for (auto succ : block->getSuccessors())
{
@@ -82,7 +106,9 @@ struct EliminateMultiLevelBreakContext
}
}
}
- breakBlocks.Remove(info.loopInst->getBreakBlock());
+
+ // Pop the break block from stack since we are no longer processing the region.
+ breakBlocks.Remove(info.getBreakBlock());
}
void gatherInfo(IRGlobalValueWithCode* func)
@@ -92,23 +118,30 @@ struct EliminateMultiLevelBreakContext
if (processedBlocks.Contains(block))
continue;
auto terminator = block->getTerminator();
- if (auto loop = as<IRLoop>(terminator))
+ switch (terminator->getOp())
{
- RefPtr<LoopInfo> loopInfo = new LoopInfo();
- loopInfo->loopInst = loop;
- collectLoopBlocks(*loopInfo);
- loops.add(loopInfo);
+ case kIROp_loop:
+ case kIROp_Switch:
+ {
+ RefPtr<BreakableRegionInfo> regionInfo = new BreakableRegionInfo();
+ regionInfo->headerInst = terminator;
+ collectBreakableRegionBlocks(*regionInfo);
+ regions.add(regionInfo);
+ }
+ break;
+ default:
+ break;
}
}
- for (auto& l : loops)
+ for (auto& l : regions)
{
l->forEach(
- [&](LoopInfo* loop)
+ [&](BreakableRegionInfo* region)
{
- mapBreakBlockToLoop.Add(loop->loopInst->getBreakBlock(), loop);
- for (auto block : loop->blocks)
- mapBlockToLoop.Add(block, loop);
+ mapBreakBlockToRegion.Add(region->getBreakBlock(), region);
+ for (auto block : region->blocks)
+ mapBlockToRegion.Add(block, region);
});
}
@@ -119,18 +152,18 @@ struct EliminateMultiLevelBreakContext
{
if (as<IRLoop>(terminator))
continue;
- LoopInfo* breakLoop = nullptr;
- LoopInfo* currentLoop = nullptr;
- if (!mapBreakBlockToLoop.TryGetValue(branch->getTargetBlock(), breakLoop))
+ BreakableRegionInfo* breakTargetRegion = nullptr;
+ BreakableRegionInfo* currentRegion = nullptr;
+ if (!mapBreakBlockToRegion.TryGetValue(branch->getTargetBlock(), breakTargetRegion))
continue;
- if (mapBlockToLoop.TryGetValue(block, currentLoop))
+ if (mapBlockToRegion.TryGetValue(block, currentRegion))
{
- if (currentLoop != breakLoop)
+ if (currentRegion != breakTargetRegion)
{
MultiLevelBreakInfo breakInfo;
breakInfo.breakInst = branch;
- breakInfo.breakTargetLoop = breakLoop;
- breakInfo.currentLoop = currentLoop;
+ breakInfo.breakTargetRegion = breakTargetRegion;
+ breakInfo.currentRegion = currentRegion;
multiLevelBreaks.add(breakInfo);
}
}
@@ -165,7 +198,7 @@ struct EliminateMultiLevelBreakContext
IRBuilder builder(&sharedBuilder);
builder.setInsertInto(func);
- OrderedHashSet<LoopInfo*> skippedOverLoops;
+ OrderedHashSet<BreakableRegionInfo*> skippedOverRegions;
auto unreachableBlock = builder.emitBlock();
builder.setInsertInto(unreachableBlock);
builder.emitUnreachable();
@@ -174,64 +207,65 @@ struct EliminateMultiLevelBreakContext
// Rewrite multi-level breaks with single level break + target level argument.
for (auto breakInfo : funcInfo.multiLevelBreaks)
{
- auto loop = breakInfo.currentLoop;
- while (loop)
+ auto region = breakInfo.currentRegion;
+ while (region)
{
- skippedOverLoops.Add(loop);
- loop = loop->parent;
- if (loop == breakInfo.breakTargetLoop)
+ skippedOverRegions.Add(region);
+ region = region->parent;
+ if (region == breakInfo.breakTargetRegion)
break;
}
builder.setInsertBefore(breakInfo.breakInst);
- auto targetLevelInst = builder.getIntValue(builder.getIntType(), breakInfo.breakTargetLoop->level);
- builder.emitBranch(breakInfo.currentLoop->getBreakBlock(), 1, &targetLevelInst);
+ auto targetLevelInst = builder.getIntValue(builder.getIntType(), breakInfo.breakTargetRegion->level);
+ builder.emitBranch(breakInfo.currentRegion->getBreakBlock(), 1, &targetLevelInst);
breakInfo.breakInst->removeAndDeallocate();
}
// Rewrite skipped-over break blocks to accept a target level argument.
builder.setInsertInto(func);
- OrderedDictionary<IRBlock*, int> mapNewBreakBlockToLoopLevel;
- for (auto skippedLoop : skippedOverLoops)
+ OrderedDictionary<IRBlock*, int> mapNewBreakBlockToRegionLevel;
+ for (auto skippedRegion : skippedOverRegions)
{
- auto breakBlock = skippedLoop->getBreakBlock();
+ auto breakBlock = skippedRegion->getBreakBlock();
// The existing break block cannot have parameters. We assume that PHI-elimination is
// run before this pass.
SLANG_RELEASE_ASSERT(breakBlock->getFirstParam() == nullptr);
- // The new CFG structure will be: newBreakBlock --> newBreakBodyBlock { IfElse (-->oldBreakBlock, -->outerBreakBlock) }
- // `newBreakBlock` defines the `IRParam` for the break target, then immediately jumps to `newBreakBodyBlock` for the actual branch. We need this
- // separation to avoid introducing critical edge to the CFG (blocks cannot have more
- // than 1 predecessors and more than 1 successors at the same time).
+ // The new CFG structure will be: newBreakBlock --> newBreakBodyBlock { IfElse
+ // (-->oldBreakBlock, -->outerBreakBlock) } `newBreakBlock` defines the `IRParam` for
+ // the break target, then immediately jumps to `newBreakBodyBlock` for the actual
+ // branch. We need this separation to avoid introducing critical edge to the CFG (blocks
+ // cannot have more than 1 predecessors and more than 1 successors at the same time).
auto jumpToOuterBlock = builder.createBlock();
auto newBreakBodyBlock = builder.createBlock();
auto newBreakBlock = builder.createBlock();
newBreakBlock->insertBefore(breakBlock);
newBreakBodyBlock->insertAfter(breakBlock);
jumpToOuterBlock->insertAfter(newBreakBlock);
- mapNewBreakBlockToLoopLevel[newBreakBlock] = skippedLoop->level;
+ mapNewBreakBlockToRegionLevel[newBreakBlock] = skippedRegion->level;
breakBlock->replaceUsesWith(newBreakBlock);
builder.setInsertInto(newBreakBlock);
auto targetLevelParam = builder.emitParam(builder.getIntType());
builder.emitBranch(newBreakBodyBlock);
builder.setInsertInto(newBreakBodyBlock);
- auto levelNeq = builder.emitNeq(targetLevelParam, builder.getIntValue(builder.getIntType(), skippedLoop->level));
+ auto levelNeq = builder.emitNeq(targetLevelParam, builder.getIntValue(builder.getIntType(), skippedRegion->level));
builder.emitIfElse(levelNeq, jumpToOuterBlock, breakBlock, unreachableBlock);
builder.setInsertInto(jumpToOuterBlock);
- if (skippedOverLoops.Contains(skippedLoop->parent))
+ if (skippedOverRegions.Contains(skippedRegion->parent))
{
- builder.emitBranch(skippedLoop->parent->getBreakBlock(), 1, (IRInst**)&targetLevelParam);
+ builder.emitBranch(skippedRegion->parent->getBreakBlock(), 1, (IRInst**)&targetLevelParam);
}
else
{
- builder.emitBranch(skippedLoop->parent->getBreakBlock());
+ builder.emitBranch(skippedRegion->parent->getBreakBlock());
}
}
- // Once we have rewritten loops' break blocks with additional targetLevel parameter, all
+ // Once we have rewritten regions' break blocks with additional targetLevel parameter, all
// original branches into that block without a parameter will now need to provide a default
- // value equal to the level of its corresponding loop.
- for (auto breakBlockKV : mapNewBreakBlockToLoopLevel)
+ // value equal to the level of its corresponding region.
+ for (auto breakBlockKV : mapNewBreakBlockToRegionLevel)
{
auto breakBlock = breakBlockKV.Key;
auto level = breakBlockKV.Value;
@@ -252,6 +286,12 @@ struct EliminateMultiLevelBreakContext
// For complex branches, insert an intermediate block so we can specify the
// target index argument.
{
+ if (user->getOp() == kIROp_Switch && &(as<IRSwitch>(user)->breakLabel) == use)
+ {
+ // If this is the "breakLabel" operand of the original switch inst, don't do anything
+ // since it is not an actual branch.
+ continue;
+ }
builder.setInsertInto(func);
auto tmpBlock = builder.createBlock();
tmpBlock->insertAfter(user->getParent());
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 70d98d10c..8f5412fa6 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -516,6 +516,28 @@ void performMandatoryEarlyInlining(IRModule* module)
pass.considerAllCallSites();
}
+struct ForceInliningPass : InliningPassBase
+{
+ typedef InliningPassBase Super;
+
+ ForceInliningPass(IRModule* module)
+ : Super(module)
+ {}
+
+ bool shouldInline(CallSiteInfo const& info)
+ {
+ if (info.callee->findDecoration<IRForceInlineDecoration>() ||
+ info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>())
+ return true;
+ return false;
+ }
+};
+
+void performForceInlining(IRModule* module)
+{
+ ForceInliningPass pass(module);
+ pass.considerAllCallSites();
+}
// Defined in slang-ir-specialize-resource.cpp
bool isResourceType(IRType* type);
diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h
index 8ac23f6b0..70c7c3321 100644
--- a/source/slang/slang-ir-inline.h
+++ b/source/slang/slang-ir-inline.h
@@ -9,6 +9,9 @@ namespace Slang
/// Inline any call sites to functions marked `[unsafeForceInlineEarly]`
void performMandatoryEarlyInlining(IRModule* module);
+ /// Inline any call sites to functions marked `[ForceInline]`
+ void performForceInlining(IRModule* module);
+
/// Inline calls to functions that returns a resource/sampler via either return value or output parameter.
void performGLSLResourceReturnFunctionInlining(IRModule* module);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f63a093aa..8f8261af5 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -639,6 +639,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// An `[unsafeForceInlineEarly]` decoration specifies that calls to this function should be inline after initial codegen
INST(UnsafeForceInlineEarlyDecoration, unsafeForceInlineEarly, 0, 0)
+ /// A `[ForceInline]` decoration indicates the callee should be inlined by the Slang compiler.
+ INST(ForceInlineDecoration, ForceInline, 0, 0)
+
/// A `[naturalSizeAndAlignment(s,a)]` decoration is attached to a type to indicate that is has natural size `s` and alignment `a`
INST(NaturalSizeAndAlignmentDecoration, naturalSizeAndAlignment, 2, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index e09d60c3e..98bc6a0a2 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -501,6 +501,8 @@ struct IRFormatDecoration : IRDecoration
IR_SIMPLE_DECORATION(UnsafeForceInlineEarlyDecoration)
+IR_SIMPLE_DECORATION(ForceInlineDecoration)
+
struct IRNaturalSizeAndAlignmentDecoration : IRDecoration
{
enum { kOp = kIROp_NaturalSizeAndAlignmentDecoration };
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index bd724aa9d..b03f3ae62 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8102,6 +8102,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
}
+ if (decl->findModifier<ForceInlineAttribute>())
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
+ }
+
if (auto attr = decl->findModifier<CustomJVPAttribute>())
{
auto loweredVal = lowerLValueExpr(this->context, attr->funcDeclRef);