summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-05-14 12:11:53 -0500
committerGitHub <noreply@github.com>2025-05-14 10:11:53 -0700
commit375ecfe2903b09f07abeba2eafb88d9a564c1458 (patch)
treea507ffcdbe118f5d69ffb3e6c341d8f954e0bfef /source/slang
parent39c9e25f6d728e970b68a9452330e754991b4ac5 (diff)
support specialization constant sized array (#6871)
Close #6859 Goal of this PR We want to support an array whose size can be specialization constant for shared/global variable e.g. layout (constant_id = 0) const uint BLOCK_SIZE = 64; shared float buf_a[(BLOCK_SIZE + 5) * 4]; Overview of the solution: During IndexExpr check, we will loose the restriction to allow SpecConst passing, but the size parameter will not be a constant value because it cannot be folded into a constant, so we will make it follow the same logic as generic parameter value, and the size will be represented by FuncCallIntVal/PolynomialIntVal/DeclRefIntVal. During IR lowering, we will detect whether there is spec constant in the IntVal, and wrap the IRInst with a SpecConstRateType, and propagate the type though the lowering logic, such that the IntVal representing the array size will have SpecConstRateType. During spirv emit stage, if we detect that a IRInst has SpecConstRateType, we will emit it as SpecConstantOp. We have to implement new logic to emit OpSpecConstantOp, the existing emit logic doesn't support emitting OpSpecConstantOp, especially this op can embed arithmetic operation at global scope, where we can only emit arithmetic instruct at local. But there are only few instructs we need to support. Overview of the solution: This PR doesn't support generic, and we will create a separate PR to extend that, tracked in #6840.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-builder.cpp1
-rw-r--r--source/slang/slang-ast-decl-ref.cpp2
-rw-r--r--source/slang/slang-ast-val.cpp10
-rw-r--r--source/slang/slang-ast-val.h17
-rw-r--r--source/slang/slang-check-constraint.cpp4
-rw-r--r--source/slang/slang-check-decl.cpp6
-rw-r--r--source/slang/slang-check-expr.cpp17
-rw-r--r--source/slang/slang-check-impl.h1
-rw-r--r--source/slang/slang-check-type.cpp4
-rw-r--r--source/slang/slang-doc-markdown-writer.cpp2
-rw-r--r--source/slang/slang-emit-spirv.cpp325
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp7
-rw-r--r--source/slang/slang-ir-util.cpp23
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--source/slang/slang-ir.cpp21
-rw-r--r--source/slang/slang-ir.h6
-rw-r--r--source/slang/slang-lower-to-ir.cpp54
-rw-r--r--source/slang/slang-mangle.cpp2
-rw-r--r--source/slang/slang-reflection-api.cpp2
-rw-r--r--source/slang/slang-type-layout.cpp6
-rw-r--r--source/slang/slang.cpp2
23 files changed, 363 insertions, 154 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 5abef94b3..893d5e6d7 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -347,6 +347,7 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element
{
if (!elementCount)
elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength);
+
if (elementCount->getType() != getIntType())
{
// Canonicalize constant elementCount to int.
diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp
index 1881f1b3c..89fa52b09 100644
--- a/source/slang/slang-ast-decl-ref.cpp
+++ b/source/slang/slang-ast-decl-ref.cpp
@@ -41,7 +41,7 @@ DeclRefBase* _getDeclRefFromVal(Val* val)
{
if (auto declRefType = as<DeclRefType>(val))
return declRefType->getDeclRef();
- else if (auto genParamIntVal = as<GenericParamIntVal>(val))
+ else if (auto genParamIntVal = as<DeclRefIntVal>(val))
return genParamIntVal->getDeclRef();
else if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(val))
return declaredSubtypeWitness->getDeclRef();
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index efb87b831..1cdca0440 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -176,9 +176,9 @@ void ConstantIntVal::_toTextOverride(StringBuilder& out)
out << getValue();
}
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-void GenericParamIntVal::_toTextOverride(StringBuilder& out)
+void DeclRefIntVal::_toTextOverride(StringBuilder& out)
{
Name* name = getDeclRef().getName();
if (name)
@@ -248,7 +248,7 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet
return paramVal;
}
-Val* GenericParamIntVal::_substituteImplOverride(
+Val* DeclRefIntVal::_substituteImplOverride(
ASTBuilder* /* astBuilder */,
SubstitutionSet subst,
int* ioDiff)
@@ -259,12 +259,12 @@ Val* GenericParamIntVal::_substituteImplOverride(
return this;
}
-bool GenericParamIntVal::_isLinkTimeValOverride()
+bool DeclRefIntVal::_isLinkTimeValOverride()
{
return getDeclRef().getDecl()->hasModifier<ExternModifier>();
}
-Val* GenericParamIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map)
+Val* DeclRefIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map)
{
auto name = getMangledName(getCurrentASTBuilder(), getDeclRef().declRefBase);
IntVal* v;
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index cdfb0b51f..2b4c7ed22 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -168,7 +168,7 @@ class ConstantIntVal : public IntVal
// The logical "value" of a reference to a generic value parameter
FIDDLE()
-class GenericParamIntVal : public IntVal
+class DeclRefIntVal : public IntVal
{
FIDDLE(...)
DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(1)); }
@@ -177,10 +177,7 @@ class GenericParamIntVal : public IntVal
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef)
- {
- setOperands(inType, inDeclRef);
- }
+ DeclRefIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef) { setOperands(inType, inDeclRef); }
bool _isLinkTimeValOverride();
Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map);
@@ -319,9 +316,9 @@ public:
// for sorting only.
bool operator<(const PolynomialIntValFactor& other) const
{
- if (auto thisGenParam = as<GenericParamIntVal>(getParam()))
+ if (auto thisGenParam = as<DeclRefIntVal>(getParam()))
{
- if (auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
+ if (auto thatGenParam = as<DeclRefIntVal>(other.getParam()))
{
if (thisGenParam->equals(thatGenParam))
return getPower() < other.getPower();
@@ -336,7 +333,7 @@ public:
}
else
{
- if (const auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
+ if (const auto thatGenParam = as<DeclRefIntVal>(other.getParam()))
{
return false;
}
@@ -347,9 +344,9 @@ public:
// for sorting only.
bool operator==(const PolynomialIntValFactor& other) const
{
- if (auto thisGenParam = as<GenericParamIntVal>(getParam()))
+ if (auto thisGenParam = as<DeclRefIntVal>(getParam()))
{
- if (auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
+ if (auto thatGenParam = as<DeclRefIntVal>(other.getParam()))
{
if (thisGenParam->equals(thatGenParam) && getPower() == other.getPower())
return true;
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 642a4bf6a..6f9191135 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -819,7 +819,7 @@ bool SemanticsVisitor::TryUnifyVals(
{
if (const auto c = as<TypeCastIntVal>(i))
i = as<IntVal>(c->getBase());
- return as<GenericParamIntVal>(i);
+ return as<DeclRefIntVal>(i);
};
auto fstParam = paramUnderCast(fstInt);
auto sndParam = paramUnderCast(sndInt);
@@ -1196,7 +1196,7 @@ void SemanticsVisitor::maybeUnifyUnconstraintIntParam(
{
param = as<IntVal>(typeCastParam->getBase());
}
- auto intParam = as<GenericParamIntVal>(param);
+ auto intParam = as<DeclRefIntVal>(param);
if (!intParam)
return;
for (auto c : constraints.constraints)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1e524e27f..dbd52ebea 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3921,7 +3921,7 @@ bool SemanticsVisitor::doesGenericSignatureMatchRequirement(
auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>();
SLANG_ASSERT(satisfyingValueParamDeclRef);
- auto satisfyingVal = m_astBuilder->getOrCreate<GenericParamIntVal>(
+ auto satisfyingVal = m_astBuilder->getOrCreate<DeclRefIntVal>(
requiredValueParamDeclRef.getDecl()->getType(),
satisfyingValueParamDeclRef);
satisfyingVal->getDeclRef() = satisfyingValueParamDeclRef;
@@ -8513,7 +8513,7 @@ List<Val*> getDefaultSubstitutionArgs(
if (semantics)
semantics->ensureDecl(genericValueParamDecl, DeclCheckState::ReadyForLookup);
- args.add(astBuilder->getOrCreate<GenericParamIntVal>(
+ args.add(astBuilder->getOrCreate<DeclRefIntVal>(
genericValueParamDecl->getType(),
astBuilder->getDirectDeclRef(genericValueParamDecl)));
}
@@ -11769,7 +11769,7 @@ void checkDerivativeAttributeImpl(
appExpr->arguments.add(baseTypeExpr);
}
- else if (auto genericValParam = as<GenericParamIntVal>(arg))
+ else if (auto genericValParam = as<DeclRefIntVal>(arg))
{
auto declRef = genericValParam->getDeclRef();
appExpr->arguments.add(
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 48f32952b..d151d37be 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1972,9 +1972,14 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef(
// The values of specialization constants aren't known at compile time even
// if they're marked `const`.
- if (decl->hasModifier<SpecializationConstantAttribute>() ||
- decl->hasModifier<VkConstantIdAttribute>())
- return nullptr;
+ if ((decl->hasModifier<SpecializationConstantAttribute>() ||
+ decl->hasModifier<VkConstantIdAttribute>()) &&
+ kind == ConstantFoldingKind::SpecializationConstant)
+ {
+ return m_astBuilder->getOrCreate<DeclRefIntVal>(
+ declRef.substitute(m_astBuilder, declRef.getDecl()->getType()),
+ declRef);
+ }
if (decl->hasModifier<ExternModifier>())
{
@@ -1982,7 +1987,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef(
if (kind == ConstantFoldingKind::CompileTime)
return nullptr;
// But if we are OK with link-time constants, we can still fold it into a val.
- auto rs = m_astBuilder->getOrCreate<GenericParamIntVal>(
+ auto rs = m_astBuilder->getOrCreate<DeclRefIntVal>(
declRef.substitute(m_astBuilder, declRef.getDecl()->getType()),
declRef);
return rs;
@@ -2067,7 +2072,7 @@ IntVal* SemanticsVisitor::tryConstantFoldExpr(
if (auto genericValParamRef = declRef.as<GenericValueParamDecl>())
{
- Val* valResult = m_astBuilder->getOrCreate<GenericParamIntVal>(
+ Val* valResult = m_astBuilder->getOrCreate<DeclRefIntVal>(
declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()),
genericValParamRef);
valResult = valResult->substitute(m_astBuilder, expr.getSubsts());
@@ -2383,7 +2388,7 @@ Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr)
subscriptExpr->indexExprs[0],
IntegerConstantExpressionCoercionType::AnyInteger,
nullptr,
- ConstantFoldingKind::LinkTime);
+ ConstantFoldingKind::SpecializationConstant);
// Validate that array size is greater than zero
if (auto constElementCount = as<ConstantIntVal>(elementCount))
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index a910a3722..6c9a0409d 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2110,6 +2110,7 @@ public:
{
CompileTime,
LinkTime,
+ SpecializationConstant
};
Expr* checkExpressionAndExpectIntegerConstant(
Expr* expr,
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index db753713b..172d09ac2 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -453,9 +453,9 @@ bool SemanticsVisitor::ValuesAreEqual(IntVal* left, IntVal* right)
}
}
- if (auto leftVar = as<GenericParamIntVal>(left))
+ if (auto leftVar = as<DeclRefIntVal>(left))
{
- if (auto rightVar = as<GenericParamIntVal>(right))
+ if (auto rightVar = as<DeclRefIntVal>(right))
{
return leftVar->getDeclRef().equals(rightVar->getDeclRef());
}
diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp
index d2e68ccc8..50fd739cb 100644
--- a/source/slang/slang-doc-markdown-writer.cpp
+++ b/source/slang/slang-doc-markdown-writer.cpp
@@ -782,7 +782,7 @@ void DocMarkdownWriter::writeExtensionConditions(
{
genericParamDecl = extTypeParamDecl.getDecl();
}
- else if (auto extValueParamVal = as<GenericParamIntVal>(arg))
+ else if (auto extValueParamVal = as<DeclRefIntVal>(arg))
{
genericParamDecl = extValueParamVal->getDeclRef().getDecl();
}
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 096e7d8bc..32d3ba7c3 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -775,6 +775,147 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
m_operandStack.setCount(operandsStartIndex);
}
+ SpvOp _specConstantOpcodeConvert(IROp irOpCode, IRBasicType* basicType)
+ {
+ SpvOp opCode = SpvOpUndef;
+ opCode = _arithmeticOpCodeConvert(irOpCode, basicType);
+ if (opCode == SpvOpUndef)
+ {
+ switch (irOpCode)
+ {
+ case kIROp_IntCast:
+ {
+ auto typeStyle = getTypeStyle(basicType->getBaseType());
+ if (typeStyle == kIROp_FloatType)
+ {
+ return SpvOpConvertFToU;
+ }
+ else if (typeStyle == kIROp_IntType)
+ {
+ return SpvOpUConvert;
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return opCode;
+ }
+ return opCode;
+ }
+
+ SpvOp _arithmeticOpCodeConvert(IROp irOpCode, IRBasicType* basicType)
+ {
+ bool isFloatingPoint = false;
+ bool isBool = false;
+ switch (basicType->getBaseType())
+ {
+ case BaseType::Float:
+ case BaseType::Double:
+ case BaseType::Half:
+ isFloatingPoint = true;
+ break;
+ case BaseType::Bool:
+ isBool = true;
+ break;
+ default:
+ break;
+ }
+ bool isSigned = isSignedType(basicType);
+ SpvOp opCode = SpvOpUndef;
+ switch (irOpCode)
+ {
+ case kIROp_Add:
+ opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
+ break;
+ case kIROp_Sub:
+ opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
+ break;
+ case kIROp_Mul:
+ opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
+ break;
+ case kIROp_Div:
+ opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
+ break;
+ case kIROp_IRem:
+ opCode = isSigned ? SpvOpSRem : SpvOpUMod;
+ break;
+ case kIROp_FRem:
+ opCode = SpvOpFRem;
+ break;
+ case kIROp_Less:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThan
+ : isSigned ? SpvOpSLessThan
+ : SpvOpULessThan;
+ break;
+ case kIROp_Leq:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
+ : isSigned ? SpvOpSLessThanEqual
+ : SpvOpULessThanEqual;
+ break;
+ case kIROp_Eql:
+ opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
+ break;
+ case kIROp_Neq:
+ opCode = isFloatingPoint ? SpvOpFUnordNotEqual
+ : isBool ? SpvOpLogicalNotEqual
+ : SpvOpINotEqual;
+ break;
+ case kIROp_Geq:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
+ : isSigned ? SpvOpSGreaterThanEqual
+ : SpvOpUGreaterThanEqual;
+ break;
+ case kIROp_Greater:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
+ : isSigned ? SpvOpSGreaterThan
+ : SpvOpUGreaterThan;
+ break;
+ case kIROp_Neg:
+ opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
+ break;
+ case kIROp_And:
+ opCode = SpvOpLogicalAnd;
+ break;
+ case kIROp_Or:
+ opCode = SpvOpLogicalOr;
+ break;
+ case kIROp_Not:
+ opCode = SpvOpLogicalNot;
+ break;
+ case kIROp_BitAnd:
+ if (isBool)
+ opCode = SpvOpLogicalAnd;
+ else
+ opCode = SpvOpBitwiseAnd;
+ break;
+ case kIROp_BitOr:
+ if (isBool)
+ opCode = SpvOpLogicalOr;
+ else
+ opCode = SpvOpBitwiseOr;
+ break;
+ case kIROp_BitXor:
+ if (isBool)
+ opCode = SpvOpLogicalNotEqual;
+ else
+ opCode = SpvOpBitwiseXor;
+ break;
+ case kIROp_BitNot:
+ if (isBool)
+ opCode = SpvOpLogicalNot;
+ else
+ opCode = SpvOpNot;
+ break;
+ case kIROp_Rsh:
+ opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
+ break;
+ case kIROp_Lsh:
+ opCode = SpvOpShiftLeftLogical;
+ break;
+ }
+ return opCode;
+ }
/// Ensure that an instruction has been emitted
SpvInst* ensureInst(IRInst* irInst)
{
@@ -1972,8 +2113,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
as<IRDebugInlinedAt>(inst));
default:
{
- if (as<IRSPIRVAsmOperand>(inst))
+ if (isSpecConstRateType(inst->getFullType()))
+ return emitSpecializationConstantOp(inst);
+
+ else if (as<IRSPIRVAsmOperand>(inst))
return nullptr;
+
String e = "Unhandled global inst in spirv-emit:\n" +
dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0});
SLANG_UNIMPLEMENTED_X(e.begin());
@@ -2756,6 +2901,66 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ SpvInst* emitSpecializationConstantOp(IRInst* inst)
+ {
+ SpvInst* spv = nullptr;
+ if (m_mapIRInstToSpvInst.tryGetValue(inst, spv))
+ return spv;
+
+ // For each OpSpecConstantOp, the operand must be:
+ // 1. A specialization constant
+ // 2. A literal constant
+ // 3. Another OpSpecConstantOp
+
+ // For 1 and 2, we can just emit the specialization constant or literal constant.
+ if (auto param = as<IRGlobalParam>(inst))
+ {
+ auto layout = getVarLayout(param);
+ if (layout)
+ {
+ if (auto offset =
+ layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
+ {
+ return emitSpecializationConstant(param, offset);
+ }
+ }
+ SLANG_UNREACHABLE("Non specialization constant used in OpSpecConstantOp\n");
+ }
+ else if (as<IRConstant>(inst))
+ {
+ // We need to emit the constant as a specialization constant
+ return emitLit(inst);
+ }
+
+ IRType* type = inst->getOperand(0)->getDataType();
+ IRBasicType* basicType = as<IRBasicType>(type);
+ SpvOp opCode = _specConstantOpcodeConvert(inst->getOp(), basicType);
+ if (opCode == SpvOpUndef)
+ {
+ String e = "Unhandled inst in spirv-emit:\n" +
+ dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0});
+ SLANG_UNIMPLEMENTED_X(e.getBuffer());
+ }
+
+ Array<SpvInst*, 3> operands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ SpvInst* spvInst = emitSpecializationConstantOp(operand);
+ operands.add(spvInst);
+ }
+
+ auto resultType = inst->getFullType();
+ return emitInst(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpSpecConstantOp,
+ resultType,
+ kResultID,
+ opCode,
+ operands);
+ }
+
/// Emit a global parameter definition.
SpvInst* emitGlobalParam(IRGlobalParam* param)
{
@@ -7197,117 +7402,13 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
- bool isFloatingPoint = false;
- bool isBool = false;
- switch (basicType->getBaseType())
- {
- case BaseType::Float:
- case BaseType::Double:
- case BaseType::Half:
- isFloatingPoint = true;
- break;
- case BaseType::Bool:
- isBool = true;
- break;
- default:
- break;
- }
- SpvOp opCode = SpvOpUndef;
- bool isSigned = isSignedType(basicType);
- switch (op)
- {
- case kIROp_Add:
- opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
- break;
- case kIROp_Sub:
- opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
- break;
- case kIROp_Mul:
- opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
- break;
- case kIROp_Div:
- opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
- break;
- case kIROp_IRem:
- opCode = isSigned ? SpvOpSRem : SpvOpUMod;
- break;
- case kIROp_FRem:
- opCode = SpvOpFRem;
- break;
- case kIROp_Less:
- opCode = isFloatingPoint ? SpvOpFOrdLessThan
- : isSigned ? SpvOpSLessThan
- : SpvOpULessThan;
- break;
- case kIROp_Leq:
- opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
- : isSigned ? SpvOpSLessThanEqual
- : SpvOpULessThanEqual;
- break;
- case kIROp_Eql:
- opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
- break;
- case kIROp_Neq:
- opCode = isFloatingPoint ? SpvOpFUnordNotEqual
- : isBool ? SpvOpLogicalNotEqual
- : SpvOpINotEqual;
- break;
- case kIROp_Geq:
- opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
- : isSigned ? SpvOpSGreaterThanEqual
- : SpvOpUGreaterThanEqual;
- break;
- case kIROp_Greater:
- opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
- : isSigned ? SpvOpSGreaterThan
- : SpvOpUGreaterThan;
- break;
- case kIROp_Neg:
- opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
- break;
- case kIROp_And:
- opCode = SpvOpLogicalAnd;
- break;
- case kIROp_Or:
- opCode = SpvOpLogicalOr;
- break;
- case kIROp_Not:
- opCode = SpvOpLogicalNot;
- break;
- case kIROp_BitAnd:
- if (isBool)
- opCode = SpvOpLogicalAnd;
- else
- opCode = SpvOpBitwiseAnd;
- break;
- case kIROp_BitOr:
- if (isBool)
- opCode = SpvOpLogicalOr;
- else
- opCode = SpvOpBitwiseOr;
- break;
- case kIROp_BitXor:
- if (isBool)
- opCode = SpvOpLogicalNotEqual;
- else
- opCode = SpvOpBitwiseXor;
- break;
- case kIROp_BitNot:
- if (isBool)
- opCode = SpvOpLogicalNot;
- else
- opCode = SpvOpNot;
- break;
- case kIROp_Rsh:
- opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
- break;
- case kIROp_Lsh:
- opCode = SpvOpShiftLeftLogical;
- break;
- default:
+
+ SpvOp opCode = _arithmeticOpCodeConvert(op, basicType);
+ if (opCode == SpvOpUndef)
SLANG_ASSERT(!"unknown arithmetic opcode");
- break;
- }
+
+ bool isFloatingPoint = (getTypeStyle(basicType->getBaseType()) == kIROp_FloatType);
+
if (operandCount == 1)
{
return emitInst(parent, instToRegister, opCode, type, kResultID, operands);
@@ -7846,7 +7947,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
emitDebugType(arrayType->getElementType()),
sizedArrayType ? builder.getIntValue(
builder.getUIntType(),
- getIntVal(sizedArrayType->getElementCount()))
+ getArraySizeVal(sizedArrayType->getElementCount()))
: builder.getIntValue(builder.getUIntType(), 0));
}
else if (auto vectorType = as<IRVectorType>(type))
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 5a62c8063..f863858e4 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -95,6 +95,7 @@ INST(Nop, nop, 0, 0)
/* Rate */
INST(ConstExprRate, ConstExpr, 0, HOISTABLE)
+ INST(SpecConstRate, SpecConst, 0, HOISTABLE)
INST(GroupSharedRate, GroupShared, 0, HOISTABLE)
INST(ActualGlobalRate, ActualGlobalRate, 0, HOISTABLE)
INST_RANGE(Rate, ConstExprRate, GroupSharedRate)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 268929fb9..3280dc35c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3938,6 +3938,7 @@ public:
IRConstExprRate* getConstExprRate();
IRGroupSharedRate* getGroupSharedRate();
IRActualGlobalRate* getActualGlobalRate();
+ IRSpecConstRate* getSpecConstRate();
IRRateQualifiedType* getRateQualifiedType(IRRate* rate, IRType* dataType);
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 6f0e22a57..1294b400d 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -309,7 +309,7 @@ struct LoweredElementTypeContext
builder.emitBlock();
auto packedParam = builder.emitParam(refStructType);
auto packedArray = builder.emitFieldAddress(packedParam, dataKey);
- auto count = getIntVal(arrayType->getElementCount());
+ auto count = getArraySizeVal(arrayType->getElementCount());
IRInst* result = nullptr;
if (count <= kMaxArraySizeToUnroll)
{
@@ -374,7 +374,7 @@ struct LoweredElementTypeContext
builder.emitBlock();
auto outParam = builder.emitParam(outLoweredType);
auto originalParam = builder.emitParam(arrayType);
- auto count = getIntVal(arrayType->getElementCount());
+ auto count = getArraySizeVal(arrayType->getElementCount());
auto destArray = builder.emitFieldAddress(outParam, arrayStructKey);
if (count <= kMaxArraySizeToUnroll)
{
@@ -602,7 +602,8 @@ struct LoweredElementTypeContext
StringBuilder nameSB;
nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
getTypeNameHint(nameSB, arrayType->getElementType());
- nameSB << getIntVal(arrayType->getElementCount());
+ nameSB << getArraySizeVal(arrayType->getElementCount());
+
builder.addNameHintDecoration(
loweredType,
nameSB.produceString().getUnownedSlice());
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 5aae53747..9d8773237 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2250,4 +2250,27 @@ bool isFirstBlock(IRInst* inst)
return block->getParent()->getFirstBlock() == block;
}
+bool isSpecConstRateType(IRType* type)
+{
+ if (auto rateQualifiedType = as<IRRateQualifiedType>(type))
+ {
+ if (as<IRSpecConstRate>(rateQualifiedType->getRate()))
+ {
+ return true;
+ }
+ }
+ return false;
+}
+void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst)
+{
+ IRInst* moduleInst = builder->getModule()->getModuleInst();
+ UInt operandCount = inst->getOperandCount();
+ for (UInt ii = 0; ii < operandCount; ++ii)
+ {
+ auto operand = inst->getOperand(ii);
+ if (operand->parent != moduleInst)
+ hoistInstAndOperandsToGlobal(builder, operand);
+ }
+ inst->insertAt(IRInsertLoc::atStart(moduleInst));
+}
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index b111f8abf..900e22c76 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -388,6 +388,9 @@ void legalizeDefUse(IRGlobalValueWithCode* func);
UnownedStringSlice getMangledName(IRInst* inst);
bool isFirstBlock(IRInst* inst);
+
+bool isSpecConstRateType(IRType* type);
+void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst);
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index fb7d752d5..9c4cb98c0 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -314,6 +314,23 @@ IRIntegerValue getIntVal(IRInst* inst)
}
}
+IRIntegerValue getArraySizeVal(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_IntLit:
+ return static_cast<IRConstant*>(inst)->value.intVal;
+ break;
+ default:
+ // Treat specialization constant array as the unsized array here.
+ if (isSpecConstRateType(inst->getFullType()))
+ return kUnsizedArrayMagicLength;
+
+ SLANG_UNEXPECTED("needed a known integer value");
+ UNREACHABLE_RETURN(0);
+ }
+}
+
// IRCapabilitySet
CapabilitySet IRCapabilitySet::getCaps()
@@ -3194,6 +3211,10 @@ IRActualGlobalRate* IRBuilder::getActualGlobalRate()
{
return (IRActualGlobalRate*)getType(kIROp_ActualGlobalRate);
}
+IRSpecConstRate* IRBuilder::getSpecConstRate()
+{
+ return (IRSpecConstRate*)getType(kIROp_SpecConstRate);
+}
IRRateQualifiedType* IRBuilder::getRateQualifiedType(IRRate* rate, IRType* dataType)
{
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 91c2f018a..461ed567a 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1162,6 +1162,11 @@ struct IRBoolLit : IRConstant
// if it has one, and assert-fail otherwise.
IRIntegerValue getIntVal(IRInst* inst);
+// If it's a specialization constant sized array or unsized array, returns
+// kUnsizedArrayMagicLength if it's an unsized array. Otherwise just returns
+// the actual size.
+IRIntegerValue getArraySizeVal(IRInst* inst);
+
struct IRStringLit : IRConstant
{
@@ -1644,6 +1649,7 @@ struct IRAtomicType : IRType
SIMPLE_IR_PARENT_TYPE(Rate, Type)
SIMPLE_IR_TYPE(ConstExprRate, Rate)
+SIMPLE_IR_TYPE(SpecConstRate, Rate)
SIMPLE_IR_TYPE(GroupSharedRate, Rate)
SIMPLE_IR_TYPE(ActualGlobalRate, Rate)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index a21c93f06..af285f221 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -492,6 +492,8 @@ struct SharedIRGenContext
Dictionary<SourceFile*, IRInst*> mapSourceFileToDebugSourceInst;
Dictionary<String, IRInst*> mapSourcePathToDebugSourceInst;
+ Dictionary<IntVal*, IRInst*> mapSpecConstValToIRInst;
+
void setGlobalValue(Decl* decl, LoweredValInfo value)
{
globalEnv.mapDeclToValue[decl] = value;
@@ -1552,6 +1554,14 @@ static bool _isTrivialLookupFromInterfaceThis(IRGenContext* context, DeclRefBase
//
+static void maybePropagateRate(IRBuilder* builder, IRType* rateQulifiedType, IRInst* inst)
+{
+ if (isSpecConstRateType(rateQulifiedType))
+ {
+ inst->setFullType(
+ builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType()));
+ }
+}
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo>
{
@@ -1565,7 +1575,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
UNREACHABLE_RETURN(LoweredValInfo());
}
- LoweredValInfo visitGenericParamIntVal(GenericParamIntVal* val)
+ LoweredValInfo visitDeclRefIntVal(DeclRefIntVal* val)
{
return emitDeclRef(
context,
@@ -1577,27 +1587,35 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
{
TryClauseEnvironment tryEnv;
List<IRInst*> args;
+ IRType* specConstRateType = nullptr;
for (auto arg : val->getArgs())
{
auto loweredArg = lowerVal(context, arg);
args.add(loweredArg.val);
+ if (!specConstRateType && isSpecConstRateType(loweredArg.val->getFullType()))
+ specConstRateType = loweredArg.val->getFullType();
}
auto funcType = lowerType(context, val->getFuncType());
- return emitCallToDeclRef(
+ auto resVal = emitCallToDeclRef(
context,
as<IRFuncType>(funcType)->getResultType(),
val->getFuncDeclRef(),
funcType,
args,
tryEnv);
+ maybePropagateRate(getBuilder(), specConstRateType, resVal.val);
+ return resVal;
}
LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val)
{
auto baseVal = lowerVal(context, val->getBase());
+
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
auto type = lowerType(context, val->getType());
- return LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
+ auto resVal = LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
+ maybePropagateRate(getBuilder(), baseVal.val->getFullType(), resVal.val);
+ return resVal;
}
LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val)
@@ -1625,8 +1643,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
{
termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal);
}
+ maybePropagateRate(getBuilder(), factorVal->getFullType(), termVal);
}
resultVal = irBuilder->emitAdd(termVal->getDataType(), resultVal, termVal);
+ maybePropagateRate(getBuilder(), termVal->getFullType(), resultVal);
}
return LoweredValInfo::simple(resultVal);
}
@@ -2056,7 +2076,18 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto elementType = lowerType(context, type->getElementType());
if (!type->isUnsized())
{
- auto elementCount = lowerSimpleVal(context, type->getElementCount());
+ IRInst* elementCount = nullptr;
+ auto sizeVal = type->getElementCount();
+ auto sharedContext = context->shared;
+ if (!sharedContext->mapSpecConstValToIRInst.tryGetValue(sizeVal, elementCount))
+ {
+ elementCount = lowerSimpleVal(context, sizeVal);
+ if (isSpecConstRateType(elementCount->getFullType()))
+ {
+ sharedContext->mapSpecConstValToIRInst.add(sizeVal, elementCount);
+ hoistInstAndOperandsToGlobal(getBuilder(), elementCount);
+ }
+ }
return getBuilder()->getArrayType(elementType, elementCount);
}
else
@@ -2446,6 +2477,13 @@ void maybeSetRate(IRGenContext* context, IRInst* inst, Decl* decl)
inst->setFullType(
builder->getRateQualifiedType(builder->getActualGlobalRate(), inst->getFullType()));
}
+ else if (
+ decl->hasModifier<SpecializationConstantAttribute>() ||
+ decl->hasModifier<VkConstantIdAttribute>())
+ {
+ inst->setFullType(
+ builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType()));
+ }
}
static String getNameForNameHint(IRGenContext* context, Decl* decl)
@@ -11846,9 +11884,15 @@ RefPtr<IRModule> generateIRForTranslationUnit(
}
#if 0
+ if (compileRequest->optionSet.shouldDumpIR())
{
DiagnosticSinkWriter writer(compileRequest->getSink());
- dumpIR(module, &writer, "GENERATED");
+ dumpIR(
+ module,
+ compileRequest->m_irDumpOptions,
+ "GENERATED",
+ compileRequest->getSourceManager(),
+ &writer);
}
#endif
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index f08ffd75d..056c7accb 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -323,7 +323,7 @@ void emitVal(ManglingContext* context, Val* val)
// to mangle in the constraints even when
// the whole thing is specialized...
}
- else if (auto genericParamIntVal = dynamicCast<GenericParamIntVal>(val))
+ else if (auto genericParamIntVal = dynamicCast<DeclRefIntVal>(val))
{
// TODO: we shouldn't be including the names of generic parameters
// anywhere in mangled names, since changing parameter names
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 065c2c3f6..258266da5 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -3906,7 +3906,7 @@ SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(
auto valueParamDeclRef = convert(valueParam);
- Val* valResult = astBuilder->getOrCreate<GenericParamIntVal>(
+ Val* valResult = astBuilder->getOrCreate<DeclRefIntVal>(
valueParamDeclRef.substitute(
astBuilder,
as<GenericValueParamDecl>(valueParamDeclRef.getDecl())->getType()),
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 9dea0f167..5f2cf4ddf 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -2336,7 +2336,7 @@ static LayoutSize GetElementCount(IntVal* val)
return LayoutSize::infinite();
return LayoutSize(LayoutSize::RawValue(constantVal->getValue()));
}
- else if (const auto varRefVal = as<GenericParamIntVal>(val))
+ else if (const auto varRefVal = as<DeclRefIntVal>(val))
{
// TODO: We want to treat the case where the number of
// elements in an array depends on a generic parameter
@@ -2352,6 +2352,10 @@ static LayoutSize GetElementCount(IntVal* val)
{
return 0;
}
+ else if (as<FuncCallIntVal>(val))
+ {
+ return 0;
+ }
SLANG_UNEXPECTED("unhandled integer literal kind");
UNREACHABLE_RETURN(LayoutSize(0));
}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 602446cda..67d13c34b 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -6079,7 +6079,7 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor
{
collectReferencedModules(type);
}
- else if (auto declRefVal = as<GenericParamIntVal>(val))
+ else if (auto declRefVal = as<DeclRefIntVal>(val))
{
collectReferencedModules(declRefVal->getDeclRef());
}