summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-type.cpp2
-rw-r--r--source/slang/slang-ir-clone.cpp3
-rw-r--r--source/slang/slang-ir-specialize.cpp8
-rw-r--r--source/slang/slang-ir-util.cpp75
-rw-r--r--source/slang/slang-ir-util.h4
-rw-r--r--source/slang/slang-ir.cpp5
-rw-r--r--source/slang/slang-lower-to-ir.cpp49
-rw-r--r--tests/spirv/spec-constant-generic.slang53
8 files changed, 150 insertions, 49 deletions
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index 172d09ac2..d32903175 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -153,7 +153,7 @@ IntVal* SemanticsVisitor::ExtractGenericArgInteger(
genericParamType ? IntegerConstantExpressionCoercionType::SpecificType
: IntegerConstantExpressionCoercionType::AnyInteger,
genericParamType,
- ConstantFoldingKind::LinkTime,
+ ConstantFoldingKind::SpecializationConstant,
sink);
if (val)
return val;
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index 5bb1c1210..1a020ec26 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -93,6 +93,9 @@ IRInst* cloneInstAndOperands(IRCloneEnv* env, IRBuilder* builder, IRInst* oldIns
auto newOperand = findCloneForOperand(env, oldOperand);
newOperands[ii] = newOperand;
+
+ if (isArithmeticInst(oldInst))
+ newType = maybeAddRateType(builder, newOperand->getFullType(), newType);
}
// Finally we create the inst with the updated operands.
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 2f51b28a2..266c1aa99 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -101,7 +101,13 @@ struct SpecializationContext
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_Select:
- return true;
+ {
+ if (isSpecConstRateType(inst->getFullType()))
+ {
+ return false;
+ }
+ return true;
+ }
default:
return false;
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 9d8773237..c8faec73b 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2261,16 +2261,75 @@ bool isSpecConstRateType(IRType* type)
}
return false;
}
-void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst)
+
+IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType)
{
- IRInst* moduleInst = builder->getModule()->getModuleInst();
- UInt operandCount = inst->getOperandCount();
- for (UInt ii = 0; ii < operandCount; ++ii)
+ if (as<IRRateQualifiedType>(oldType))
{
- auto operand = inst->getOperand(ii);
- if (operand->parent != moduleInst)
- hoistInstAndOperandsToGlobal(builder, operand);
+ return oldType;
}
- inst->insertAt(IRInsertLoc::atStart(moduleInst));
+
+ if (isSpecConstRateType(rateQulifiedType))
+ {
+ return builder->getRateQualifiedType(builder->getSpecConstRate(), oldType);
+ }
+ return oldType;
+}
+
+bool isArithmeticInst(IROp op)
+{
+ switch (op)
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Neg:
+ case kIROp_Not:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Leq:
+ case kIROp_Geq:
+ case kIROp_Less:
+ case kIROp_IRem:
+ case kIROp_FRem:
+ case kIROp_Greater:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_BitNot:
+ case kIROp_BitCast:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_Select:
+ return true;
+ default:
+ return false;
+ }
+}
+bool isArithmeticInst(IRInst* inst)
+{
+ return isArithmeticInst(inst->getOp());
+}
+
+bool isInstHoistable(IROp op, IRType* type)
+{
+ if ((getIROpInfo(op).flags & kIROpFlag_Hoistable))
+ {
+ return true;
+ }
+
+ if (isArithmeticInst(op))
+ {
+ if (type && isSpecConstRateType(type))
+ {
+ return true;
+ }
+ }
+ return false;
}
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 900e22c76..1e5a5eb2a 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -391,6 +391,10 @@ bool isFirstBlock(IRInst* inst);
bool isSpecConstRateType(IRType* type);
void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst);
+IRType* maybeAddRateType(IRBuilder* builder, IRType* rateQulifiedType, IRType* oldType);
+bool isArithmeticInst(IRInst* inst);
+bool isArithmeticInst(IROp op);
+bool isInstHoistable(IROp op, IRType* type);
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 9c4cb98c0..c44196bc5 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1821,7 +1821,7 @@ IRInst* IRBuilder::_createInst(
m_dedupContext->getInstReplacementMap().tryGetValue(type, instReplacement);
type = (IRType*)instReplacement;
- if (getIROpInfo(op).flags & kIROpFlag_Hoistable)
+ if (isInstHoistable(op, type))
{
return _findOrEmitHoistableInst(
type,
@@ -2527,7 +2527,8 @@ static void addGlobalValue(IRBuilder* builder, IRInst* value)
//
if (value->parent)
{
- SLANG_ASSERT(getIROpInfo(value->getOp()).isHoistable());
+ SLANG_ASSERT(
+ getIROpInfo(value->getOp()).isHoistable() || isSpecConstRateType(value->getFullType()));
return;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index af285f221..f8946f5dc 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1552,17 +1552,6 @@ static bool _isTrivialLookupFromInterfaceThis(IRGenContext* context, DeclRefBase
return context->thisTypeWitness == nullptr;
}
-
-//
-static void maybePropagateRate(IRBuilder* builder, IRType* rateQulifiedType, IRInst* inst)
-{
- if (isSpecConstRateType(rateQulifiedType))
- {
- inst->setFullType(
- builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType()));
- }
-}
-
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo>
{
IRGenContext* context;
@@ -1596,14 +1585,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
specConstRateType = loweredArg.val->getFullType();
}
auto funcType = lowerType(context, val->getFuncType());
- auto resVal = emitCallToDeclRef(
- context,
- as<IRFuncType>(funcType)->getResultType(),
- val->getFuncDeclRef(),
- funcType,
- args,
- tryEnv);
- maybePropagateRate(getBuilder(), specConstRateType, resVal.val);
+ auto funcResType = maybeAddRateType(
+ getBuilder(),
+ specConstRateType,
+ as<IRFuncType>(funcType)->getResultType());
+ auto resVal =
+ emitCallToDeclRef(context, funcResType, val->getFuncDeclRef(), funcType, args, tryEnv);
return resVal;
}
@@ -1613,8 +1600,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
auto type = lowerType(context, val->getType());
+ type = maybeAddRateType(getBuilder(), baseVal.val->getFullType(), type);
auto resVal = LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
- maybePropagateRate(getBuilder(), baseVal.val->getFullType(), resVal.val);
return resVal;
}
@@ -1641,12 +1628,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto factorVal = lowerVal(context, factor->getParam()).val;
for (IntegerLiteralValue i = 0; i < factor->getPower(); i++)
{
- termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal);
+ termVal = irBuilder->emitMul(factorVal->getFullType(), termVal, factorVal);
}
- maybePropagateRate(getBuilder(), factorVal->getFullType(), termVal);
}
- resultVal = irBuilder->emitAdd(termVal->getDataType(), resultVal, termVal);
- maybePropagateRate(getBuilder(), termVal->getFullType(), resultVal);
+ resultVal = irBuilder->emitAdd(termVal->getFullType(), resultVal, termVal);
}
return LoweredValInfo::simple(resultVal);
}
@@ -2076,19 +2061,9 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto elementType = lowerType(context, type->getElementType());
if (!type->isUnsized())
{
- IRInst* elementCount = nullptr;
- auto sizeVal = type->getElementCount();
- auto sharedContext = context->shared;
- if (!sharedContext->mapSpecConstValToIRInst.tryGetValue(sizeVal, elementCount))
- {
- elementCount = lowerSimpleVal(context, sizeVal);
- if (isSpecConstRateType(elementCount->getFullType()))
- {
- sharedContext->mapSpecConstValToIRInst.add(sizeVal, elementCount);
- hoistInstAndOperandsToGlobal(getBuilder(), elementCount);
- }
- }
- return getBuilder()->getArrayType(elementType, elementCount);
+ return getBuilder()->getArrayType(
+ elementType,
+ lowerSimpleVal(context, type->getElementCount()));
}
else
{
diff --git a/tests/spirv/spec-constant-generic.slang b/tests/spirv/spec-constant-generic.slang
new file mode 100644
index 000000000..65eed2810
--- /dev/null
+++ b/tests/spirv/spec-constant-generic.slang
@@ -0,0 +1,53 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type
+
+// CHECK: %[[C0:[0-9A-Za-z_]+]] = OpSpecConstant %int 32
+// CHECK: %[[C1:[0-9A-Za-z_]+]] = OpSpecConstant %int 2
+// CHECK: %[[COP0:[0-9A-Za-z_]+]] = OpSpecConstantOp %int SDiv %[[C0]] %[[C1]]
+// CHECK: %[[ARR_TYPE:[0-9A-Za-z_]+]] = OpTypeArray %float %[[COP0]]
+// CHECK: %[[PT_TYPE:[0-9A-Za-z_]+]] = OpTypePointer Function %[[ARR_TYPE]]
+
+[SpecializationConstant]
+const int constValue0 = 32;
+
+[SpecializationConstant]
+const int constValue1 = 2;
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+void func(out float buffer[constValue0 / constValue1])
+{
+ for (uint i = 0; i < constValue0 / constValue1; i++)
+ {
+ buffer[i] = i;
+ }
+}
+
+struct MyStruct<let N: int>
+{
+ float buffer[N / constValue1];
+}
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+ // This test checks we can use spec constants for generic arguments, and also
+ // we can show that the array size is computed correctly.
+ // The function call shows that the two arrays are the same type.
+ MyStruct<constValue0> s;
+ // CHECK: OpVariable %[[PT_TYPE]] Function
+
+ func(s.buffer);
+
+ float temp = 0.0f;
+ for (uint i = 0; i < constValue0 / constValue1; i++)
+ {
+ temp += s.buffer[i] * 2;
+ }
+
+ // Result will be (0 + localConst-1) * localConst = 15 * 16 = 240
+ outputBuffer[0] = temp;
+ // BUF: 240
+}