summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-type-layout.cpp75
-rw-r--r--tools/slang-unit-test/unit-test-argument-buffer-tier-2-reflection.cpp70
2 files changed, 143 insertions, 2 deletions
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 20a18a5ac..a412bf5b2 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -661,7 +661,7 @@ struct MetalLayoutRulesImpl : public CPULayoutRulesImpl
auto alignedElementCount = 1 << Math::Log2Ceil((uint32_t)elementCount);
// Metal aligns vectors to 2/4 element boundaries.
- size_t size = elementSize * elementCount;
+ size_t size = alignedElementCount * elementSize;
size_t alignment = alignedElementCount * elementSize;
SimpleLayoutInfo vectorInfo;
@@ -1147,6 +1147,14 @@ struct MetalLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
LayoutRulesImpl* getStructuredBufferRules(CompilerOptionSet& compilerOptions) override;
};
+struct MetalArgumentBufferTier2LayoutRulesFamilyImpl : MetalLayoutRulesFamilyImpl
+{
+ virtual LayoutRulesImpl* getConstantBufferRules(
+ CompilerOptionSet& compilerOptions,
+ Type* containerType) override;
+ virtual LayoutRulesImpl* getParameterBlockRules(CompilerOptionSet& compilerOptions) override;
+};
+
struct WGSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
{
virtual LayoutRulesImpl* getAnyValueRules() override;
@@ -1175,6 +1183,7 @@ HLSLLayoutRulesFamilyImpl kHLSLLayoutRulesFamilyImpl;
CPULayoutRulesFamilyImpl kCPULayoutRulesFamilyImpl;
CUDALayoutRulesFamilyImpl kCUDALayoutRulesFamilyImpl;
MetalLayoutRulesFamilyImpl kMetalLayoutRulesFamilyImpl;
+MetalArgumentBufferTier2LayoutRulesFamilyImpl kMetalArgumentBufferTier2LayoutRulesFamilyImpl;
WGSLLayoutRulesFamilyImpl kWGSLLayoutRulesFamilyImpl;
// CPU case
@@ -1969,8 +1978,44 @@ struct MetalArgumentBufferElementLayoutRulesImpl : ObjectLayoutRulesImpl, Defaul
}
};
+struct MetalTier2ObjectLayoutRulesImpl : ObjectLayoutRulesImpl
+{
+ virtual ObjectLayoutInfo GetObjectLayout(ShaderParameterKind kind, const Options& /* options */)
+ override
+ {
+ switch (kind)
+ {
+ case ShaderParameterKind::ConstantBuffer:
+ case ShaderParameterKind::ParameterBlock:
+ case ShaderParameterKind::StructuredBuffer:
+ case ShaderParameterKind::MutableStructuredBuffer:
+ case ShaderParameterKind::RawBuffer:
+ case ShaderParameterKind::Buffer:
+ case ShaderParameterKind::MutableRawBuffer:
+ case ShaderParameterKind::MutableBuffer:
+ case ShaderParameterKind::ShaderStorageBuffer:
+ case ShaderParameterKind::AccelerationStructure:
+ return SimpleLayoutInfo(LayoutResourceKind::Uniform, 8, 8);
+ case ShaderParameterKind::AppendConsumeStructuredBuffer:
+ return SimpleLayoutInfo(LayoutResourceKind::Uniform, 16, 8);
+ case ShaderParameterKind::MutableTexture:
+ case ShaderParameterKind::TextureUniformBuffer:
+ case ShaderParameterKind::Texture:
+ case ShaderParameterKind::SamplerState:
+ return SimpleLayoutInfo(LayoutResourceKind::Uniform, 8, 8);
+ case ShaderParameterKind::TextureSampler:
+ case ShaderParameterKind::MutableTextureSampler:
+ return SimpleLayoutInfo(LayoutResourceKind::Uniform, 16, 8);
+ default:
+ SLANG_UNEXPECTED("unhandled shader parameter kind");
+ UNREACHABLE_RETURN(SimpleLayoutInfo());
+ }
+ }
+};
+
static MetalObjectLayoutRulesImpl kMetalObjectLayoutRulesImpl;
static MetalArgumentBufferElementLayoutRulesImpl kMetalArgumentBufferElementLayoutRulesImpl;
+static MetalTier2ObjectLayoutRulesImpl kMetalTier2ObjectLayoutRulesImpl;
static MetalLayoutRulesImpl kMetalLayoutRulesImpl;
LayoutRulesImpl kMetalAnyValueLayoutRulesImpl_ = {
@@ -1991,6 +2036,18 @@ LayoutRulesImpl kMetalParameterBlockLayoutRulesImpl_ = {
&kMetalArgumentBufferElementLayoutRulesImpl,
};
+LayoutRulesImpl kMetalTier2ConstantBufferLayoutRulesImpl_ = {
+ &kMetalLayoutRulesFamilyImpl,
+ &kMetalLayoutRulesImpl,
+ &kMetalTier2ObjectLayoutRulesImpl,
+};
+
+LayoutRulesImpl kMetalTier2ParameterBlockLayoutRulesImpl_ = {
+ &kMetalLayoutRulesFamilyImpl,
+ &kMetalLayoutRulesImpl,
+ &kMetalTier2ObjectLayoutRulesImpl,
+};
+
LayoutRulesImpl kMetalStructuredBufferLayoutRulesImpl_ = {
&kMetalLayoutRulesFamilyImpl,
&kMetalLayoutRulesImpl,
@@ -2079,6 +2136,20 @@ LayoutRulesImpl* MetalLayoutRulesFamilyImpl::getHitAttributesParameterRules()
return nullptr;
}
+LayoutRulesImpl* MetalArgumentBufferTier2LayoutRulesFamilyImpl::getConstantBufferRules(
+ CompilerOptionSet&,
+ Type*)
+{
+ return &kMetalTier2ConstantBufferLayoutRulesImpl_;
+}
+
+LayoutRulesImpl* MetalArgumentBufferTier2LayoutRulesFamilyImpl::getParameterBlockRules(
+ CompilerOptionSet&)
+{
+ return &kMetalTier2ParameterBlockLayoutRulesImpl_;
+}
+
+
// WGSL Family
LayoutRulesImpl kWGSLConstantBufferLayoutRulesImpl_ = {
@@ -2229,7 +2300,7 @@ TypeLayoutContext getInitialLayoutContextForTarget(
rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq);
break;
case slang::LayoutRules::MetalArgumentBufferTier2:
- rulesFamily = &kCPULayoutRulesFamilyImpl;
+ rulesFamily = &kMetalArgumentBufferTier2LayoutRulesFamilyImpl;
break;
}
diff --git a/tools/slang-unit-test/unit-test-argument-buffer-tier-2-reflection.cpp b/tools/slang-unit-test/unit-test-argument-buffer-tier-2-reflection.cpp
new file mode 100644
index 000000000..a63f38f89
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-argument-buffer-tier-2-reflection.cpp
@@ -0,0 +1,70 @@
+// unit-test-argument-buffer-tier-2-reflection.cpp
+
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-process.h"
+#include "slang-com-ptr.h"
+#include "slang.h"
+#include "unit-test/slang-unit-test.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+using namespace Slang;
+
+// Test metal argument buffer tier2 layout rules.
+
+SLANG_UNIT_TEST(metalArgumentBufferTier2Reflection)
+{
+ const char* userSourceBody = R"(
+ struct A
+ {
+ float3 one;
+ float3 two;
+ float three;
+ }
+
+ struct Args{
+ ParameterBlock<A> a;
+ }
+ ParameterBlock<Args> argument_buffer;
+ RWStructuredBuffer<float> outputBuffer;
+
+ [numthreads(1,1,1)]
+ void computeMain()
+ {
+ outputBuffer[0] = argument_buffer.a.two.x;
+ }
+ )";
+
+ auto moduleName = "moduleG" + String(Process::getId());
+ String userSource = "import " + moduleName + ";\n" + userSourceBody;
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_SPIRV;
+ targetDesc.profile = globalSession->findProfile("spirv_1_5");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ "m",
+ "m.slang",
+ userSourceBody,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ auto layout = module->getLayout();
+
+ auto type = layout->findTypeByName("A");
+ auto typeLayout = layout->getTypeLayout(type, slang::LayoutRules::MetalArgumentBufferTier2);
+ SLANG_CHECK(typeLayout->getFieldByIndex(0)->getOffset() == 0);
+ SLANG_CHECK(typeLayout->getFieldByIndex(0)->getTypeLayout()->getSize() == 16);
+ SLANG_CHECK(typeLayout->getFieldByIndex(1)->getOffset() == 16);
+ SLANG_CHECK(typeLayout->getFieldByIndex(1)->getTypeLayout()->getSize() == 16);
+ SLANG_CHECK(typeLayout->getFieldByIndex(2)->getOffset() == 32);
+ SLANG_CHECK(typeLayout->getFieldByIndex(2)->getTypeLayout()->getSize() == 4);
+}