summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-12 10:32:35 -0700
committerGitHub <noreply@github.com>2024-03-12 10:32:35 -0700
commit1c4e1acdd48779b94c1008ba2456c63975e5fb7d (patch)
tree8e3b870efe4fa9cc8bd7633d5b7f2885fe0568ad
parentd8eb701170bf6050718750e6a5e72aa55fb5bd45 (diff)
[SPIRV] Use VectorTimesScalar opcode. (#3737)
* [SPIRV] Use VectorTimesScalar opcode. * Fix.
-rw-r--r--source/slang/slang-emit-spirv.cpp12
-rw-r--r--tests/spirv/vector-times-scalar.slang19
2 files changed, 31 insertions, 0 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 8063975e2..252f0e917 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -4981,6 +4981,18 @@ struct SPIRVEmitContext
const auto lVec = as<IRVectorType>(l->getDataType());
auto r = operands[1];
const auto rVec = as<IRVectorType>(r->getDataType());
+ if (op == kIROp_Mul && isFloatingPoint)
+ {
+ if (lVec && !rVec)
+ {
+ return emitInst(parent, instToRegister, SpvOpVectorTimesScalar, type, kResultID, operands);
+ }
+ else if (!lVec && rVec)
+ {
+ IRInst* newOperands[2] = { operands[1], operands[0] };
+ return emitInst(parent, instToRegister, SpvOpVectorTimesScalar, type, kResultID, ArrayView<IRInst*>(newOperands, 2));
+ }
+ }
const auto go = [&](const auto l, const auto r) {
return emitInst(parent, instToRegister, opCode, type, kResultID, l, r);
};
diff --git a/tests/spirv/vector-times-scalar.slang b/tests/spirv/vector-times-scalar.slang
new file mode 100644
index 000000000..fb997d490
--- /dev/null
+++ b/tests/spirv/vector-times-scalar.slang
@@ -0,0 +1,19 @@
+
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -output-using-type
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry computeMain -emit-spirv-directly
+
+//TEST_INPUT:set output = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float> output;
+
+// Test that we are able to use the VectorTimesScalar opcode to simplify the resulting spirv.
+
+// CHECK: OpVectorTimesScalar
+
+[numthreads(1,1,1)]
+void computeMain(int3 tid : SV_DispatchThreadID)
+{
+ float3 v = tid + 2.0;
+ float3 v1 = v * 0.5;
+ // BUFFER: 1.0
+ output[0] = v1.x;
+} \ No newline at end of file