summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2025-02-14 01:55:28 +0800
committerGitHub <noreply@github.com>2025-02-13 09:55:28 -0800
commit1ea2ab1b638b0e6d2c385b2b06157e6109417e6b (patch)
tree438aede974cc87fffbe58e9c99d99719bb25680a
parentccc75cdd9508a4e19efa22e7c911cc2013f514fa (diff)
Disallow only resources in constant buffers in parameterblocks on metal (#6342)
* Neaten metal parameter block checking * Disallow only resources in constant buffers in parameterblocks on metal closes https://github.com/shader-slang/slang/issues/6200 * add unit test for metal parameterblock cbuffer --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-diagnostic-defs.h6
-rw-r--r--source/slang/slang-ir-check-shader-parameter-type.cpp114
-rw-r--r--tests/diagnostics/nested-constant-buffer-in-parameter-block.slang27
-rw-r--r--tools/slang-unit-test/unit-test-metal-parameter-block-constant-buffer.cpp133
4 files changed, 231 insertions, 49 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 4f72332ef..ce6217825 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -2560,6 +2560,12 @@ DIAGNOSTIC(
constantBufferInParameterBlockNotAllowedOnMetal,
"nested 'ConstantBuffer' inside a 'ParameterBlock' is not supported on Metal, use "
"'ParameterBlock' instead.")
+DIAGNOSTIC(
+ 56101,
+ Error,
+ resourceTypesInConstantBufferInParameterBlockNotAllowedOnMetal,
+ "nesting a 'ConstantBuffer' containing resource types inside a 'ParameterBlock' is not "
+ "supported on Metal, please use 'ParameterBlock' instead.")
DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.")
diff --git a/source/slang/slang-ir-check-shader-parameter-type.cpp b/source/slang/slang-ir-check-shader-parameter-type.cpp
index 6f3161110..ad61e9540 100644
--- a/source/slang/slang-ir-check-shader-parameter-type.cpp
+++ b/source/slang/slang-ir-check-shader-parameter-type.cpp
@@ -4,63 +4,93 @@
namespace Slang
{
-void checkForInvalidShaderParameterTypeForMetal(IRModule* module, DiagnosticSink* sink)
+
+template<typename P>
+auto isOrContains(P predicate, IRType* type) -> decltype(predicate(type))
{
- HashSet<IRInst*> workListSet;
- List<IRInst*> workList;
- for (auto inst : module->getGlobalInsts())
+ HashSet<IRType*> visited;
+
+ auto go = [&visited, &predicate](auto&& self, IRType* type) -> decltype(predicate(type))
{
- if (inst->getOp() == kIROp_ParameterBlockType)
+ // Prevent infinite recursion by tracking visited types
+ if (!visited.add(type))
+ return {};
+
+ // Check if the current type matches the predicate
+ if (auto result = predicate(type))
+ return result;
+
+ // Recursively check struct fields
+ if (auto structType = as<IRStructType>(type))
{
- auto type = inst->getOperand(0);
- if (workListSet.add(type))
- workList.add(type);
- // Diagnose an error on `ParameterBlock<ConstantBuffer<T>>`.
- if (type->getOp() == kIROp_ConstantBufferType)
+ for (auto field : structType->getFields())
{
- bool foundUseSite = false;
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- auto user = use->getUser();
- if (user->sourceLoc.isValid())
- {
- sink->diagnose(
- user,
- Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal);
- foundUseSite = true;
- break;
- }
- }
- if (!foundUseSite)
- sink->diagnose(
- inst,
- Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal);
+ auto fieldType = field->getFieldType();
+ if (auto result = self(self, fieldType))
+ return result;
}
}
- }
- // Diagnose an error any any struct fields whose type is `ConstantBuffer<T>` if the
- // struct is used inside a `ParameterBlock`.
- for (Index i = 0; i < workList.getCount(); i++)
+
+ return {};
+ };
+
+ return go(go, type);
+}
+
+void checkForInvalidShaderParameterTypeForMetal(IRModule* module, DiagnosticSink* sink)
+{
+ auto isConstantBufferWithResource = [](IRType* type) -> std::optional<IRType*>
{
- auto type = workList[i];
- if (auto structType = as<IRStructType>(type))
+ if (type->getOp() == kIROp_ConstantBufferType)
{
- for (auto field : structType->getFields())
+ // Get the type inside the constant buffer
+ auto innerType = as<IRType>(type->getOperand(0));
+
+ // Check if the inner type contains any resource types
+ auto hasResource = [](IRType* t) -> std::optional<IRType*>
{
- auto fieldType = field->getFieldType();
- if (fieldType->getOp() == kIROp_ConstantBufferType)
+ if (isResourceType(t))
+ return t;
+ return {};
+ };
+
+ if (auto resourceType = isOrContains(hasResource, innerType))
+ return type; // Return the constant buffer type if it contains a resource
+ }
+ return {};
+ };
+
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (inst->getOp() != kIROp_ParameterBlockType)
+ continue;
+
+ auto type = as<IRType>(inst->getOperand(0));
+ if (auto invalidCBType = isOrContains(isConstantBufferWithResource, type))
+ {
+ // Try to find a valid source location from uses
+ bool foundUseSite = false;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (user->sourceLoc.isValid())
{
sink->diagnose(
- field->getKey(),
- Diagnostics::constantBufferInParameterBlockNotAllowedOnMetal);
+ user,
+ Diagnostics::
+ resourceTypesInConstantBufferInParameterBlockNotAllowedOnMetal);
+ foundUseSite = true;
+ break;
}
- if (workListSet.add(fieldType))
- workList.add(fieldType);
}
+
+ if (!foundUseSite)
+ sink->diagnose(
+ inst,
+ Diagnostics::resourceTypesInConstantBufferInParameterBlockNotAllowedOnMetal);
}
}
}
-
void checkForInvalidShaderParameterType(
TargetRequest* target,
IRModule* module,
@@ -69,4 +99,4 @@ void checkForInvalidShaderParameterType(
if (isMetalTarget(target))
checkForInvalidShaderParameterTypeForMetal(module, sink);
}
-} // namespace Slang \ No newline at end of file
+} // namespace Slang
diff --git a/tests/diagnostics/nested-constant-buffer-in-parameter-block.slang b/tests/diagnostics/nested-constant-buffer-in-parameter-block.slang
index eb9ecdd14..9fb83f034 100644
--- a/tests/diagnostics/nested-constant-buffer-in-parameter-block.slang
+++ b/tests/diagnostics/nested-constant-buffer-in-parameter-block.slang
@@ -1,20 +1,33 @@
//TEST:SIMPLE(filecheck=CHECK): -target metal
+struct T
+{
+ RWStructuredBuffer<int> buf;
+}
+
+struct U
+{
+ ConstantBuffer<T> t;
+}
+
struct S
{
- // CHECK-DAG: ([[# @LINE+1]]): error 56100:
- ConstantBuffer<int> cb;
+ ConstantBuffer<RWStructuredBuffer<int>> cb;
}
-ParameterBlock<S> s;
+// CHECK-DAG: ([[# @LINE+1]]): error 56101:
+ParameterBlock<S> s1;
+
+// CHECK-DAG: ([[# @LINE+1]]): error 56101:
+ParameterBlock<U> s2;
-// CHECK-DAG: ([[# @LINE+1]]): error 56100:
-ParameterBlock<ConstantBuffer<int>> s2;
+// CHECK-DAG: ([[# @LINE+1]]): error 56101:
+ParameterBlock<ConstantBuffer<RWStructuredBuffer<int>>> s3;
RWStructuredBuffer<int> outputBuffer;
[numthreads(1,1,1)]
void kernelMain()
{
- outputBuffer[0] = s.cb + s2;
-} \ No newline at end of file
+ outputBuffer[0] = s1.cb[0] + s2.t.buf[0] + s3[0];
+}
diff --git a/tools/slang-unit-test/unit-test-metal-parameter-block-constant-buffer.cpp b/tools/slang-unit-test/unit-test-metal-parameter-block-constant-buffer.cpp
new file mode 100644
index 000000000..d04cea974
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-metal-parameter-block-constant-buffer.cpp
@@ -0,0 +1,133 @@
+// unit-test-ptr-layout.cpp
+
+#include "slang-com-ptr.h"
+#include "slang.h"
+#include "unit-test/slang-unit-test.h"
+
+#include <stdlib.h>
+
+using namespace Slang;
+
+SLANG_UNIT_TEST(metalConstantBufferInParameterBlockLayout)
+{
+ const char* testSource = R"(
+ struct T
+ {
+ float4 m0;
+ float m1;
+ float3 m2;
+ };
+
+ ParameterBlock<ConstantBuffer<T>> params;
+ )";
+
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_METAL;
+ targetDesc.profile = globalSession->findProfile("metal");
+
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ "test",
+ "test.slang",
+ testSource,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ auto testBody = [&]()
+ {
+ auto reflection = module->getLayout();
+
+ // Collect our layouts
+ auto paramBlockType = reflection->findTypeByName("ParameterBlock<ConstantBuffer<T>>");
+ SLANG_CHECK(paramBlockType != nullptr);
+ auto paramBlockLayout = reflection->getTypeLayout(paramBlockType);
+ SLANG_CHECK(paramBlockLayout != nullptr);
+ auto cbufferLayout = paramBlockLayout->getElementTypeLayout();
+ SLANG_CHECK(cbufferLayout != nullptr);
+ auto structLayout = cbufferLayout->getElementTypeLayout();
+ SLANG_CHECK(structLayout != nullptr);
+
+ // Check offsets follow constant buffer rules (uniform alignment)
+ // m0 : float4 should be at offset 0
+ // m1 : float should be at offset 16 (after float4)
+ // m2 : float3 should be at offset 32 (aligned to 16-byte boundary)
+ SLANG_CHECK(structLayout->getFieldCount() == 3);
+ SLANG_CHECK(structLayout->getFieldByIndex(0)->getOffset() == 0);
+ SLANG_CHECK(structLayout->getFieldByIndex(1)->getOffset() == 16);
+ SLANG_CHECK(structLayout->getFieldByIndex(2)->getOffset() == 32);
+ };
+
+ testBody();
+}
+
+SLANG_UNIT_TEST(metalArgumentBufferLayout)
+{
+ const char* testSource = R"(
+ struct T
+ {
+ float4 m0;
+ float m1;
+ float3 m2;
+ };
+
+ // Using ParameterBlock directly without ConstantBuffer wrapper
+ ParameterBlock<T> params;
+ )";
+
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_METAL;
+ targetDesc.profile = globalSession->findProfile("metal");
+
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ "test",
+ "test.slang",
+ testSource,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ auto testBody = [&]()
+ {
+ auto reflection = module->getLayout();
+
+ // Collect our layouts
+ auto paramBlockType = reflection->findTypeByName("ParameterBlock<T>");
+ SLANG_CHECK(paramBlockType != nullptr);
+ auto paramBlockLayout = reflection->getTypeLayout(paramBlockType);
+ SLANG_CHECK(paramBlockLayout != nullptr);
+ auto structLayout = paramBlockLayout->getElementTypeLayout();
+ SLANG_CHECK(structLayout != nullptr);
+
+ // Check that offsets follow Metal argument buffer rules
+ // Fields should have 0 offset and meaningful binding indices
+ SLANG_CHECK(structLayout->getFieldCount() == 3);
+ SLANG_CHECK(structLayout->getFieldByIndex(0)->getOffset() == 0);
+ SLANG_CHECK(structLayout->getFieldByIndex(1)->getOffset() == 0);
+ SLANG_CHECK(structLayout->getFieldByIndex(2)->getOffset() == 0);
+ SLANG_CHECK(structLayout->getFieldByIndex(0)->getBindingIndex() == 0);
+ SLANG_CHECK(structLayout->getFieldByIndex(1)->getBindingIndex() == 1);
+ SLANG_CHECK(structLayout->getFieldByIndex(2)->getBindingIndex() == 2);
+ };
+
+ testBody();
+}