summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-05-15 11:59:15 -0500
committerGitHub <noreply@github.com>2025-05-15 09:59:15 -0700
commitb325474c4aba52cca7e0bcd4eae02d23ca4ab9a3 (patch)
treed0790bb56bec865743c95895aedaeb14f63312d7
parented837e205f3e67c4ae112f544cfe486ca3cc8455 (diff)
Implement spec const for generic parameter (#7121)
Close #6840. This PR add supports to use specialize constant in generic parameter, and that parameter can also be used as array size, e.g. following code should work: ``` struct MyStruct<let N: int> { float buffer[N]; } MyStruct<SpecConstVar> s; ``` - Loose the restriction from Link-Time to SpecializationConstant when extract generic argument - Tweak the logic of how we decide whether a inst is hoistable. Besides checking existing hoistable flag of each IRInst, when we detect a IRInst's type is SpecConstRateType, we will treat that inst hoistable. Because IRInst in global scope can be deduplicated, and every SpecConstRateType inst should be in the global scope or IRGeneric scope (which will be at global scope after specialization). - Remove the SpecConstIntVal to IRInst map in IR lowering logic, because we already have way to deduplicate the global scope IR.
-rw-r--r--source/slang/slang-check-type.cpp2
-rw-r--r--source/slang/slang-ir-clone.cpp3
-rw-r--r--source/slang/slang-ir-specialize.cpp8
-rw-r--r--source/slang/slang-ir-util.cpp75
-rw-r--r--source/slang/slang-ir-util.h4
-rw-r--r--source/slang/slang-ir.cpp5
-rw-r--r--source/slang/slang-lower-to-ir.cpp49
-rw-r--r--tests/spirv/spec-constant-generic.slang53
8 files changed, 150 insertions, 49 deletions
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index 172d09ac2..d32903175 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -153,7 +153,7 @@ IntVal* SemanticsVisitor::ExtractGenericArgInteger(
genericParamType ? IntegerConstantExpressionCoercionType::SpecificType
: IntegerConstantExpressionCoercionType::AnyInteger,
genericParamType,
- ConstantFoldingKind::LinkTime,
+ ConstantFoldingKind::SpecializationConstant,
sink);
if (val)
return val;
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index 5bb1c1210..1a020ec26 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -93,6 +93,9 @@ IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldIns
auto newOperand = findCloneForOperand(env, oldOperand);
newOperands[ii] = newOperand;
+
+ if (isArithmeticInst(oldInst))
+ newType = maybeAddRateType(builder, newOperand->getFullType(), newType);
}
// Finally we create the inst with the updated operands.
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 2f51b28a2..266c1aa99 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -101,7 +101,13 @@ struct SpecializationContext
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_Select:
- return true;
+ {
+ if (isSpecConstRateType(inst->getFullType()))
+ {
+ return false;
+ }
+ return true;
+ }
default:
return false;
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 9d8773237..c8faec73b 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2261,16 +2261,75 @@ bool isSpecConstRateType(IRType* type)
}
return false;
}
-void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst)
+
+IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType)
{
- IRInst* moduleInst = builder->getModule()->getModuleInst();
- UInt operandCount = inst->getOperandCount();
- for (UInt ii = 0; ii < operandCount; ++ii)
+ if (as<IRRateQualifiedType>(oldType))
{
- auto operand = inst->getOperand(ii);
- if (operand->parent != moduleInst)
- hoistInstAndOperandsToGlobal(builder, operand);
+ return oldType;
}
- inst->insertAt(IRInsertLoc::atStart(moduleInst));
+
+ if (isSpecConstRateType(rateQulifiedType))
+ {
+ return builder->getRateQualifiedType(builder->getSpecConstRate(), oldType);
+ }
+ return oldType;
+}
+
+bool isArithmeticInst(IROp op)
+{
+ switch (op)
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Neg:
+ case kIROp_Not:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Leq:
+ case kIROp_Geq:
+ case kIROp_Less:
+ case kIROp_IRem:
+ case kIROp_FRem:
+ case kIROp_Greater:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_BitNot:
+ case kIROp_BitCast:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_Select:
+ return true;
+ default:
+ return false;
+ }
+}
+bool isArithmeticInst(IRInst* inst)
+{
+ return isArithmeticInst(inst->getOp());
+}
+
+bool isInstHoistable(IROp op, IRType* type)
+{
+ if ((getIROpInfo(op).flags & kIROpFlag_Hoistable))
+ {
+ return true;
+ }
+
+ if (isArithmeticInst(op))
+ {
+ if (type && isSpecConstRateType(type))
+ {
+ return true;
+ }
+ }
+ return false;
}
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 900e22c76..1e5a5eb2a 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -391,6 +391,10 @@ bool isFirstBlock(IRInst* inst);
bool isSpecConstRateType(IRType* type);
void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst);
+IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType);
+bool isArithmeticInst(IRInst* inst);
+bool isArithmeticInst(IROp op);
+bool isInstHoistable(IROp op, IRType* type);
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 9c4cb98c0..c44196bc5 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1821,7 +1821,7 @@ IRInst* IRBuilder::_createInst(
m_dedupContext->getInstReplacementMap().tryGetValue(type, instReplacement);
type = (IRType*)instReplacement;
- if (getIROpInfo(op).flags & kIROpFlag_Hoistable)
+ if (isInstHoistable(op, type))
{
return _findOrEmitHoistableInst(
type,
@@ -2527,7 +2527,8 @@ static void addGlobalValue(IRBuilder* builder, IRInst* value)
//
if (value->parent)
{
- SLANG_ASSERT(getIROpInfo(value->getOp()).isHoistable());
+ SLANG_ASSERT(
+ getIROpInfo(value->getOp()).isHoistable() || isSpecConstRateType(value->getFullType()));
return;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index af285f221..f8946f5dc 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1552,17 +1552,6 @@ static bool _isTrivialLookupFromInterfaceThis(IRGenContext* context, DeclRefBase
return context->thisTypeWitness == nullptr;
}
-
-//
-static void maybePropagateRate(IRBuilder* builder, IRType* rateQulifiedType, IRInst* inst)
-{
- if (isSpecConstRateType(rateQulifiedType))
- {
- inst->setFullType(
- builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType()));
- }
-}
-
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo>
{
IRGenContext* context;
@@ -1596,14 +1585,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
specConstRateType = loweredArg.val->getFullType();
}
auto funcType = lowerType(context, val->getFuncType());
- auto resVal = emitCallToDeclRef(
- context,
- as<IRFuncType>(funcType)->getResultType(),
- val->getFuncDeclRef(),
- funcType,
- args,
- tryEnv);
- maybePropagateRate(getBuilder(), specConstRateType, resVal.val);
+ auto funcResType = maybeAddRateType(
+ getBuilder(),
+ specConstRateType,
+ as<IRFuncType>(funcType)->getResultType());
+ auto resVal =
+ emitCallToDeclRef(context, funcResType, val->getFuncDeclRef(), funcType, args, tryEnv);
return resVal;
}
@@ -1613,8 +1600,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
auto type = lowerType(context, val->getType());
+ type = maybeAddRateType(getBuilder(), baseVal.val->getFullType(), type);
auto resVal = LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
- maybePropagateRate(getBuilder(), baseVal.val->getFullType(), resVal.val);
return resVal;
}
@@ -1641,12 +1628,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto factorVal = lowerVal(context, factor->getParam()).val;
for (IntegerLiteralValue i = 0; i < factor->getPower(); i++)
{
- termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal);
+ termVal = irBuilder->emitMul(factorVal->getFullType(), termVal, factorVal);
}
- maybePropagateRate(getBuilder(), factorVal->getFullType(), termVal);
}
- resultVal = irBuilder->emitAdd(termVal->getDataType(), resultVal, termVal);
- maybePropagateRate(getBuilder(), termVal->getFullType(), resultVal);
+ resultVal = irBuilder->emitAdd(termVal->getFullType(), resultVal, termVal);
}
return LoweredValInfo::simple(resultVal);
}
@@ -2076,19 +2061,9 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto elementType = lowerType(context, type->getElementType());
if (!type->isUnsized())
{
- IRInst* elementCount = nullptr;
- auto sizeVal = type->getElementCount();
- auto sharedContext = context->shared;
- if (!sharedContext->mapSpecConstValToIRInst.tryGetValue(sizeVal, elementCount))
- {
- elementCount = lowerSimpleVal(context, sizeVal);
- if (isSpecConstRateType(elementCount->getFullType()))
- {
- sharedContext->mapSpecConstValToIRInst.add(sizeVal, elementCount);
- hoistInstAndOperandsToGlobal(getBuilder(), elementCount);
- }
- }
- return getBuilder()->getArrayType(elementType, elementCount);
+ return getBuilder()->getArrayType(
+ elementType,
+ lowerSimpleVal(context, type->getElementCount()));
}
else
{
diff --git a/tests/spirv/spec-constant-generic.slang b/tests/spirv/spec-constant-generic.slang
new file mode 100644
index 000000000..65eed2810
--- /dev/null
+++ b/tests/spirv/spec-constant-generic.slang
@@ -0,0 +1,53 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type
+
+// CHECK: %[[C0:[0-9A-Za-z_]+]] = OpSpecConstant %int 32
+// CHECK: %[[C1:[0-9A-Za-z_]+]] = OpSpecConstant %int 2
+// CHECK: %[[COP0:[0-9A-Za-z_]+]] = OpSpecConstantOp %int SDiv %[[C0]] %[[C1]]
+// CHECK: %[[ARR_TYPE:[0-9A-Za-z_]+]] = OpTypeArray %float %[[COP0]]
+// CHECK: %[[PT_TYPE:[0-9A-Za-z_]+]] = OpTypePointer Function %[[ARR_TYPE]]
+
+[SpecializationConstant]
+const int constValue0 = 32;
+
+[SpecializationConstant]
+const int constValue1 = 2;
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+void func(out float buffer[constValue0 / constValue1])
+{
+ for (uint i = 0; i < constValue0 / constValue1; i++)
+ {
+ buffer[i] = i;
+ }
+}
+
+struct MyStruct<let N: int>
+{
+ float buffer[N / constValue1];
+}
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+ // This test checks we can use spec constants for generic arguments, and also
+ // we can show that the array size is computed correctly.
+ // The function call shows that the two arrays are the same type.
+ MyStruct<constValue0> s;
+ // CHECK: OpVariable %[[PT_TYPE]] Function
+
+ func(s.buffer);
+
+ float temp = 0.0f;
+ for (uint i = 0; i < constValue0 / constValue1; i++)
+ {
+ temp += s.buffer[i] * 2;
+ }
+
+ // Result will be (0 + localConst-1) * localConst = 15 * 16 = 240
+ outputBuffer[0] = temp;
+ // BUF: 240
+}