summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-18 12:48:46 -0700
committerGitHub <noreply@github.com>2023-08-18 12:48:46 -0700
commitf94b2f7a328a898c5e3dc1389d08e0b7ce6e092e (patch)
tree129f39703b10b5684825ce8626d3a4e908970fad
parent4de3d9b1987fddf8d95efe75aab592282b672a97 (diff)
Allow loop counters to be used as constexpr arguments. (#3139)
* Allow loop counters to be used as constexpr arguments. * Fix. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp2
-rw-r--r--source/slang/slang-ir-constexpr.cpp278
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp43
-rw-r--r--tests/diagnostics/autodiff-data-flow-2.slang1
-rw-r--r--tests/diagnostics/autodiff-data-flow.slang1
-rw-r--r--tests/diagnostics/constexpr-error.slang.expected6
-rw-r--r--tests/expected-failure.txt1
-rw-r--r--tests/language-feature/constants/constexpr-loop.slang24
-rw-r--r--tests/language-feature/constants/constexpr-loop.slang.expected.txt2
10 files changed, 260 insertions, 99 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index d46fba5e3..8a88e69ad 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -790,6 +790,8 @@ void normalizeCFG(
disableIRValidationAtInsert();
constructSSA(module, func);
enableIRValidationAtInsert();
+
+ module->invalidateAnalysisForInst(func);
#if _DEBUG
validateIRInst(maybeFindOuterGeneric(func));
#endif
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index dbfec9ae7..5731238ee 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -3,6 +3,7 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
+#include "slang-ir-dominators.h"
namespace Slang {
@@ -74,6 +75,7 @@ bool opCanBeConstExpr(IROp op)
case kIROp_IntLit:
case kIROp_FloatLit:
case kIROp_BoolLit:
+ case kIROp_Param:
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
@@ -142,13 +144,35 @@ bool opCanBeConstExpr(IROp op)
}
}
-bool opCanBeConstExpr(IRInst* value)
+bool opCanBeConstExprByForwardPass(IRInst* value)
{
// TODO: realistically need to special-case `call`
// operations here, so that we check whether the
// callee function is fixed/known, and if it is
// whether it has been declared as constant-foldable
+ if (value->getOp() == kIROp_Param)
+ return false;
+ return opCanBeConstExpr(value->getOp());
+}
+
+IRLoop* isLoopPhi(IRParam* param)
+{
+ IRBlock* bb = cast<IRBlock>(param->getParent());
+ for (auto pred : bb->getPredecessors())
+ {
+ auto loop = as<IRLoop>(pred->getTerminator());
+ if (loop)
+ {
+ return loop;
+ }
+ }
+ return nullptr;
+}
+bool opCanBeConstExprByBackwardPass(IRInst* value)
+{
+ if (value->getOp() == kIROp_Param)
+ return isLoopPhi(as<IRParam>(value));
return opCanBeConstExpr(value->getOp());
}
@@ -159,6 +183,110 @@ void markConstExpr(
Slang::markConstExpr(context->getBuilder(), value);
}
+void maybeAddToWorkList(
+ PropagateConstExprContext* context,
+ IRInst* gv)
+{
+ if (!context->onWorkList.contains(gv))
+ {
+ context->workList.add(gv);
+ context->onWorkList.add(gv);
+ }
+}
+
+bool maybeMarkConstExprBackwardPass(
+ PropagateConstExprContext* context,
+ IRInst* value)
+{
+ if (isConstExpr(value))
+ return false;
+
+ if (!opCanBeConstExprByBackwardPass(value))
+ return false;
+
+ markConstExpr(context, value);
+
+ // TODO: we should only allow function parameters to be
+ // changed to be `constexpr` when we are compiling "application"
+ // code, and not library code.
+ // (Or eventually we'd have a rule that only non-`public` symbols
+ // can have this kind of propagation applied).
+
+ if (value->getOp() == kIROp_Param)
+ {
+ auto param = (IRParam*)value;
+ auto block = (IRBlock*)param->parent;
+ auto code = block->getParent();
+
+ if (block == code->getFirstBlock())
+ {
+ // We've just changed a function parameter to
+ // be `constexpr`. We need to remember that
+ // fact so taht we can mark callers of this
+ // function as `constexpr` themselves.
+
+ for (auto u = code->firstUse; u; u = u->nextUse)
+ {
+ auto user = u->getUser();
+
+ switch (user->getOp())
+ {
+ case kIROp_Call:
+ {
+ auto inst = (IRCall*)user;
+ auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent());
+ maybeAddToWorkList(context, caller);
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
+// Produce an estimate on whether a loop is unrollable, by checking
+// if there is at least one exit path where all the conditions along
+// the control path has a constexpr condition.
+bool isUnrollableLoop(IRLoop* loop)
+{
+ // A loop is unrollable if all exit conditions are constexpr.
+ auto breakBlock = loop->getBreakBlock();
+ auto func = getParentFunc(loop);
+ auto domTree = loop->getModule()->findOrCreateDominatorTree(func);
+ List<IRBlock*> workList;
+ bool result = false;
+ for (auto pred : breakBlock->getPredecessors())
+ {
+ workList.clear();
+ workList.add(pred);
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto block = workList[i];
+ if (auto ifElse = as<IRConditionalBranch>(block->getTerminator()))
+ {
+ if (!isConstExpr(ifElse->getCondition()))
+ return false;
+ }
+ else if (auto switchInst = as<IRSwitch>(block->getTerminator()))
+ {
+ if (!isConstExpr(ifElse->getCondition()))
+ return false;
+ }
+ auto idom = domTree->getImmediateDominator(block);
+ if (idom && idom != loop->getParent())
+ workList.add(idom);
+ }
+ // We found at least one exit path that is constexpr,
+ // we will regard this loop as unrollable.
+ result = true;
+ }
+ return result;
+}
// Propagate `constexpr`-ness in a forward direction, from the
// operands of an instruction to the instruction itself.
@@ -179,7 +307,7 @@ bool propagateConstExprForward(
continue;
// Is the operation one that we can actually make be constexpr?
- if(!opCanBeConstExpr(ii))
+ if(!opCanBeConstExprByForwardPass(ii))
continue;
// Are all arguments `constexpr`?
@@ -211,71 +339,6 @@ bool propagateConstExprForward(
}
}
-void maybeAddToWorkList(
- PropagateConstExprContext* context,
- IRInst* gv)
-{
- if( !context->onWorkList.contains(gv) )
- {
- context->workList.add(gv);
- context->onWorkList.add(gv);
- }
-}
-
-bool maybeMarkConstExpr(
- PropagateConstExprContext* context,
- IRInst* value)
-{
- if(isConstExpr(value))
- return false;
-
- if(!opCanBeConstExpr(value))
- return false;
-
- markConstExpr(context, value);
-
- // TODO: we should only allow function parameters to be
- // changed to be `constexpr` when we are compiling "application"
- // code, and not library code.
- // (Or eventually we'd have a rule that only non-`public` symbols
- // can have this kind of propagation applied).
-
- if(value->getOp() == kIROp_Param)
- {
- auto param = (IRParam*) value;
- auto block = (IRBlock*) param->parent;
- auto code = block->getParent();
-
- if(block == code->getFirstBlock())
- {
- // We've just changed a function parameter to
- // be `constexpr`. We need to remember that
- // fact so taht we can mark callers of this
- // function as `constexpr` themselves.
-
- for( auto u = code->firstUse; u; u = u->nextUse )
- {
- auto user = u->getUser();
-
- switch( user->getOp() )
- {
- case kIROp_Call:
- {
- auto inst = (IRCall*) user;
- auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent());
- maybeAddToWorkList(context, caller);
- }
- break;
-
- default:
- break;
- }
- }
- }
- }
-
- return true;
-}
// Propagate `constexpr`-ness in a backward direction, from an instruction
// to its operands.
@@ -312,10 +375,10 @@ bool propagateConstExprBackward(
if(isConstExpr(arg))
continue;
- if(!opCanBeConstExpr(arg))
+ if(!opCanBeConstExprByBackwardPass(arg))
continue;
- if( maybeMarkConstExpr(context, arg) )
+ if( maybeMarkConstExprBackwardPass(context, arg) )
{
changedThisIteration = true;
}
@@ -383,7 +446,7 @@ bool propagateConstExprBackward(
if(isConstExpr(param))
{
- if(maybeMarkConstExpr(context, arg))
+ if(maybeMarkConstExprBackwardPass(context, arg))
{
changedThisIteration = true;
}
@@ -411,7 +474,7 @@ bool propagateConstExprBackward(
auto arg = callInst->getOperand(firstCallArg + pp);
if( isConstExpr(paramType) )
{
- if( maybeMarkConstExpr(context, arg) )
+ if(maybeMarkConstExprBackwardPass(context, arg) )
{
changedThisIteration = true;
}
@@ -439,15 +502,14 @@ bool propagateConstExprBackward(
for(auto pred : bb->getPredecessors())
{
- auto terminator = pred->getLastInst();
- if(terminator->getOp() != kIROp_unconditionalBranch)
+ auto terminator = as<IRUnconditionalBranch>(pred->getLastInst());
+ if(!terminator)
continue;
- UInt operandIndex = paramIndex + 1;
- SLANG_RELEASE_ASSERT(operandIndex < terminator->getOperandCount());
+ SLANG_RELEASE_ASSERT(paramIndex < terminator->getArgCount());
- auto operand = terminator->getOperand(operandIndex);
- if( maybeMarkConstExpr(context, operand) )
+ auto operand = terminator->getArg(paramIndex);
+ if(maybeMarkConstExprBackwardPass(context, operand) )
{
changedThisIteration = true;
}
@@ -463,7 +525,6 @@ bool propagateConstExprBackward(
anyChanges = true;
}
}
-
// Validate use of `constexpr` within a function (in particular,
// diagnose places where a value that must be contexpr depends
// on a value that cannot be)
@@ -484,9 +545,32 @@ void validateConstExpr(
for( UInt aa = 0; aa < argCount; ++aa )
{
auto arg = ii->getOperand(aa);
-
- if( !isConstExpr(arg) )
+ bool shouldDiagnose = !isConstExpr(arg);
+ if (!shouldDiagnose)
+ {
+ if (auto param = as<IRParam>(arg))
+ {
+ if (IRLoop * loopInst = isLoopPhi(param))
+ {
+ // If the param is a phi node in a loop that
+ // does not depend on non-constexpr values, we
+ // can make it constexpr by force unrolling the
+ // loop, if the loop is unrollable.
+ if (isUnrollableLoop(loopInst))
+ {
+ if (!loopInst->findDecoration<IRForceUnrollDecoration>())
+ {
+ context->getBuilder()->addLoopForceUnrollDecoration(loopInst, 0);
+ }
+ continue;
+ }
+ shouldDiagnose = true;
+ }
+ }
+ }
+ if (shouldDiagnose)
{
+
// Diagnose the failure.
context->getSink()->diagnose(ii->sourceLoc, Diagnostics::needCompileTimeConstant);
@@ -499,6 +583,24 @@ void validateConstExpr(
}
}
+void propagateInFunc(PropagateConstExprContext* context, IRGlobalValueWithCode* code)
+{
+ for (;;)
+ {
+ bool anyChange = false;
+ if (propagateConstExprForward(context, code))
+ {
+ anyChange = true;
+ }
+ if (propagateConstExprBackward(context, code))
+ {
+ anyChange = true;
+ }
+ if (!anyChange)
+ break;
+ }
+}
+
void propagateConstExpr(
IRModule* module,
DiagnosticSink* sink)
@@ -548,21 +650,7 @@ void propagateConstExpr(
case kIROp_GlobalVar:
{
IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv;
-
- for( ;;)
- {
- bool anyChange = false;
- if( propagateConstExprForward(&context, code) )
- {
- anyChange = true;
- }
- if( propagateConstExprBackward(&context, code) )
- {
- anyChange = true;
- }
- if(!anyChange)
- break;
- }
+ propagateInFunc(&context, code);
}
break;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 64663120d..22355bd7e 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -643,7 +643,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(FlattenDecoration, flatten, 0, 0)
INST(LoopControlDecoration, loopControl, 1, 0)
INST(LoopMaxItersDecoration, loopMaxIters, 1, 0)
- INST(LoopInferredMaxItersDecoration, loopInferredMaxIters, 2, 0)
INST(LoopExitPrimalValueDecoration, loopExitPrimalValue, 2, 0)
INST(IntrinsicOpDecoration, intrinsicOp, 1, 0)
/* TargetSpecificDecoration */
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 489a89287..0d7e27bc4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -29,6 +29,7 @@
#include "slang-ir-lower-error-handling.h"
#include "slang-ir-obfuscate-loc.h"
#include "slang-ir-use-uninitialized-out-param.h"
+#include "slang-ir-peephole.h"
#include "slang-mangle.h"
#include "slang-type-layout.h"
@@ -8761,6 +8762,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
context->setGlobalValue(decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
+ bool isInline = false;
+
for (auto modifier : decl->modifiers)
{
if (as<RequiresNVAPIAttribute>(modifier))
@@ -8858,10 +8861,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
else if (as<UnsafeForceInlineEarlyAttribute>(modifier))
{
getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
+ isInline = true;
}
else if (as<ForceInlineAttribute>(modifier))
{
getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
+ isInline = true;
}
else if (as<TreatAsDifferentiableAttribute>(modifier))
{
@@ -8871,6 +8876,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op);
getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op);
+ isInline = true;
}
else if (as<UserDefinedDerivativeAttribute>(modifier) || as<PrimalSubstituteAttribute>(modifier))
{
@@ -8930,6 +8936,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
+ if (!isInline)
+ {
+ // If there are any constant expr rate parameters, we should inline this function.
+ // TODO: consider specializing them instead of inlining.
+ for (auto param : decl->getParameters())
+ {
+ if (param->hasModifier<ConstExprModifier>())
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
+ isInline = true;
+ break;
+ }
+ }
+ }
+
if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
{
if (decl->body)
@@ -9698,6 +9719,8 @@ RefPtr<IRModule> generateIRForTranslationUnit(
constructSSA(module);
simplifyCFG(module);
applySparseConditionalConstantPropagation(module, compileRequest->getSink());
+ peepholeOptimize(module);
+
for (auto inst : module->getGlobalInsts())
{
if (auto func = as<IRGlobalValueWithCode>(inst))
@@ -9732,14 +9755,28 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// - If sccp is unable to eliminate the outer 'if' then we end up with
// duplicated code the the conditional value. Users don't tend to put
// huge gobs of code in the conditional expression in loops however.
+
invertLoops(module);
// Next, attempt to promote local variables to SSA
// temporaries and do basic simplifications.
//
- constructSSA(module);
- simplifyCFG(module);
- applySparseConditionalConstantPropagation(module, compileRequest->getSink());
+ for (;;)
+ {
+ bool changed = false;
+ performMandatoryEarlyInlining(module);
+ changed |= constructSSA(module);
+ simplifyCFG(module);
+ changed |= applySparseConditionalConstantPropagation(module, compileRequest->getSink());
+ changed |= peepholeOptimize(module);
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRGlobalValueWithCode>(inst))
+ eliminateDeadCode(func);
+ }
+ if (!changed)
+ break;
+ }
// Propagate `constexpr`-ness through the dataflow graph (and the
// call graph) based on constraints imposed by different instructions.
diff --git a/tests/diagnostics/autodiff-data-flow-2.slang b/tests/diagnostics/autodiff-data-flow-2.slang
index 3148c6a41..42aee0d01 100644
--- a/tests/diagnostics/autodiff-data-flow-2.slang
+++ b/tests/diagnostics/autodiff-data-flow-2.slang
@@ -28,6 +28,7 @@ float h(float x)
// error: dynamic loop without [MaxIters] or [ForceUnroll]
for (int i = 0; i < (int)x; i++)
{
+ no_diff debugBreak();
}
return val;
diff --git a/tests/diagnostics/autodiff-data-flow.slang b/tests/diagnostics/autodiff-data-flow.slang
index e8d9502e4..bbade0e0a 100644
--- a/tests/diagnostics/autodiff-data-flow.slang
+++ b/tests/diagnostics/autodiff-data-flow.slang
@@ -28,6 +28,7 @@ void g(float x)
for (int i = 0; i < 5; i++) // Not ok, we can't infer the loop iterations because the body modifies induction var.
{
i = (int)x;
+ no_diff debugBreak();
}
return;
}
diff --git a/tests/diagnostics/constexpr-error.slang.expected b/tests/diagnostics/constexpr-error.slang.expected
index f6c27b006..6f124fe34 100644
--- a/tests/diagnostics/constexpr-error.slang.expected
+++ b/tests/diagnostics/constexpr-error.slang.expected
@@ -6,9 +6,15 @@ tests/diagnostics/constexpr-error.slang(27): error 40006: expected a compile-tim
tests/diagnostics/constexpr-error.slang(35): error 40006: expected a compile-time constant
result += t.Sample(s, uv, int2(ii));
^
+tests/diagnostics/constexpr-error.slang(39): error 40006: expected a compile-time constant
+ for(uint jj = 0; jj < uv.y; jj++)
+ ^~
tests/diagnostics/constexpr-error.slang(41): error 40006: expected a compile-time constant
result += t.Sample(s, uv, int2(jj));
^
+tests/diagnostics/constexpr-error.slang(39): error 40006: expected a compile-time constant
+ for(uint jj = 0; jj < uv.y; jj++)
+ ^~
}
standard output = {
}
diff --git a/tests/expected-failure.txt b/tests/expected-failure.txt
index f40435ec1..98cf36724 100644
--- a/tests/expected-failure.txt
+++ b/tests/expected-failure.txt
@@ -152,6 +152,7 @@ tests/hlsl-intrinsic/wave-mask/wave-vector.slang.3 (vk)
tests/hlsl-intrinsic/wave-mask/wave.slang.3 (vk)
tests/ir/loop-unroll-0.slang.1 (vk)
tests/ir/string-literal-hash.slang.1 (vk)
+tests/language-feature/constants/constexpr-loop.slang.1 (vk)
tests/language-feature/general-inline.slang.1 (vk)
tests/language-feature/simple-inline.slang.1 (vk)
tests/language-feature/initializer-lists/default-init-16bit-types.slang (vk)
diff --git a/tests/language-feature/constants/constexpr-loop.slang b/tests/language-feature/constants/constexpr-loop.slang
new file mode 100644
index 000000000..31b4294a5
--- /dev/null
+++ b/tests/language-feature/constants/constexpr-loop.slang
@@ -0,0 +1,24 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT: set g_texture = Texture2D(size=8, content = one)
+//TEST_INPUT: set g_sampler = Sampler
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+
+Texture2D g_texture;
+SamplerState g_sampler;
+RWStructuredBuffer<float> outputBuffer;
+
+float4 sample(float2 uv, constexpr int2 ofs )
+{
+ return g_texture.SampleLevel(g_sampler, uv, 0.0, ofs);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ float4 result = 0;
+ for (int i = 0; i < 3; i++)
+ result += sample(float2(1.0), int2(i, i));
+ outputBuffer[dispatchThreadID.x] = result.x;
+}
diff --git a/tests/language-feature/constants/constexpr-loop.slang.expected.txt b/tests/language-feature/constants/constexpr-loop.slang.expected.txt
new file mode 100644
index 000000000..a6122d7ce
--- /dev/null
+++ b/tests/language-feature/constants/constexpr-loop.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+3.000000