From f94b2f7a328a898c5e3dc1389d08e0b7ce6e092e Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 18 Aug 2023 12:48:46 -0700 Subject: Allow loop counters to be used as constexpr arguments. (#3139) * Allow loop counters to be used as constexpr arguments. * Fix. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-ir-constexpr.cpp | 278 ++++++++++++++++++++++++------------ 1 file changed, 183 insertions(+), 95 deletions(-) (limited to 'source/slang/slang-ir-constexpr.cpp') diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index dbfec9ae7..5731238ee 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -3,6 +3,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -74,6 +75,7 @@ bool opCanBeConstExpr(IROp op) case kIROp_IntLit: case kIROp_FloatLit: case kIROp_BoolLit: + case kIROp_Param: case kIROp_Add: case kIROp_Sub: case kIROp_Mul: @@ -142,13 +144,35 @@ bool opCanBeConstExpr(IROp op) } } -bool opCanBeConstExpr(IRInst* value) +bool opCanBeConstExprByForwardPass(IRInst* value) { // TODO: realistically need to special-case `call` // operations here, so that we check whether the // callee function is fixed/known, and if it is // whether it has been declared as constant-foldable + if (value->getOp() == kIROp_Param) + return false; + return opCanBeConstExpr(value->getOp()); +} + +IRLoop* isLoopPhi(IRParam* param) +{ + IRBlock* bb = cast(param->getParent()); + for (auto pred : bb->getPredecessors()) + { + auto loop = as(pred->getTerminator()); + if (loop) + { + return loop; + } + } + return nullptr; +} +bool opCanBeConstExprByBackwardPass(IRInst* value) +{ + if (value->getOp() == kIROp_Param) + return isLoopPhi(as(value)); return opCanBeConstExpr(value->getOp()); } @@ -159,6 +183,110 @@ void markConstExpr( Slang::markConstExpr(context->getBuilder(), value); } +void maybeAddToWorkList( + PropagateConstExprContext* context, + IRInst* gv) +{ + if (!context->onWorkList.contains(gv)) + { + context->workList.add(gv); + context->onWorkList.add(gv); + } +} + +bool maybeMarkConstExprBackwardPass( + PropagateConstExprContext* context, + IRInst* value) +{ + if (isConstExpr(value)) + return false; + + if (!opCanBeConstExprByBackwardPass(value)) + return false; + + markConstExpr(context, value); + + // TODO: we should only allow function parameters to be + // changed to be `constexpr` when we are compiling "application" + // code, and not library code. + // (Or eventually we'd have a rule that only non-`public` symbols + // can have this kind of propagation applied). + + if (value->getOp() == kIROp_Param) + { + auto param = (IRParam*)value; + auto block = (IRBlock*)param->parent; + auto code = block->getParent(); + + if (block == code->getFirstBlock()) + { + // We've just changed a function parameter to + // be `constexpr`. We need to remember that + // fact so taht we can mark callers of this + // function as `constexpr` themselves. + + for (auto u = code->firstUse; u; u = u->nextUse) + { + auto user = u->getUser(); + + switch (user->getOp()) + { + case kIROp_Call: + { + auto inst = (IRCall*)user; + auto caller = as(inst->getParent()->getParent()); + maybeAddToWorkList(context, caller); + } + break; + + default: + break; + } + } + } + } + + return true; +} + +// Produce an estimate on whether a loop is unrollable, by checking +// if there is at least one exit path where all the conditions along +// the control path has a constexpr condition. +bool isUnrollableLoop(IRLoop* loop) +{ + // A loop is unrollable if all exit conditions are constexpr. + auto breakBlock = loop->getBreakBlock(); + auto func = getParentFunc(loop); + auto domTree = loop->getModule()->findOrCreateDominatorTree(func); + List workList; + bool result = false; + for (auto pred : breakBlock->getPredecessors()) + { + workList.clear(); + workList.add(pred); + for (Index i = 0; i < workList.getCount(); i++) + { + auto block = workList[i]; + if (auto ifElse = as(block->getTerminator())) + { + if (!isConstExpr(ifElse->getCondition())) + return false; + } + else if (auto switchInst = as(block->getTerminator())) + { + if (!isConstExpr(ifElse->getCondition())) + return false; + } + auto idom = domTree->getImmediateDominator(block); + if (idom && idom != loop->getParent()) + workList.add(idom); + } + // We found at least one exit path that is constexpr, + // we will regard this loop as unrollable. + result = true; + } + return result; +} // Propagate `constexpr`-ness in a forward direction, from the // operands of an instruction to the instruction itself. @@ -179,7 +307,7 @@ bool propagateConstExprForward( continue; // Is the operation one that we can actually make be constexpr? - if(!opCanBeConstExpr(ii)) + if(!opCanBeConstExprByForwardPass(ii)) continue; // Are all arguments `constexpr`? @@ -211,71 +339,6 @@ bool propagateConstExprForward( } } -void maybeAddToWorkList( - PropagateConstExprContext* context, - IRInst* gv) -{ - if( !context->onWorkList.contains(gv) ) - { - context->workList.add(gv); - context->onWorkList.add(gv); - } -} - -bool maybeMarkConstExpr( - PropagateConstExprContext* context, - IRInst* value) -{ - if(isConstExpr(value)) - return false; - - if(!opCanBeConstExpr(value)) - return false; - - markConstExpr(context, value); - - // TODO: we should only allow function parameters to be - // changed to be `constexpr` when we are compiling "application" - // code, and not library code. - // (Or eventually we'd have a rule that only non-`public` symbols - // can have this kind of propagation applied). - - if(value->getOp() == kIROp_Param) - { - auto param = (IRParam*) value; - auto block = (IRBlock*) param->parent; - auto code = block->getParent(); - - if(block == code->getFirstBlock()) - { - // We've just changed a function parameter to - // be `constexpr`. We need to remember that - // fact so taht we can mark callers of this - // function as `constexpr` themselves. - - for( auto u = code->firstUse; u; u = u->nextUse ) - { - auto user = u->getUser(); - - switch( user->getOp() ) - { - case kIROp_Call: - { - auto inst = (IRCall*) user; - auto caller = as(inst->getParent()->getParent()); - maybeAddToWorkList(context, caller); - } - break; - - default: - break; - } - } - } - } - - return true; -} // Propagate `constexpr`-ness in a backward direction, from an instruction // to its operands. @@ -312,10 +375,10 @@ bool propagateConstExprBackward( if(isConstExpr(arg)) continue; - if(!opCanBeConstExpr(arg)) + if(!opCanBeConstExprByBackwardPass(arg)) continue; - if( maybeMarkConstExpr(context, arg) ) + if( maybeMarkConstExprBackwardPass(context, arg) ) { changedThisIteration = true; } @@ -383,7 +446,7 @@ bool propagateConstExprBackward( if(isConstExpr(param)) { - if(maybeMarkConstExpr(context, arg)) + if(maybeMarkConstExprBackwardPass(context, arg)) { changedThisIteration = true; } @@ -411,7 +474,7 @@ bool propagateConstExprBackward( auto arg = callInst->getOperand(firstCallArg + pp); if( isConstExpr(paramType) ) { - if( maybeMarkConstExpr(context, arg) ) + if(maybeMarkConstExprBackwardPass(context, arg) ) { changedThisIteration = true; } @@ -439,15 +502,14 @@ bool propagateConstExprBackward( for(auto pred : bb->getPredecessors()) { - auto terminator = pred->getLastInst(); - if(terminator->getOp() != kIROp_unconditionalBranch) + auto terminator = as(pred->getLastInst()); + if(!terminator) continue; - UInt operandIndex = paramIndex + 1; - SLANG_RELEASE_ASSERT(operandIndex < terminator->getOperandCount()); + SLANG_RELEASE_ASSERT(paramIndex < terminator->getArgCount()); - auto operand = terminator->getOperand(operandIndex); - if( maybeMarkConstExpr(context, operand) ) + auto operand = terminator->getArg(paramIndex); + if(maybeMarkConstExprBackwardPass(context, operand) ) { changedThisIteration = true; } @@ -463,7 +525,6 @@ bool propagateConstExprBackward( anyChanges = true; } } - // Validate use of `constexpr` within a function (in particular, // diagnose places where a value that must be contexpr depends // on a value that cannot be) @@ -484,9 +545,32 @@ void validateConstExpr( for( UInt aa = 0; aa < argCount; ++aa ) { auto arg = ii->getOperand(aa); - - if( !isConstExpr(arg) ) + bool shouldDiagnose = !isConstExpr(arg); + if (!shouldDiagnose) + { + if (auto param = as(arg)) + { + if (IRLoop * loopInst = isLoopPhi(param)) + { + // If the param is a phi node in a loop that + // does not depend on non-constexpr values, we + // can make it constexpr by force unrolling the + // loop, if the loop is unrollable. + if (isUnrollableLoop(loopInst)) + { + if (!loopInst->findDecoration()) + { + context->getBuilder()->addLoopForceUnrollDecoration(loopInst, 0); + } + continue; + } + shouldDiagnose = true; + } + } + } + if (shouldDiagnose) { + // Diagnose the failure. context->getSink()->diagnose(ii->sourceLoc, Diagnostics::needCompileTimeConstant); @@ -499,6 +583,24 @@ void validateConstExpr( } } +void propagateInFunc(PropagateConstExprContext* context, IRGlobalValueWithCode* code) +{ + for (;;) + { + bool anyChange = false; + if (propagateConstExprForward(context, code)) + { + anyChange = true; + } + if (propagateConstExprBackward(context, code)) + { + anyChange = true; + } + if (!anyChange) + break; + } +} + void propagateConstExpr( IRModule* module, DiagnosticSink* sink) @@ -548,21 +650,7 @@ void propagateConstExpr( case kIROp_GlobalVar: { IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv; - - for( ;;) - { - bool anyChange = false; - if( propagateConstExprForward(&context, code) ) - { - anyChange = true; - } - if( propagateConstExprBackward(&context, code) ) - { - anyChange = true; - } - if(!anyChange) - break; - } + propagateInFunc(&context, code); } break; } -- cgit v1.2.3