diff options
| author | Yong He <yonghe@outlook.com> | 2024-12-04 21:19:25 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-05 13:19:25 +0800 |
| commit | bd50f9947905feb5199c7cfe29c640084f882199 (patch) | |
| tree | 17f259ccbfc320d139f63f5bcbdb28b4134dd2f5 | |
| parent | 46aee66664732b85f50e377687182715b3ec89e7 (diff) | |
Make fvk-invert-y work on mesh shader ouptuts. (#5760)
Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
| -rw-r--r-- | source/slang/slang-ir-vk-invert-y.cpp | 67 | ||||
| -rw-r--r-- | tests/spirv/mesh-shader-invert-y.slang | 24 |
2 files changed, 62 insertions, 29 deletions
diff --git a/source/slang/slang-ir-vk-invert-y.cpp b/source/slang/slang-ir-vk-invert-y.cpp index 8a59472dd..e7fc81144 100644 --- a/source/slang/slang-ir-vk-invert-y.cpp +++ b/source/slang/slang-ir-vk-invert-y.cpp @@ -1,6 +1,7 @@ #include "slang-ir-vk-invert-y.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" #include "slang-ir.h" namespace Slang @@ -28,37 +29,45 @@ void invertYOfPositionOutput(IRModule* module) { // Find all loads and stores to it. IRBuilder builder(module); - traverseUses( - globalInst, - [&](IRUse* use) + List<IRUse*> useWorkList; + auto processUse = [&](IRUse* use) + { + if (auto store = as<IRStore>(use->getUser())) { - if (auto store = as<IRStore>(use->getUser())) - { - if (store->getPtr() != globalInst) - return; + if (getRootAddr(store->getPtr()) != globalInst) + return; - builder.setInsertBefore(store); - auto originalVal = store->getVal(); - auto invertedVal = _invertYOfVector(builder, originalVal); - builder.replaceOperand(&store->val, invertedVal); - } - else if (auto load = as<IRLoad>(use->getUser())) - { - // Since we negate the y coordinate before writing - // to gl_Position, we also need to negate the value after reading from it. - builder.setInsertAfter(load); - // Store existing uses of the load that we are going to replace with - // inverted val later. - List<IRUse*> oldUses; - for (auto loadUse = load->firstUse; loadUse; loadUse = loadUse->nextUse) - oldUses.add(loadUse); - // Get the inverted vector. - auto invertedVal = _invertYOfVector(builder, load); - // Replace original uses with the invertex vector. - for (auto loadUse : oldUses) - builder.replaceOperand(loadUse, invertedVal); - } - }); + builder.setInsertBefore(store); + auto originalVal = store->getVal(); + auto invertedVal = _invertYOfVector(builder, originalVal); + builder.replaceOperand(&store->val, invertedVal); + } + else if (auto load = as<IRLoad>(use->getUser())) + { + // Since we negate the y coordinate before writing + // to gl_Position, we also need to negate the value after reading from it. + builder.setInsertAfter(load); + // Store existing uses of the load that we are going to replace with + // inverted val later. + List<IRUse*> oldUses; + for (auto loadUse = load->firstUse; loadUse; loadUse = loadUse->nextUse) + oldUses.add(loadUse); + // Get the inverted vector. + auto invertedVal = _invertYOfVector(builder, load); + // Replace original uses with the invertex vector. + for (auto loadUse : oldUses) + builder.replaceOperand(loadUse, invertedVal); + } + else if (auto getElementPtr = as<IRGetElementPtr>(use->getUser())) + { + traverseUses(getElementPtr, [&](IRUse* use) { useWorkList.add(use); }); + } + }; + traverseUses(globalInst, processUse); + for (Index i = 0; i < useWorkList.getCount(); i++) + { + processUse(useWorkList[i]); + } } } } diff --git a/tests/spirv/mesh-shader-invert-y.slang b/tests/spirv/mesh-shader-invert-y.slang new file mode 100644 index 000000000..e385167ae --- /dev/null +++ b/tests/spirv/mesh-shader-invert-y.slang @@ -0,0 +1,24 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -fvk-invert-y + +// CHECK: %[[CONST:[0-9]+]] = OpConstantComposite %v4float %float_0 %float_n1 %float_0 %float_0 +// CHECK: OpStore {{.*}} %[[CONST]] + +struct PS_IN +{ + float4 vPositionPs : SV_Position; +}; + +[ shader("mesh") ] +[ outputtopology( "line" ) ] +[ numthreads( 64, 1, 1 ) ] +void main( + uint nThreadId : SV_DispatchThreadID, + uint nGroupThreadId : SV_GroupThreadID, + uint nGroupId : SV_GroupID, + out vertices PS_IN outputVerts[ 128 ], + out indices uint2 outputIB[ 64 ] ) +{ + SetMeshOutputCounts( 128, 64 ); + + outputVerts[ 2 * nGroupThreadId ].vPositionPs = float4( 0, 1, 0, 0 ); +} |
