diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-02-07 18:27:23 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-07 18:27:23 -0800 |
| commit | 57b09a8986668626c37055e431fa0ac6449d7214 (patch) | |
| tree | 8a96d7ea01a150d22ba47818655239b21d4ac25e | |
| parent | 79aebc18d54db3f0be8bd6529c0d79f4d8d4fc58 (diff) | |
Use and() and or() functions for logical-AND and OR (#6310)
* Use and() and or() functions for logical-AND and OR
With this commit, Slang will emit function calls to `and()` and `or()`
for the logical-AND and logical-OR when the operands are non-scalar and
the target profile is SM6.0 and above. This is required change from
SM6.0.
For WGSL, there is no operator overloadings of `&&` and `||` when the
operands are non-scalar. Unlike HLSL, WGSL also don't have `and()` nor
`or()`. Alternatively, we can use `select()`.
| -rw-r--r-- | lock | 0 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 8 | ||||
| -rw-r--r-- | source/slang/slang-emit-hlsl.cpp | 47 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-binary-operator.cpp | 97 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-binary-operator.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 14 | ||||
| -rw-r--r-- | tests/compute/logic-no-short-circuit-evaluation.slang | 66 | ||||
| -rw-r--r-- | tests/compute/logic-short-circuit-evaluation.slang | 15 | ||||
| -rw-r--r-- | tests/compute/logic-short-circuit-evaluation.slang.expected.txt | 16 |
12 files changed, 287 insertions, 23 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 2399df254..491f0ef4d 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6228,7 +6228,9 @@ bool all(vector<T,N> x) case hlsl: __intrinsic_asm "all"; case metal: - __intrinsic_asm "all"; + if (__isBool<T>()) + __intrinsic_asm "all"; + __intrinsic_asm "all(bool$N0($0))"; case glsl: __intrinsic_asm "all(bvec$N0($0))"; case spirv: @@ -6256,7 +6258,9 @@ bool all(vector<T,N> x) }; } case wgsl: - __intrinsic_asm "all"; + if (__isBool<T>()) + __intrinsic_asm "all"; + __intrinsic_asm "all(vec$N0<bool>($0))"; default: bool result = true; for(int i = 0; i < N; ++i) diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 9ebec0893..ff4514d69 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -821,6 +821,53 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } break; } + case kIROp_And: + case kIROp_Or: + { + // SM6.0 requires to use `and()` and `or()` functions for the logical-AND and + // logical-OR, respectively, with non-scalar operands. + auto targetProfile = getTargetProgram()->getOptionSet().getProfile(); + if (targetProfile.getVersion() < ProfileVersion::DX_6_0) + return false; + + if (as<IRBasicType>(inst->getDataType())) + return false; + + if (inst->getOp() == kIROp_And) + { + m_writer->emit("and("); + } + else + { + m_writer->emit("or("); + } + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_Select: + { + // SM6.0 requires to use `select()` instead of the ternary operator "?:" when the + // operands are non-scalar. + auto targetProfile = getTargetProgram()->getOptionSet().getProfile(); + if (targetProfile.getVersion() < ProfileVersion::DX_6_0) + return false; + + if (as<IRBasicType>(inst->getDataType())) + return false; + + m_writer->emit("select("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_BitCast: { // For simplicity, we will handle all bit-cast operations diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index aea766f9f..13c79e9ac 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1312,6 +1312,40 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } break; + case kIROp_And: + case kIROp_Or: + { + // WGSL doesn't have operator overloadings for `&&` and `||` when the operands are + // non-scalar. Unlike HLSL, WGSL doesn't have `and()` and `or()`. + auto vecType = as<IRVectorType>(inst->getDataType()); + if (!vecType) + return false; + + // The function signature for `select` in WGSL is different from others: + // @const @must_use fn select(f: T, t: T, cond: bool) -> T + if (inst->getOp() == kIROp_And) + { + m_writer->emit("select(vec"); + m_writer->emit(getIntVal(vecType->getElementCount())); + m_writer->emit("<bool>(false), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + } + else + { + m_writer->emit("select("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", vec"); + m_writer->emit(getIntVal(vecType->getElementCount())); + m_writer->emit("<bool>(true), "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + } + return true; + } + case kIROp_BitCast: { // In WGSL there is a built-in bitcast function! diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index f093104bd..e20a4a90f 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -52,6 +52,7 @@ #include "slang-ir-insts.h" #include "slang-ir-layout.h" #include "slang-ir-legalize-array-return-type.h" +#include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-image-subscript.h" #include "slang-ir-legalize-mesh-outputs.h" @@ -1469,6 +1470,10 @@ Result linkAndOptimizeIR( floatNonUniformResourceIndex(irModule, NonUniformResourceIndexFloatMode::Textual); } + if (isD3DTarget(targetRequest) || isKhronosTarget(targetRequest) || + isWGPUTarget(targetRequest) || isMetalTarget(targetRequest)) + legalizeLogicalAndOr(irModule->getModuleInst()); + // Legalize non struct parameters that are expected to be structs for HLSL. if (isD3DTarget(targetRequest)) legalizeNonStructParameterToStructForHLSL(irModule); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 3249f34b0..2bc5bca1e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4520,6 +4520,9 @@ public: IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1); IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1); + IRInst* emitAnd(IRType* type, IRInst* left, IRInst* right); + IRInst* emitOr(IRType* type, IRInst* left, IRInst* right); + IRSPIRVAsmOperand* emitSPIRVAsmOperandLiteral(IRInst* literal); IRSPIRVAsmOperand* emitSPIRVAsmOperandInst(IRInst* inst); IRSPIRVAsmOperand* createSPIRVAsmOperandInst(IRInst* inst); diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp index a1affb7e9..1595aa130 100644 --- a/source/slang/slang-ir-legalize-binary-operator.cpp +++ b/source/slang/slang-ir-legalize-binary-operator.cpp @@ -118,4 +118,101 @@ void legalizeBinaryOp(IRInst* inst) } } +void legalizeLogicalAndOr(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_And: + case kIROp_Or: + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Logical-AND and logical-OR takes boolean types as its operands. + // If they are not, legalize them by casting to boolean type. + // + SLANG_ASSERT(inst->getOperandCount() == 2); + for (UInt i = 0; i < 2; i++) + { + auto operand = inst->getOperand(i); + auto operandDataType = operand->getDataType(); + + if (auto vecType = as<IRVectorType>(operandDataType)) + { + if (!as<IRBoolType>(vecType->getElementType())) + { + // Cast operand to vector<bool,N> + auto elemCount = vecType->getElementCount(); + auto vb = builder.getVectorType(builder.getBoolType(), elemCount); + auto v = builder.emitCast(vb, operand); + builder.replaceOperand(inst->getOperands() + i, v); + } + } + else if (!as<IRBoolType>(operandDataType)) + { + // Cast operand to bool + auto s = builder.emitCast(builder.getBoolType(), operand); + builder.replaceOperand(inst->getOperands() + i, s); + } + } + + // Legalize the return type; mostly for SPIRV. + // The return type of OpLogicalOr must be boolean type. + // If not, we need to recreate the instruction with boolean return type. + // Then, we have to cast it back to the original type so that other instrucitons that + // use have the matching types. + // + auto dataType = inst->getDataType(); + auto lhs = inst->getOperand(0); + auto rhs = inst->getOperand(1); + IRInst* newInst = nullptr; + + if (auto vecType = as<IRVectorType>(dataType)) + { + if (!as<IRBoolType>(vecType->getElementType())) + { + // Return type should be vector<bool,N> + auto elemCount = vecType->getElementCount(); + auto vb = builder.getVectorType(builder.getBoolType(), elemCount); + + if (inst->getOp() == kIROp_And) + { + newInst = builder.emitAnd(vb, lhs, rhs); + } + else + { + newInst = builder.emitOr(vb, lhs, rhs); + } + newInst = builder.emitCast(dataType, newInst); + } + } + else if (!as<IRBoolType>(dataType)) + { + // Return type should be bool + if (inst->getOp() == kIROp_And) + { + newInst = builder.emitAnd(builder.getBoolType(), lhs, rhs); + } + else + { + newInst = builder.emitOr(builder.getBoolType(), lhs, rhs); + } + newInst = builder.emitCast(dataType, newInst); + } + + if (newInst && inst != newInst) + { + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + } + break; + } + + for (auto child : inst->getModifiableChildren()) + { + legalizeLogicalAndOr(child); + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-legalize-binary-operator.h b/source/slang/slang-ir-legalize-binary-operator.h index 71c319718..f9ebf90d8 100644 --- a/source/slang/slang-ir-legalize-binary-operator.h +++ b/source/slang/slang-ir-legalize-binary-operator.h @@ -13,4 +13,9 @@ struct IRInst; // signed operand is converted to unsigned. void legalizeBinaryOp(IRInst* inst); +// The logical binary operators such as AND and OR takes boolean types are its input. +// If they are in integer type, as an example, we need to explicitly cast to bool type. +// Also the return type from the logical operators should be a boolean type. +void legalizeLogicalAndOr(IRInst* inst); + } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6cf7a1786..07e8b2742 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6020,6 +6020,20 @@ IRInst* IRBuilder::emitShl(IRType* type, IRInst* left, IRInst* right) return inst; } +IRInst* IRBuilder::emitAnd(IRType* type, IRInst* left, IRInst* right) +{ + auto inst = createInst<IRInst>(this, kIROp_And, type, left, right); + addInst(inst); + return inst; +} + +IRInst* IRBuilder::emitOr(IRType* type, IRInst* left, IRInst* right) +{ + auto inst = createInst<IRInst>(this, kIROp_Or, type, left, right); + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitGetNativePtr(IRInst* value) { auto valueType = value->getDataType(); diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang new file mode 100644 index 000000000..74351a505 --- /dev/null +++ b/tests/compute/logic-no-short-circuit-evaluation.slang @@ -0,0 +1,66 @@ +//TEST(compute):SIMPLE(filecheck=SM5):-target hlsl -profile cs_5_1 -entry computeMain +//TEST(compute):SIMPLE(filecheck=SM6):-target hlsl -profile cs_6_0 -entry computeMain +//TEST(compute):SIMPLE(filecheck=WGS):-target wgsl -stage compute -entry computeMain +//TEST(compute):SIMPLE(filecheck=MTL):-target metal -stage compute -entry computeMain +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -output-using-type -xslang -Wno-30056 + +// Testnig logical-AND, logical-OR and ternary operator with non-scalar operands + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +static int result = 0; + +bool2 assignFunc(int index) +{ + result += 10; + return bool2(true); +} + +[numthreads(4, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + int index = dispatchThreadID.x; + + // No short-circuiting for vector types + + //SM5:(all({{.*}}&& + //SM6:(all(and( + //WGS:(all(select(vec2<bool>(false), + //MTL:(all({{.*}}&& + if (all(bool2(index >= 1) && assignFunc(index))) + { + result++; + } + + // Intentionally using non-boolean type for testing. + + //SM5:(all({{.*}}|| + //SM6:(or(vector<bool,2>( + //WGS:(select({{.*}}, vec2<bool>(true), vec2<bool>( + //MTL:(all(bool2({{.*}}|| + if (all(int2(index >= 2) || !assignFunc(index))) + { + result++; + } + + //SM5:(all({{.*}}?{{.*}}: + //SM6:(all(select( + //WGS:(all(select(vec2<bool>(false), + //MTL:(all(select(bool2(false) + if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false))) + { + result++; + } + + outputBuffer[index] = result; + + //CHK:30 + //CHK-NEXT:31 + //CHK-NEXT:32 + //CHK-NEXT:33 +} diff --git a/tests/compute/logic-short-circuit-evaluation.slang b/tests/compute/logic-short-circuit-evaluation.slang index 585a04770..eed30898f 100644 --- a/tests/compute/logic-short-circuit-evaluation.slang +++ b/tests/compute/logic-short-circuit-evaluation.slang @@ -1,8 +1,9 @@ -//TEST(compute):COMPARE_COMPUTE:-dx12 -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE:-vk -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -compile-arg -O3 -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-dx12 -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -compile-arg -O3 -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -shaderobj // Test doing vector comparisons @@ -25,4 +26,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) // Only the last 4 elements will be 1. (index < 12) || assignFunc(index); + + //CHK-COUNT-4: 1 + //CHK-COUNT-8: 0 + //CHK-COUNT-4: 1 } diff --git a/tests/compute/logic-short-circuit-evaluation.slang.expected.txt b/tests/compute/logic-short-circuit-evaluation.slang.expected.txt deleted file mode 100644 index 945f08f2c..000000000 --- a/tests/compute/logic-short-circuit-evaluation.slang.expected.txt +++ /dev/null @@ -1,16 +0,0 @@ -1 -1 -1 -1 -0 -0 -0 -0 -0 -0 -0 -0 -1 -1 -1 -1 |
