summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-12-04 21:19:25 -0800
committerGitHub <noreply@github.com>2024-12-05 13:19:25 +0800
commitbd50f9947905feb5199c7cfe29c640084f882199 (patch)
tree17f259ccbfc320d139f63f5bcbdb28b4134dd2f5
parent46aee66664732b85f50e377687182715b3ec89e7 (diff)
Make fvk-invert-y work on mesh shader ouptuts. (#5760)
Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
-rw-r--r--source/slang/slang-ir-vk-invert-y.cpp67
-rw-r--r--tests/spirv/mesh-shader-invert-y.slang24
2 files changed, 62 insertions, 29 deletions
diff --git a/source/slang/slang-ir-vk-invert-y.cpp b/source/slang/slang-ir-vk-invert-y.cpp
index 8a59472dd..e7fc81144 100644
--- a/source/slang/slang-ir-vk-invert-y.cpp
+++ b/source/slang/slang-ir-vk-invert-y.cpp
@@ -1,6 +1,7 @@
#include "slang-ir-vk-invert-y.h"
#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
#include "slang-ir.h"
namespace Slang
@@ -28,37 +29,45 @@ void invertYOfPositionOutput(IRModule* module)
{
// Find all loads and stores to it.
IRBuilder builder(module);
- traverseUses(
- globalInst,
- [&](IRUse* use)
+ List<IRUse*> useWorkList;
+ auto processUse = [&](IRUse* use)
+ {
+ if (auto store = as<IRStore>(use->getUser()))
{
- if (auto store = as<IRStore>(use->getUser()))
- {
- if (store->getPtr() != globalInst)
- return;
+ if (getRootAddr(store->getPtr()) != globalInst)
+ return;
- builder.setInsertBefore(store);
- auto originalVal = store->getVal();
- auto invertedVal = _invertYOfVector(builder, originalVal);
- builder.replaceOperand(&store->val, invertedVal);
- }
- else if (auto load = as<IRLoad>(use->getUser()))
- {
- // Since we negate the y coordinate before writing
- // to gl_Position, we also need to negate the value after reading from it.
- builder.setInsertAfter(load);
- // Store existing uses of the load that we are going to replace with
- // inverted val later.
- List<IRUse*> oldUses;
- for (auto loadUse = load->firstUse; loadUse; loadUse = loadUse->nextUse)
- oldUses.add(loadUse);
- // Get the inverted vector.
- auto invertedVal = _invertYOfVector(builder, load);
- // Replace original uses with the invertex vector.
- for (auto loadUse : oldUses)
- builder.replaceOperand(loadUse, invertedVal);
- }
- });
+ builder.setInsertBefore(store);
+ auto originalVal = store->getVal();
+ auto invertedVal = _invertYOfVector(builder, originalVal);
+ builder.replaceOperand(&store->val, invertedVal);
+ }
+ else if (auto load = as<IRLoad>(use->getUser()))
+ {
+ // Since we negate the y coordinate before writing
+ // to gl_Position, we also need to negate the value after reading from it.
+ builder.setInsertAfter(load);
+ // Store existing uses of the load that we are going to replace with
+ // inverted val later.
+ List<IRUse*> oldUses;
+ for (auto loadUse = load->firstUse; loadUse; loadUse = loadUse->nextUse)
+ oldUses.add(loadUse);
+ // Get the inverted vector.
+ auto invertedVal = _invertYOfVector(builder, load);
+ // Replace original uses with the invertex vector.
+ for (auto loadUse : oldUses)
+ builder.replaceOperand(loadUse, invertedVal);
+ }
+ else if (auto getElementPtr = as<IRGetElementPtr>(use->getUser()))
+ {
+ traverseUses(getElementPtr, [&](IRUse* use) { useWorkList.add(use); });
+ }
+ };
+ traverseUses(globalInst, processUse);
+ for (Index i = 0; i < useWorkList.getCount(); i++)
+ {
+ processUse(useWorkList[i]);
+ }
}
}
}
diff --git a/tests/spirv/mesh-shader-invert-y.slang b/tests/spirv/mesh-shader-invert-y.slang
new file mode 100644
index 000000000..e385167ae
--- /dev/null
+++ b/tests/spirv/mesh-shader-invert-y.slang
@@ -0,0 +1,24 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -fvk-invert-y
+
+// CHECK: %[[CONST:[0-9]+]] = OpConstantComposite %v4float %float_0 %float_n1 %float_0 %float_0
+// CHECK: OpStore {{.*}} %[[CONST]]
+
+struct PS_IN
+{
+ float4 vPositionPs : SV_Position;
+};
+
+[ shader("mesh") ]
+[ outputtopology( "line" ) ]
+[ numthreads( 64, 1, 1 ) ]
+void main(
+ uint nThreadId : SV_DispatchThreadID,
+ uint nGroupThreadId : SV_GroupThreadID,
+ uint nGroupId : SV_GroupID,
+ out vertices PS_IN outputVerts[ 128 ],
+ out indices uint2 outputIB[ 64 ] )
+{
+ SetMeshOutputCounts( 128, 64 );
+
+ outputVerts[ 2 * nGroupThreadId ].vPositionPs = float4( 0, 1, 0, 0 );
+}