diff options
| author | Yong He <yonghe@outlook.com> | 2024-03-12 10:32:35 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-12 10:32:35 -0700 |
| commit | 1c4e1acdd48779b94c1008ba2456c63975e5fb7d (patch) | |
| tree | 8e3b870efe4fa9cc8bd7633d5b7f2885fe0568ad | |
| parent | d8eb701170bf6050718750e6a5e72aa55fb5bd45 (diff) | |
[SPIRV] Use VectorTimesScalar opcode. (#3737)
* [SPIRV] Use VectorTimesScalar opcode.
* Fix.
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 12 | ||||
| -rw-r--r-- | tests/spirv/vector-times-scalar.slang | 19 |
2 files changed, 31 insertions, 0 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 8063975e2..252f0e917 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4981,6 +4981,18 @@ struct SPIRVEmitContext const auto lVec = as<IRVectorType>(l->getDataType()); auto r = operands[1]; const auto rVec = as<IRVectorType>(r->getDataType()); + if (op == kIROp_Mul && isFloatingPoint) + { + if (lVec && !rVec) + { + return emitInst(parent, instToRegister, SpvOpVectorTimesScalar, type, kResultID, operands); + } + else if (!lVec && rVec) + { + IRInst* newOperands[2] = { operands[1], operands[0] }; + return emitInst(parent, instToRegister, SpvOpVectorTimesScalar, type, kResultID, ArrayView<IRInst*>(newOperands, 2)); + } + } const auto go = [&](const auto l, const auto r) { return emitInst(parent, instToRegister, opCode, type, kResultID, l, r); }; diff --git a/tests/spirv/vector-times-scalar.slang b/tests/spirv/vector-times-scalar.slang new file mode 100644 index 000000000..fb997d490 --- /dev/null +++ b/tests/spirv/vector-times-scalar.slang @@ -0,0 +1,19 @@ + +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -output-using-type +//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry computeMain -emit-spirv-directly + +//TEST_INPUT:set output = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<float> output; + +// Test that we are able to use the VectorTimesScalar opcode to simplify the resulting spirv. + +// CHECK: OpVectorTimesScalar + +[numthreads(1,1,1)] +void computeMain(int3 tid : SV_DispatchThreadID) +{ + float3 v = tid + 2.0; + float3 v1 = v * 0.5; + // BUFFER: 1.0 + output[0] = v1.x; +}
\ No newline at end of file |
