diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-legalize-varying-params.cpp | 224 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-varying-params.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 79 |
3 files changed, 197 insertions, 119 deletions
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index d6a12b0b3..582af4ac8 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -177,6 +177,118 @@ void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, IRInst* src) assign(builder, dest, LegalizedVaryingVal::makeValue(src)); } + +// Several of the derived calcluations rely on having +// access to the "group extents" of a compute shader. +// That information is expected to be present on +// the entry point as a `[numthreads(...)]` attribute, +// and we define a convenience routine for accessing +// that information. + +IRInst* emitCalcGroupExtents( + IRBuilder& builder, + IRFunc* entryPoint, + IRVectorType* type) +{ + if (auto numThreadsDecor = entryPoint->findDecoration<IRNumThreadsDecoration>()) + { + static const int kAxisCount = 3; + IRInst* groupExtentAlongAxis[kAxisCount] = {}; + + for (int axis = 0; axis < kAxisCount; axis++) + { + auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis)); + if (!litValue) + return nullptr; + + groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue()); + } + + return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis); + } + + // TODO: We may want to implement a backup option here, + // in case we ever want to support compute shaders with + // dynamic/flexible group size on targets that allow it. + // + SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point."); + UNREACHABLE_RETURN(nullptr); +} + +// There are some cases of system-value inputs that can be derived +// from other inputs; notably compute shaders support `SV_DispatchThreadID` +// and `SV_GroupIndex` which can both be derived from the more primitive +// `SV_GroupID` and `SV_GroupThreadID`, together with the extents +// of the thread group (which are specified with `[numthreads(...)]`). +// +// As a utilty to target-specific subtypes, we define helpers for +// calculating the value of these derived system values from the +// more primitive ones. + + /// Emit code to calculate `SV_DispatchThreadID` +IRInst* emitCalcDispatchThreadID( + IRBuilder& builder, + IRType* type, + IRInst* groupID, + IRInst* groupThreadID, + IRInst* groupExtents) +{ + // The dispatch thread ID can be computed as: + // + // dispatchThreadID = groupID*groupExtents + groupThreadID + // + // where `groupExtents` is the X,Y,Z extents of + // each thread group in threads (as given by + // `[numthreads(X,Y,Z)]`). + + return builder.emitAdd(type, + builder.emitMul(type, + groupID, + groupExtents), + groupThreadID); +} + +/// Emit code to calculate `SV_GroupIndex` +IRInst* emitCalcGroupThreadIndex( + IRBuilder& builder, + IRInst* groupThreadID, + IRInst* groupExtents) +{ + auto intType = builder.getIntType(); + auto uintType = builder.getBasicType(BaseType::UInt); + + // The group thread index can be computed as: + // + // groupThreadIndex = groupThreadID.x + // + groupThreadID.y*groupExtents.x + // + groupThreadID.z*groupExtents.x*groupExtents.z; + // + // or equivalently (with one less multiply): + // + // groupThreadIndex = (groupThreadID.z * groupExtents.y + // + groupThreadID.y) * groupExtents.x + // + groupThreadID.x; + // + + // `offset = groupThreadID.z` + auto zAxis = builder.getIntValue(intType, 2); + IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis); + + // `offset *= groupExtents.y` + // `offset += groupExtents.y` + auto yAxis = builder.getIntValue(intType, 1); + offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis)); + offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis)); + + // `offset *= groupExtents.x` + // `offset += groupExtents.x` + auto xAxis = builder.getIntValue(intType, 0); + offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis)); + offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis)); + + return offset; +} + /// Context for the IR pass that legalizing entry-point /// varying parameters for a target. /// @@ -915,116 +1027,6 @@ protected: return LegalizedVaryingVal(); } - - // There are some cases of system-value inputs that can be derived - // from other inputs; notably compute shaders support `SV_DispatchThreadID` - // and `SV_GroupIndex` which can both be derived from the more primitive - // `SV_GroupID` and `SV_GroupThreadID`, together with the extents - // of the thread group (which are specified with `[numthreads(...)]`). - // - // As a utilty to target-specific subtypes, we define helpers for - // calculating the value of these derived system values from the - // more primitive ones. - - /// Emit code to calculate `SV_DispatchThreadID` - IRInst* emitCalcDispatchThreadID( - IRBuilder& builder, - IRType* type, - IRInst* groupID, - IRInst* groupThreadID, - IRInst* groupExtents) - { - // The dispatch thread ID can be computed as: - // - // dispatchThreadID = groupID*groupExtents + groupThreadID - // - // where `groupExtents` is the X,Y,Z extents of - // each thread group in threads (as given by - // `[numthreads(X,Y,Z)]`). - - return builder.emitAdd(type, - builder.emitMul(type, - groupID, - groupExtents), - groupThreadID); - } - - /// Emit code to calculate `SV_GroupIndex` - IRInst* emitCalcGroupThreadIndex( - IRBuilder& builder, - IRInst* groupThreadID, - IRInst* groupExtents) - { - auto intType = builder.getIntType(); - auto uintType = builder.getBasicType(BaseType::UInt); - - // The group thread index can be computed as: - // - // groupThreadIndex = groupThreadID.x - // + groupThreadID.y*groupExtents.x - // + groupThreadID.z*groupExtents.x*groupExtents.z; - // - // or equivalently (with one less multiply): - // - // groupThreadIndex = (groupThreadID.z * groupExtents.y - // + groupThreadID.y) * groupExtents.x - // + groupThreadID.x; - // - - // `offset = groupThreadID.z` - auto zAxis = builder.getIntValue(intType, 2); - IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis); - - // `offset *= groupExtents.y` - // `offset += groupExtents.y` - auto yAxis = builder.getIntValue(intType, 1); - offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis)); - offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis)); - - // `offset *= groupExtents.x` - // `offset += groupExtents.x` - auto xAxis = builder.getIntValue(intType, 0); - offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis)); - offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis)); - - return offset; - } - - // Several of the derived calcluations rely on having - // access to the "group extents" of a compute shader. - // That information is expected to be present on - // the entry point as a `[numthreads(...)]` attribute, - // and we define a convenience routine for accessing - // that information. - - IRInst* emitCalcGroupExtents( - IRBuilder& builder, - IRVectorType* type) - { - if(auto numThreadsDecor = m_entryPointFunc->findDecoration<IRNumThreadsDecoration>()) - { - static const int kAxisCount = 3; - IRInst* groupExtentAlongAxis[kAxisCount] = {}; - - for( int axis = 0; axis < kAxisCount; axis++ ) - { - auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis)); - if(!litValue) - return nullptr; - - groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue()); - } - - return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis); - } - - // TODO: We may want to implement a backup option here, - // in case we ever want to support compute shaders with - // dynamic/flexible group size on targets that allow it. - // - SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point."); - UNREACHABLE_RETURN(nullptr); - } }; // With the target-independent core of the pass out of the way, we can @@ -1391,7 +1393,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize // CPU target, we'd need to change it so that the thread-group size can // be passed in as part of `ComputeVaryingThreadInput`. // - groupExtents = emitCalcGroupExtents(builder, uint3Type); + groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type); dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index ff93f38dd..952192def 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -8,6 +8,10 @@ class DiagnosticSink; struct IRFunc; struct IRModule; +struct IRInst; +struct IRFunc; +struct IRVectorType; +struct IRBuilder; void legalizeEntryPointVaryingParamsForCPU( IRModule* module, @@ -17,4 +21,13 @@ void legalizeEntryPointVaryingParamsForCUDA( IRModule* module, DiagnosticSink* sink); +IRInst* emitCalcGroupThreadIndex( + IRBuilder& builder, + IRInst* groupThreadID, + IRInst* groupExtents); + +IRInst* emitCalcGroupExtents( + IRBuilder& builder, + IRFunc* entryPoint, + IRVectorType* type); } diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 6b6c86040..f771f7a33 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -5,9 +5,12 @@ #include "slang-ir-util.h" #include "slang-ir-clone.h" #include "slang-ir-specialize-address-space.h" +#include "slang-ir-legalize-varying-params.h" namespace Slang { + const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid"); + struct EntryPointInfo { IRFunc* entryPointFunc; @@ -229,6 +232,11 @@ namespace Slang bool isSpecial; }; + IRType* getGroupThreadIdType(IRBuilder& builder) + { + return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); + } + MetalSystemValueInfo getSystemValueInfo(IRBuilder& builder, String semanticName, UInt attrIndex) { SLANG_UNUSED(attrIndex); @@ -288,7 +296,8 @@ namespace Slang } else if (semanticName == "sv_groupid") { - result.isSpecial = true; + result.metalSystemValueName = toSlice("threadgroup_position_in_grid"); + result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); } else if (semanticName == "sv_groupindex") { @@ -297,7 +306,7 @@ namespace Slang else if (semanticName == "sv_groupthreadid") { result.metalSystemValueName = toSlice("thread_position_in_threadgroup"); - result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); + result.requiredType = getGroupThreadIdType(builder); } else if (semanticName == "sv_gsinstanceid") { @@ -629,10 +638,8 @@ namespace Slang UInt attrIndex; }; List<SystemValLegalizationWorkItem> systemValWorkItems; - List<SystemValLegalizationWorkItem> workList; IRBuilder builder(entryPoint.entryPointFunc); - List<IRParam*> params; for (auto param : entryPoint.entryPointFunc->getParams()) { @@ -655,8 +662,12 @@ namespace Slang auto sysAttrIndex = sysValAttr->getIndex(); systemValWorkItems.add({ param, semanticName, sysAttrIndex }); } - for (auto workItem : systemValWorkItems) + + IRParam* groupThreadId = nullptr; + for (auto index = 0; index < systemValWorkItems.getCount(); index++) { + auto workItem = systemValWorkItems[index]; + auto param = workItem.param; auto semanticName = workItem.attrName; auto sysAttrIndex = workItem.attrIndex; @@ -671,10 +682,62 @@ namespace Slang param->replaceUsesWith(val); param->removeAndDeallocate(); } - else + else if (semanticName == "sv_groupindex") { - // Process special cases after trivial cases. - workList.add(workItem); + // Ensure we have a cached "sv_groupthreadid" + if (!groupThreadId) + { + for (auto i : systemValWorkItems) + { + if (i.attrName == groupThreadIDString) + { + groupThreadId = i.param; + } + } + if (!groupThreadId) + { + // Add the missing groupthreadid needed to compute sv_groupindex + IRBuilder groupThreadIdBuilder(builder); + groupThreadIdBuilder.setInsertInto(entryPoint.entryPointFunc->getFirstBlock()); + groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder)); + groupThreadIdBuilder.addNameHintDecoration(groupThreadId, groupThreadIDString); + + // Since "sv_groupindex" will be translated out to a global var and no longer be considered a system value + // we can reuse its layout and semantic info + Index foundRequiredDecorations = 0; + IRLayoutDecoration* layoutDecoration = nullptr; + UInt semanticIndex = 0; + for (auto decoration : param->getDecorations()) + { + if (auto layoutDecorationTmp = as<IRLayoutDecoration>(decoration)) + { + layoutDecoration = layoutDecorationTmp; + foundRequiredDecorations++; + } + else if (auto semanticDecoration = as<IRSemanticDecoration>(decoration)) + { + semanticIndex = semanticDecoration->getSemanticIndex(); + groupThreadIdBuilder.addSemanticDecoration(groupThreadId, groupThreadIDString, (int)semanticIndex); + foundRequiredDecorations++; + } + if (foundRequiredDecorations >= 2) + break; + } + SLANG_ASSERT(layoutDecoration); + layoutDecoration->removeFromParent(); + layoutDecoration->insertAtStart(groupThreadId); + systemValWorkItems.add({ groupThreadId, groupThreadIDString, semanticIndex }); + } + } + + IRBuilder svBuilder(builder.getModule()); + svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); + auto computeExtent = emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, builder.getVectorType(builder.getUIntType(), builder.getIntValue(builder.getIntType(), 3))); + auto groupIndexCalc = emitCalcGroupThreadIndex(svBuilder, groupThreadId, computeExtent); + svBuilder.addNameHintDecoration(groupIndexCalc, UnownedStringSlice("sv_groupindex")); + + param->replaceUsesWith(groupIndexCalc); + param->removeAndDeallocate(); } } if (info.isUnsupported) |
