diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-18 12:48:46 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-18 12:48:46 -0700 |
| commit | f94b2f7a328a898c5e3dc1389d08e0b7ce6e092e (patch) | |
| tree | 129f39703b10b5684825ce8626d3a4e908970fad | |
| parent | 4de3d9b1987fddf8d95efe75aab592282b672a97 (diff) | |
Allow loop counters to be used as constexpr arguments. (#3139)
* Allow loop counters to be used as constexpr arguments.
* Fix.
* Fix.
* Fix.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-constexpr.cpp | 278 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 43 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-data-flow-2.slang | 1 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-data-flow.slang | 1 | ||||
| -rw-r--r-- | tests/diagnostics/constexpr-error.slang.expected | 6 | ||||
| -rw-r--r-- | tests/expected-failure.txt | 1 | ||||
| -rw-r--r-- | tests/language-feature/constants/constexpr-loop.slang | 24 | ||||
| -rw-r--r-- | tests/language-feature/constants/constexpr-loop.slang.expected.txt | 2 |
10 files changed, 260 insertions, 99 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index d46fba5e3..8a88e69ad 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -790,6 +790,8 @@ void normalizeCFG( disableIRValidationAtInsert(); constructSSA(module, func); enableIRValidationAtInsert(); + + module->invalidateAnalysisForInst(func); #if _DEBUG validateIRInst(maybeFindOuterGeneric(func)); #endif diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index dbfec9ae7..5731238ee 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -3,6 +3,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -74,6 +75,7 @@ bool opCanBeConstExpr(IROp op) case kIROp_IntLit: case kIROp_FloatLit: case kIROp_BoolLit: + case kIROp_Param: case kIROp_Add: case kIROp_Sub: case kIROp_Mul: @@ -142,13 +144,35 @@ bool opCanBeConstExpr(IROp op) } } -bool opCanBeConstExpr(IRInst* value) +bool opCanBeConstExprByForwardPass(IRInst* value) { // TODO: realistically need to special-case `call` // operations here, so that we check whether the // callee function is fixed/known, and if it is // whether it has been declared as constant-foldable + if (value->getOp() == kIROp_Param) + return false; + return opCanBeConstExpr(value->getOp()); +} + +IRLoop* isLoopPhi(IRParam* param) +{ + IRBlock* bb = cast<IRBlock>(param->getParent()); + for (auto pred : bb->getPredecessors()) + { + auto loop = as<IRLoop>(pred->getTerminator()); + if (loop) + { + return loop; + } + } + return nullptr; +} +bool opCanBeConstExprByBackwardPass(IRInst* value) +{ + if (value->getOp() == kIROp_Param) + return isLoopPhi(as<IRParam>(value)); return opCanBeConstExpr(value->getOp()); } @@ -159,6 +183,110 @@ void markConstExpr( Slang::markConstExpr(context->getBuilder(), value); } +void maybeAddToWorkList( + PropagateConstExprContext* context, + IRInst* gv) +{ + if (!context->onWorkList.contains(gv)) + { + context->workList.add(gv); + context->onWorkList.add(gv); + } +} + +bool maybeMarkConstExprBackwardPass( + PropagateConstExprContext* context, + IRInst* value) +{ + if (isConstExpr(value)) + return false; + + if (!opCanBeConstExprByBackwardPass(value)) + return false; + + markConstExpr(context, value); + + // TODO: we should only allow function parameters to be + // changed to be `constexpr` when we are compiling "application" + // code, and not library code. + // (Or eventually we'd have a rule that only non-`public` symbols + // can have this kind of propagation applied). + + if (value->getOp() == kIROp_Param) + { + auto param = (IRParam*)value; + auto block = (IRBlock*)param->parent; + auto code = block->getParent(); + + if (block == code->getFirstBlock()) + { + // We've just changed a function parameter to + // be `constexpr`. We need to remember that + // fact so taht we can mark callers of this + // function as `constexpr` themselves. + + for (auto u = code->firstUse; u; u = u->nextUse) + { + auto user = u->getUser(); + + switch (user->getOp()) + { + case kIROp_Call: + { + auto inst = (IRCall*)user; + auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent()); + maybeAddToWorkList(context, caller); + } + break; + + default: + break; + } + } + } + } + + return true; +} + +// Produce an estimate on whether a loop is unrollable, by checking +// if there is at least one exit path where all the conditions along +// the control path has a constexpr condition. +bool isUnrollableLoop(IRLoop* loop) +{ + // A loop is unrollable if all exit conditions are constexpr. + auto breakBlock = loop->getBreakBlock(); + auto func = getParentFunc(loop); + auto domTree = loop->getModule()->findOrCreateDominatorTree(func); + List<IRBlock*> workList; + bool result = false; + for (auto pred : breakBlock->getPredecessors()) + { + workList.clear(); + workList.add(pred); + for (Index i = 0; i < workList.getCount(); i++) + { + auto block = workList[i]; + if (auto ifElse = as<IRConditionalBranch>(block->getTerminator())) + { + if (!isConstExpr(ifElse->getCondition())) + return false; + } + else if (auto switchInst = as<IRSwitch>(block->getTerminator())) + { + if (!isConstExpr(ifElse->getCondition())) + return false; + } + auto idom = domTree->getImmediateDominator(block); + if (idom && idom != loop->getParent()) + workList.add(idom); + } + // We found at least one exit path that is constexpr, + // we will regard this loop as unrollable. + result = true; + } + return result; +} // Propagate `constexpr`-ness in a forward direction, from the // operands of an instruction to the instruction itself. @@ -179,7 +307,7 @@ bool propagateConstExprForward( continue; // Is the operation one that we can actually make be constexpr? - if(!opCanBeConstExpr(ii)) + if(!opCanBeConstExprByForwardPass(ii)) continue; // Are all arguments `constexpr`? @@ -211,71 +339,6 @@ bool propagateConstExprForward( } } -void maybeAddToWorkList( - PropagateConstExprContext* context, - IRInst* gv) -{ - if( !context->onWorkList.contains(gv) ) - { - context->workList.add(gv); - context->onWorkList.add(gv); - } -} - -bool maybeMarkConstExpr( - PropagateConstExprContext* context, - IRInst* value) -{ - if(isConstExpr(value)) - return false; - - if(!opCanBeConstExpr(value)) - return false; - - markConstExpr(context, value); - - // TODO: we should only allow function parameters to be - // changed to be `constexpr` when we are compiling "application" - // code, and not library code. - // (Or eventually we'd have a rule that only non-`public` symbols - // can have this kind of propagation applied). - - if(value->getOp() == kIROp_Param) - { - auto param = (IRParam*) value; - auto block = (IRBlock*) param->parent; - auto code = block->getParent(); - - if(block == code->getFirstBlock()) - { - // We've just changed a function parameter to - // be `constexpr`. We need to remember that - // fact so taht we can mark callers of this - // function as `constexpr` themselves. - - for( auto u = code->firstUse; u; u = u->nextUse ) - { - auto user = u->getUser(); - - switch( user->getOp() ) - { - case kIROp_Call: - { - auto inst = (IRCall*) user; - auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent()); - maybeAddToWorkList(context, caller); - } - break; - - default: - break; - } - } - } - } - - return true; -} // Propagate `constexpr`-ness in a backward direction, from an instruction // to its operands. @@ -312,10 +375,10 @@ bool propagateConstExprBackward( if(isConstExpr(arg)) continue; - if(!opCanBeConstExpr(arg)) + if(!opCanBeConstExprByBackwardPass(arg)) continue; - if( maybeMarkConstExpr(context, arg) ) + if( maybeMarkConstExprBackwardPass(context, arg) ) { changedThisIteration = true; } @@ -383,7 +446,7 @@ bool propagateConstExprBackward( if(isConstExpr(param)) { - if(maybeMarkConstExpr(context, arg)) + if(maybeMarkConstExprBackwardPass(context, arg)) { changedThisIteration = true; } @@ -411,7 +474,7 @@ bool propagateConstExprBackward( auto arg = callInst->getOperand(firstCallArg + pp); if( isConstExpr(paramType) ) { - if( maybeMarkConstExpr(context, arg) ) + if(maybeMarkConstExprBackwardPass(context, arg) ) { changedThisIteration = true; } @@ -439,15 +502,14 @@ bool propagateConstExprBackward( for(auto pred : bb->getPredecessors()) { - auto terminator = pred->getLastInst(); - if(terminator->getOp() != kIROp_unconditionalBranch) + auto terminator = as<IRUnconditionalBranch>(pred->getLastInst()); + if(!terminator) continue; - UInt operandIndex = paramIndex + 1; - SLANG_RELEASE_ASSERT(operandIndex < terminator->getOperandCount()); + SLANG_RELEASE_ASSERT(paramIndex < terminator->getArgCount()); - auto operand = terminator->getOperand(operandIndex); - if( maybeMarkConstExpr(context, operand) ) + auto operand = terminator->getArg(paramIndex); + if(maybeMarkConstExprBackwardPass(context, operand) ) { changedThisIteration = true; } @@ -463,7 +525,6 @@ bool propagateConstExprBackward( anyChanges = true; } } - // Validate use of `constexpr` within a function (in particular, // diagnose places where a value that must be contexpr depends // on a value that cannot be) @@ -484,9 +545,32 @@ void validateConstExpr( for( UInt aa = 0; aa < argCount; ++aa ) { auto arg = ii->getOperand(aa); - - if( !isConstExpr(arg) ) + bool shouldDiagnose = !isConstExpr(arg); + if (!shouldDiagnose) + { + if (auto param = as<IRParam>(arg)) + { + if (IRLoop * loopInst = isLoopPhi(param)) + { + // If the param is a phi node in a loop that + // does not depend on non-constexpr values, we + // can make it constexpr by force unrolling the + // loop, if the loop is unrollable. + if (isUnrollableLoop(loopInst)) + { + if (!loopInst->findDecoration<IRForceUnrollDecoration>()) + { + context->getBuilder()->addLoopForceUnrollDecoration(loopInst, 0); + } + continue; + } + shouldDiagnose = true; + } + } + } + if (shouldDiagnose) { + // Diagnose the failure. context->getSink()->diagnose(ii->sourceLoc, Diagnostics::needCompileTimeConstant); @@ -499,6 +583,24 @@ void validateConstExpr( } } +void propagateInFunc(PropagateConstExprContext* context, IRGlobalValueWithCode* code) +{ + for (;;) + { + bool anyChange = false; + if (propagateConstExprForward(context, code)) + { + anyChange = true; + } + if (propagateConstExprBackward(context, code)) + { + anyChange = true; + } + if (!anyChange) + break; + } +} + void propagateConstExpr( IRModule* module, DiagnosticSink* sink) @@ -548,21 +650,7 @@ void propagateConstExpr( case kIROp_GlobalVar: { IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv; - - for( ;;) - { - bool anyChange = false; - if( propagateConstExprForward(&context, code) ) - { - anyChange = true; - } - if( propagateConstExprBackward(&context, code) ) - { - anyChange = true; - } - if(!anyChange) - break; - } + propagateInFunc(&context, code); } break; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 64663120d..22355bd7e 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -643,7 +643,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(FlattenDecoration, flatten, 0, 0) INST(LoopControlDecoration, loopControl, 1, 0) INST(LoopMaxItersDecoration, loopMaxIters, 1, 0) - INST(LoopInferredMaxItersDecoration, loopInferredMaxIters, 2, 0) INST(LoopExitPrimalValueDecoration, loopExitPrimalValue, 2, 0) INST(IntrinsicOpDecoration, intrinsicOp, 1, 0) /* TargetSpecificDecoration */ diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 489a89287..0d7e27bc4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -29,6 +29,7 @@ #include "slang-ir-lower-error-handling.h" #include "slang-ir-obfuscate-loc.h" #include "slang-ir-use-uninitialized-out-param.h" +#include "slang-ir-peephole.h" #include "slang-mangle.h" #include "slang-type-layout.h" @@ -8761,6 +8762,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute context->setGlobalValue(decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); + bool isInline = false; + for (auto modifier : decl->modifiers) { if (as<RequiresNVAPIAttribute>(modifier)) @@ -8858,10 +8861,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> else if (as<UnsafeForceInlineEarlyAttribute>(modifier)) { getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); + isInline = true; } else if (as<ForceInlineAttribute>(modifier)) { getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); + isInline = true; } else if (as<TreatAsDifferentiableAttribute>(modifier)) { @@ -8871,6 +8876,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); + isInline = true; } else if (as<UserDefinedDerivativeAttribute>(modifier) || as<PrimalSubstituteAttribute>(modifier)) { @@ -8930,6 +8936,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } + if (!isInline) + { + // If there are any constant expr rate parameters, we should inline this function. + // TODO: consider specializing them instead of inlining. + for (auto param : decl->getParameters()) + { + if (param->hasModifier<ConstExprModifier>()) + { + getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); + isInline = true; + break; + } + } + } + if (auto diffAttr = decl->findModifier<DifferentiableAttribute>()) { if (decl->body) @@ -9698,6 +9719,8 @@ RefPtr<IRModule> generateIRForTranslationUnit( constructSSA(module); simplifyCFG(module); applySparseConditionalConstantPropagation(module, compileRequest->getSink()); + peepholeOptimize(module); + for (auto inst : module->getGlobalInsts()) { if (auto func = as<IRGlobalValueWithCode>(inst)) @@ -9732,14 +9755,28 @@ RefPtr<IRModule> generateIRForTranslationUnit( // - If sccp is unable to eliminate the outer 'if' then we end up with // duplicated code the the conditional value. Users don't tend to put // huge gobs of code in the conditional expression in loops however. + invertLoops(module); // Next, attempt to promote local variables to SSA // temporaries and do basic simplifications. // - constructSSA(module); - simplifyCFG(module); - applySparseConditionalConstantPropagation(module, compileRequest->getSink()); + for (;;) + { + bool changed = false; + performMandatoryEarlyInlining(module); + changed |= constructSSA(module); + simplifyCFG(module); + changed |= applySparseConditionalConstantPropagation(module, compileRequest->getSink()); + changed |= peepholeOptimize(module); + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as<IRGlobalValueWithCode>(inst)) + eliminateDeadCode(func); + } + if (!changed) + break; + } // Propagate `constexpr`-ness through the dataflow graph (and the // call graph) based on constraints imposed by different instructions. diff --git a/tests/diagnostics/autodiff-data-flow-2.slang b/tests/diagnostics/autodiff-data-flow-2.slang index 3148c6a41..42aee0d01 100644 --- a/tests/diagnostics/autodiff-data-flow-2.slang +++ b/tests/diagnostics/autodiff-data-flow-2.slang @@ -28,6 +28,7 @@ float h(float x) // error: dynamic loop without [MaxIters] or [ForceUnroll] for (int i = 0; i < (int)x; i++) { + no_diff debugBreak(); } return val; diff --git a/tests/diagnostics/autodiff-data-flow.slang b/tests/diagnostics/autodiff-data-flow.slang index e8d9502e4..bbade0e0a 100644 --- a/tests/diagnostics/autodiff-data-flow.slang +++ b/tests/diagnostics/autodiff-data-flow.slang @@ -28,6 +28,7 @@ void g(float x) for (int i = 0; i < 5; i++) // Not ok, we can't infer the loop iterations because the body modifies induction var. { i = (int)x; + no_diff debugBreak(); } return; } diff --git a/tests/diagnostics/constexpr-error.slang.expected b/tests/diagnostics/constexpr-error.slang.expected index f6c27b006..6f124fe34 100644 --- a/tests/diagnostics/constexpr-error.slang.expected +++ b/tests/diagnostics/constexpr-error.slang.expected @@ -6,9 +6,15 @@ tests/diagnostics/constexpr-error.slang(27): error 40006: expected a compile-tim tests/diagnostics/constexpr-error.slang(35): error 40006: expected a compile-time constant result += t.Sample(s, uv, int2(ii)); ^ +tests/diagnostics/constexpr-error.slang(39): error 40006: expected a compile-time constant + for(uint jj = 0; jj < uv.y; jj++) + ^~ tests/diagnostics/constexpr-error.slang(41): error 40006: expected a compile-time constant result += t.Sample(s, uv, int2(jj)); ^ +tests/diagnostics/constexpr-error.slang(39): error 40006: expected a compile-time constant + for(uint jj = 0; jj < uv.y; jj++) + ^~ } standard output = { } diff --git a/tests/expected-failure.txt b/tests/expected-failure.txt index f40435ec1..98cf36724 100644 --- a/tests/expected-failure.txt +++ b/tests/expected-failure.txt @@ -152,6 +152,7 @@ tests/hlsl-intrinsic/wave-mask/wave-vector.slang.3 (vk) tests/hlsl-intrinsic/wave-mask/wave.slang.3 (vk) tests/ir/loop-unroll-0.slang.1 (vk) tests/ir/string-literal-hash.slang.1 (vk) +tests/language-feature/constants/constexpr-loop.slang.1 (vk) tests/language-feature/general-inline.slang.1 (vk) tests/language-feature/simple-inline.slang.1 (vk) tests/language-feature/initializer-lists/default-init-16bit-types.slang (vk) diff --git a/tests/language-feature/constants/constexpr-loop.slang b/tests/language-feature/constants/constexpr-loop.slang new file mode 100644 index 000000000..31b4294a5 --- /dev/null +++ b/tests/language-feature/constants/constexpr-loop.slang @@ -0,0 +1,24 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT: set g_texture = Texture2D(size=8, content = one) +//TEST_INPUT: set g_sampler = Sampler +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer + +Texture2D g_texture; +SamplerState g_sampler; +RWStructuredBuffer<float> outputBuffer; + +float4 sample(float2 uv, constexpr int2 ofs ) +{ + return g_texture.SampleLevel(g_sampler, uv, 0.0, ofs); +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + float4 result = 0; + for (int i = 0; i < 3; i++) + result += sample(float2(1.0), int2(i, i)); + outputBuffer[dispatchThreadID.x] = result.x; +} diff --git a/tests/language-feature/constants/constexpr-loop.slang.expected.txt b/tests/language-feature/constants/constexpr-loop.slang.expected.txt new file mode 100644 index 000000000..a6122d7ce --- /dev/null +++ b/tests/language-feature/constants/constexpr-loop.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +3.000000 |
