summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-12 19:31:25 -0700
committerGitHub <noreply@github.com>2024-03-12 19:31:25 -0700
commit6f7c8271710b43349d34b8f7569ceb6957400548 (patch)
tree288c18bb4b9a2cf32de7e400c1fe8b56385b727e
parenteef7e208bf7436a4f111a9290f37204e3220d82b (diff)
Fix `sessionDesc.defaultMatrixLayoutMode` being ineffective. (#3753)
* Fix `sessionDesc.defaultMatrixLayoutMode` being ineffective. * Fix matrix layout in buffer pointer. * Attempt to fix. * Fix buffer element type lowering for buffer pointers. * Add comment. * Fix test. * Fix member lookup in `Ref<T>`. * Fix validation error. * Enhance test.
-rw-r--r--build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj1
-rw-r--r--build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters3
-rw-r--r--source/slang/slang-check-expr.cpp30
-rw-r--r--source/slang/slang-check-type.cpp4
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp45
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.h2
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp11
-rw-r--r--source/slang/slang.cpp24
-rw-r--r--tests/front-end/typedef-matrix.slang13
-rw-r--r--tests/spirv/buffer-pointer-matrix-layout.slang34
-rw-r--r--tests/spirv/pointer-bug.slang2
-rw-r--r--tests/spirv/pointer.slang7
-rw-r--r--tools/slang-unit-test/unit-test-default-matrix-layout.cpp84
14 files changed, 232 insertions, 29 deletions
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
index e1a07455e..2798b80e4 100644
--- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
+++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
@@ -292,6 +292,7 @@
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-command-line-args.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-compression.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-crypto.cpp" />
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-default-matrix-layout.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-file-system.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-type-by-name.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-free-list.cpp" />
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
index 3c5302b7f..3fd04c077 100644
--- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
+++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
@@ -32,6 +32,9 @@
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-crypto.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-default-matrix-layout.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-file-system.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index ea6afdc1d..995b5e888 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -4164,15 +4164,39 @@ namespace Slang
List<Val*> modifierVals;
for( auto modifier : expr->modifiers )
{
+ if (auto matrixLayoutModifier = as<MatrixLayoutModifier>(modifier))
+ {
+ if (auto matrixType = as<MatrixExpressionType>(baseType))
+ {
+ if (as<ColumnMajorLayoutModifier>(matrixLayoutModifier))
+ {
+ baseType = m_astBuilder->getMatrixType(matrixType->getElementType(), matrixType->getRowCount(), matrixType->getColumnCount(),
+ m_astBuilder->getIntVal(m_astBuilder->getIntType(), kMatrixLayoutMode_ColumnMajor));
+ }
+ else
+ {
+ baseType = m_astBuilder->getMatrixType(matrixType->getElementType(), matrixType->getRowCount(), matrixType->getColumnCount(),
+ m_astBuilder->getIntVal(m_astBuilder->getIntType(), kMatrixLayoutMode_RowMajor));
+ }
+ expr->type = m_astBuilder->getTypeType(baseType);
+ }
+ else
+ {
+ getSink()->diagnose(matrixLayoutModifier, Diagnostics::matrixLayoutModifierOnNonMatrixType, baseType);
+ }
+ continue;
+ }
auto modifierVal = checkTypeModifier(modifier, baseType);
if(!modifierVal)
continue;
modifierVals.add(modifierVal);
}
- auto modifiedType = m_astBuilder->getModifiedType(baseType, modifierVals);
- expr->type = m_astBuilder->getTypeType(modifiedType);
-
+ if (modifierVals.getCount())
+ {
+ auto modifiedType = m_astBuilder->getModifiedType(baseType, modifierVals);
+ expr->type = m_astBuilder->getTypeType(modifiedType);
+ }
return expr;
}
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index 0ab3998d8..217d1b545 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -33,6 +33,10 @@ namespace Slang
{
return ptrType->getValueType();
}
+ else if (auto refType = as<RefType>(type))
+ {
+ return refType->getValueType();
+ }
return nullptr;
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d8cff1a9a..54761e772 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -662,6 +662,7 @@ DIAGNOSTIC(39024, Warning, cannotInferVulkanBindingWithoutRegisterModifier, "sha
DIAGNOSTIC(39025, Error, conflictingVulkanInferredBindingForParameter, "conflicting vulkan inferred binding for parameter '$0' overlap is $1 and $2")
+DIAGNOSTIC(39026, Error, matrixLayoutModifierOnNonMatrixType, "matrix layout modifier cannot be used on non-matrix type '$0'.")
//
// 4xxxx - IL code generation.
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 5c797ab23..eb5fba6a9 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -24,9 +24,10 @@ namespace Slang
SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
TargetProgram* target;
+ bool lowerBufferPointer = false;
- LoweredElementTypeContext(TargetProgram* target, SlangMatrixLayoutMode inDefaultMatrixLayout)
- : target(target), defaultMatrixLayout(inDefaultMatrixLayout)
+ LoweredElementTypeContext(TargetProgram* target, bool lowerBufferPointer, SlangMatrixLayoutMode inDefaultMatrixLayout)
+ : target(target), defaultMatrixLayout(inDefaultMatrixLayout), lowerBufferPointer(lowerBufferPointer)
{}
IRFunc* createMatrixUnpackFunc(
@@ -460,7 +461,15 @@ namespace Slang
LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules)
{
+ // If `type` is already a lowered type, no more lowering is required.
LoweredElementTypeInfo info;
+ if (auto pInfo = mapLoweredTypeToInfo->tryGetValue(type))
+ {
+ info.originalType = type;
+ info.loweredType = type;
+ return info;
+ }
+
if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info))
return info;
info = getLoweredTypeInfoImpl(type, rules);
@@ -513,10 +522,21 @@ namespace Slang
for (auto globalInst : module->getGlobalInsts())
{
IRType* elementType = nullptr;
- if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst))
- elementType = structBuffer->getElementType();
- else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst))
- elementType = constBuffer->getElementType();
+ if (lowerBufferPointer)
+ {
+ if (auto ptrType = as<IRPtrType>(globalInst))
+ {
+ if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer)
+ elementType = ptrType->getValueType();
+ }
+ }
+ else
+ {
+ if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst))
+ elementType = structBuffer->getElementType();
+ else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst))
+ elementType = constBuffer->getElementType();
+ }
if (as<IRTextureBufferType>(globalInst))
continue;
if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && !as<IRArrayType>(elementType) && !as<IRBoolType>(elementType))
@@ -654,17 +674,19 @@ namespace Slang
}
break;
case kIROp_RWStructuredBufferGetElementPtr:
+ case kIROp_GetOffsetPtr:
ptrValsWorkList.add(user);
break;
case kIROp_StructuredBufferGetDimensions:
break;
case kIROp_Call:
{
- // If a structured buffer typed value is used directly as an argument,
+ // If a structured buffer or pointer typed value is used directly as an argument,
// we don't need to do any marshalling here.
if (as<IRHLSLStructuredBufferTypeBase>(ptrVal->getDataType()))
break;
-
+ if (lowerBufferPointer && as<IRPtrType>(ptrVal->getDataType()))
+ break;
// If we are calling a function with an l-value pointer from buffer access,
// we need to materialize the object as a local variable, and pass the address
// of the local variable to the function.
@@ -681,7 +703,6 @@ namespace Slang
}
break;
default:
- SLANG_UNREACHABLE("unhandled inst of a buffer/pointer value that needs storage lowering.");
break;
}
});
@@ -801,12 +822,12 @@ namespace Slang
}
};
- void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module)
+ void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module, bool lowerBufferPointer)
{
SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode();
if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- LoweredElementTypeContext context(target, defaultMatrixMode);
+ LoweredElementTypeContext context(target, lowerBufferPointer, defaultMatrixMode);
context.processModule(module);
}
@@ -853,6 +874,8 @@ namespace Slang
case kIROp_ConstantBufferType:
case kIROp_ParameterBlockType:
return IRTypeLayoutRules::getStd140();
+ case kIROp_PtrType:
+ return IRTypeLayoutRules::getNatural();
}
return IRTypeLayoutRules::getNatural();
}
diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h
index 4d0a7eabe..95e6e6651 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.h
+++ b/source/slang/slang-ir-lower-buffer-element-type.h
@@ -15,7 +15,7 @@ namespace Slang
// This pass needs to take place after type legalization, and before array return type lowering
// because it may create functions that returns array typed values.
//
- void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module);
+ void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module, bool lowerBufferPointer = false);
// Returns the type layout rules should be used for a buffer resource type.
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index d072e2da0..511c596a4 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -20,6 +20,7 @@
#include "slang-ir-peephole.h"
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-loop-unroll.h"
+#include "slang-ir-lower-buffer-element-type.h"
namespace Slang
{
@@ -1853,8 +1854,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
void processModule()
{
- //convertCompositeTypeParametersToPointers(m_module);
-
// Process global params before anything else, so we don't generate inefficient
// array marhalling code for array-typed global params.
for (auto globalInst : m_module->getGlobalInsts())
@@ -1944,6 +1943,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Some legalization processing may change the function parameter types,
// so we need to update the function types to match that.
updateFunctionTypes();
+
+ // Lower all loads/stores from buffer pointers to use correct storage types.
+ // We didn't do the lowering for buffer pointers because we don't know which pointer
+ // types are actual storage buffer pointers until we propagated the address space of
+ // pointers in this pass. In the future we should consider separate out IRAddress as
+ // the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can
+ // safely lower the pointer load stores early together with other buffer types.
+ lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true);
}
void updateFunctionTypes()
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 69c0f0e14..978dff7c4 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -594,17 +594,6 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::createSession(
slang::SessionDesc desc = makeFromSizeVersioned<slang::SessionDesc>((uint8_t*)&inDesc);
RefPtr<Linkage> linkage = new Linkage(this, astBuilder, getBuiltinLinkage());
- linkage->m_optionSet.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries);
-
- {
- const Int targetCount = desc.targetCount;
- const uint8_t* targetDescPtr = reinterpret_cast<const uint8_t*>(desc.targets);
- for (Int ii = 0; ii < targetCount; ++ii, targetDescPtr += _getStructureSize(targetDescPtr))
- {
- const auto targetDesc = makeFromSizeVersioned<slang::TargetDesc>(targetDescPtr);
- linkage->addTarget(targetDesc);
- }
- }
linkage->setMatrixLayoutMode(desc.defaultMatrixLayoutMode);
@@ -630,6 +619,19 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::createSession(
{
linkage->m_optionSet.set(CompilerOptionName::EnableEffectAnnotations, desc.enableEffectAnnotations);
}
+
+ linkage->m_optionSet.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries);
+
+ {
+ const Int targetCount = desc.targetCount;
+ const uint8_t* targetDescPtr = reinterpret_cast<const uint8_t*>(desc.targets);
+ for (Int ii = 0; ii < targetCount; ++ii, targetDescPtr += _getStructureSize(targetDescPtr))
+ {
+ const auto targetDesc = makeFromSizeVersioned<slang::TargetDesc>(targetDescPtr);
+ linkage->addTarget(targetDesc);
+ }
+ }
+
*outSession = asExternal(linkage.detach());
return SLANG_OK;
}
diff --git a/tests/front-end/typedef-matrix.slang b/tests/front-end/typedef-matrix.slang
new file mode 100644
index 000000000..81188f806
--- /dev/null
+++ b/tests/front-end/typedef-matrix.slang
@@ -0,0 +1,13 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -entry main -stage compute -matrix-layout-row-major
+
+// CHECK: ColMajor
+
+typedef column_major float3x4 Mat;
+
+RWStructuredBuffer<float> output;
+
+[numthreads(1,1,1)]
+void main(uniform Mat m)
+{
+ output[0] = m[0][0];
+} \ No newline at end of file
diff --git a/tests/spirv/buffer-pointer-matrix-layout.slang b/tests/spirv/buffer-pointer-matrix-layout.slang
new file mode 100644
index 000000000..2ccde7b71
--- /dev/null
+++ b/tests/spirv/buffer-pointer-matrix-layout.slang
@@ -0,0 +1,34 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -stage compute -entry main -matrix-layout-column-major
+
+// CHECK: OpLoad %_MatrixStorage_float3x4_ColMajornatural {{.*}} Aligned 4
+// CHECK: OpLoad %_MatrixStorage_float3x4_ColMajornatural {{.*}} Aligned 4
+
+struct Push
+{
+ float3x4* ptr;
+};
+
+[[vk::push_constant]] Push push;
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void main(uint3 dtid : SV_DispatchThreadID)
+{
+ // This matrix is in memry column major. Slang respects this here and load it properly!
+ float3x4 correctly_read_matrix = *push.ptr;
+ printf("(%f,%f,%f,%f)\n(%f,%f,%f,%f)\n",
+ correctly_read_matrix[0][0], correctly_read_matrix[0][1], correctly_read_matrix[0][2], correctly_read_matrix[0][3],
+ correctly_read_matrix[1][0], correctly_read_matrix[1][1], correctly_read_matrix[1][2], correctly_read_matrix[1][3]
+ );
+ printf("(%f,%f,%f,%f)\n\n",
+ correctly_read_matrix[2][0], correctly_read_matrix[2][1], correctly_read_matrix[2][2], correctly_read_matrix[2][3]
+ );
+ // With this syntax however, Slang ignores the column major setting and loads it as it it was row major!
+ float3x4 broken_matrix = push.ptr[0];
+ printf("(%f,%f,%f,%f)\n(%f,%f,%f,%f)\n",
+ broken_matrix[0][0], broken_matrix[0][1], broken_matrix[0][2], broken_matrix[0][3],
+ broken_matrix[1][0], broken_matrix[1][1], broken_matrix[1][2], broken_matrix[1][3]
+ );
+ printf("(%f,%f,%f,%f)\n\n",
+ broken_matrix[2][0], broken_matrix[2][1], broken_matrix[2][2], broken_matrix[2][3]
+ );
+} \ No newline at end of file
diff --git a/tests/spirv/pointer-bug.slang b/tests/spirv/pointer-bug.slang
index 1668cec13..404da286f 100644
--- a/tests/spirv/pointer-bug.slang
+++ b/tests/spirv/pointer-bug.slang
@@ -7,7 +7,7 @@ struct Params {
Foo *foo;
};
-// CHECK: %_ptr_PhysicalStorageBuffer_Foo = OpTypePointer PhysicalStorageBuffer %Foo
+// CHECK: OpTypePointer PushConstant %_ptr_PhysicalStorageBuffer_Foo_natural
[[vk::push_constant]] Params params;
diff --git a/tests/spirv/pointer.slang b/tests/spirv/pointer.slang
index cb2d56f66..cd4845e4d 100644
--- a/tests/spirv/pointer.slang
+++ b/tests/spirv/pointer.slang
@@ -21,6 +21,11 @@ int* funcThatReturnsPointer(PP* p)
return &p.data;
}
+void funcWithInOutParam(inout PP p)
+{
+ p.data = 0;
+}
+
// CHECK: OpEntryPoint
StructuredBuffer<Data> buffer;
@@ -44,5 +49,7 @@ void main(int id : SV_DispatchThreadID)
if (pData1 > pData)
{
funcThatTakesPointer(buffer[0].pNext);
+ output[1] = (*buffer[0].pNext).data;
}
+ funcWithInOutParam(*buffer[0].pNext);
}
diff --git a/tools/slang-unit-test/unit-test-default-matrix-layout.cpp b/tools/slang-unit-test/unit-test-default-matrix-layout.cpp
new file mode 100644
index 000000000..468b7d986
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-default-matrix-layout.cpp
@@ -0,0 +1,84 @@
+// unit-test-default-matrix-layout.cpp
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "tools/unit-test/slang-unit-test.h"
+
+#include "../../slang.h"
+#include "../../slang-com-helper.h"
+#include "../../slang-com-ptr.h"
+
+#include "../../source/core/slang-list.h"
+
+namespace {
+
+using namespace Slang;
+
+struct DefaultMatrixLayoutTestContext
+{
+ DefaultMatrixLayoutTestContext(UnitTestContext* context):
+ m_unitTestContext(context)
+ {
+ slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession;
+ }
+
+ SlangResult runTests()
+ {
+ slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession;
+ ComPtr<slang::ISession> session;
+ slang::SessionDesc sessionDesc{};
+ sessionDesc.targetCount = 1;
+ slang::TargetDesc targetDesc{};
+ targetDesc.format = SLANG_GLSL;
+ targetDesc.profile = slangSession->findProfile("glsl_460");
+ sessionDesc.targets = &targetDesc;
+ sessionDesc.defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
+ SLANG_RETURN_ON_FAIL(slangSession->createSession(sessionDesc, session.writeRef()));
+
+ auto module = session->loadModuleFromSourceString("mymodule", "mymodule.slang",
+ R"(
+ RWStructuredBuffer<float> output;
+ [numthreads(1,1,1)] [shader("compute")]
+ void main(uniform float3x4 m)
+ {
+ output[0] = m[0][0];
+ })");
+ if (!module)
+ return SLANG_FAIL;
+
+ ComPtr<slang::IEntryPoint> entryPoint;
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("main", entryPoint.writeRef()));
+
+ if (!entryPoint)
+ return SLANG_FAIL;
+
+ slang::IComponentType* components[] = { module, entryPoint.get() };
+ ComPtr<slang::IComponentType> composedProgram;
+ SLANG_RETURN_ON_FAIL(session->createCompositeComponentType(components, 2, composedProgram.writeRef()));
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ SLANG_RETURN_ON_FAIL(composedProgram->link(linkedProgram.writeRef()));
+
+ ComPtr<slang::IBlob> outCode;
+ SLANG_RETURN_ON_FAIL(linkedProgram->getEntryPointCode(0, 0, outCode.writeRef()));
+
+ const char* code = (const char*)outCode->getBufferPointer();
+ if (strstr(code, "row_major") != nullptr)
+ return SLANG_OK;
+ return SLANG_FAIL;
+ }
+
+ UnitTestContext* m_unitTestContext;
+};
+
+} // anonymous
+
+SLANG_UNIT_TEST(defaultMatrixLayout)
+{
+ DefaultMatrixLayoutTestContext context(unitTestContext);
+
+ const auto result = context.runTests();
+
+ SLANG_CHECK(SLANG_SUCCEEDED(result));
+}