diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-05-14 03:01:47 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-14 07:01:47 +0000 |
| commit | eb5648b41d0718648477cbcf941fb3c6edf6dfc7 (patch) | |
| tree | 73f19cd4b4f66d0ce6fbbd61c9a1253969cc3139 | |
| parent | 04ba87e23435e76583c05d4530d63686f9af712f (diff) | |
Error out on invalid vector sizes (#7076)
* Error out on invalid vector sizes
* Remove unnecessary include
* Fix incorrect assert
* Add test
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 101 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.cpp | 120 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.h | 6 | ||||
| -rw-r--r-- | tests/diagnostics/invalid-vector-element-count.slang | 15 |
5 files changed, 148 insertions, 100 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index ac4008b7f..dbd4da588 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1987,6 +1987,12 @@ DIAGNOSTIC( vectorWithDisallowedElementTypeEncountered, "vector with disallowed element type '$0' encountered") +DIAGNOSTIC( + 38203, + Error, + vectorWithInvalidElementCountEncountered, + "vector has invalid element count '$0', valid values are between '$1' and '$2' inclusive") + // 39xxx - Type layout and parameter binding. DIAGNOSTIC( diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 9410a8f45..02b5a44ed 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -614,105 +614,6 @@ static void unexportNonEmbeddableIR(CodeGenTarget target, IRModule* irModule) } } -static void validateVectorOrMatrixElementType( - DiagnosticSink* sink, - SourceLoc sourceLoc, - IRType* elementType, - uint32_t allowedWidths, - const DiagnosticInfo& disallowedElementTypeEncountered) -{ - if (!isFloatingType(elementType)) - { - if (isIntegralType(elementType)) - { - IntInfo info = getIntTypeInfo(elementType); - if (allowedWidths == 0U) - { - sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); - } - else - { - bool widthAllowed = false; - SLANG_ASSERT((allowedWidths & ~(0xfU << 3)) == 0U); - for (uint32_t p = 3U; p <= 6U; p++) - { - uint32_t width = 1U << p; - if (!(allowedWidths & width)) - continue; - widthAllowed = widthAllowed || (info.width == width); - } - if (!widthAllowed) - { - sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); - } - } - } - else if (!as<IRBoolType>(elementType)) - { - sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); - } - } -} - -static void validateVectorsAndMatrices( - DiagnosticSink* sink, - IRModule* module, - TargetRequest* targetRequest) -{ - for (auto globalInst : module->getGlobalInsts()) - { - if (auto matrixType = as<IRMatrixType>(globalInst)) - { - // Matrices with row/col dimension 1 are only well-supported on D3D targets - if (!isD3DTarget(targetRequest)) - { - // Verify that neither row nor col count is 1 - auto colCount = as<IRIntLit>(matrixType->getColumnCount()); - auto rowCount = as<IRIntLit>(matrixType->getRowCount()); - - if ((rowCount && (rowCount->getValue() == 1)) || - (colCount && (colCount->getValue() == 1))) - { - sink->diagnose(matrixType->sourceLoc, Diagnostics::matrixColumnOrRowCountIsOne); - } - } - - // Verify that the element type is a floating point type, or an allowed integral type - auto elementType = matrixType->getElementType(); - uint32_t allowedWidths = 0U; - if (isCPUTarget(targetRequest)) - allowedWidths = 8U | 16U | 32U | 64U; - else if (isCUDATarget(targetRequest)) - allowedWidths = 32U | 64U; - else if (isD3DTarget(targetRequest)) - allowedWidths = 16U | 32U; - validateVectorOrMatrixElementType( - sink, - matrixType->sourceLoc, - elementType, - allowedWidths, - Diagnostics::matrixWithDisallowedElementTypeEncountered); - } - else if (auto vectorType = as<IRVectorType>(globalInst)) - { - // Verify that the element type is a floating point type, or an allowed integral type - auto elementType = vectorType->getElementType(); - uint32_t allowedWidths = 0U; - if (isWGPUTarget(targetRequest)) - allowedWidths = 32U; - else - allowedWidths = 8U | 16U | 32U | 64U; - - validateVectorOrMatrixElementType( - sink, - vectorType->sourceLoc, - elementType, - allowedWidths, - Diagnostics::vectorWithDisallowedElementTypeEncountered); - } - } -} - Result linkAndOptimizeIR( CodeGenContext* codeGenContext, LinkingAndOptimizationOptions const& options, @@ -1730,7 +1631,7 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); // Validate vectors and matrices according to what the target allows - validateVectorsAndMatrices(sink, irModule, targetRequest); + validateVectorsAndMatrices(irModule, sink, targetRequest); // The resource-based specialization pass above // may create specialized versions of functions, but diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 19de64618..565ae97d8 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -511,4 +511,124 @@ void validateAtomicOperations(bool skipFuncParamValidation, DiagnosticSink* sink } } +static void validateVectorOrMatrixElementType( + DiagnosticSink* sink, + SourceLoc sourceLoc, + IRType* elementType, + uint32_t allowedWidths, + const DiagnosticInfo& disallowedElementTypeEncountered) +{ + if (!isFloatingType(elementType)) + { + if (isIntegralType(elementType)) + { + IntInfo info = getIntTypeInfo(elementType); + if (allowedWidths == 0U) + { + sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); + } + else + { + bool widthAllowed = false; + SLANG_ASSERT((allowedWidths & ~(0xfU << 3)) == 0U); + for (uint32_t p = 3U; p <= 6U; p++) + { + uint32_t width = 1U << p; + if (!(allowedWidths & width)) + continue; + widthAllowed = widthAllowed || (info.width == width); + } + if (!widthAllowed) + { + sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); + } + } + } + else if (!as<IRBoolType>(elementType)) + { + sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); + } + } +} + +static void validateVectorElementCount(DiagnosticSink* sink, IRVectorType* vectorType) +{ + const auto elementCount = as<IRIntLit>(vectorType->getElementCount())->getValue(); + + // 1-vectors are supported and are legalized/transformed properly when targetting unsupported + // backends. + const IRIntegerValue minCount = 1; + const IRIntegerValue maxCount = 4; + if ((elementCount < minCount) || (elementCount > maxCount)) + { + sink->diagnose( + vectorType->sourceLoc, + Diagnostics::vectorWithInvalidElementCountEncountered, + elementCount, + "1", + maxCount); + } +} + +void validateVectorsAndMatrices( + IRModule* module, + DiagnosticSink* sink, + TargetRequest* targetRequest) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (auto matrixType = as<IRMatrixType>(globalInst)) + { + // Matrices with row/col dimension 1 are only well-supported on D3D targets + if (!isD3DTarget(targetRequest)) + { + // Verify that neither row nor col count is 1 + auto colCount = as<IRIntLit>(matrixType->getColumnCount()); + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + + if ((rowCount && (rowCount->getValue() == 1)) || + (colCount && (colCount->getValue() == 1))) + { + sink->diagnose(matrixType->sourceLoc, Diagnostics::matrixColumnOrRowCountIsOne); + } + } + + // Verify that the element type is a floating point type, or an allowed integral type + auto elementType = matrixType->getElementType(); + uint32_t allowedWidths = 0U; + if (isCPUTarget(targetRequest)) + allowedWidths = 8U | 16U | 32U | 64U; + else if (isCUDATarget(targetRequest)) + allowedWidths = 32U | 64U; + else if (isD3DTarget(targetRequest)) + allowedWidths = 16U | 32U; + validateVectorOrMatrixElementType( + sink, + matrixType->sourceLoc, + elementType, + allowedWidths, + Diagnostics::matrixWithDisallowedElementTypeEncountered); + } + else if (auto vectorType = as<IRVectorType>(globalInst)) + { + // Verify that the element type is a floating point type, or an allowed integral type + auto elementType = vectorType->getElementType(); + uint32_t allowedWidths = 0U; + if (isWGPUTarget(targetRequest)) + allowedWidths = 32U; + else + allowedWidths = 8U | 16U | 32U | 64U; + + validateVectorOrMatrixElementType( + sink, + vectorType->sourceLoc, + elementType, + allowedWidths, + Diagnostics::vectorWithDisallowedElementTypeEncountered); + + validateVectorElementCount(sink, vectorType); + } + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h index dc7a52cee..722359452 100644 --- a/source/slang/slang-ir-validate.h +++ b/source/slang/slang-ir-validate.h @@ -6,6 +6,7 @@ namespace Slang struct CodeGenContext; class CompileRequestBase; class DiagnosticSink; +class TargetRequest; struct IRModule; struct IRInst; @@ -49,4 +50,9 @@ void enableIRValidationAtInsert(); // lead back to in/inout parameters that we can't validate. void validateAtomicOperations(bool skipFuncParamValidation, DiagnosticSink* sink, IRInst* inst); +void validateVectorsAndMatrices( + IRModule* module, + DiagnosticSink* sink, + TargetRequest* targetRequest); + } // namespace Slang diff --git a/tests/diagnostics/invalid-vector-element-count.slang b/tests/diagnostics/invalid-vector-element-count.slang new file mode 100644 index 000000000..b801dfd07 --- /dev/null +++ b/tests/diagnostics/invalid-vector-element-count.slang @@ -0,0 +1,15 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target spirv + +RWStructuredBuffer<vector<float, 8>> bufferIn1; +RWStructuredBuffer<vector<float, 0>> bufferIn2; +RWStructuredBuffer<float> resultOut; + +[shader("compute")] +[numthreads(32,1,1)] +void computeMain(uint3 threadId : SV_DispatchThreadID) +{ + // CHECK: error 38203: vector has invalid element count '0', valid values are between '1' and '4' inclusive + // CHECK: error 38203: vector has invalid element count '8', valid values are between '1' and '4' inclusive + uint index = threadId.x; + resultOut[index] = bufferIn1[index][0] + bufferIn2[index][0]; +} |
