diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2024-04-01 22:02:25 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-01 22:02:25 -0700 |
| commit | 251f55c5ec4cb2b7432e71d6ba8adc96700d35c2 (patch) | |
| tree | 6360ae937545943a97f3a380cbcb3c2d8fb950bd | |
| parent | daf63cc983fd5f8f2b24872a9125e0394ed2180e (diff) | |
Support SM6.6 keyword "WaveSize" (#3871)
Resolves an issue #3385
Shader Model 6.6 added a new keyowrd, "WaveSize". See the following link
for more details:
https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | slang.h | 10 | ||||
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 41 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit-hlsl.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 22 | ||||
| -rw-r--r-- | tests/diagnostics/wavesize-invalid-size.slang | 18 | ||||
| -rw-r--r-- | tests/hlsl/wave-size.slang | 13 |
15 files changed, 170 insertions, 0 deletions
@@ -2521,6 +2521,10 @@ extern "C" SlangUInt axisCount, SlangUInt* outSizeAlongAxis); + SLANG_API void spReflectionEntryPoint_getComputeWaveSize( + SlangReflectionEntryPoint* entryPoint, + SlangUInt* outWaveSize); + SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( SlangReflectionEntryPoint* entryPoint); @@ -3335,6 +3339,12 @@ namespace slang return spReflectionEntryPoint_getComputeThreadGroupSize((SlangReflectionEntryPoint*) this, axisCount, outSizeAlongAxis); } + void getComputeWaveSize( + SlangUInt* outWaveSize) + { + return spReflectionEntryPoint_getComputeWaveSize((SlangReflectionEntryPoint*)this, outWaveSize); + } + bool usesAnySampleRateInput() { return 0 != spReflectionEntryPoint_usesAnySampleRateInput((SlangReflectionEntryPoint*) this); diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index ead81fcf5..8d91f27ab 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2458,6 +2458,9 @@ attribute_syntax [earlydepthstencil] : EarlyDepthStencilAttribute; __attributeTarget(FuncDecl) attribute_syntax [numthreads(x: int, y: int = 1, z: int = 1)] : NumThreadsAttribute; +__attributeTarget(FuncDecl) +attribute_syntax [WaveSize(numLanes: int)] : WaveSizeAttribute; + // __attributeTarget(VarDeclBase) attribute_syntax [__vulkanRayPayload(location : int = -1)] : VulkanRayPayloadAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 847844e5f..31cac6dda 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -870,6 +870,17 @@ class NumThreadsAttribute : public Attribute IntVal* z; }; +class WaveSizeAttribute : public Attribute +{ + SLANG_AST_CLASS(WaveSizeAttribute) + + // "numLanes" must be a compile time constant integer + // value of an allowed wave size, which is one of the + // followings: 4, 8, 16, 32, 64 or 128. + // + IntVal* numLanes; +}; + class MaxVertexCountAttribute : public Attribute { SLANG_AST_CLASS(MaxVertexCountAttribute) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 1587fa1b0..a387e458c 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -362,6 +362,47 @@ namespace Slang numThreadsAttr->y = values[1]; numThreadsAttr->z = values[2]; } + else if (auto waveSizeAttr = as<WaveSizeAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + IntVal* value = nullptr; + + auto arg = attr->args[0]; + if (arg) + { + auto intValue = checkLinkTimeConstantIntVal(arg); + if (!intValue) + { + return false; + } + if (auto constIntVal = as<ConstantIntVal>(intValue)) + { + bool isValidWaveSize = false; + const IntegerLiteralValue waveSize = constIntVal->getValue(); + for (int validWaveSize : { 4, 8, 16, 32, 64, 128 }) + { + if (validWaveSize == waveSize) + { + isValidWaveSize = true; + break; + } + } + if (!isValidWaveSize) + { + getSink()->diagnose(attr, Diagnostics::invalidWaveSize, constIntVal->getValue()); + return false; + } + } + value = intValue; + } + else + { + value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + } + + waveSizeAttr->numLanes = value; + } else if (auto anyValueSizeAttr = as<AnyValueSizeAttribute>(attr)) { // This case handles GLSL-oriented layout attributes diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e1c929f7e..16e494d4a 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -407,6 +407,7 @@ DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'") DIAGNOSTIC(31101, Error, unknownImageFormatName, "unknown image format '$0'") DIAGNOSTIC(31101, Error, unknownDiagnosticName, "unknown diagnostic '$0'") DIAGNOSTIC(31102, Error, nonPositiveNumThreads, "expected a positive integer in 'numthreads' attribute, got '$0'") +DIAGNOSTIC(31103, Error, invalidWaveSize, "expected a power of 2 between 4 and 128, inclusive, in 'WaveSize' attribute, got '$0'") DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 5813819b4..58ca39b69 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -268,6 +268,16 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type) return decor; } +/* static */IRWaveSizeDecoration* CLikeSourceEmitter::getComputeWaveSize(IRFunc* func, Int* outWaveSize) +{ + IRWaveSizeDecoration* decor = func->findDecoration<IRWaveSizeDecoration>(); + if (decor) + { + *outWaveSize = Int(getIntVal(decor->getOperand(0))); + } + return decor; +} + List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWitnessTable* witnessTable) { List<IRWitnessTableEntry*> sortedWitnessTableEntries; diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index c559bb135..7778c78f6 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -451,6 +451,9 @@ public: /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1 static IRNumThreadsDecoration* getComputeThreadGroupSize(IRFunc* func, Int outNumThreads[kThreadGroupAxisCount]); + /// Finds the IRWaveSizeDecoration and gets the size from that. + static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int *outWaveSize); + protected: diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 92866f9c4..37411c93e 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -351,10 +351,22 @@ void HLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin m_writer->emit(")]\n"); }; + auto emitWaveSizeAttribute = [&]() + { + Int waveSize; + if (getComputeWaveSize(irFunc, &waveSize)) + { + m_writer->emit("[WaveSize("); + m_writer->emit(waveSize); + m_writer->emit(")]\n"); + } + }; + switch (stage) { case Stage::Compute: { + emitWaveSizeAttribute(); emitNumThreadsAttribute(); } break; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index db1200926..b717e50b3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -736,6 +736,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(MaxVertexCountDecoration, maxVertexCount, 1, 0) INST(InstanceDecoration, instance, 1, 0) INST(NumThreadsDecoration, numThreads, 3, 0) + INST(WaveSizeDecoration, waveSize, 1, 0) // Added to IRParam parameters to an entry point /* GeometryInputPrimitiveTypeDecoration */ diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 165c49f11..547b034a7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -480,6 +480,14 @@ struct IRNumThreadsDecoration : IRDecoration IRIntLit* getExtentAlongAxis(int axis) { return cast<IRIntLit>(getOperand(axis)); } }; +struct IRWaveSizeDecoration : IRDecoration +{ + enum { kOp = kIROp_WaveSizeDecoration }; + IR_LEAF_ISA(WaveSizeDecoration) + + IRIntLit* getNumLanes() { return cast<IRIntLit>(getOperand(0)); } +}; + struct IREntryPointDecoration : IRDecoration { enum { kOp = kIROp_EntryPointDecoration }; @@ -3581,6 +3589,7 @@ public: IRInst* addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode); IRInst* addNumThreadsDecoration(IRInst* inst, IRInst* x, IRInst* y, IRInst* z); + IRInst* addWaveSizeDecoration(IRInst* inst, IRInst* numLanes); IRInst* emitSpecializeInst( IRType* type, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 95f2a7c75..85464446d 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5105,6 +5105,15 @@ namespace Slang return addDecoration(inst, kIROp_NumThreadsDecoration, operands, 3); } + IRInst* IRBuilder::addWaveSizeDecoration(IRInst* inst, IRInst* numLanes) + { + IRInst* operands[1] = { + numLanes + }; + + return addDecoration(inst, kIROp_WaveSizeDecoration, operands, 1); + } + IRInst* IRBuilder::emitSwizzle( IRType* type, IRInst* base, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4372d6e7c..1e324b4a4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9625,6 +9625,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getSimpleVal(context, lowerVal(context, numThreadsAttr->z)) ); } + else if (auto waveSizeAttr = as<WaveSizeAttribute>(modifier)) + { + getBuilder()->addWaveSizeDecoration( + irFunc, + getSimpleVal(context, lowerVal(context, waveSizeAttr->numLanes)) + ); + } else if (as<ReadNoneAttribute>(modifier)) { getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc); diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 9b20e2933..90103d8d9 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -2840,6 +2840,28 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( } } +SLANG_API void spReflectionEntryPoint_getComputeWaveSize( + SlangReflectionEntryPoint* inEntryPoint, + SlangUInt* outWaveSize) +{ + auto entryPointLayout = convert(inEntryPoint); + + if (!entryPointLayout) return; + if (!outWaveSize) return; + + auto entryPointFunc = entryPointLayout->entryPoint; + if (!entryPointFunc) return; + + // First look for the HLSL case, where we have an attribute attached to the entry point function + if (auto waveSizeAttribute = entryPointFunc.getDecl()->findModifier<WaveSizeAttribute>()) + { + if (auto cint = entryPointLayout->program->tryFoldIntVal(waveSizeAttribute->numLanes)) + *outWaveSize = (SlangUInt)cint->getValue(); + else if (waveSizeAttribute->numLanes) + *outWaveSize = 0; + } +} + SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput( SlangReflectionEntryPoint* inEntryPoint) { diff --git a/tests/diagnostics/wavesize-invalid-size.slang b/tests/diagnostics/wavesize-invalid-size.slang new file mode 100644 index 000000000..840434137 --- /dev/null +++ b/tests/diagnostics/wavesize-invalid-size.slang @@ -0,0 +1,18 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry computeMain + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +// Print an error when the numLanes is an invalid value for WaveSize. +// The value has to be a power of 2 between 4 and 128, inclusive. +// In other words, the set: [4, 8, 16, 32, 64, 128]. + +// "5" is an invalid value for WaveSize +// CHECK: error 31103: +[WaveSize(5)] +[numthreads(4, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + int tid = dispatchThreadID.x; + outputBuffer[tid] = tid; +} diff --git a/tests/hlsl/wave-size.slang b/tests/hlsl/wave-size.slang new file mode 100644 index 000000000..3fc6363c4 --- /dev/null +++ b/tests/hlsl/wave-size.slang @@ -0,0 +1,13 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry computeMain + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +// CHECK: [WaveSize(4)] +[WaveSize(4)] +[numthreads(4, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + int tid = dispatchThreadID.x; + outputBuffer[tid] = tid; +} |
