summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit-spirv.cpp23
-rw-r--r--tests/bugs/branch-attribute.slang9
-rw-r--r--tests/bugs/branch-switch-attribute.slang6
3 files changed, 31 insertions, 7 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 0449c6b88..aa4c3bec8 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -3962,6 +3962,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
}
+ SpvSelectionControlMask getSpvBranchSelectionControl(IRInst* inst)
+ {
+ if (inst->findDecorationImpl(kIROp_BranchDecoration))
+ return SpvSelectionControlDontFlattenMask;
+
+ if (inst->findDecorationImpl(kIROp_FlattenDecoration))
+ return SpvSelectionControlFlattenMask;
+
+ return SpvSelectionControlMaskNone;
+ }
+
// The instructions that appear inside the basic blocks of
// functions are what we will call "local" instructions.
//
@@ -4269,7 +4280,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto ifelseInst = as<IRIfElse>(inst);
auto afterBlockID = getIRInstSpvID(ifelseInst->getAfterBlock());
- emitOpSelectionMerge(parent, nullptr, afterBlockID, SpvSelectionControlMaskNone);
+ emitOpSelectionMerge(
+ parent,
+ nullptr,
+ afterBlockID,
+ getSpvBranchSelectionControl(ifelseInst));
auto falseLabel = ifelseInst->getFalseBlock();
result = emitOpBranchConditional(
parent,
@@ -4284,7 +4299,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto switchInst = as<IRSwitch>(inst);
auto mergeBlockID = getIRInstSpvID(switchInst->getBreakLabel());
- emitOpSelectionMerge(parent, nullptr, mergeBlockID, SpvSelectionControlMaskNone);
+ emitOpSelectionMerge(
+ parent,
+ nullptr,
+ mergeBlockID,
+ getSpvBranchSelectionControl(switchInst));
result = emitInstCustomOperandFunc(
parent,
inst,
diff --git a/tests/bugs/branch-attribute.slang b/tests/bugs/branch-attribute.slang
index 17e30a278..95288ecfa 100644
--- a/tests/bugs/branch-attribute.slang
+++ b/tests/bugs/branch-attribute.slang
@@ -1,10 +1,13 @@
-//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain
+//TEST:SIMPLE(filecheck=HLSL): -target hlsl -profile cs_5_0 -entry computeMain
+//TEST:SIMPLE(filecheck=SPIRV): -target spirv -O0
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
-// CHECK: [branch]
-// CHECK: [flatten]
+// HLSL: [branch]
+// HLSL: [flatten]
+// SPIRV: OpSelectionMerge {{.*}} DontFlatten
+// SPIRV: OpSelectionMerge {{.*}} Flatten
[numthreads(4, 1, 1)]
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
diff --git a/tests/bugs/branch-switch-attribute.slang b/tests/bugs/branch-switch-attribute.slang
index 761b80e9e..5fe82b7e4 100644
--- a/tests/bugs/branch-switch-attribute.slang
+++ b/tests/bugs/branch-switch-attribute.slang
@@ -1,9 +1,11 @@
-//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain
+//TEST:SIMPLE(filecheck=HLSL): -target hlsl -profile cs_5_0 -entry computeMain
+//TEST:SIMPLE(filecheck=SPIRV): -target spirv -O0
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
-// CHECK: [branch]
+// HLSL: [branch]
+// SPIRV: OpSelectionMerge {{.*}} DontFlatten
[numthreads(4, 1, 1)]
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)