summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp224
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h13
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp79
-rw-r--r--tests/bugs/frexp-double.slang3
-rw-r--r--tests/bugs/frexp.slang2
-rw-r--r--tests/bugs/gh-3980.slang2
-rw-r--r--tests/compute/texture-simpler.slang2
-rw-r--r--tests/language-feature/higher-order-functions/simple.slang2
-rw-r--r--tests/language-feature/swizzles/matrix-swizzle-write-array.slang2
-rw-r--r--tests/language-feature/swizzles/matrix-swizzle-write-single.slang2
-rw-r--r--tests/language-feature/swizzles/matrix-swizzle-write-swizzle.slang2
-rw-r--r--tests/language-feature/swizzles/matrix-swizzle-write.slang2
-rw-r--r--tests/metal/thread_position_in_threadgroup.slang27
-rw-r--r--tests/pipeline/compute/compute-system-values.slang2
14 files changed, 235 insertions, 129 deletions
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index d6a12b0b3..582af4ac8 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -177,6 +177,118 @@ void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, IRInst* src)
assign(builder, dest, LegalizedVaryingVal::makeValue(src));
}
+
+// Several of the derived calcluations rely on having
+// access to the "group extents" of a compute shader.
+// That information is expected to be present on
+// the entry point as a `[numthreads(...)]` attribute,
+// and we define a convenience routine for accessing
+// that information.
+
+IRInst* emitCalcGroupExtents(
+ IRBuilder& builder,
+ IRFunc* entryPoint,
+ IRVectorType* type)
+{
+ if (auto numThreadsDecor = entryPoint->findDecoration<IRNumThreadsDecoration>())
+ {
+ static const int kAxisCount = 3;
+ IRInst* groupExtentAlongAxis[kAxisCount] = {};
+
+ for (int axis = 0; axis < kAxisCount; axis++)
+ {
+ auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis));
+ if (!litValue)
+ return nullptr;
+
+ groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue());
+ }
+
+ return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis);
+ }
+
+ // TODO: We may want to implement a backup option here,
+ // in case we ever want to support compute shaders with
+ // dynamic/flexible group size on targets that allow it.
+ //
+ SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point.");
+ UNREACHABLE_RETURN(nullptr);
+}
+
+// There are some cases of system-value inputs that can be derived
+// from other inputs; notably compute shaders support `SV_DispatchThreadID`
+// and `SV_GroupIndex` which can both be derived from the more primitive
+// `SV_GroupID` and `SV_GroupThreadID`, together with the extents
+// of the thread group (which are specified with `[numthreads(...)]`).
+//
+// As a utilty to target-specific subtypes, we define helpers for
+// calculating the value of these derived system values from the
+// more primitive ones.
+
+ /// Emit code to calculate `SV_DispatchThreadID`
+IRInst* emitCalcDispatchThreadID(
+ IRBuilder& builder,
+ IRType* type,
+ IRInst* groupID,
+ IRInst* groupThreadID,
+ IRInst* groupExtents)
+{
+ // The dispatch thread ID can be computed as:
+ //
+ // dispatchThreadID = groupID*groupExtents + groupThreadID
+ //
+ // where `groupExtents` is the X,Y,Z extents of
+ // each thread group in threads (as given by
+ // `[numthreads(X,Y,Z)]`).
+
+ return builder.emitAdd(type,
+ builder.emitMul(type,
+ groupID,
+ groupExtents),
+ groupThreadID);
+}
+
+/// Emit code to calculate `SV_GroupIndex`
+IRInst* emitCalcGroupThreadIndex(
+ IRBuilder& builder,
+ IRInst* groupThreadID,
+ IRInst* groupExtents)
+{
+ auto intType = builder.getIntType();
+ auto uintType = builder.getBasicType(BaseType::UInt);
+
+ // The group thread index can be computed as:
+ //
+ // groupThreadIndex = groupThreadID.x
+ // + groupThreadID.y*groupExtents.x
+ // + groupThreadID.z*groupExtents.x*groupExtents.z;
+ //
+ // or equivalently (with one less multiply):
+ //
+ // groupThreadIndex = (groupThreadID.z * groupExtents.y
+ // + groupThreadID.y) * groupExtents.x
+ // + groupThreadID.x;
+ //
+
+ // `offset = groupThreadID.z`
+ auto zAxis = builder.getIntValue(intType, 2);
+ IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis);
+
+ // `offset *= groupExtents.y`
+ // `offset += groupExtents.y`
+ auto yAxis = builder.getIntValue(intType, 1);
+ offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis));
+ offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis));
+
+ // `offset *= groupExtents.x`
+ // `offset += groupExtents.x`
+ auto xAxis = builder.getIntValue(intType, 0);
+ offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis));
+ offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis));
+
+ return offset;
+}
+
/// Context for the IR pass that legalizing entry-point
/// varying parameters for a target.
///
@@ -915,116 +1027,6 @@ protected:
return LegalizedVaryingVal();
}
-
- // There are some cases of system-value inputs that can be derived
- // from other inputs; notably compute shaders support `SV_DispatchThreadID`
- // and `SV_GroupIndex` which can both be derived from the more primitive
- // `SV_GroupID` and `SV_GroupThreadID`, together with the extents
- // of the thread group (which are specified with `[numthreads(...)]`).
- //
- // As a utilty to target-specific subtypes, we define helpers for
- // calculating the value of these derived system values from the
- // more primitive ones.
-
- /// Emit code to calculate `SV_DispatchThreadID`
- IRInst* emitCalcDispatchThreadID(
- IRBuilder& builder,
- IRType* type,
- IRInst* groupID,
- IRInst* groupThreadID,
- IRInst* groupExtents)
- {
- // The dispatch thread ID can be computed as:
- //
- // dispatchThreadID = groupID*groupExtents + groupThreadID
- //
- // where `groupExtents` is the X,Y,Z extents of
- // each thread group in threads (as given by
- // `[numthreads(X,Y,Z)]`).
-
- return builder.emitAdd(type,
- builder.emitMul(type,
- groupID,
- groupExtents),
- groupThreadID);
- }
-
- /// Emit code to calculate `SV_GroupIndex`
- IRInst* emitCalcGroupThreadIndex(
- IRBuilder& builder,
- IRInst* groupThreadID,
- IRInst* groupExtents)
- {
- auto intType = builder.getIntType();
- auto uintType = builder.getBasicType(BaseType::UInt);
-
- // The group thread index can be computed as:
- //
- // groupThreadIndex = groupThreadID.x
- // + groupThreadID.y*groupExtents.x
- // + groupThreadID.z*groupExtents.x*groupExtents.z;
- //
- // or equivalently (with one less multiply):
- //
- // groupThreadIndex = (groupThreadID.z * groupExtents.y
- // + groupThreadID.y) * groupExtents.x
- // + groupThreadID.x;
- //
-
- // `offset = groupThreadID.z`
- auto zAxis = builder.getIntValue(intType, 2);
- IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis);
-
- // `offset *= groupExtents.y`
- // `offset += groupExtents.y`
- auto yAxis = builder.getIntValue(intType, 1);
- offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis));
- offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis));
-
- // `offset *= groupExtents.x`
- // `offset += groupExtents.x`
- auto xAxis = builder.getIntValue(intType, 0);
- offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis));
- offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis));
-
- return offset;
- }
-
- // Several of the derived calcluations rely on having
- // access to the "group extents" of a compute shader.
- // That information is expected to be present on
- // the entry point as a `[numthreads(...)]` attribute,
- // and we define a convenience routine for accessing
- // that information.
-
- IRInst* emitCalcGroupExtents(
- IRBuilder& builder,
- IRVectorType* type)
- {
- if(auto numThreadsDecor = m_entryPointFunc->findDecoration<IRNumThreadsDecoration>())
- {
- static const int kAxisCount = 3;
- IRInst* groupExtentAlongAxis[kAxisCount] = {};
-
- for( int axis = 0; axis < kAxisCount; axis++ )
- {
- auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis));
- if(!litValue)
- return nullptr;
-
- groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue());
- }
-
- return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis);
- }
-
- // TODO: We may want to implement a backup option here,
- // in case we ever want to support compute shaders with
- // dynamic/flexible group size on targets that allow it.
- //
- SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point.");
- UNREACHABLE_RETURN(nullptr);
- }
};
// With the target-independent core of the pass out of the way, we can
@@ -1391,7 +1393,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
// CPU target, we'd need to change it so that the thread-group size can
// be passed in as part of `ComputeVaryingThreadInput`.
//
- groupExtents = emitCalcGroupExtents(builder, uint3Type);
+ groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type);
dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents);
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
index ff93f38dd..952192def 100644
--- a/source/slang/slang-ir-legalize-varying-params.h
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -8,6 +8,10 @@ class DiagnosticSink;
struct IRFunc;
struct IRModule;
+struct IRInst;
+struct IRFunc;
+struct IRVectorType;
+struct IRBuilder;
void legalizeEntryPointVaryingParamsForCPU(
IRModule* module,
@@ -17,4 +21,13 @@ void legalizeEntryPointVaryingParamsForCUDA(
IRModule* module,
DiagnosticSink* sink);
+IRInst* emitCalcGroupThreadIndex(
+ IRBuilder& builder,
+ IRInst* groupThreadID,
+ IRInst* groupExtents);
+
+IRInst* emitCalcGroupExtents(
+ IRBuilder& builder,
+ IRFunc* entryPoint,
+ IRVectorType* type);
}
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 6b6c86040..f771f7a33 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -5,9 +5,12 @@
#include "slang-ir-util.h"
#include "slang-ir-clone.h"
#include "slang-ir-specialize-address-space.h"
+#include "slang-ir-legalize-varying-params.h"
namespace Slang
{
+ const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid");
+
struct EntryPointInfo
{
IRFunc* entryPointFunc;
@@ -229,6 +232,11 @@ namespace Slang
bool isSpecial;
};
+ IRType* getGroupThreadIdType(IRBuilder& builder)
+ {
+ return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
+ }
+
MetalSystemValueInfo getSystemValueInfo(IRBuilder& builder, String semanticName, UInt attrIndex)
{
SLANG_UNUSED(attrIndex);
@@ -288,7 +296,8 @@ namespace Slang
}
else if (semanticName == "sv_groupid")
{
- result.isSpecial = true;
+ result.metalSystemValueName = toSlice("threadgroup_position_in_grid");
+ result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
}
else if (semanticName == "sv_groupindex")
{
@@ -297,7 +306,7 @@ namespace Slang
else if (semanticName == "sv_groupthreadid")
{
result.metalSystemValueName = toSlice("thread_position_in_threadgroup");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
+ result.requiredType = getGroupThreadIdType(builder);
}
else if (semanticName == "sv_gsinstanceid")
{
@@ -629,10 +638,8 @@ namespace Slang
UInt attrIndex;
};
List<SystemValLegalizationWorkItem> systemValWorkItems;
- List<SystemValLegalizationWorkItem> workList;
IRBuilder builder(entryPoint.entryPointFunc);
- List<IRParam*> params;
for (auto param : entryPoint.entryPointFunc->getParams())
{
@@ -655,8 +662,12 @@ namespace Slang
auto sysAttrIndex = sysValAttr->getIndex();
systemValWorkItems.add({ param, semanticName, sysAttrIndex });
}
- for (auto workItem : systemValWorkItems)
+
+ IRParam* groupThreadId = nullptr;
+ for (auto index = 0; index < systemValWorkItems.getCount(); index++)
{
+ auto workItem = systemValWorkItems[index];
+
auto param = workItem.param;
auto semanticName = workItem.attrName;
auto sysAttrIndex = workItem.attrIndex;
@@ -671,10 +682,62 @@ namespace Slang
param->replaceUsesWith(val);
param->removeAndDeallocate();
}
- else
+ else if (semanticName == "sv_groupindex")
{
- // Process special cases after trivial cases.
- workList.add(workItem);
+ // Ensure we have a cached "sv_groupthreadid"
+ if (!groupThreadId)
+ {
+ for (auto i : systemValWorkItems)
+ {
+ if (i.attrName == groupThreadIDString)
+ {
+ groupThreadId = i.param;
+ }
+ }
+ if (!groupThreadId)
+ {
+ // Add the missing groupthreadid needed to compute sv_groupindex
+ IRBuilder groupThreadIdBuilder(builder);
+ groupThreadIdBuilder.setInsertInto(entryPoint.entryPointFunc->getFirstBlock());
+ groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder));
+ groupThreadIdBuilder.addNameHintDecoration(groupThreadId, groupThreadIDString);
+
+ // Since "sv_groupindex" will be translated out to a global var and no longer be considered a system value
+ // we can reuse its layout and semantic info
+ Index foundRequiredDecorations = 0;
+ IRLayoutDecoration* layoutDecoration = nullptr;
+ UInt semanticIndex = 0;
+ for (auto decoration : param->getDecorations())
+ {
+ if (auto layoutDecorationTmp = as<IRLayoutDecoration>(decoration))
+ {
+ layoutDecoration = layoutDecorationTmp;
+ foundRequiredDecorations++;
+ }
+ else if (auto semanticDecoration = as<IRSemanticDecoration>(decoration))
+ {
+ semanticIndex = semanticDecoration->getSemanticIndex();
+ groupThreadIdBuilder.addSemanticDecoration(groupThreadId, groupThreadIDString, (int)semanticIndex);
+ foundRequiredDecorations++;
+ }
+ if (foundRequiredDecorations >= 2)
+ break;
+ }
+ SLANG_ASSERT(layoutDecoration);
+ layoutDecoration->removeFromParent();
+ layoutDecoration->insertAtStart(groupThreadId);
+ systemValWorkItems.add({ groupThreadId, groupThreadIDString, semanticIndex });
+ }
+ }
+
+ IRBuilder svBuilder(builder.getModule());
+ svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst());
+ auto computeExtent = emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, builder.getVectorType(builder.getUIntType(), builder.getIntValue(builder.getIntType(), 3)));
+ auto groupIndexCalc = emitCalcGroupThreadIndex(svBuilder, groupThreadId, computeExtent);
+ svBuilder.addNameHintDecoration(groupIndexCalc, UnownedStringSlice("sv_groupindex"));
+
+ param->replaceUsesWith(groupIndexCalc);
+ param->removeAndDeallocate();
}
}
if (info.isUnsupported)
diff --git a/tests/bugs/frexp-double.slang b/tests/bugs/frexp-double.slang
index 40623a17b..c2b4f21c7 100644
--- a/tests/bugs/frexp-double.slang
+++ b/tests/bugs/frexp-double.slang
@@ -4,7 +4,8 @@
//TEST(smoke,compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type -render-feature double
//TEST(smoke,compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -emit-spirv-directly -output-using-type -render-feature double
//TEST(smoke,compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-cuda -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//metal currently does not support `double`
+//DISABLE_TEST(smoke,compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -shaderobj -mtl -output-using-type
// BUF: type: int32_t
// BUF-NEXT: 1
diff --git a/tests/bugs/frexp.slang b/tests/bugs/frexp.slang
index 01c345d1e..60912297c 100644
--- a/tests/bugs/frexp.slang
+++ b/tests/bugs/frexp.slang
@@ -4,7 +4,7 @@
//TEST(smoke,compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type
//TEST(smoke,compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -emit-spirv-directly -output-using-type
//TEST(smoke,compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-cuda -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(smoke,compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -shaderobj -mtl -output-using-type
// BUF: type: int32_t
// BUF-NEXT: 1
diff --git a/tests/bugs/gh-3980.slang b/tests/bugs/gh-3980.slang
index 509212ea9..57cd28dbe 100644
--- a/tests/bugs/gh-3980.slang
+++ b/tests/bugs/gh-3980.slang
@@ -4,7 +4,7 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj -output-using-type -emit-spirv-directly
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-slang -shaderobj -mtl -output-using-type
// Slang removes parentheses characters for the bitwise operators when they are not needed.
// DXC prints warning messages even when the expression is correct.
diff --git a/tests/compute/texture-simpler.slang b/tests/compute/texture-simpler.slang
index 196e8e96c..0f9dfcbad 100644
--- a/tests/compute/texture-simpler.slang
+++ b/tests/compute/texture-simpler.slang
@@ -4,7 +4,7 @@
//TEST(smoke,compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -profile cs_6_0 -use-dxil -shaderobj -output-using-type
//TEST(smoke,compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -render-feature hardware-device
//TEST(smoke,compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(smoke,compute):COMPARE_COMPUTE:-slang -shaderobj -mtl -output-using-type
//TEST_INPUT: Texture2D(size=4, content = one):name t2D
Texture2D<float> t2D;
diff --git a/tests/language-feature/higher-order-functions/simple.slang b/tests/language-feature/higher-order-functions/simple.slang
index 8a3544b91..13fc16aa5 100644
--- a/tests/language-feature/higher-order-functions/simple.slang
+++ b/tests/language-feature/higher-order-functions/simple.slang
@@ -1,7 +1,7 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX():-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX():-slang -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<uint> outputBuffer;
diff --git a/tests/language-feature/swizzles/matrix-swizzle-write-array.slang b/tests/language-feature/swizzles/matrix-swizzle-write-array.slang
index e9510de06..616a19b19 100644
--- a/tests/language-feature/swizzles/matrix-swizzle-write-array.slang
+++ b/tests/language-feature/swizzles/matrix-swizzle-write-array.slang
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE: -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE: -vk -compute -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl -output-using-type
// Test that matrix swizzle writes work correctly
// Matrix swizzles can either be one or zero indexed
diff --git a/tests/language-feature/swizzles/matrix-swizzle-write-single.slang b/tests/language-feature/swizzles/matrix-swizzle-write-single.slang
index d8bb11ea5..34d54ac55 100644
--- a/tests/language-feature/swizzles/matrix-swizzle-write-single.slang
+++ b/tests/language-feature/swizzles/matrix-swizzle-write-single.slang
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE: -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE: -vk -compute -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl -output-using-type
// Test that writes to single matrix elements with swizzles work
diff --git a/tests/language-feature/swizzles/matrix-swizzle-write-swizzle.slang b/tests/language-feature/swizzles/matrix-swizzle-write-swizzle.slang
index b73b14249..f1ad7bc6a 100644
--- a/tests/language-feature/swizzles/matrix-swizzle-write-swizzle.slang
+++ b/tests/language-feature/swizzles/matrix-swizzle-write-swizzle.slang
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE: -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE: -vk -compute -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl -output-using-type
// Test that writing to swizzles of matrix swizzles works correctly
diff --git a/tests/language-feature/swizzles/matrix-swizzle-write.slang b/tests/language-feature/swizzles/matrix-swizzle-write.slang
index d467d33db..6ccb4c29c 100644
--- a/tests/language-feature/swizzles/matrix-swizzle-write.slang
+++ b/tests/language-feature/swizzles/matrix-swizzle-write.slang
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE: -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE: -vk -compute -shaderobj -output-using-type
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl -output-using-type
// Test that matrix swizzle writes work correctly
// Matrix swizzles can either be one or zero indexed
diff --git a/tests/metal/thread_position_in_threadgroup.slang b/tests/metal/thread_position_in_threadgroup.slang
new file mode 100644
index 000000000..a20fe2ec7
--- /dev/null
+++ b/tests/metal/thread_position_in_threadgroup.slang
@@ -0,0 +1,27 @@
+//TEST:SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target metal -D GROUPID
+//TEST:SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target metal
+
+//CHECK: computeMain
+
+// ensure we compute the SV_GroupIndex from SV_GroupThreadID and `numthreads`
+// CHECK: thread_position_in_threadgroup
+// CHECK-DAG: *{{.*}}2
+
+RWBuffer<uint> dst;
+
+void indirection(uint groupIndex)
+{
+ dst[groupIndex] = groupIndex;
+}
+
+#define THREAD_COUNT 2
+[numthreads(THREAD_COUNT, 1, 1)]
+#ifdef GROUPID
+void computeMain(uint GI : SV_GroupIndex, uint GTID : SV_GroupThreadID)
+#else
+void computeMain(uint GI : SV_GroupIndex)
+#endif
+{
+ dst[GI + THREAD_COUNT] = GI;
+ indirection(GI);
+} \ No newline at end of file
diff --git a/tests/pipeline/compute/compute-system-values.slang b/tests/pipeline/compute/compute-system-values.slang
index f7aef06ff..912827557 100644
--- a/tests/pipeline/compute/compute-system-values.slang
+++ b/tests/pipeline/compute/compute-system-values.slang
@@ -2,7 +2,7 @@
//TEST(compute):COMPARE_COMPUTE: -shaderobj
//TEST(compute):COMPARE_COMPUTE: -cpu -shaderobj
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer;