summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-expr.cpp15
-rw-r--r--source/slang/slang-diagnostic-defs.h6
-rw-r--r--source/slang/slang-emit-spirv.cpp15
-rw-r--r--source/slang/slang-ir-clone.cpp8
-rw-r--r--source/slang/slang-ir-util.cpp62
-rw-r--r--source/slang/slang-ir-util.h9
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--tests/spirv/spec-constant-int-val-float-to-int-cast.slang15
-rw-r--r--tests/spirv/spec-constant-operations.slang84
9 files changed, 175 insertions, 41 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 205575a81..9472138c3 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1980,8 +1980,21 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef(
decl->hasModifier<VkConstantIdAttribute>()) &&
kind == ConstantFoldingKind::SpecializationConstant)
{
+ // Float-to-inst casts cannot be`OpSpecConstOp` operations in SPIR-V,
+ // which means they need to be local instructions can cannot be hoisted to the
+ // global scope. Deduplication logic is run for `IntVal`s however and without hoisting
+ // instructions using this `IntVal` will trigger error. Hence we emit error here
+ // to not allow such cases.
+ //
+ // Note that float-to-inst casts for non-`IntVal`s are allowed.
+ if (!isScalarIntegerType(decl->getType()))
+ {
+ getSink()->diagnose(declRef, Diagnostics::intValFromNonIntSpecConstEncountered);
+ return nullptr;
+ }
+
return m_astBuilder->getOrCreate<DeclRefIntVal>(
- declRef.substitute(m_astBuilder, declRef.getDecl()->getType()),
+ declRef.substitute(m_astBuilder, decl->getType()),
declRef);
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index fc7a4d5bb..4aadfd78d 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -2058,6 +2058,12 @@ DIAGNOSTIC(
nonUniformEntryPointParameterTreatedAsUniform,
"parameter '$0' is treated as 'uniform' because it does not have a system-value semantic.")
+DIAGNOSTIC(
+ 38041,
+ Error,
+ intValFromNonIntSpecConstEncountered,
+ "cannot cast non-integer specialization constant to compile-time integer")
+
DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself")
DIAGNOSTIC(
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 57ad1a988..0a3dab78a 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -784,18 +784,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
switch (irOpCode)
{
case kIROp_IntCast:
- {
- auto typeStyle = getTypeStyle(basicType->getBaseType());
- if (typeStyle == kIROp_FloatType)
- {
- return SpvOpConvertFToU;
- }
- else if (typeStyle == kIROp_IntType)
- {
- return SpvOpUConvert;
- }
- break;
- }
+ return SpvOpUConvert;
+ case kIROp_FloatCast:
+ return SpvOpFConvert;
default:
break;
}
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index 1a020ec26..74a972c1d 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -79,6 +79,12 @@ IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldIns
//
SLANG_ASSERT(!as<IRConstant>(oldInst));
+ const auto canBeSpecConst = canOperationBeSpecConst(
+ oldInst->getOp(),
+ oldInst->getDataType(),
+ nullptr,
+ oldInst->getOperands());
+
// Next we will iterate over the operands of `oldInst`
// to find their replacements and install them as
// the operands of `newInst`.
@@ -94,7 +100,7 @@ IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldIns
newOperands[ii] = newOperand;
- if (isArithmeticInst(oldInst))
+ if (canBeSpecConst)
newType = maybeAddRateType(builder, newOperand->getFullType(), newType);
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index c8faec73b..13742711c 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2276,8 +2276,14 @@ IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* o
return oldType;
}
-bool isArithmeticInst(IROp op)
+bool canOperationBeSpecConst(IROp op, IRType* resultType, IRInst* const* fixedArgs, IRUse* operands)
{
+ // Returns true for ops that can be declared as an operation under `OpSpecConstantOp`.
+ //
+ // Integer arithmetic and comparison operations can be `OpSpecConstantOp` with the `Shader`
+ // capability, while floating-point arithmetic and comparison operations require the `Kernel`
+ // capability. We only support `Shader` capability for now, return false when floating-point
+ // arithmetic/comparison is encountered.
switch (op)
{
case kIROp_Add:
@@ -2285,51 +2291,61 @@ bool isArithmeticInst(IROp op)
case kIROp_Mul:
case kIROp_Div:
case kIROp_Neg:
- case kIROp_Not:
+ return !isFloatingType(resultType);
+
case kIROp_Eql:
case kIROp_Neq:
case kIROp_Leq:
case kIROp_Geq:
case kIROp_Less:
- case kIROp_IRem:
- case kIROp_FRem:
case kIROp_Greater:
+ {
+ IRInst* operand1;
+ IRInst* operand2;
+ if (fixedArgs)
+ {
+ operand1 = fixedArgs[0];
+ operand2 = fixedArgs[1];
+ }
+ else
+ {
+ operand1 = operands[0].get();
+ operand2 = operands[1].get();
+ }
+ return !isFloatingType(operand1->getDataType()) &&
+ !isFloatingType(operand2->getDataType());
+ }
+
+ case kIROp_Not:
+ case kIROp_IRem:
case kIROp_Lsh:
case kIROp_Rsh:
case kIROp_BitAnd:
case kIROp_BitOr:
case kIROp_BitXor:
case kIROp_BitNot:
- case kIROp_BitCast:
- case kIROp_CastIntToFloat:
- case kIROp_CastFloatToInt:
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_Select:
return true;
+
default:
return false;
}
}
-bool isArithmeticInst(IRInst* inst)
+
+bool isSpecConstOpHoistable(IROp op, IRType* type, IRInst* const* fixedArgs)
{
- return isArithmeticInst(inst->getOp());
+ auto rateType = as<IRRateQualifiedType>(type);
+ return rateType && as<IRSpecConstRate>(rateType->getRate()) &&
+ canOperationBeSpecConst(op, rateType->getValueType(), fixedArgs, nullptr);
}
-bool isInstHoistable(IROp op, IRType* type)
-{
- if ((getIROpInfo(op).flags & kIROpFlag_Hoistable))
- {
- return true;
- }
- if (isArithmeticInst(op))
- {
- if (type && isSpecConstRateType(type))
- {
- return true;
- }
- }
- return false;
+bool isInstHoistable(IROp op, IRType* type, IRInst* const* fixedArgs)
+{
+ return (getIROpInfo(op).flags & kIROpFlag_Hoistable) ||
+ isSpecConstOpHoistable(op, type, fixedArgs);
}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 1e5a5eb2a..aa1ae3989 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -392,9 +392,12 @@ bool isFirstBlock(IRInst* inst);
bool isSpecConstRateType(IRType* type);
void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst);
IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType);
-bool isArithmeticInst(IRInst* inst);
-bool isArithmeticInst(IROp op);
-bool isInstHoistable(IROp op, IRType* type);
+bool canOperationBeSpecConst(
+ IROp op,
+ IRType* resultType,
+ IRInst* const* fixedArgs,
+ IRUse* operands);
+bool isInstHoistable(IROp op, IRType* type, IRInst* const* fixedArgs);
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 85fe2fa04..f571ec20b 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1827,7 +1827,7 @@ IRInst* IRBuilder::_createInst(
m_dedupContext->getInstReplacementMap().tryGetValue(type, instReplacement);
type = (IRType*)instReplacement;
- if (isInstHoistable(op, type))
+ if (isInstHoistable(op, type, fixedArgs))
{
return _findOrEmitHoistableInst(
type,
diff --git a/tests/spirv/spec-constant-int-val-float-to-int-cast.slang b/tests/spirv/spec-constant-int-val-float-to-int-cast.slang
new file mode 100644
index 000000000..9f9f96178
--- /dev/null
+++ b/tests/spirv/spec-constant-int-val-float-to-int-cast.slang
@@ -0,0 +1,15 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv
+
+// CHECK: error 38041: cannot cast non-integer specialization
+// CHECK-NEXT: const float X
+
+[[SpecializationConstant]]
+const float X = 10.0;
+
+[shader("compute")]
+[numthreads(32, 1, 1)]
+void computeMain() : SV_Target
+{
+ float arr[int(X)];
+ float a = arr[0];
+}
diff --git a/tests/spirv/spec-constant-operations.slang b/tests/spirv/spec-constant-operations.slang
new file mode 100644
index 000000000..86d16ef34
--- /dev/null
+++ b/tests/spirv/spec-constant-operations.slang
@@ -0,0 +1,84 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type -emit-spirv-directly
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+// `OpSpecConstantOp` can only contain integer operations when targeting Vulkan SPIRV, not floating-point operations.
+// This test checks that floating-point operations that strictly contain specialization constant variables are not declared with `OpSpecContantOp`,
+// while integer operations that strictly contain specializaton constant operands are declared as `OpSpecConstantOp`.
+
+// CHECK-DAG: OpSpecConstant %float 1
+// CHECK-DAG: OpSpecConstant %ulong 256
+// CHECK-DAG: OpSpecConstant %float 100
+// CHECK-DAG: OpSpecConstantOp %half FConvert
+// CHECK-DAG: OpSpecConstantOp %int UConvert
+
+// CHECK-NOT: OpSpecConstantOp {{.*}} FAdd
+// CHECK-NOT: OpSpecConstantOp {{.*}} FSub
+// CHECK-NOT: OpSpecConstantOp {{.*}} FMul
+// CHECK-NOT: OpSpecConstantOp {{.*}} FDiv
+// CHECK-NOT: OpSpecConstantOp {{.*}} SpvOpConvertUToF
+// CHECK-NOT: OpSpecConstantOp {{.*}} SpvOpConvertFToU
+
+[[SpecializationConstant]]
+const float X = 1.0;
+[[SpecializationConstant]]
+const uint64_t Y = 256;
+[[SpecializationConstant]]
+const float Z = 100.0;
+
+int func1()
+{
+ // Test float-to-float and int-to-int conversions.
+ int a = int(Y);
+ half b = half(X);
+ int16_t c = int16_t(Y);
+
+ // Test comparisons.
+ if (X < 2.0)
+ {
+ a = 3;
+ }
+ else if (X > 5.0)
+ {
+ a = 5;
+ }
+
+ if (Y < 200)
+ {
+ b = 2.0h;
+ }
+ else if (Y > 500)
+ {
+ b = 5.0h;
+ }
+
+ return a + int(b) + int(c);
+}
+
+float func2()
+{
+ // Test floating-point arithmetic.
+ float a = X + Z;
+ a += (X - Z);
+ a += (X * Z);
+ a += (X / Z);
+
+ return a;
+}
+
+float func3()
+{
+ // Test float-to-int and int-to-float conversions.
+ int a = int(Z) * 2;
+ return float(Y) + float(a);
+}
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+ // BUF: 818.01
+ outputBuffer[0] = float(func1()) + func2() + func3();
+}