summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang.h10
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h11
-rw-r--r--source/slang/slang-check-modifier.cpp41
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-emit-c-like.cpp10
-rw-r--r--source/slang/slang-emit-c-like.h3
-rw-r--r--source/slang/slang-emit-hlsl.cpp12
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h9
-rw-r--r--source/slang/slang-ir.cpp9
-rw-r--r--source/slang/slang-lower-to-ir.cpp7
-rw-r--r--source/slang/slang-reflection-api.cpp22
-rw-r--r--tests/diagnostics/wavesize-invalid-size.slang18
-rw-r--r--tests/hlsl/wave-size.slang13
15 files changed, 170 insertions, 0 deletions
diff --git a/slang.h b/slang.h
index e31de7656..0c3c41fe0 100644
--- a/slang.h
+++ b/slang.h
@@ -2521,6 +2521,10 @@ extern "C"
SlangUInt axisCount,
SlangUInt* outSizeAlongAxis);
+ SLANG_API void spReflectionEntryPoint_getComputeWaveSize(
+ SlangReflectionEntryPoint* entryPoint,
+ SlangUInt* outWaveSize);
+
SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput(
SlangReflectionEntryPoint* entryPoint);
@@ -3335,6 +3339,12 @@ namespace slang
return spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*) this, axisCount, outSizeAlongAxis);
}
+ void getComputeWaveSize(
+ SlangUInt* outWaveSize)
+ {
+ return spReflectionEntryPoint_getComputeWaveSize((SlangReflectionEntryPoint*)this, outWaveSize);
+ }
+
bool usesAnySampleRateInput()
{
return 0 != spReflectionEntryPoint_usesAnySampleRateInput((SlangReflectionEntryPoint*) this);
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index ead81fcf5..8d91f27ab 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2458,6 +2458,9 @@ attribute_syntax [earlydepthstencil] : EarlyDepthStencilAttribute;
__attributeTarget(FuncDecl)
attribute_syntax [numthreads(x: int, y: int = 1, z: int = 1)] : NumThreadsAttribute;
+__attributeTarget(FuncDecl)
+attribute_syntax [WaveSize(numLanes: int)] : WaveSizeAttribute;
+
//
__attributeTarget(VarDeclBase)
attribute_syntax [__vulkanRayPayload(location : int = -1)] : VulkanRayPayloadAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 847844e5f..31cac6dda 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -870,6 +870,17 @@ class NumThreadsAttribute : public Attribute
IntVal* z;
};
+class WaveSizeAttribute : public Attribute
+{
+ SLANG_AST_CLASS(WaveSizeAttribute)
+
+ // "numLanes" must be a compile time constant integer
+ // value of an allowed wave size, which is one of the
+ // followings: 4, 8, 16, 32, 64 or 128.
+ //
+ IntVal* numLanes;
+};
+
class MaxVertexCountAttribute : public Attribute
{
SLANG_AST_CLASS(MaxVertexCountAttribute)
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 1587fa1b0..a387e458c 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -362,6 +362,47 @@ namespace Slang
numThreadsAttr->y = values[1];
numThreadsAttr->z = values[2];
}
+ else if (auto waveSizeAttr = as<WaveSizeAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+
+ IntVal* value = nullptr;
+
+ auto arg = attr->args[0];
+ if (arg)
+ {
+ auto intValue = checkLinkTimeConstantIntVal(arg);
+ if (!intValue)
+ {
+ return false;
+ }
+ if (auto constIntVal = as<ConstantIntVal>(intValue))
+ {
+ bool isValidWaveSize = false;
+ const IntegerLiteralValue waveSize = constIntVal->getValue();
+ for (int validWaveSize : { 4, 8, 16, 32, 64, 128 })
+ {
+ if (validWaveSize == waveSize)
+ {
+ isValidWaveSize = true;
+ break;
+ }
+ }
+ if (!isValidWaveSize)
+ {
+ getSink()->diagnose(attr, Diagnostics::invalidWaveSize, constIntVal->getValue());
+ return false;
+ }
+ }
+ value = intValue;
+ }
+ else
+ {
+ value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
+ }
+
+ waveSizeAttr->numLanes = value;
+ }
else if (auto anyValueSizeAttr = as<AnyValueSizeAttribute>(attr))
{
// This case handles GLSL-oriented layout attributes
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index e1c929f7e..16e494d4a 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -407,6 +407,7 @@ DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'")
DIAGNOSTIC(31101, Error, unknownImageFormatName, "unknown image format '$0'")
DIAGNOSTIC(31101, Error, unknownDiagnosticName, "unknown diagnostic '$0'")
DIAGNOSTIC(31102, Error, nonPositiveNumThreads, "expected a positive integer in 'numthreads' attribute, got '$0'")
+DIAGNOSTIC(31103, Error, invalidWaveSize, "expected a power of 2 between 4 and 128, inclusive, in 'WaveSize' attribute, got '$0'")
DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute")
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 5813819b4..58ca39b69 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -268,6 +268,16 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
return decor;
}
+/* static */IRWaveSizeDecoration* CLikeSourceEmitter::getComputeWaveSize(IRFunc* func, Int* outWaveSize)
+{
+ IRWaveSizeDecoration* decor = func->findDecoration<IRWaveSizeDecoration>();
+ if (decor)
+ {
+ *outWaveSize = Int(getIntVal(decor->getOperand(0)));
+ }
+ return decor;
+}
+
List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWitnessTable* witnessTable)
{
List<IRWitnessTableEntry*> sortedWitnessTableEntries;
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index c559bb135..7778c78f6 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -451,6 +451,9 @@ public:
/// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1
static IRNumThreadsDecoration* getComputeThreadGroupSize(IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]);
+ /// Finds the IRWaveSizeDecoration and gets the size from that.
+ static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int *outWaveSize);
+
protected:
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index 92866f9c4..37411c93e 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -351,10 +351,22 @@ void HLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
m_writer->emit(")]\n");
};
+ auto emitWaveSizeAttribute = [&]()
+ {
+ Int waveSize;
+ if (getComputeWaveSize(irFunc, &waveSize))
+ {
+ m_writer->emit("[WaveSize(");
+ m_writer->emit(waveSize);
+ m_writer->emit(")]\n");
+ }
+ };
+
switch (stage)
{
case Stage::Compute:
{
+ emitWaveSizeAttribute();
emitNumThreadsAttribute();
}
break;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index db1200926..b717e50b3 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -736,6 +736,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(MaxVertexCountDecoration, maxVertexCount, 1, 0)
INST(InstanceDecoration, instance, 1, 0)
INST(NumThreadsDecoration, numThreads, 3, 0)
+ INST(WaveSizeDecoration, waveSize, 1, 0)
// Added to IRParam parameters to an entry point
/* GeometryInputPrimitiveTypeDecoration */
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 165c49f11..547b034a7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -480,6 +480,14 @@ struct IRNumThreadsDecoration : IRDecoration
IRIntLit* getExtentAlongAxis(int axis) { return cast<IRIntLit>(getOperand(axis)); }
};
+struct IRWaveSizeDecoration : IRDecoration
+{
+ enum { kOp = kIROp_WaveSizeDecoration };
+ IR_LEAF_ISA(WaveSizeDecoration)
+
+ IRIntLit* getNumLanes() { return cast<IRIntLit>(getOperand(0)); }
+};
+
struct IREntryPointDecoration : IRDecoration
{
enum { kOp = kIROp_EntryPointDecoration };
@@ -3581,6 +3589,7 @@ public:
IRInst* addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode);
IRInst* addNumThreadsDecoration(IRInst* inst, IRInst* x, IRInst* y, IRInst* z);
+ IRInst* addWaveSizeDecoration(IRInst* inst, IRInst* numLanes);
IRInst* emitSpecializeInst(
IRType* type,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 95f2a7c75..85464446d 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5105,6 +5105,15 @@ namespace Slang
return addDecoration(inst, kIROp_NumThreadsDecoration, operands, 3);
}
+ IRInst* IRBuilder::addWaveSizeDecoration(IRInst* inst, IRInst* numLanes)
+ {
+ IRInst* operands[1] = {
+ numLanes
+ };
+
+ return addDecoration(inst, kIROp_WaveSizeDecoration, operands, 1);
+ }
+
IRInst* IRBuilder::emitSwizzle(
IRType* type,
IRInst* base,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 4372d6e7c..1e324b4a4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9625,6 +9625,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getSimpleVal(context, lowerVal(context, numThreadsAttr->z))
);
}
+ else if (auto waveSizeAttr = as<WaveSizeAttribute>(modifier))
+ {
+ getBuilder()->addWaveSizeDecoration(
+ irFunc,
+ getSimpleVal(context, lowerVal(context, waveSizeAttr->numLanes))
+ );
+ }
else if (as<ReadNoneAttribute>(modifier))
{
getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc);
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 9b20e2933..90103d8d9 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -2840,6 +2840,28 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
}
}
+SLANG_API void spReflectionEntryPoint_getComputeWaveSize(
+ SlangReflectionEntryPoint* inEntryPoint,
+ SlangUInt* outWaveSize)
+{
+ auto entryPointLayout = convert(inEntryPoint);
+
+ if (!entryPointLayout) return;
+ if (!outWaveSize) return;
+
+ auto entryPointFunc = entryPointLayout->entryPoint;
+ if (!entryPointFunc) return;
+
+ // First look for the HLSL case, where we have an attribute attached to the entry point function
+ if (auto waveSizeAttribute = entryPointFunc.getDecl()->findModifier<WaveSizeAttribute>())
+ {
+ if (auto cint = entryPointLayout->program->tryFoldIntVal(waveSizeAttribute->numLanes))
+ *outWaveSize = (SlangUInt)cint->getValue();
+ else if (waveSizeAttribute->numLanes)
+ *outWaveSize = 0;
+ }
+}
+
SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput(
SlangReflectionEntryPoint* inEntryPoint)
{
diff --git a/tests/diagnostics/wavesize-invalid-size.slang b/tests/diagnostics/wavesize-invalid-size.slang
new file mode 100644
index 000000000..840434137
--- /dev/null
+++ b/tests/diagnostics/wavesize-invalid-size.slang
@@ -0,0 +1,18 @@
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry computeMain
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+// Print an error when the numLanes is an invalid value for WaveSize.
+// The value has to be a power of 2 between 4 and 128, inclusive.
+// In other words, the set: [4, 8, 16, 32, 64, 128].
+
+// "5" is an invalid value for WaveSize
+// CHECK: error 31103:
+[WaveSize(5)]
+[numthreads(4, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int tid = dispatchThreadID.x;
+ outputBuffer[tid] = tid;
+}
diff --git a/tests/hlsl/wave-size.slang b/tests/hlsl/wave-size.slang
new file mode 100644
index 000000000..3fc6363c4
--- /dev/null
+++ b/tests/hlsl/wave-size.slang
@@ -0,0 +1,13 @@
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry computeMain
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+// CHECK: [WaveSize(4)]
+[WaveSize(4)]
+[numthreads(4, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int tid = dispatchThreadID.x;
+ outputBuffer[tid] = tid;
+}