summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-05-14 03:01:47 -0400
committerGitHub <noreply@github.com>2025-05-14 07:01:47 +0000
commiteb5648b41d0718648477cbcf941fb3c6edf6dfc7 (patch)
tree73f19cd4b4f66d0ce6fbbd61c9a1253969cc3139
parent04ba87e23435e76583c05d4530d63686f9af712f (diff)
Error out on invalid vector sizes (#7076)
* Error out on invalid vector sizes * Remove unnecessary include * Fix incorrect assert * Add test
-rw-r--r--source/slang/slang-diagnostic-defs.h6
-rw-r--r--source/slang/slang-emit.cpp101
-rw-r--r--source/slang/slang-ir-validate.cpp120
-rw-r--r--source/slang/slang-ir-validate.h6
-rw-r--r--tests/diagnostics/invalid-vector-element-count.slang15
5 files changed, 148 insertions, 100 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index ac4008b7f..dbd4da588 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -1987,6 +1987,12 @@ DIAGNOSTIC(
vectorWithDisallowedElementTypeEncountered,
"vector with disallowed element type '$0' encountered")
+DIAGNOSTIC(
+ 38203,
+ Error,
+ vectorWithInvalidElementCountEncountered,
+ "vector has invalid element count '$0', valid values are between '$1' and '$2' inclusive")
+
// 39xxx - Type layout and parameter binding.
DIAGNOSTIC(
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 9410a8f45..02b5a44ed 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -614,105 +614,6 @@ static void unexportNonEmbeddableIR(CodeGenTarget target, IRModule* irModule)
}
}
-static void validateVectorOrMatrixElementType(
- DiagnosticSink* sink,
- SourceLoc sourceLoc,
- IRType* elementType,
- uint32_t allowedWidths,
- const DiagnosticInfo& disallowedElementTypeEncountered)
-{
- if (!isFloatingType(elementType))
- {
- if (isIntegralType(elementType))
- {
- IntInfo info = getIntTypeInfo(elementType);
- if (allowedWidths == 0U)
- {
- sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType);
- }
- else
- {
- bool widthAllowed = false;
- SLANG_ASSERT((allowedWidths & ~(0xfU << 3)) == 0U);
- for (uint32_t p = 3U; p <= 6U; p++)
- {
- uint32_t width = 1U << p;
- if (!(allowedWidths & width))
- continue;
- widthAllowed = widthAllowed || (info.width == width);
- }
- if (!widthAllowed)
- {
- sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType);
- }
- }
- }
- else if (!as<IRBoolType>(elementType))
- {
- sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType);
- }
- }
-}
-
-static void validateVectorsAndMatrices(
- DiagnosticSink* sink,
- IRModule* module,
- TargetRequest* targetRequest)
-{
- for (auto globalInst : module->getGlobalInsts())
- {
- if (auto matrixType = as<IRMatrixType>(globalInst))
- {
- // Matrices with row/col dimension 1 are only well-supported on D3D targets
- if (!isD3DTarget(targetRequest))
- {
- // Verify that neither row nor col count is 1
- auto colCount = as<IRIntLit>(matrixType->getColumnCount());
- auto rowCount = as<IRIntLit>(matrixType->getRowCount());
-
- if ((rowCount && (rowCount->getValue() == 1)) ||
- (colCount && (colCount->getValue() == 1)))
- {
- sink->diagnose(matrixType->sourceLoc, Diagnostics::matrixColumnOrRowCountIsOne);
- }
- }
-
- // Verify that the element type is a floating point type, or an allowed integral type
- auto elementType = matrixType->getElementType();
- uint32_t allowedWidths = 0U;
- if (isCPUTarget(targetRequest))
- allowedWidths = 8U | 16U | 32U | 64U;
- else if (isCUDATarget(targetRequest))
- allowedWidths = 32U | 64U;
- else if (isD3DTarget(targetRequest))
- allowedWidths = 16U | 32U;
- validateVectorOrMatrixElementType(
- sink,
- matrixType->sourceLoc,
- elementType,
- allowedWidths,
- Diagnostics::matrixWithDisallowedElementTypeEncountered);
- }
- else if (auto vectorType = as<IRVectorType>(globalInst))
- {
- // Verify that the element type is a floating point type, or an allowed integral type
- auto elementType = vectorType->getElementType();
- uint32_t allowedWidths = 0U;
- if (isWGPUTarget(targetRequest))
- allowedWidths = 32U;
- else
- allowedWidths = 8U | 16U | 32U | 64U;
-
- validateVectorOrMatrixElementType(
- sink,
- vectorType->sourceLoc,
- elementType,
- allowedWidths,
- Diagnostics::vectorWithDisallowedElementTypeEncountered);
- }
- }
-}
-
Result linkAndOptimizeIR(
CodeGenContext* codeGenContext,
LinkingAndOptimizationOptions const& options,
@@ -1730,7 +1631,7 @@ Result linkAndOptimizeIR(
validateIRModuleIfEnabled(codeGenContext, irModule);
// Validate vectors and matrices according to what the target allows
- validateVectorsAndMatrices(sink, irModule, targetRequest);
+ validateVectorsAndMatrices(irModule, sink, targetRequest);
// The resource-based specialization pass above
// may create specialized versions of functions, but
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index 19de64618..565ae97d8 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -511,4 +511,124 @@ void validateAtomicOperations(bool skipFuncParamValidation, DiagnosticSink* sink
}
}
+static void validateVectorOrMatrixElementType(
+ DiagnosticSink* sink,
+ SourceLoc sourceLoc,
+ IRType* elementType,
+ uint32_t allowedWidths,
+ const DiagnosticInfo& disallowedElementTypeEncountered)
+{
+ if (!isFloatingType(elementType))
+ {
+ if (isIntegralType(elementType))
+ {
+ IntInfo info = getIntTypeInfo(elementType);
+ if (allowedWidths == 0U)
+ {
+ sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType);
+ }
+ else
+ {
+ bool widthAllowed = false;
+ SLANG_ASSERT((allowedWidths & ~(0xfU << 3)) == 0U);
+ for (uint32_t p = 3U; p <= 6U; p++)
+ {
+ uint32_t width = 1U << p;
+ if (!(allowedWidths & width))
+ continue;
+ widthAllowed = widthAllowed || (info.width == width);
+ }
+ if (!widthAllowed)
+ {
+ sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType);
+ }
+ }
+ }
+ else if (!as<IRBoolType>(elementType))
+ {
+ sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType);
+ }
+ }
+}
+
+static void validateVectorElementCount(DiagnosticSink* sink, IRVectorType* vectorType)
+{
+ const auto elementCount = as<IRIntLit>(vectorType->getElementCount())->getValue();
+
+ // 1-vectors are supported and are legalized/transformed properly when targetting unsupported
+ // backends.
+ const IRIntegerValue minCount = 1;
+ const IRIntegerValue maxCount = 4;
+ if ((elementCount < minCount) || (elementCount > maxCount))
+ {
+ sink->diagnose(
+ vectorType->sourceLoc,
+ Diagnostics::vectorWithInvalidElementCountEncountered,
+ elementCount,
+ "1",
+ maxCount);
+ }
+}
+
+void validateVectorsAndMatrices(
+ IRModule* module,
+ DiagnosticSink* sink,
+ TargetRequest* targetRequest)
+{
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto matrixType = as<IRMatrixType>(globalInst))
+ {
+ // Matrices with row/col dimension 1 are only well-supported on D3D targets
+ if (!isD3DTarget(targetRequest))
+ {
+ // Verify that neither row nor col count is 1
+ auto colCount = as<IRIntLit>(matrixType->getColumnCount());
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+
+ if ((rowCount && (rowCount->getValue() == 1)) ||
+ (colCount && (colCount->getValue() == 1)))
+ {
+ sink->diagnose(matrixType->sourceLoc, Diagnostics::matrixColumnOrRowCountIsOne);
+ }
+ }
+
+ // Verify that the element type is a floating point type, or an allowed integral type
+ auto elementType = matrixType->getElementType();
+ uint32_t allowedWidths = 0U;
+ if (isCPUTarget(targetRequest))
+ allowedWidths = 8U | 16U | 32U | 64U;
+ else if (isCUDATarget(targetRequest))
+ allowedWidths = 32U | 64U;
+ else if (isD3DTarget(targetRequest))
+ allowedWidths = 16U | 32U;
+ validateVectorOrMatrixElementType(
+ sink,
+ matrixType->sourceLoc,
+ elementType,
+ allowedWidths,
+ Diagnostics::matrixWithDisallowedElementTypeEncountered);
+ }
+ else if (auto vectorType = as<IRVectorType>(globalInst))
+ {
+ // Verify that the element type is a floating point type, or an allowed integral type
+ auto elementType = vectorType->getElementType();
+ uint32_t allowedWidths = 0U;
+ if (isWGPUTarget(targetRequest))
+ allowedWidths = 32U;
+ else
+ allowedWidths = 8U | 16U | 32U | 64U;
+
+ validateVectorOrMatrixElementType(
+ sink,
+ vectorType->sourceLoc,
+ elementType,
+ allowedWidths,
+ Diagnostics::vectorWithDisallowedElementTypeEncountered);
+
+ validateVectorElementCount(sink, vectorType);
+ }
+ }
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h
index dc7a52cee..722359452 100644
--- a/source/slang/slang-ir-validate.h
+++ b/source/slang/slang-ir-validate.h
@@ -6,6 +6,7 @@ namespace Slang
struct CodeGenContext;
class CompileRequestBase;
class DiagnosticSink;
+class TargetRequest;
struct IRModule;
struct IRInst;
@@ -49,4 +50,9 @@ void enableIRValidationAtInsert();
// lead back to in/inout parameters that we can't validate.
void validateAtomicOperations(bool skipFuncParamValidation, DiagnosticSink* sink, IRInst* inst);
+void validateVectorsAndMatrices(
+ IRModule* module,
+ DiagnosticSink* sink,
+ TargetRequest* targetRequest);
+
} // namespace Slang
diff --git a/tests/diagnostics/invalid-vector-element-count.slang b/tests/diagnostics/invalid-vector-element-count.slang
new file mode 100644
index 000000000..b801dfd07
--- /dev/null
+++ b/tests/diagnostics/invalid-vector-element-count.slang
@@ -0,0 +1,15 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target spirv
+
+RWStructuredBuffer<vector<float, 8>> bufferIn1;
+RWStructuredBuffer<vector<float, 0>> bufferIn2;
+RWStructuredBuffer<float> resultOut;
+
+[shader("compute")]
+[numthreads(32,1,1)]
+void computeMain(uint3 threadId : SV_DispatchThreadID)
+{
+ // CHECK: error 38203: vector has invalid element count '0', valid values are between '1' and '4' inclusive
+ // CHECK: error 38203: vector has invalid element count '8', valid values are between '1' and '4' inclusive
+ uint index = threadId.x;
+ resultOut[index] = bufferIn1[index][0] + bufferIn2[index][0];
+}