summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2024-06-14 01:29:35 -0400
committerGitHub <noreply@github.com>2024-06-13 22:29:35 -0700
commit2cc96907e4152291e0b6bca78a0bfbc69ddb8839 (patch)
treec5ebf1cc6afc99d2058c673e79b1fef2d60f5f46 /source
parenta6b8348f69a4cd1ab1edbc1ccf1133c807b81b5b (diff)
Implement for metal `SV_GroupIndex` (#4385)
* Implement for metal `SV_GroupIndex` 1. If we don't have `sv_GroupThreadId` available we create one using `SV_GroupIndex`s location data. 2. We emit code emulating `sv_GroupThreadId` from the same logic that CUDA/CPP uses. * address most review comments Addressed all but two: [1](https://github.com/shader-slang/slang/pull/4385#discussion_r1639058473) and [2](https://github.com/shader-slang/slang/pull/4385#issuecomment-2166934855) I want to enable tests and be sure there is no bugs using CI before I redesign the code so I have a working fallback. * address comment, enable tests enable now functioning tests due to `SV_GroupIndex` working with metal * syntax error with groupThreadID search did `= param` instead of `= i.param` * add `sv_groupid` for test + test fixes * disable test that won't work regardless
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp224
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h13
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp79
3 files changed, 197 insertions, 119 deletions
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index d6a12b0b3..582af4ac8 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -177,6 +177,118 @@ void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, IRInst* src)
assign(builder, dest, LegalizedVaryingVal::makeValue(src));
}
+
+// Several of the derived calcluations rely on having
+// access to the "group extents" of a compute shader.
+// That information is expected to be present on
+// the entry point as a `[numthreads(...)]` attribute,
+// and we define a convenience routine for accessing
+// that information.
+
+IRInst* emitCalcGroupExtents(
+ IRBuilder& builder,
+ IRFunc* entryPoint,
+ IRVectorType* type)
+{
+ if (auto numThreadsDecor = entryPoint->findDecoration<IRNumThreadsDecoration>())
+ {
+ static const int kAxisCount = 3;
+ IRInst* groupExtentAlongAxis[kAxisCount] = {};
+
+ for (int axis = 0; axis < kAxisCount; axis++)
+ {
+ auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis));
+ if (!litValue)
+ return nullptr;
+
+ groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue());
+ }
+
+ return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis);
+ }
+
+ // TODO: We may want to implement a backup option here,
+ // in case we ever want to support compute shaders with
+ // dynamic/flexible group size on targets that allow it.
+ //
+ SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point.");
+ UNREACHABLE_RETURN(nullptr);
+}
+
+// There are some cases of system-value inputs that can be derived
+// from other inputs; notably compute shaders support `SV_DispatchThreadID`
+// and `SV_GroupIndex` which can both be derived from the more primitive
+// `SV_GroupID` and `SV_GroupThreadID`, together with the extents
+// of the thread group (which are specified with `[numthreads(...)]`).
+//
+// As a utilty to target-specific subtypes, we define helpers for
+// calculating the value of these derived system values from the
+// more primitive ones.
+
+ /// Emit code to calculate `SV_DispatchThreadID`
+IRInst* emitCalcDispatchThreadID(
+ IRBuilder& builder,
+ IRType* type,
+ IRInst* groupID,
+ IRInst* groupThreadID,
+ IRInst* groupExtents)
+{
+ // The dispatch thread ID can be computed as:
+ //
+ // dispatchThreadID = groupID*groupExtents + groupThreadID
+ //
+ // where `groupExtents` is the X,Y,Z extents of
+ // each thread group in threads (as given by
+ // `[numthreads(X,Y,Z)]`).
+
+ return builder.emitAdd(type,
+ builder.emitMul(type,
+ groupID,
+ groupExtents),
+ groupThreadID);
+}
+
+/// Emit code to calculate `SV_GroupIndex`
+IRInst* emitCalcGroupThreadIndex(
+ IRBuilder& builder,
+ IRInst* groupThreadID,
+ IRInst* groupExtents)
+{
+ auto intType = builder.getIntType();
+ auto uintType = builder.getBasicType(BaseType::UInt);
+
+ // The group thread index can be computed as:
+ //
+ // groupThreadIndex = groupThreadID.x
+ // + groupThreadID.y*groupExtents.x
+ // + groupThreadID.z*groupExtents.x*groupExtents.z;
+ //
+ // or equivalently (with one less multiply):
+ //
+ // groupThreadIndex = (groupThreadID.z * groupExtents.y
+ // + groupThreadID.y) * groupExtents.x
+ // + groupThreadID.x;
+ //
+
+ // `offset = groupThreadID.z`
+ auto zAxis = builder.getIntValue(intType, 2);
+ IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis);
+
+ // `offset *= groupExtents.y`
+ // `offset += groupExtents.y`
+ auto yAxis = builder.getIntValue(intType, 1);
+ offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis));
+ offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis));
+
+ // `offset *= groupExtents.x`
+ // `offset += groupExtents.x`
+ auto xAxis = builder.getIntValue(intType, 0);
+ offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis));
+ offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis));
+
+ return offset;
+}
+
/// Context for the IR pass that legalizing entry-point
/// varying parameters for a target.
///
@@ -915,116 +1027,6 @@ protected:
return LegalizedVaryingVal();
}
-
- // There are some cases of system-value inputs that can be derived
- // from other inputs; notably compute shaders support `SV_DispatchThreadID`
- // and `SV_GroupIndex` which can both be derived from the more primitive
- // `SV_GroupID` and `SV_GroupThreadID`, together with the extents
- // of the thread group (which are specified with `[numthreads(...)]`).
- //
- // As a utilty to target-specific subtypes, we define helpers for
- // calculating the value of these derived system values from the
- // more primitive ones.
-
- /// Emit code to calculate `SV_DispatchThreadID`
- IRInst* emitCalcDispatchThreadID(
- IRBuilder& builder,
- IRType* type,
- IRInst* groupID,
- IRInst* groupThreadID,
- IRInst* groupExtents)
- {
- // The dispatch thread ID can be computed as:
- //
- // dispatchThreadID = groupID*groupExtents + groupThreadID
- //
- // where `groupExtents` is the X,Y,Z extents of
- // each thread group in threads (as given by
- // `[numthreads(X,Y,Z)]`).
-
- return builder.emitAdd(type,
- builder.emitMul(type,
- groupID,
- groupExtents),
- groupThreadID);
- }
-
- /// Emit code to calculate `SV_GroupIndex`
- IRInst* emitCalcGroupThreadIndex(
- IRBuilder& builder,
- IRInst* groupThreadID,
- IRInst* groupExtents)
- {
- auto intType = builder.getIntType();
- auto uintType = builder.getBasicType(BaseType::UInt);
-
- // The group thread index can be computed as:
- //
- // groupThreadIndex = groupThreadID.x
- // + groupThreadID.y*groupExtents.x
- // + groupThreadID.z*groupExtents.x*groupExtents.z;
- //
- // or equivalently (with one less multiply):
- //
- // groupThreadIndex = (groupThreadID.z * groupExtents.y
- // + groupThreadID.y) * groupExtents.x
- // + groupThreadID.x;
- //
-
- // `offset = groupThreadID.z`
- auto zAxis = builder.getIntValue(intType, 2);
- IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis);
-
- // `offset *= groupExtents.y`
- // `offset += groupExtents.y`
- auto yAxis = builder.getIntValue(intType, 1);
- offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis));
- offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis));
-
- // `offset *= groupExtents.x`
- // `offset += groupExtents.x`
- auto xAxis = builder.getIntValue(intType, 0);
- offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis));
- offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis));
-
- return offset;
- }
-
- // Several of the derived calcluations rely on having
- // access to the "group extents" of a compute shader.
- // That information is expected to be present on
- // the entry point as a `[numthreads(...)]` attribute,
- // and we define a convenience routine for accessing
- // that information.
-
- IRInst* emitCalcGroupExtents(
- IRBuilder& builder,
- IRVectorType* type)
- {
- if(auto numThreadsDecor = m_entryPointFunc->findDecoration<IRNumThreadsDecoration>())
- {
- static const int kAxisCount = 3;
- IRInst* groupExtentAlongAxis[kAxisCount] = {};
-
- for( int axis = 0; axis < kAxisCount; axis++ )
- {
- auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis));
- if(!litValue)
- return nullptr;
-
- groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue());
- }
-
- return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis);
- }
-
- // TODO: We may want to implement a backup option here,
- // in case we ever want to support compute shaders with
- // dynamic/flexible group size on targets that allow it.
- //
- SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point.");
- UNREACHABLE_RETURN(nullptr);
- }
};
// With the target-independent core of the pass out of the way, we can
@@ -1391,7 +1393,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
// CPU target, we'd need to change it so that the thread-group size can
// be passed in as part of `ComputeVaryingThreadInput`.
//
- groupExtents = emitCalcGroupExtents(builder, uint3Type);
+ groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type);
dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents);
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
index ff93f38dd..952192def 100644
--- a/source/slang/slang-ir-legalize-varying-params.h
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -8,6 +8,10 @@ class DiagnosticSink;
struct IRFunc;
struct IRModule;
+struct IRInst;
+struct IRFunc;
+struct IRVectorType;
+struct IRBuilder;
void legalizeEntryPointVaryingParamsForCPU(
IRModule* module,
@@ -17,4 +21,13 @@ void legalizeEntryPointVaryingParamsForCUDA(
IRModule* module,
DiagnosticSink* sink);
+IRInst* emitCalcGroupThreadIndex(
+ IRBuilder& builder,
+ IRInst* groupThreadID,
+ IRInst* groupExtents);
+
+IRInst* emitCalcGroupExtents(
+ IRBuilder& builder,
+ IRFunc* entryPoint,
+ IRVectorType* type);
}
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 6b6c86040..f771f7a33 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -5,9 +5,12 @@
#include "slang-ir-util.h"
#include "slang-ir-clone.h"
#include "slang-ir-specialize-address-space.h"
+#include "slang-ir-legalize-varying-params.h"
namespace Slang
{
+ const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid");
+
struct EntryPointInfo
{
IRFunc* entryPointFunc;
@@ -229,6 +232,11 @@ namespace Slang
bool isSpecial;
};
+ IRType* getGroupThreadIdType(IRBuilder& builder)
+ {
+ return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
+ }
+
MetalSystemValueInfo getSystemValueInfo(IRBuilder& builder, String semanticName, UInt attrIndex)
{
SLANG_UNUSED(attrIndex);
@@ -288,7 +296,8 @@ namespace Slang
}
else if (semanticName == "sv_groupid")
{
- result.isSpecial = true;
+ result.metalSystemValueName = toSlice("threadgroup_position_in_grid");
+ result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
}
else if (semanticName == "sv_groupindex")
{
@@ -297,7 +306,7 @@ namespace Slang
else if (semanticName == "sv_groupthreadid")
{
result.metalSystemValueName = toSlice("thread_position_in_threadgroup");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
+ result.requiredType = getGroupThreadIdType(builder);
}
else if (semanticName == "sv_gsinstanceid")
{
@@ -629,10 +638,8 @@ namespace Slang
UInt attrIndex;
};
List<SystemValLegalizationWorkItem> systemValWorkItems;
- List<SystemValLegalizationWorkItem> workList;
IRBuilder builder(entryPoint.entryPointFunc);
- List<IRParam*> params;
for (auto param : entryPoint.entryPointFunc->getParams())
{
@@ -655,8 +662,12 @@ namespace Slang
auto sysAttrIndex = sysValAttr->getIndex();
systemValWorkItems.add({ param, semanticName, sysAttrIndex });
}
- for (auto workItem : systemValWorkItems)
+
+ IRParam* groupThreadId = nullptr;
+ for (auto index = 0; index < systemValWorkItems.getCount(); index++)
{
+ auto workItem = systemValWorkItems[index];
+
auto param = workItem.param;
auto semanticName = workItem.attrName;
auto sysAttrIndex = workItem.attrIndex;
@@ -671,10 +682,62 @@ namespace Slang
param->replaceUsesWith(val);
param->removeAndDeallocate();
}
- else
+ else if (semanticName == "sv_groupindex")
{
- // Process special cases after trivial cases.
- workList.add(workItem);
+ // Ensure we have a cached "sv_groupthreadid"
+ if (!groupThreadId)
+ {
+ for (auto i : systemValWorkItems)
+ {
+ if (i.attrName == groupThreadIDString)
+ {
+ groupThreadId = i.param;
+ }
+ }
+ if (!groupThreadId)
+ {
+ // Add the missing groupthreadid needed to compute sv_groupindex
+ IRBuilder groupThreadIdBuilder(builder);
+ groupThreadIdBuilder.setInsertInto(entryPoint.entryPointFunc->getFirstBlock());
+ groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder));
+ groupThreadIdBuilder.addNameHintDecoration(groupThreadId, groupThreadIDString);
+
+ // Since "sv_groupindex" will be translated out to a global var and no longer be considered a system value
+ // we can reuse its layout and semantic info
+ Index foundRequiredDecorations = 0;
+ IRLayoutDecoration* layoutDecoration = nullptr;
+ UInt semanticIndex = 0;
+ for (auto decoration : param->getDecorations())
+ {
+ if (auto layoutDecorationTmp = as<IRLayoutDecoration>(decoration))
+ {
+ layoutDecoration = layoutDecorationTmp;
+ foundRequiredDecorations++;
+ }
+ else if (auto semanticDecoration = as<IRSemanticDecoration>(decoration))
+ {
+ semanticIndex = semanticDecoration->getSemanticIndex();
+ groupThreadIdBuilder.addSemanticDecoration(groupThreadId, groupThreadIDString, (int)semanticIndex);
+ foundRequiredDecorations++;
+ }
+ if (foundRequiredDecorations >= 2)
+ break;
+ }
+ SLANG_ASSERT(layoutDecoration);
+ layoutDecoration->removeFromParent();
+ layoutDecoration->insertAtStart(groupThreadId);
+ systemValWorkItems.add({ groupThreadId, groupThreadIDString, semanticIndex });
+ }
+ }
+
+ IRBuilder svBuilder(builder.getModule());
+ svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst());
+ auto computeExtent = emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, builder.getVectorType(builder.getUIntType(), builder.getIntValue(builder.getIntType(), 3)));
+ auto groupIndexCalc = emitCalcGroupThreadIndex(svBuilder, groupThreadId, computeExtent);
+ svBuilder.addNameHintDecoration(groupIndexCalc, UnownedStringSlice("sv_groupindex"));
+
+ param->replaceUsesWith(groupIndexCalc);
+ param->removeAndDeallocate();
}
}
if (info.isUnsupported)