summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-18 12:48:46 -0700
committerGitHub <noreply@github.com>2023-08-18 12:48:46 -0700
commitf94b2f7a328a898c5e3dc1389d08e0b7ce6e092e (patch)
tree129f39703b10b5684825ce8626d3a4e908970fad /source
parent4de3d9b1987fddf8d95efe75aab592282b672a97 (diff)
Allow loop counters to be used as constexpr arguments. (#3139)
* Allow loop counters to be used as constexpr arguments. * Fix. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp2
-rw-r--r--source/slang/slang-ir-constexpr.cpp278
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp43
4 files changed, 225 insertions, 99 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index d46fba5e3..8a88e69ad 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -790,6 +790,8 @@ void normalizeCFG(
disableIRValidationAtInsert();
constructSSA(module, func);
enableIRValidationAtInsert();
+
+ module->invalidateAnalysisForInst(func);
#if _DEBUG
validateIRInst(maybeFindOuterGeneric(func));
#endif
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index dbfec9ae7..5731238ee 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -3,6 +3,7 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
+#include "slang-ir-dominators.h"
namespace Slang {
@@ -74,6 +75,7 @@ bool opCanBeConstExpr(IROp op)
case kIROp_IntLit:
case kIROp_FloatLit:
case kIROp_BoolLit:
+ case kIROp_Param:
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
@@ -142,13 +144,35 @@ bool opCanBeConstExpr(IROp op)
}
}
-bool opCanBeConstExpr(IRInst* value)
+bool opCanBeConstExprByForwardPass(IRInst* value)
{
// TODO: realistically need to special-case `call`
// operations here, so that we check whether the
// callee function is fixed/known, and if it is
// whether it has been declared as constant-foldable
+ if (value->getOp() == kIROp_Param)
+ return false;
+ return opCanBeConstExpr(value->getOp());
+}
+
+IRLoop* isLoopPhi(IRParam* param)
+{
+ IRBlock* bb = cast<IRBlock>(param->getParent());
+ for (auto pred : bb->getPredecessors())
+ {
+ auto loop = as<IRLoop>(pred->getTerminator());
+ if (loop)
+ {
+ return loop;
+ }
+ }
+ return nullptr;
+}
+bool opCanBeConstExprByBackwardPass(IRInst* value)
+{
+ if (value->getOp() == kIROp_Param)
+ return isLoopPhi(as<IRParam>(value));
return opCanBeConstExpr(value->getOp());
}
@@ -159,6 +183,110 @@ void markConstExpr(
Slang::markConstExpr(context->getBuilder(), value);
}
+void maybeAddToWorkList(
+ PropagateConstExprContext* context,
+ IRInst* gv)
+{
+ if (!context->onWorkList.contains(gv))
+ {
+ context->workList.add(gv);
+ context->onWorkList.add(gv);
+ }
+}
+
+bool maybeMarkConstExprBackwardPass(
+ PropagateConstExprContext* context,
+ IRInst* value)
+{
+ if (isConstExpr(value))
+ return false;
+
+ if (!opCanBeConstExprByBackwardPass(value))
+ return false;
+
+ markConstExpr(context, value);
+
+ // TODO: we should only allow function parameters to be
+ // changed to be `constexpr` when we are compiling "application"
+ // code, and not library code.
+ // (Or eventually we'd have a rule that only non-`public` symbols
+ // can have this kind of propagation applied).
+
+ if (value->getOp() == kIROp_Param)
+ {
+ auto param = (IRParam*)value;
+ auto block = (IRBlock*)param->parent;
+ auto code = block->getParent();
+
+ if (block == code->getFirstBlock())
+ {
+ // We've just changed a function parameter to
+ // be `constexpr`. We need to remember that
+ // fact so taht we can mark callers of this
+ // function as `constexpr` themselves.
+
+ for (auto u = code->firstUse; u; u = u->nextUse)
+ {
+ auto user = u->getUser();
+
+ switch (user->getOp())
+ {
+ case kIROp_Call:
+ {
+ auto inst = (IRCall*)user;
+ auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent());
+ maybeAddToWorkList(context, caller);
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
+// Produce an estimate on whether a loop is unrollable, by checking
+// if there is at least one exit path where all the conditions along
+// the control path has a constexpr condition.
+bool isUnrollableLoop(IRLoop* loop)
+{
+ // A loop is unrollable if all exit conditions are constexpr.
+ auto breakBlock = loop->getBreakBlock();
+ auto func = getParentFunc(loop);
+ auto domTree = loop->getModule()->findOrCreateDominatorTree(func);
+ List<IRBlock*> workList;
+ bool result = false;
+ for (auto pred : breakBlock->getPredecessors())
+ {
+ workList.clear();
+ workList.add(pred);
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto block = workList[i];
+ if (auto ifElse = as<IRConditionalBranch>(block->getTerminator()))
+ {
+ if (!isConstExpr(ifElse->getCondition()))
+ return false;
+ }
+ else if (auto switchInst = as<IRSwitch>(block->getTerminator()))
+ {
+ if (!isConstExpr(ifElse->getCondition()))
+ return false;
+ }
+ auto idom = domTree->getImmediateDominator(block);
+ if (idom && idom != loop->getParent())
+ workList.add(idom);
+ }
+ // We found at least one exit path that is constexpr,
+ // we will regard this loop as unrollable.
+ result = true;
+ }
+ return result;
+}
// Propagate `constexpr`-ness in a forward direction, from the
// operands of an instruction to the instruction itself.
@@ -179,7 +307,7 @@ bool propagateConstExprForward(
continue;
// Is the operation one that we can actually make be constexpr?
- if(!opCanBeConstExpr(ii))
+ if(!opCanBeConstExprByForwardPass(ii))
continue;
// Are all arguments `constexpr`?
@@ -211,71 +339,6 @@ bool propagateConstExprForward(
}
}
-void maybeAddToWorkList(
- PropagateConstExprContext* context,
- IRInst* gv)
-{
- if( !context->onWorkList.contains(gv) )
- {
- context->workList.add(gv);
- context->onWorkList.add(gv);
- }
-}
-
-bool maybeMarkConstExpr(
- PropagateConstExprContext* context,
- IRInst* value)
-{
- if(isConstExpr(value))
- return false;
-
- if(!opCanBeConstExpr(value))
- return false;
-
- markConstExpr(context, value);
-
- // TODO: we should only allow function parameters to be
- // changed to be `constexpr` when we are compiling "application"
- // code, and not library code.
- // (Or eventually we'd have a rule that only non-`public` symbols
- // can have this kind of propagation applied).
-
- if(value->getOp() == kIROp_Param)
- {
- auto param = (IRParam*) value;
- auto block = (IRBlock*) param->parent;
- auto code = block->getParent();
-
- if(block == code->getFirstBlock())
- {
- // We've just changed a function parameter to
- // be `constexpr`. We need to remember that
- // fact so taht we can mark callers of this
- // function as `constexpr` themselves.
-
- for( auto u = code->firstUse; u; u = u->nextUse )
- {
- auto user = u->getUser();
-
- switch( user->getOp() )
- {
- case kIROp_Call:
- {
- auto inst = (IRCall*) user;
- auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent());
- maybeAddToWorkList(context, caller);
- }
- break;
-
- default:
- break;
- }
- }
- }
- }
-
- return true;
-}
// Propagate `constexpr`-ness in a backward direction, from an instruction
// to its operands.
@@ -312,10 +375,10 @@ bool propagateConstExprBackward(
if(isConstExpr(arg))
continue;
- if(!opCanBeConstExpr(arg))
+ if(!opCanBeConstExprByBackwardPass(arg))
continue;
- if( maybeMarkConstExpr(context, arg) )
+ if( maybeMarkConstExprBackwardPass(context, arg) )
{
changedThisIteration = true;
}
@@ -383,7 +446,7 @@ bool propagateConstExprBackward(
if(isConstExpr(param))
{
- if(maybeMarkConstExpr(context, arg))
+ if(maybeMarkConstExprBackwardPass(context, arg))
{
changedThisIteration = true;
}
@@ -411,7 +474,7 @@ bool propagateConstExprBackward(
auto arg = callInst->getOperand(firstCallArg + pp);
if( isConstExpr(paramType) )
{
- if( maybeMarkConstExpr(context, arg) )
+ if(maybeMarkConstExprBackwardPass(context, arg) )
{
changedThisIteration = true;
}
@@ -439,15 +502,14 @@ bool propagateConstExprBackward(
for(auto pred : bb->getPredecessors())
{
- auto terminator = pred->getLastInst();
- if(terminator->getOp() != kIROp_unconditionalBranch)
+ auto terminator = as<IRUnconditionalBranch>(pred->getLastInst());
+ if(!terminator)
continue;
- UInt operandIndex = paramIndex + 1;
- SLANG_RELEASE_ASSERT(operandIndex < terminator->getOperandCount());
+ SLANG_RELEASE_ASSERT(paramIndex < terminator->getArgCount());
- auto operand = terminator->getOperand(operandIndex);
- if( maybeMarkConstExpr(context, operand) )
+ auto operand = terminator->getArg(paramIndex);
+ if(maybeMarkConstExprBackwardPass(context, operand) )
{
changedThisIteration = true;
}
@@ -463,7 +525,6 @@ bool propagateConstExprBackward(
anyChanges = true;
}
}
-
// Validate use of `constexpr` within a function (in particular,
// diagnose places where a value that must be contexpr depends
// on a value that cannot be)
@@ -484,9 +545,32 @@ void validateConstExpr(
for( UInt aa = 0; aa < argCount; ++aa )
{
auto arg = ii->getOperand(aa);
-
- if( !isConstExpr(arg) )
+ bool shouldDiagnose = !isConstExpr(arg);
+ if (!shouldDiagnose)
+ {
+ if (auto param = as<IRParam>(arg))
+ {
+ if (IRLoop * loopInst = isLoopPhi(param))
+ {
+ // If the param is a phi node in a loop that
+ // does not depend on non-constexpr values, we
+ // can make it constexpr by force unrolling the
+ // loop, if the loop is unrollable.
+ if (isUnrollableLoop(loopInst))
+ {
+ if (!loopInst->findDecoration<IRForceUnrollDecoration>())
+ {
+ context->getBuilder()->addLoopForceUnrollDecoration(loopInst, 0);
+ }
+ continue;
+ }
+ shouldDiagnose = true;
+ }
+ }
+ }
+ if (shouldDiagnose)
{
+
// Diagnose the failure.
context->getSink()->diagnose(ii->sourceLoc, Diagnostics::needCompileTimeConstant);
@@ -499,6 +583,24 @@ void validateConstExpr(
}
}
+void propagateInFunc(PropagateConstExprContext* context, IRGlobalValueWithCode* code)
+{
+ for (;;)
+ {
+ bool anyChange = false;
+ if (propagateConstExprForward(context, code))
+ {
+ anyChange = true;
+ }
+ if (propagateConstExprBackward(context, code))
+ {
+ anyChange = true;
+ }
+ if (!anyChange)
+ break;
+ }
+}
+
void propagateConstExpr(
IRModule* module,
DiagnosticSink* sink)
@@ -548,21 +650,7 @@ void propagateConstExpr(
case kIROp_GlobalVar:
{
IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv;
-
- for( ;;)
- {
- bool anyChange = false;
- if( propagateConstExprForward(&context, code) )
- {
- anyChange = true;
- }
- if( propagateConstExprBackward(&context, code) )
- {
- anyChange = true;
- }
- if(!anyChange)
- break;
- }
+ propagateInFunc(&context, code);
}
break;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 64663120d..22355bd7e 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -643,7 +643,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(FlattenDecoration, flatten, 0, 0)
INST(LoopControlDecoration, loopControl, 1, 0)
INST(LoopMaxItersDecoration, loopMaxIters, 1, 0)
- INST(LoopInferredMaxItersDecoration, loopInferredMaxIters, 2, 0)
INST(LoopExitPrimalValueDecoration, loopExitPrimalValue, 2, 0)
INST(IntrinsicOpDecoration, intrinsicOp, 1, 0)
/* TargetSpecificDecoration */
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 489a89287..0d7e27bc4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -29,6 +29,7 @@
#include "slang-ir-lower-error-handling.h"
#include "slang-ir-obfuscate-loc.h"
#include "slang-ir-use-uninitialized-out-param.h"
+#include "slang-ir-peephole.h"
#include "slang-mangle.h"
#include "slang-type-layout.h"
@@ -8761,6 +8762,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
context->setGlobalValue(decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
+ bool isInline = false;
+
for (auto modifier : decl->modifiers)
{
if (as<RequiresNVAPIAttribute>(modifier))
@@ -8858,10 +8861,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
else if (as<UnsafeForceInlineEarlyAttribute>(modifier))
{
getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
+ isInline = true;
}
else if (as<ForceInlineAttribute>(modifier))
{
getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
+ isInline = true;
}
else if (as<TreatAsDifferentiableAttribute>(modifier))
{
@@ -8871,6 +8876,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op);
getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op);
+ isInline = true;
}
else if (as<UserDefinedDerivativeAttribute>(modifier) || as<PrimalSubstituteAttribute>(modifier))
{
@@ -8930,6 +8936,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
+ if (!isInline)
+ {
+ // If there are any constant expr rate parameters, we should inline this function.
+ // TODO: consider specializing them instead of inlining.
+ for (auto param : decl->getParameters())
+ {
+ if (param->hasModifier<ConstExprModifier>())
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
+ isInline = true;
+ break;
+ }
+ }
+ }
+
if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
{
if (decl->body)
@@ -9698,6 +9719,8 @@ RefPtr<IRModule> generateIRForTranslationUnit(
constructSSA(module);
simplifyCFG(module);
applySparseConditionalConstantPropagation(module, compileRequest->getSink());
+ peepholeOptimize(module);
+
for (auto inst : module->getGlobalInsts())
{
if (auto func = as<IRGlobalValueWithCode>(inst))
@@ -9732,14 +9755,28 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// - If sccp is unable to eliminate the outer 'if' then we end up with
// duplicated code the the conditional value. Users don't tend to put
// huge gobs of code in the conditional expression in loops however.
+
invertLoops(module);
// Next, attempt to promote local variables to SSA
// temporaries and do basic simplifications.
//
- constructSSA(module);
- simplifyCFG(module);
- applySparseConditionalConstantPropagation(module, compileRequest->getSink());
+ for (;;)
+ {
+ bool changed = false;
+ performMandatoryEarlyInlining(module);
+ changed |= constructSSA(module);
+ simplifyCFG(module);
+ changed |= applySparseConditionalConstantPropagation(module, compileRequest->getSink());
+ changed |= peepholeOptimize(module);
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRGlobalValueWithCode>(inst))
+ eliminateDeadCode(func);
+ }
+ if (!changed)
+ break;
+ }
// Propagate `constexpr`-ness through the dataflow graph (and the
// call graph) based on constraints imposed by different instructions.