summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit-spirv.cpp33
-rw-r--r--tests/bugs/gh-4556.slang15
2 files changed, 45 insertions, 3 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 8a0b9f2ed..8c7232963 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -5190,6 +5190,32 @@ struct SPIRVEmitContext
SLANG_UNREACHABLE(__func__);
}
+ SpvInst* emitFloatCastForMatrix(SpvInstParent* parent, IRFloatCast* inst, IRMatrixType* fromTypeM, IRMatrixType* toTypeM)
+ {
+ // Because there is no spirv instruction to convert matrix to matrix, we need to convert it row by row.
+ auto rowCount = getIntVal(fromTypeM->getRowCount());
+ auto colCount = getIntVal(fromTypeM->getColumnCount());
+
+ IRBuilder builder(m_irModule);
+ // Get from and to type of the row vector
+ auto fromTypeV = builder.getVectorType(fromTypeM->getElementType(), colCount);
+ auto toVectorV = builder.getVectorType(toTypeM->getElementType(), colCount);
+
+ List<SpvInst*> rowVectorsConverted;
+ // convert each row vector to toType.
+ for (uint32_t i = 0; i < rowCount; i++)
+ {
+ auto rowVector = emitOpCompositeExtract(parent, nullptr, fromTypeV,
+ inst->getOperand(0), makeArray(SpvLiteralInteger::from32(i)));
+
+ auto rowVectorConverted = emitOpFConvert(parent, nullptr, toVectorV, rowVector);
+ rowVectorsConverted.add(rowVectorConverted);
+ }
+
+ // construct a matrix from the converted row vectors.
+ return emitCompositeConstruct(parent, inst, toTypeM, rowVectorsConverted);
+ }
+
SpvInst* emitFloatCast(SpvInstParent* parent, IRFloatCast* inst)
{
const auto fromTypeV = inst->getOperand(0)->getDataType();
@@ -5198,6 +5224,7 @@ struct SPIRVEmitContext
IRType* fromType = nullptr;
IRType* toType = nullptr;
+ bool isMatrixCast = false;
if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV))
{
fromType = getVectorElementType(fromTypeV);
@@ -5207,6 +5234,7 @@ struct SPIRVEmitContext
{
fromType = getMatrixElementType(fromTypeV);
toType = getMatrixElementType(toTypeV);
+ isMatrixCast = true;
}
else
{
@@ -5228,6 +5256,11 @@ struct SPIRVEmitContext
SLANG_ASSERT(isFloatingType(toType));
SLANG_ASSERT(!isTypeEqual(fromType, toType));
+ if (isMatrixCast)
+ {
+ return emitFloatCastForMatrix(parent, inst, as<IRMatrixType>(fromTypeV), as<IRMatrixType>(toTypeV));
+ }
+
return emitOpFConvert(parent, inst, toTypeV, inst->getOperand(0));
}
diff --git a/tests/bugs/gh-4556.slang b/tests/bugs/gh-4556.slang
index e5d938840..1f779e199 100644
--- a/tests/bugs/gh-4556.slang
+++ b/tests/bugs/gh-4556.slang
@@ -1,12 +1,15 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -compute -output-using-type -shaderobj
-//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compute -output-using-type -shaderobj
-//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -glsl -compute -output-using-type -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compute -output-using-type -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -glsl -compute -output-using-type -shaderobj
+//TEST(compute):SIMPLE(filecheck=SPIRV): -target spirv-asm -stage compute
//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-mtl -compute -output-using-type -shaderobj
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -output-using-type -shaderobj
-//TEST_INPUT:ubuffer(data=[0.0 0.0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
+//SPIRV-NOT: Validation of generated SPIR-V failed
+
[shader("compute")]
[numthreads(1, 1, 1)]
void computeMain(uint3 id: SV_DispatchThreadID)
@@ -18,4 +21,10 @@ void computeMain(uint3 id: SV_DispatchThreadID)
outputBuffer[0] = (float)b[0][0];
// CHECK: 2.000000
outputBuffer[1] = (float)b[0][1];
+
+ // CHECK: 7.000000
+ outputBuffer[2] = (float)b[1][2];
+
+ // CHECK: 12.000000
+ outputBuffer[3] = (float)b[2][3];
}