summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2024-04-01 22:02:25 -0700
committerGitHub <noreply@github.com>2024-04-01 22:02:25 -0700
commit251f55c5ec4cb2b7432e71d6ba8adc96700d35c2 (patch)
tree6360ae937545943a97f3a380cbcb3c2d8fb950bd /source
parentdaf63cc983fd5f8f2b24872a9125e0394ed2180e (diff)
Support SM6.6 keyword "WaveSize" (#3871)
Resolves an issue #3385 Shader Model 6.6 added a new keyowrd, "WaveSize". See the following link for more details: https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h11
-rw-r--r--source/slang/slang-check-modifier.cpp41
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-emit-c-like.cpp10
-rw-r--r--source/slang/slang-emit-c-like.h3
-rw-r--r--source/slang/slang-emit-hlsl.cpp12
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h9
-rw-r--r--source/slang/slang-ir.cpp9
-rw-r--r--source/slang/slang-lower-to-ir.cpp7
-rw-r--r--source/slang/slang-reflection-api.cpp22
12 files changed, 129 insertions, 0 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index ead81fcf5..8d91f27ab 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2458,6 +2458,9 @@ attribute_syntax [earlydepthstencil] : EarlyDepthStencilAttribute;
__attributeTarget(FuncDecl)
attribute_syntax [numthreads(x: int, y: int = 1, z: int = 1)] : NumThreadsAttribute;
+__attributeTarget(FuncDecl)
+attribute_syntax [WaveSize(numLanes: int)] : WaveSizeAttribute;
+
//
__attributeTarget(VarDeclBase)
attribute_syntax [__vulkanRayPayload(location : int = -1)] : VulkanRayPayloadAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 847844e5f..31cac6dda 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -870,6 +870,17 @@ class NumThreadsAttribute : public Attribute
IntVal* z;
};
+class WaveSizeAttribute : public Attribute
+{
+ SLANG_AST_CLASS(WaveSizeAttribute)
+
+ // "numLanes" must be a compile time constant integer
+ // value of an allowed wave size, which is one of the
+ // followings: 4, 8, 16, 32, 64 or 128.
+ //
+ IntVal* numLanes;
+};
+
class MaxVertexCountAttribute : public Attribute
{
SLANG_AST_CLASS(MaxVertexCountAttribute)
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 1587fa1b0..a387e458c 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -362,6 +362,47 @@ namespace Slang
numThreadsAttr->y = values[1];
numThreadsAttr->z = values[2];
}
+ else if (auto waveSizeAttr = as<WaveSizeAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+
+ IntVal* value = nullptr;
+
+ auto arg = attr->args[0];
+ if (arg)
+ {
+ auto intValue = checkLinkTimeConstantIntVal(arg);
+ if (!intValue)
+ {
+ return false;
+ }
+ if (auto constIntVal = as<ConstantIntVal>(intValue))
+ {
+ bool isValidWaveSize = false;
+ const IntegerLiteralValue waveSize = constIntVal->getValue();
+ for (int validWaveSize : { 4, 8, 16, 32, 64, 128 })
+ {
+ if (validWaveSize == waveSize)
+ {
+ isValidWaveSize = true;
+ break;
+ }
+ }
+ if (!isValidWaveSize)
+ {
+ getSink()->diagnose(attr, Diagnostics::invalidWaveSize, constIntVal->getValue());
+ return false;
+ }
+ }
+ value = intValue;
+ }
+ else
+ {
+ value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
+ }
+
+ waveSizeAttr->numLanes = value;
+ }
else if (auto anyValueSizeAttr = as<AnyValueSizeAttribute>(attr))
{
// This case handles GLSL-oriented layout attributes
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index e1c929f7e..16e494d4a 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -407,6 +407,7 @@ DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'")
DIAGNOSTIC(31101, Error, unknownImageFormatName, "unknown image format '$0'")
DIAGNOSTIC(31101, Error, unknownDiagnosticName, "unknown diagnostic '$0'")
DIAGNOSTIC(31102, Error, nonPositiveNumThreads, "expected a positive integer in 'numthreads' attribute, got '$0'")
+DIAGNOSTIC(31103, Error, invalidWaveSize, "expected a power of 2 between 4 and 128, inclusive, in 'WaveSize' attribute, got '$0'")
DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute")
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 5813819b4..58ca39b69 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -268,6 +268,16 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
return decor;
}
+/* static */IRWaveSizeDecoration* CLikeSourceEmitter::getComputeWaveSize(IRFunc* func, Int* outWaveSize)
+{
+ IRWaveSizeDecoration* decor = func->findDecoration<IRWaveSizeDecoration>();
+ if (decor)
+ {
+ *outWaveSize = Int(getIntVal(decor->getOperand(0)));
+ }
+ return decor;
+}
+
List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWitnessTable* witnessTable)
{
List<IRWitnessTableEntry*> sortedWitnessTableEntries;
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index c559bb135..7778c78f6 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -451,6 +451,9 @@ public:
/// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1
static IRNumThreadsDecoration* getComputeThreadGroupSize(IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]);
+ /// Finds the IRWaveSizeDecoration and gets the size from that.
+ static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int *outWaveSize);
+
protected:
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index 92866f9c4..37411c93e 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -351,10 +351,22 @@ void HLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
m_writer->emit(")]\n");
};
+ auto emitWaveSizeAttribute = [&]()
+ {
+ Int waveSize;
+ if (getComputeWaveSize(irFunc, &waveSize))
+ {
+ m_writer->emit("[WaveSize(");
+ m_writer->emit(waveSize);
+ m_writer->emit(")]\n");
+ }
+ };
+
switch (stage)
{
case Stage::Compute:
{
+ emitWaveSizeAttribute();
emitNumThreadsAttribute();
}
break;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index db1200926..b717e50b3 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -736,6 +736,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(MaxVertexCountDecoration, maxVertexCount, 1, 0)
INST(InstanceDecoration, instance, 1, 0)
INST(NumThreadsDecoration, numThreads, 3, 0)
+ INST(WaveSizeDecoration, waveSize, 1, 0)
// Added to IRParam parameters to an entry point
/* GeometryInputPrimitiveTypeDecoration */
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 165c49f11..547b034a7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -480,6 +480,14 @@ struct IRNumThreadsDecoration : IRDecoration
IRIntLit* getExtentAlongAxis(int axis) { return cast<IRIntLit>(getOperand(axis)); }
};
+struct IRWaveSizeDecoration : IRDecoration
+{
+ enum { kOp = kIROp_WaveSizeDecoration };
+ IR_LEAF_ISA(WaveSizeDecoration)
+
+ IRIntLit* getNumLanes() { return cast<IRIntLit>(getOperand(0)); }
+};
+
struct IREntryPointDecoration : IRDecoration
{
enum { kOp = kIROp_EntryPointDecoration };
@@ -3581,6 +3589,7 @@ public:
IRInst* addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode);
IRInst* addNumThreadsDecoration(IRInst* inst, IRInst* x, IRInst* y, IRInst* z);
+ IRInst* addWaveSizeDecoration(IRInst* inst, IRInst* numLanes);
IRInst* emitSpecializeInst(
IRType* type,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 95f2a7c75..85464446d 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5105,6 +5105,15 @@ namespace Slang
return addDecoration(inst, kIROp_NumThreadsDecoration, operands, 3);
}
+ IRInst* IRBuilder::addWaveSizeDecoration(IRInst* inst, IRInst* numLanes)
+ {
+ IRInst* operands[1] = {
+ numLanes
+ };
+
+ return addDecoration(inst, kIROp_WaveSizeDecoration, operands, 1);
+ }
+
IRInst* IRBuilder::emitSwizzle(
IRType* type,
IRInst* base,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 4372d6e7c..1e324b4a4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9625,6 +9625,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getSimpleVal(context, lowerVal(context, numThreadsAttr->z))
);
}
+ else if (auto waveSizeAttr = as<WaveSizeAttribute>(modifier))
+ {
+ getBuilder()->addWaveSizeDecoration(
+ irFunc,
+ getSimpleVal(context, lowerVal(context, waveSizeAttr->numLanes))
+ );
+ }
else if (as<ReadNoneAttribute>(modifier))
{
getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc);
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 9b20e2933..90103d8d9 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -2840,6 +2840,28 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
}
}
+SLANG_API void spReflectionEntryPoint_getComputeWaveSize(
+ SlangReflectionEntryPoint* inEntryPoint,
+ SlangUInt* outWaveSize)
+{
+ auto entryPointLayout = convert(inEntryPoint);
+
+ if (!entryPointLayout) return;
+ if (!outWaveSize) return;
+
+ auto entryPointFunc = entryPointLayout->entryPoint;
+ if (!entryPointFunc) return;
+
+ // First look for the HLSL case, where we have an attribute attached to the entry point function
+ if (auto waveSizeAttribute = entryPointFunc.getDecl()->findModifier<WaveSizeAttribute>())
+ {
+ if (auto cint = entryPointLayout->program->tryFoldIntVal(waveSizeAttribute->numLanes))
+ *outWaveSize = (SlangUInt)cint->getValue();
+ else if (waveSizeAttribute->numLanes)
+ *outWaveSize = 0;
+ }
+}
+
SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput(
SlangReflectionEntryPoint* inEntryPoint)
{