diff options
| author | Yong He <yonghe@outlook.com> | 2024-03-12 19:31:25 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-12 19:31:25 -0700 |
| commit | 6f7c8271710b43349d34b8f7569ceb6957400548 (patch) | |
| tree | 288c18bb4b9a2cf32de7e400c1fe8b56385b727e | |
| parent | eef7e208bf7436a4f111a9290f37204e3220d82b (diff) | |
Fix `sessionDesc.defaultMatrixLayoutMode` being ineffective. (#3753)
* Fix `sessionDesc.defaultMatrixLayoutMode` being ineffective.
* Fix matrix layout in buffer pointer.
* Attempt to fix.
* Fix buffer element type lowering for buffer pointers.
* Add comment.
* Fix test.
* Fix member lookup in `Ref<T>`.
* Fix validation error.
* Enhance test.
| -rw-r--r-- | build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj | 1 | ||||
| -rw-r--r-- | build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters | 3 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 45 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 24 | ||||
| -rw-r--r-- | tests/front-end/typedef-matrix.slang | 13 | ||||
| -rw-r--r-- | tests/spirv/buffer-pointer-matrix-layout.slang | 34 | ||||
| -rw-r--r-- | tests/spirv/pointer-bug.slang | 2 | ||||
| -rw-r--r-- | tests/spirv/pointer.slang | 7 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-default-matrix-layout.cpp | 84 |
14 files changed, 232 insertions, 29 deletions
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj index e1a07455e..2798b80e4 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj @@ -292,6 +292,7 @@ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-command-line-args.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-compression.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-crypto.cpp" />
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-default-matrix-layout.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-file-system.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-find-type-by-name.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-free-list.cpp" />
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters index 3c5302b7f..3fd04c077 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters @@ -32,6 +32,9 @@ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-crypto.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-default-matrix-layout.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-file-system.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index ea6afdc1d..995b5e888 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -4164,15 +4164,39 @@ namespace Slang List<Val*> modifierVals; for( auto modifier : expr->modifiers ) { + if (auto matrixLayoutModifier = as<MatrixLayoutModifier>(modifier)) + { + if (auto matrixType = as<MatrixExpressionType>(baseType)) + { + if (as<ColumnMajorLayoutModifier>(matrixLayoutModifier)) + { + baseType = m_astBuilder->getMatrixType(matrixType->getElementType(), matrixType->getRowCount(), matrixType->getColumnCount(), + m_astBuilder->getIntVal(m_astBuilder->getIntType(), kMatrixLayoutMode_ColumnMajor)); + } + else + { + baseType = m_astBuilder->getMatrixType(matrixType->getElementType(), matrixType->getRowCount(), matrixType->getColumnCount(), + m_astBuilder->getIntVal(m_astBuilder->getIntType(), kMatrixLayoutMode_RowMajor)); + } + expr->type = m_astBuilder->getTypeType(baseType); + } + else + { + getSink()->diagnose(matrixLayoutModifier, Diagnostics::matrixLayoutModifierOnNonMatrixType, baseType); + } + continue; + } auto modifierVal = checkTypeModifier(modifier, baseType); if(!modifierVal) continue; modifierVals.add(modifierVal); } - auto modifiedType = m_astBuilder->getModifiedType(baseType, modifierVals); - expr->type = m_astBuilder->getTypeType(modifiedType); - + if (modifierVals.getCount()) + { + auto modifiedType = m_astBuilder->getModifiedType(baseType, modifierVals); + expr->type = m_astBuilder->getTypeType(modifiedType); + } return expr; } diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 0ab3998d8..217d1b545 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -33,6 +33,10 @@ namespace Slang { return ptrType->getValueType(); } + else if (auto refType = as<RefType>(type)) + { + return refType->getValueType(); + } return nullptr; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d8cff1a9a..54761e772 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -662,6 +662,7 @@ DIAGNOSTIC(39024, Warning, cannotInferVulkanBindingWithoutRegisterModifier, "sha DIAGNOSTIC(39025, Error, conflictingVulkanInferredBindingForParameter, "conflicting vulkan inferred binding for parameter '$0' overlap is $1 and $2") +DIAGNOSTIC(39026, Error, matrixLayoutModifierOnNonMatrixType, "matrix layout modifier cannot be used on non-matrix type '$0'.") // // 4xxxx - IL code generation. diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 5c797ab23..eb5fba6a9 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -24,9 +24,10 @@ namespace Slang SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; TargetProgram* target; + bool lowerBufferPointer = false; - LoweredElementTypeContext(TargetProgram* target, SlangMatrixLayoutMode inDefaultMatrixLayout) - : target(target), defaultMatrixLayout(inDefaultMatrixLayout) + LoweredElementTypeContext(TargetProgram* target, bool lowerBufferPointer, SlangMatrixLayoutMode inDefaultMatrixLayout) + : target(target), defaultMatrixLayout(inDefaultMatrixLayout), lowerBufferPointer(lowerBufferPointer) {} IRFunc* createMatrixUnpackFunc( @@ -460,7 +461,15 @@ namespace Slang LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules) { + // If `type` is already a lowered type, no more lowering is required. LoweredElementTypeInfo info; + if (auto pInfo = mapLoweredTypeToInfo->tryGetValue(type)) + { + info.originalType = type; + info.loweredType = type; + return info; + } + if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info)) return info; info = getLoweredTypeInfoImpl(type, rules); @@ -513,10 +522,21 @@ namespace Slang for (auto globalInst : module->getGlobalInsts()) { IRType* elementType = nullptr; - if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) - elementType = structBuffer->getElementType(); - else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) - elementType = constBuffer->getElementType(); + if (lowerBufferPointer) + { + if (auto ptrType = as<IRPtrType>(globalInst)) + { + if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + elementType = ptrType->getValueType(); + } + } + else + { + if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) + elementType = structBuffer->getElementType(); + else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) + elementType = constBuffer->getElementType(); + } if (as<IRTextureBufferType>(globalInst)) continue; if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && !as<IRArrayType>(elementType) && !as<IRBoolType>(elementType)) @@ -654,17 +674,19 @@ namespace Slang } break; case kIROp_RWStructuredBufferGetElementPtr: + case kIROp_GetOffsetPtr: ptrValsWorkList.add(user); break; case kIROp_StructuredBufferGetDimensions: break; case kIROp_Call: { - // If a structured buffer typed value is used directly as an argument, + // If a structured buffer or pointer typed value is used directly as an argument, // we don't need to do any marshalling here. if (as<IRHLSLStructuredBufferTypeBase>(ptrVal->getDataType())) break; - + if (lowerBufferPointer && as<IRPtrType>(ptrVal->getDataType())) + break; // If we are calling a function with an l-value pointer from buffer access, // we need to materialize the object as a local variable, and pass the address // of the local variable to the function. @@ -681,7 +703,6 @@ namespace Slang } break; default: - SLANG_UNREACHABLE("unhandled inst of a buffer/pointer value that needs storage lowering."); break; } }); @@ -801,12 +822,12 @@ namespace Slang } }; - void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module) + void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module, bool lowerBufferPointer) { SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode(); if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - LoweredElementTypeContext context(target, defaultMatrixMode); + LoweredElementTypeContext context(target, lowerBufferPointer, defaultMatrixMode); context.processModule(module); } @@ -853,6 +874,8 @@ namespace Slang case kIROp_ConstantBufferType: case kIROp_ParameterBlockType: return IRTypeLayoutRules::getStd140(); + case kIROp_PtrType: + return IRTypeLayoutRules::getNatural(); } return IRTypeLayoutRules::getNatural(); } diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h index 4d0a7eabe..95e6e6651 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.h +++ b/source/slang/slang-ir-lower-buffer-element-type.h @@ -15,7 +15,7 @@ namespace Slang // This pass needs to take place after type legalization, and before array return type lowering // because it may create functions that returns array typed values. // - void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module); + void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module, bool lowerBufferPointer = false); // Returns the type layout rules should be used for a buffer resource type. diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index d072e2da0..511c596a4 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -20,6 +20,7 @@ #include "slang-ir-peephole.h" #include "slang-ir-redundancy-removal.h" #include "slang-ir-loop-unroll.h" +#include "slang-ir-lower-buffer-element-type.h" namespace Slang { @@ -1853,8 +1854,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void processModule() { - //convertCompositeTypeParametersToPointers(m_module); - // Process global params before anything else, so we don't generate inefficient // array marhalling code for array-typed global params. for (auto globalInst : m_module->getGlobalInsts()) @@ -1944,6 +1943,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // Some legalization processing may change the function parameter types, // so we need to update the function types to match that. updateFunctionTypes(); + + // Lower all loads/stores from buffer pointers to use correct storage types. + // We didn't do the lowering for buffer pointers because we don't know which pointer + // types are actual storage buffer pointers until we propagated the address space of + // pointers in this pass. In the future we should consider separate out IRAddress as + // the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can + // safely lower the pointer load stores early together with other buffer types. + lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true); } void updateFunctionTypes() diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 69c0f0e14..978dff7c4 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -594,17 +594,6 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::createSession( slang::SessionDesc desc = makeFromSizeVersioned<slang::SessionDesc>((uint8_t*)&inDesc); RefPtr<Linkage> linkage = new Linkage(this, astBuilder, getBuiltinLinkage()); - linkage->m_optionSet.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries); - - { - const Int targetCount = desc.targetCount; - const uint8_t* targetDescPtr = reinterpret_cast<const uint8_t*>(desc.targets); - for (Int ii = 0; ii < targetCount; ++ii, targetDescPtr += _getStructureSize(targetDescPtr)) - { - const auto targetDesc = makeFromSizeVersioned<slang::TargetDesc>(targetDescPtr); - linkage->addTarget(targetDesc); - } - } linkage->setMatrixLayoutMode(desc.defaultMatrixLayoutMode); @@ -630,6 +619,19 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::createSession( { linkage->m_optionSet.set(CompilerOptionName::EnableEffectAnnotations, desc.enableEffectAnnotations); } + + linkage->m_optionSet.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries); + + { + const Int targetCount = desc.targetCount; + const uint8_t* targetDescPtr = reinterpret_cast<const uint8_t*>(desc.targets); + for (Int ii = 0; ii < targetCount; ++ii, targetDescPtr += _getStructureSize(targetDescPtr)) + { + const auto targetDesc = makeFromSizeVersioned<slang::TargetDesc>(targetDescPtr); + linkage->addTarget(targetDesc); + } + } + *outSession = asExternal(linkage.detach()); return SLANG_OK; } diff --git a/tests/front-end/typedef-matrix.slang b/tests/front-end/typedef-matrix.slang new file mode 100644 index 000000000..81188f806 --- /dev/null +++ b/tests/front-end/typedef-matrix.slang @@ -0,0 +1,13 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -entry main -stage compute -matrix-layout-row-major + +// CHECK: ColMajor + +typedef column_major float3x4 Mat; + +RWStructuredBuffer<float> output; + +[numthreads(1,1,1)] +void main(uniform Mat m) +{ + output[0] = m[0][0]; +}
\ No newline at end of file diff --git a/tests/spirv/buffer-pointer-matrix-layout.slang b/tests/spirv/buffer-pointer-matrix-layout.slang new file mode 100644 index 000000000..2ccde7b71 --- /dev/null +++ b/tests/spirv/buffer-pointer-matrix-layout.slang @@ -0,0 +1,34 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -stage compute -entry main -matrix-layout-column-major + +// CHECK: OpLoad %_MatrixStorage_float3x4_ColMajornatural {{.*}} Aligned 4 +// CHECK: OpLoad %_MatrixStorage_float3x4_ColMajornatural {{.*}} Aligned 4 + +struct Push +{ + float3x4* ptr; +}; + +[[vk::push_constant]] Push push; +[shader("compute")] +[numthreads(1, 1, 1)] +void main(uint3 dtid : SV_DispatchThreadID) +{ + // This matrix is in memry column major. Slang respects this here and load it properly! + float3x4 correctly_read_matrix = *push.ptr; + printf("(%f,%f,%f,%f)\n(%f,%f,%f,%f)\n", + correctly_read_matrix[0][0], correctly_read_matrix[0][1], correctly_read_matrix[0][2], correctly_read_matrix[0][3], + correctly_read_matrix[1][0], correctly_read_matrix[1][1], correctly_read_matrix[1][2], correctly_read_matrix[1][3] + ); + printf("(%f,%f,%f,%f)\n\n", + correctly_read_matrix[2][0], correctly_read_matrix[2][1], correctly_read_matrix[2][2], correctly_read_matrix[2][3] + ); + // With this syntax however, Slang ignores the column major setting and loads it as it it was row major! + float3x4 broken_matrix = push.ptr[0]; + printf("(%f,%f,%f,%f)\n(%f,%f,%f,%f)\n", + broken_matrix[0][0], broken_matrix[0][1], broken_matrix[0][2], broken_matrix[0][3], + broken_matrix[1][0], broken_matrix[1][1], broken_matrix[1][2], broken_matrix[1][3] + ); + printf("(%f,%f,%f,%f)\n\n", + broken_matrix[2][0], broken_matrix[2][1], broken_matrix[2][2], broken_matrix[2][3] + ); +}
\ No newline at end of file diff --git a/tests/spirv/pointer-bug.slang b/tests/spirv/pointer-bug.slang index 1668cec13..404da286f 100644 --- a/tests/spirv/pointer-bug.slang +++ b/tests/spirv/pointer-bug.slang @@ -7,7 +7,7 @@ struct Params { Foo *foo; }; -// CHECK: %_ptr_PhysicalStorageBuffer_Foo = OpTypePointer PhysicalStorageBuffer %Foo +// CHECK: OpTypePointer PushConstant %_ptr_PhysicalStorageBuffer_Foo_natural [[vk::push_constant]] Params params; diff --git a/tests/spirv/pointer.slang b/tests/spirv/pointer.slang index cb2d56f66..cd4845e4d 100644 --- a/tests/spirv/pointer.slang +++ b/tests/spirv/pointer.slang @@ -21,6 +21,11 @@ int* funcThatReturnsPointer(PP* p) return &p.data; } +void funcWithInOutParam(inout PP p) +{ + p.data = 0; +} + // CHECK: OpEntryPoint StructuredBuffer<Data> buffer; @@ -44,5 +49,7 @@ void main(int id : SV_DispatchThreadID) if (pData1 > pData) { funcThatTakesPointer(buffer[0].pNext); + output[1] = (*buffer[0].pNext).data; } + funcWithInOutParam(*buffer[0].pNext); } diff --git a/tools/slang-unit-test/unit-test-default-matrix-layout.cpp b/tools/slang-unit-test/unit-test-default-matrix-layout.cpp new file mode 100644 index 000000000..468b7d986 --- /dev/null +++ b/tools/slang-unit-test/unit-test-default-matrix-layout.cpp @@ -0,0 +1,84 @@ +// unit-test-default-matrix-layout.cpp + +#include <stdio.h> +#include <stdlib.h> + +#include "tools/unit-test/slang-unit-test.h" + +#include "../../slang.h" +#include "../../slang-com-helper.h" +#include "../../slang-com-ptr.h" + +#include "../../source/core/slang-list.h" + +namespace { + +using namespace Slang; + +struct DefaultMatrixLayoutTestContext +{ + DefaultMatrixLayoutTestContext(UnitTestContext* context): + m_unitTestContext(context) + { + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + } + + SlangResult runTests() + { + slang::IGlobalSession* slangSession = m_unitTestContext->slangGlobalSession; + ComPtr<slang::ISession> session; + slang::SessionDesc sessionDesc{}; + sessionDesc.targetCount = 1; + slang::TargetDesc targetDesc{}; + targetDesc.format = SLANG_GLSL; + targetDesc.profile = slangSession->findProfile("glsl_460"); + sessionDesc.targets = &targetDesc; + sessionDesc.defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; + SLANG_RETURN_ON_FAIL(slangSession->createSession(sessionDesc, session.writeRef())); + + auto module = session->loadModuleFromSourceString("mymodule", "mymodule.slang", + R"( + RWStructuredBuffer<float> output; + [numthreads(1,1,1)] [shader("compute")] + void main(uniform float3x4 m) + { + output[0] = m[0][0]; + })"); + if (!module) + return SLANG_FAIL; + + ComPtr<slang::IEntryPoint> entryPoint; + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("main", entryPoint.writeRef())); + + if (!entryPoint) + return SLANG_FAIL; + + slang::IComponentType* components[] = { module, entryPoint.get() }; + ComPtr<slang::IComponentType> composedProgram; + SLANG_RETURN_ON_FAIL(session->createCompositeComponentType(components, 2, composedProgram.writeRef())); + + ComPtr<slang::IComponentType> linkedProgram; + SLANG_RETURN_ON_FAIL(composedProgram->link(linkedProgram.writeRef())); + + ComPtr<slang::IBlob> outCode; + SLANG_RETURN_ON_FAIL(linkedProgram->getEntryPointCode(0, 0, outCode.writeRef())); + + const char* code = (const char*)outCode->getBufferPointer(); + if (strstr(code, "row_major") != nullptr) + return SLANG_OK; + return SLANG_FAIL; + } + + UnitTestContext* m_unitTestContext; +}; + +} // anonymous + +SLANG_UNIT_TEST(defaultMatrixLayout) +{ + DefaultMatrixLayoutTestContext context(unitTestContext); + + const auto result = context.runTests(); + + SLANG_CHECK(SLANG_SUCCEEDED(result)); +} |
