summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/hlsl-intrinsic/wave-mask/wave.slang21
1 files changed, 15 insertions, 6 deletions
diff --git a/tests/hlsl-intrinsic/wave-mask/wave.slang b/tests/hlsl-intrinsic/wave-mask/wave.slang
index 6b641906d..346940cb2 100644
--- a/tests/hlsl-intrinsic/wave-mask/wave.slang
+++ b/tests/hlsl-intrinsic/wave-mask/wave.slang
@@ -14,7 +14,7 @@ groupshared int sharedMem[32];
int exclusivePrefixSum(WaveMask mask, int index, int waveLaneId, int originalValue, int elementCount)
{
- WaveMask localMask = WaveMaskBallot(mask, waveLaneId < elementCount);
+ WaveMask localMask = WaveMaskBallot(mask, index < elementCount);
sharedMem[index] = 0;
@@ -23,7 +23,7 @@ int exclusivePrefixSum(WaveMask mask, int index, int waveLaneId, int originalVal
int temp = 0;
int val = originalValue;
- for(int i = 1; i < elementCount; i += i)
+ for(int i = 1; i < elementCount; i += i)
{
int temp = WaveMaskShuffle(localMask, val, waveLaneId - i);
if(waveLaneId >= i)
@@ -37,25 +37,34 @@ int exclusivePrefixSum(WaveMask mask, int index, int waveLaneId, int originalVal
// Write to shared memory
sharedMem[index] = val;
-
- // Syncronizes on the mask, and ensures memory fence for shared data write
- WaveMaskSharedSync(localMask);
return val;
}
return 0;
}
+// It matters how kernels with WaveMask intrinsics are launched(!).
+// TODO(JS):
+// If I launch with an numthreads amount that is not the size of the Wave on the device, then some
+// lanes will not be executing at startup, and the kernel will have to know that is the case.
+// This works currently though because the mask is only used
+// on CUDA, and it's Wave size is 32.
[numthreads(32, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
+ // Assumes all threads in the Wave are active at start.
+ WaveMask waveMask = ~WaveMask(0);
+
int index = int(dispatchThreadID.x);
const int waveLaneId = WaveGetLaneIndex();
const int value = inputBuffer[index];
const int elementCount = 9;
- exclusivePrefixSum(WaveGetActiveMask(), index, waveLaneId, value, elementCount);
+ exclusivePrefixSum(waveMask, index, waveLaneId, value, elementCount);
+
+ // We don't read from any other lane, so we don't actually need any sync
+ //WaveMaskSharedSync(waveMask);
// It returns the result, but we are going to read from shared memory, to check that aspect worked
int prefixValue = sharedMem[index];