summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source/slang/lower-to-ir.cpp44
-rw-r--r--tests/compute/compile-time-loop.slang92
-rw-r--r--tests/compute/compile-time-loop.slang.expected.txt1
3 files changed, 128 insertions, 9 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 827504122..326d25649 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -1637,9 +1637,37 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
SLANG_UNEXPECTED("`case` or `default` not under `switch`");
}
- void visitCompileTimeForStmt(CompileTimeForStmt*)
+ void visitCompileTimeForStmt(CompileTimeForStmt* stmt)
{
- SLANG_UNIMPLEMENTED_X("IR lowering of CompileTimeForStmt");
+ // The user is asking us to emit code for the loop
+ // body for each value in the given integer range.
+ // For now, we will handle this by repeatedly lowering
+ // the body statement, with the loop variable bound
+ // to a different integer literal value each time.
+ //
+ // TODO: eventually we might handle this as just an
+ // ordinary loop, with an `[unroll]` attribute on
+ // it that we would respect.
+
+ auto rangeBeginVal = GetIntVal(stmt->rangeBeginVal);
+ auto rangeEndVal = GetIntVal(stmt->rangeEndVal);
+
+ if (rangeBeginVal >= rangeEndVal)
+ return;
+
+ auto varDecl = stmt->varDecl;
+ auto varType = varDecl->type;
+
+ for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii)
+ {
+ auto constVal = getBuilder()->getIntValue(
+ varType,
+ ii);
+
+ context->shared->declValues[varDecl] = LoweredValInfo::simple(constVal);
+
+ lowerStmt(context, stmt->body);
+ }
}
// Create a basic block in the current function,
@@ -2590,9 +2618,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// A global variable's SSA value is a *pointer* to
// the underlying storage.
auto globalVal = LoweredValInfo::ptr(irGlobal);
- context->shared->declValues.Add(
- DeclRef<VarDeclBase>(decl, nullptr),
- globalVal);
+ context->shared->declValues[
+ DeclRef<VarDeclBase>(decl, nullptr)] = globalVal;
if( auto initExpr = decl->initExpr )
{
@@ -2667,9 +2694,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
assign(context, varVal, initVal);
}
- context->shared->declValues.Add(
- DeclRef<VarDeclBase>(decl, nullptr),
- varVal);
+ context->shared->declValues[
+ DeclRef<VarDeclBase>(decl, nullptr)] = varVal;
return varVal;
}
@@ -3214,7 +3240,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if( auto paramDecl = paramInfo.decl )
{
DeclRef<VarDeclBase> paramDeclRef = makeDeclRef(paramDecl);
- subContext->shared->declValues.Add(paramDeclRef, paramVal);
+ subContext->shared->declValues[paramDeclRef] = paramVal;
}
if (paramInfo.isThisParam)
diff --git a/tests/compute/compile-time-loop.slang b/tests/compute/compile-time-loop.slang
new file mode 100644
index 000000000..43b35d42b
--- /dev/null
+++ b/tests/compute/compile-time-loop.slang
@@ -0,0 +1,92 @@
+//TEST(compute):COMPARE_RENDER_COMPUTE:-xslang -use-ir
+
+//TEST_INPUT: Texture2D(size=4, content = one) : dxbinding(0),glbinding(0)
+//TEST_INPUT: Sampler : dxbinding(0),glbinding(0)
+
+//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(1),glbinding(0),out
+
+Texture2D t;
+SamplerState s;
+RWStructuredBuffer<float> outputBuffer;
+
+cbuffer Uniforms
+{
+ float4x4 modelViewProjection;
+}
+
+struct AssembledVertex
+{
+ float3 position;
+ float3 color;
+ float2 uv;
+};
+
+struct CoarseVertex
+{
+ float3 color;
+ float2 uv;
+};
+
+struct Fragment
+{
+ float4 color;
+};
+
+// Vertex Shader
+
+struct VertexStageInput
+{
+ AssembledVertex assembledVertex : A;
+};
+
+struct VertexStageOutput
+{
+ CoarseVertex coarseVertex : CoarseVertex;
+ float4 sv_position : SV_Position;
+};
+
+VertexStageOutput vertexMain(VertexStageInput input)
+{
+ VertexStageOutput output;
+
+ float3 position = input.assembledVertex.position;
+ float3 color = input.assembledVertex.color;
+
+ output.coarseVertex.color = color;
+ output.sv_position = mul(modelViewProjection, float4(position, 1.0));
+ output.coarseVertex.uv = input.assembledVertex.uv;
+ return output;
+}
+
+// Fragment Shader
+
+struct FragmentStageInput
+{
+ CoarseVertex coarseVertex : CoarseVertex;
+};
+
+struct FragmentStageOutput
+{
+ Fragment fragment : SV_Target;
+};
+
+FragmentStageOutput fragmentMain(FragmentStageInput input)
+{
+ FragmentStageOutput output;
+
+ float3 color = input.coarseVertex.color;
+ float2 uv = input.coarseVertex.uv;
+ output.fragment.color = float4(color, 1.0);
+
+
+ float4 result = 0;
+ $for(i in Range(0,5))
+ {
+ float4 v = t.Sample(s, uv, int2(i - 2, 0));
+ result += v;
+ }
+
+ outputBuffer[0] = result.x;
+
+ return output;
+}
diff --git a/tests/compute/compile-time-loop.slang.expected.txt b/tests/compute/compile-time-loop.slang.expected.txt
new file mode 100644
index 000000000..2b58069cf
--- /dev/null
+++ b/tests/compute/compile-time-loop.slang.expected.txt
@@ -0,0 +1 @@
+40A00000