From 82f308ca692878bfe9844b86629c6536b4cd0f0a Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:10:33 -0700 Subject: Fix the invalid spirv generation for matrix cast (#4588) Spirv doesn't have instruction to do the float cast for the matrix type. So we have to convert the matrix row by row, and then construct them to a new matrix. Update the unit test to make sure the cast won't miss any elements. Co-authored-by: Yong He --- source/slang/slang-emit-spirv.cpp | 33 +++++++++++++++++++++++++++++++++ tests/bugs/gh-4556.slang | 15 ++++++++++++--- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 8a0b9f2ed..8c7232963 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -5190,6 +5190,32 @@ struct SPIRVEmitContext SLANG_UNREACHABLE(__func__); } + SpvInst* emitFloatCastForMatrix(SpvInstParent* parent, IRFloatCast* inst, IRMatrixType* fromTypeM, IRMatrixType* toTypeM) + { + // Because there is no spirv instruction to convert matrix to matrix, we need to convert it row by row. + auto rowCount = getIntVal(fromTypeM->getRowCount()); + auto colCount = getIntVal(fromTypeM->getColumnCount()); + + IRBuilder builder(m_irModule); + // Get from and to type of the row vector + auto fromTypeV = builder.getVectorType(fromTypeM->getElementType(), colCount); + auto toVectorV = builder.getVectorType(toTypeM->getElementType(), colCount); + + List rowVectorsConverted; + // convert each row vector to toType. + for (uint32_t i = 0; i < rowCount; i++) + { + auto rowVector = emitOpCompositeExtract(parent, nullptr, fromTypeV, + inst->getOperand(0), makeArray(SpvLiteralInteger::from32(i))); + + auto rowVectorConverted = emitOpFConvert(parent, nullptr, toVectorV, rowVector); + rowVectorsConverted.add(rowVectorConverted); + } + + // construct a matrix from the converted row vectors. + return emitCompositeConstruct(parent, inst, toTypeM, rowVectorsConverted); + } + SpvInst* emitFloatCast(SpvInstParent* parent, IRFloatCast* inst) { const auto fromTypeV = inst->getOperand(0)->getDataType(); @@ -5198,6 +5224,7 @@ struct SPIRVEmitContext IRType* fromType = nullptr; IRType* toType = nullptr; + bool isMatrixCast = false; if (as(fromTypeV) || as(toTypeV)) { fromType = getVectorElementType(fromTypeV); @@ -5207,6 +5234,7 @@ struct SPIRVEmitContext { fromType = getMatrixElementType(fromTypeV); toType = getMatrixElementType(toTypeV); + isMatrixCast = true; } else { @@ -5228,6 +5256,11 @@ struct SPIRVEmitContext SLANG_ASSERT(isFloatingType(toType)); SLANG_ASSERT(!isTypeEqual(fromType, toType)); + if (isMatrixCast) + { + return emitFloatCastForMatrix(parent, inst, as(fromTypeV), as(toTypeV)); + } + return emitOpFConvert(parent, inst, toTypeV, inst->getOperand(0)); } diff --git a/tests/bugs/gh-4556.slang b/tests/bugs/gh-4556.slang index e5d938840..1f779e199 100644 --- a/tests/bugs/gh-4556.slang +++ b/tests/bugs/gh-4556.slang @@ -1,12 +1,15 @@ //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -compute -output-using-type -shaderobj -//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compute -output-using-type -shaderobj -//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -glsl -compute -output-using-type -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compute -output-using-type -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -glsl -compute -output-using-type -shaderobj +//TEST(compute):SIMPLE(filecheck=SPIRV): -target spirv-asm -stage compute //DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-mtl -compute -output-using-type -shaderobj //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -output-using-type -shaderobj -//TEST_INPUT:ubuffer(data=[0.0 0.0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; +//SPIRV-NOT: Validation of generated SPIR-V failed + [shader("compute")] [numthreads(1, 1, 1)] void computeMain(uint3 id: SV_DispatchThreadID) @@ -18,4 +21,10 @@ void computeMain(uint3 id: SV_DispatchThreadID) outputBuffer[0] = (float)b[0][0]; // CHECK: 2.000000 outputBuffer[1] = (float)b[0][1]; + + // CHECK: 7.000000 + outputBuffer[2] = (float)b[1][2]; + + // CHECK: 12.000000 + outputBuffer[3] = (float)b[2][3]; } -- cgit v1.2.3