diff options
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-clone.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 62 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 2 | ||||
| -rw-r--r-- | tests/spirv/spec-constant-int-val-float-to-int-cast.slang | 15 | ||||
| -rw-r--r-- | tests/spirv/spec-constant-operations.slang | 84 |
9 files changed, 175 insertions, 41 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 205575a81..9472138c3 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1980,8 +1980,21 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef( decl->hasModifier<VkConstantIdAttribute>()) && kind == ConstantFoldingKind::SpecializationConstant) { + // Float-to-inst casts cannot be`OpSpecConstOp` operations in SPIR-V, + // which means they need to be local instructions can cannot be hoisted to the + // global scope. Deduplication logic is run for `IntVal`s however and without hoisting + // instructions using this `IntVal` will trigger error. Hence we emit error here + // to not allow such cases. + // + // Note that float-to-inst casts for non-`IntVal`s are allowed. + if (!isScalarIntegerType(decl->getType())) + { + getSink()->diagnose(declRef, Diagnostics::intValFromNonIntSpecConstEncountered); + return nullptr; + } + return m_astBuilder->getOrCreate<DeclRefIntVal>( - declRef.substitute(m_astBuilder, declRef.getDecl()->getType()), + declRef.substitute(m_astBuilder, decl->getType()), declRef); } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index fc7a4d5bb..4aadfd78d 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2058,6 +2058,12 @@ DIAGNOSTIC( nonUniformEntryPointParameterTreatedAsUniform, "parameter '$0' is treated as 'uniform' because it does not have a system-value semantic.") +DIAGNOSTIC( + 38041, + Error, + intValFromNonIntSpecConstEncountered, + "cannot cast non-integer specialization constant to compile-time integer") + DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC( diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 57ad1a988..0a3dab78a 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -784,18 +784,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex switch (irOpCode) { case kIROp_IntCast: - { - auto typeStyle = getTypeStyle(basicType->getBaseType()); - if (typeStyle == kIROp_FloatType) - { - return SpvOpConvertFToU; - } - else if (typeStyle == kIROp_IntType) - { - return SpvOpUConvert; - } - break; - } + return SpvOpUConvert; + case kIROp_FloatCast: + return SpvOpFConvert; default: break; } diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index 1a020ec26..74a972c1d 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -79,6 +79,12 @@ IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldIns // SLANG_ASSERT(!as<IRConstant>(oldInst)); + const auto canBeSpecConst = canOperationBeSpecConst( + oldInst->getOp(), + oldInst->getDataType(), + nullptr, + oldInst->getOperands()); + // Next we will iterate over the operands of `oldInst` // to find their replacements and install them as // the operands of `newInst`. @@ -94,7 +100,7 @@ IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldIns newOperands[ii] = newOperand; - if (isArithmeticInst(oldInst)) + if (canBeSpecConst) newType = maybeAddRateType(builder, newOperand->getFullType(), newType); } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index c8faec73b..13742711c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2276,8 +2276,14 @@ IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* o return oldType; } -bool isArithmeticInst(IROp op) +bool canOperationBeSpecConst(IROp op, IRType* resultType, IRInst* const* fixedArgs, IRUse* operands) { + // Returns true for ops that can be declared as an operation under `OpSpecConstantOp`. + // + // Integer arithmetic and comparison operations can be `OpSpecConstantOp` with the `Shader` + // capability, while floating-point arithmetic and comparison operations require the `Kernel` + // capability. We only support `Shader` capability for now, return false when floating-point + // arithmetic/comparison is encountered. switch (op) { case kIROp_Add: @@ -2285,51 +2291,61 @@ bool isArithmeticInst(IROp op) case kIROp_Mul: case kIROp_Div: case kIROp_Neg: - case kIROp_Not: + return !isFloatingType(resultType); + case kIROp_Eql: case kIROp_Neq: case kIROp_Leq: case kIROp_Geq: case kIROp_Less: - case kIROp_IRem: - case kIROp_FRem: case kIROp_Greater: + { + IRInst* operand1; + IRInst* operand2; + if (fixedArgs) + { + operand1 = fixedArgs[0]; + operand2 = fixedArgs[1]; + } + else + { + operand1 = operands[0].get(); + operand2 = operands[1].get(); + } + return !isFloatingType(operand1->getDataType()) && + !isFloatingType(operand2->getDataType()); + } + + case kIROp_Not: + case kIROp_IRem: case kIROp_Lsh: case kIROp_Rsh: case kIROp_BitAnd: case kIROp_BitOr: case kIROp_BitXor: case kIROp_BitNot: - case kIROp_BitCast: - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: case kIROp_IntCast: case kIROp_FloatCast: case kIROp_Select: return true; + default: return false; } } -bool isArithmeticInst(IRInst* inst) + +bool isSpecConstOpHoistable(IROp op, IRType* type, IRInst* const* fixedArgs) { - return isArithmeticInst(inst->getOp()); + auto rateType = as<IRRateQualifiedType>(type); + return rateType && as<IRSpecConstRate>(rateType->getRate()) && + canOperationBeSpecConst(op, rateType->getValueType(), fixedArgs, nullptr); } -bool isInstHoistable(IROp op, IRType* type) -{ - if ((getIROpInfo(op).flags & kIROpFlag_Hoistable)) - { - return true; - } - if (isArithmeticInst(op)) - { - if (type && isSpecConstRateType(type)) - { - return true; - } - } - return false; +bool isInstHoistable(IROp op, IRType* type, IRInst* const* fixedArgs) +{ + return (getIROpInfo(op).flags & kIROpFlag_Hoistable) || + isSpecConstOpHoistable(op, type, fixedArgs); } + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 1e5a5eb2a..aa1ae3989 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -392,9 +392,12 @@ bool isFirstBlock(IRInst* inst); bool isSpecConstRateType(IRType* type); void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst); IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType); -bool isArithmeticInst(IRInst* inst); -bool isArithmeticInst(IROp op); -bool isInstHoistable(IROp op, IRType* type); +bool canOperationBeSpecConst( + IROp op, + IRType* resultType, + IRInst* const* fixedArgs, + IRUse* operands); +bool isInstHoistable(IROp op, IRType* type, IRInst* const* fixedArgs); } // namespace Slang #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 85fe2fa04..f571ec20b 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1827,7 +1827,7 @@ IRInst* IRBuilder::_createInst( m_dedupContext->getInstReplacementMap().tryGetValue(type, instReplacement); type = (IRType*)instReplacement; - if (isInstHoistable(op, type)) + if (isInstHoistable(op, type, fixedArgs)) { return _findOrEmitHoistableInst( type, diff --git a/tests/spirv/spec-constant-int-val-float-to-int-cast.slang b/tests/spirv/spec-constant-int-val-float-to-int-cast.slang new file mode 100644 index 000000000..9f9f96178 --- /dev/null +++ b/tests/spirv/spec-constant-int-val-float-to-int-cast.slang @@ -0,0 +1,15 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv + +// CHECK: error 38041: cannot cast non-integer specialization +// CHECK-NEXT: const float X + +[[SpecializationConstant]] +const float X = 10.0; + +[shader("compute")] +[numthreads(32, 1, 1)] +void computeMain() : SV_Target +{ + float arr[int(X)]; + float a = arr[0]; +} diff --git a/tests/spirv/spec-constant-operations.slang b/tests/spirv/spec-constant-operations.slang new file mode 100644 index 000000000..86d16ef34 --- /dev/null +++ b/tests/spirv/spec-constant-operations.slang @@ -0,0 +1,84 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type -emit-spirv-directly + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +// `OpSpecConstantOp` can only contain integer operations when targeting Vulkan SPIRV, not floating-point operations. +// This test checks that floating-point operations that strictly contain specialization constant variables are not declared with `OpSpecContantOp`, +// while integer operations that strictly contain specializaton constant operands are declared as `OpSpecConstantOp`. + +// CHECK-DAG: OpSpecConstant %float 1 +// CHECK-DAG: OpSpecConstant %ulong 256 +// CHECK-DAG: OpSpecConstant %float 100 +// CHECK-DAG: OpSpecConstantOp %half FConvert +// CHECK-DAG: OpSpecConstantOp %int UConvert + +// CHECK-NOT: OpSpecConstantOp {{.*}} FAdd +// CHECK-NOT: OpSpecConstantOp {{.*}} FSub +// CHECK-NOT: OpSpecConstantOp {{.*}} FMul +// CHECK-NOT: OpSpecConstantOp {{.*}} FDiv +// CHECK-NOT: OpSpecConstantOp {{.*}} SpvOpConvertUToF +// CHECK-NOT: OpSpecConstantOp {{.*}} SpvOpConvertFToU + +[[SpecializationConstant]] +const float X = 1.0; +[[SpecializationConstant]] +const uint64_t Y = 256; +[[SpecializationConstant]] +const float Z = 100.0; + +int func1() +{ + // Test float-to-float and int-to-int conversions. + int a = int(Y); + half b = half(X); + int16_t c = int16_t(Y); + + // Test comparisons. + if (X < 2.0) + { + a = 3; + } + else if (X > 5.0) + { + a = 5; + } + + if (Y < 200) + { + b = 2.0h; + } + else if (Y > 500) + { + b = 5.0h; + } + + return a + int(b) + int(c); +} + +float func2() +{ + // Test floating-point arithmetic. + float a = X + Z; + a += (X - Z); + a += (X * Z); + a += (X / Z); + + return a; +} + +float func3() +{ + // Test float-to-int and int-to-float conversions. + int a = int(Z) * 2; + return float(Y) + float(a); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + // BUF: 818.01 + outputBuffer[0] = float(func1()) + func2() + func3(); +} |
