diff options
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 44 | ||||
| -rw-r--r-- | tests/compute/compile-time-loop.slang | 92 | ||||
| -rw-r--r-- | tests/compute/compile-time-loop.slang.expected.txt | 1 |
3 files changed, 128 insertions, 9 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 827504122..326d25649 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -1637,9 +1637,37 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> SLANG_UNEXPECTED("`case` or `default` not under `switch`"); } - void visitCompileTimeForStmt(CompileTimeForStmt*) + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) { - SLANG_UNIMPLEMENTED_X("IR lowering of CompileTimeForStmt"); + // The user is asking us to emit code for the loop + // body for each value in the given integer range. + // For now, we will handle this by repeatedly lowering + // the body statement, with the loop variable bound + // to a different integer literal value each time. + // + // TODO: eventually we might handle this as just an + // ordinary loop, with an `[unroll]` attribute on + // it that we would respect. + + auto rangeBeginVal = GetIntVal(stmt->rangeBeginVal); + auto rangeEndVal = GetIntVal(stmt->rangeEndVal); + + if (rangeBeginVal >= rangeEndVal) + return; + + auto varDecl = stmt->varDecl; + auto varType = varDecl->type; + + for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii) + { + auto constVal = getBuilder()->getIntValue( + varType, + ii); + + context->shared->declValues[varDecl] = LoweredValInfo::simple(constVal); + + lowerStmt(context, stmt->body); + } } // Create a basic block in the current function, @@ -2590,9 +2618,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // A global variable's SSA value is a *pointer* to // the underlying storage. auto globalVal = LoweredValInfo::ptr(irGlobal); - context->shared->declValues.Add( - DeclRef<VarDeclBase>(decl, nullptr), - globalVal); + context->shared->declValues[ + DeclRef<VarDeclBase>(decl, nullptr)] = globalVal; if( auto initExpr = decl->initExpr ) { @@ -2667,9 +2694,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> assign(context, varVal, initVal); } - context->shared->declValues.Add( - DeclRef<VarDeclBase>(decl, nullptr), - varVal); + context->shared->declValues[ + DeclRef<VarDeclBase>(decl, nullptr)] = varVal; return varVal; } @@ -3214,7 +3240,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if( auto paramDecl = paramInfo.decl ) { DeclRef<VarDeclBase> paramDeclRef = makeDeclRef(paramDecl); - subContext->shared->declValues.Add(paramDeclRef, paramVal); + subContext->shared->declValues[paramDeclRef] = paramVal; } if (paramInfo.isThisParam) diff --git a/tests/compute/compile-time-loop.slang b/tests/compute/compile-time-loop.slang new file mode 100644 index 000000000..43b35d42b --- /dev/null +++ b/tests/compute/compile-time-loop.slang @@ -0,0 +1,92 @@ +//TEST(compute):COMPARE_RENDER_COMPUTE:-xslang -use-ir + +//TEST_INPUT: Texture2D(size=4, content = one) : dxbinding(0),glbinding(0) +//TEST_INPUT: Sampler : dxbinding(0),glbinding(0) + +//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(1),glbinding(0),out + +Texture2D t; +SamplerState s; +RWStructuredBuffer<float> outputBuffer; + +cbuffer Uniforms +{ + float4x4 modelViewProjection; +} + +struct AssembledVertex +{ + float3 position; + float3 color; + float2 uv; +}; + +struct CoarseVertex +{ + float3 color; + float2 uv; +}; + +struct Fragment +{ + float4 color; +}; + +// Vertex Shader + +struct VertexStageInput +{ + AssembledVertex assembledVertex : A; +}; + +struct VertexStageOutput +{ + CoarseVertex coarseVertex : CoarseVertex; + float4 sv_position : SV_Position; +}; + +VertexStageOutput vertexMain(VertexStageInput input) +{ + VertexStageOutput output; + + float3 position = input.assembledVertex.position; + float3 color = input.assembledVertex.color; + + output.coarseVertex.color = color; + output.sv_position = mul(modelViewProjection, float4(position, 1.0)); + output.coarseVertex.uv = input.assembledVertex.uv; + return output; +} + +// Fragment Shader + +struct FragmentStageInput +{ + CoarseVertex coarseVertex : CoarseVertex; +}; + +struct FragmentStageOutput +{ + Fragment fragment : SV_Target; +}; + +FragmentStageOutput fragmentMain(FragmentStageInput input) +{ + FragmentStageOutput output; + + float3 color = input.coarseVertex.color; + float2 uv = input.coarseVertex.uv; + output.fragment.color = float4(color, 1.0); + + + float4 result = 0; + $for(i in Range(0,5)) + { + float4 v = t.Sample(s, uv, int2(i - 2, 0)); + result += v; + } + + outputBuffer[0] = result.x; + + return output; +} diff --git a/tests/compute/compile-time-loop.slang.expected.txt b/tests/compute/compile-time-loop.slang.expected.txt new file mode 100644 index 000000000..2b58069cf --- /dev/null +++ b/tests/compute/compile-time-loop.slang.expected.txt @@ -0,0 +1 @@ +40A00000 |
