summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-builder.cpp1
-rw-r--r--source/slang/slang-ast-decl-ref.cpp2
-rw-r--r--source/slang/slang-ast-val.cpp10
-rw-r--r--source/slang/slang-ast-val.h17
-rw-r--r--source/slang/slang-check-constraint.cpp4
-rw-r--r--source/slang/slang-check-decl.cpp6
-rw-r--r--source/slang/slang-check-expr.cpp17
-rw-r--r--source/slang/slang-check-impl.h1
-rw-r--r--source/slang/slang-check-type.cpp4
-rw-r--r--source/slang/slang-doc-markdown-writer.cpp2
-rw-r--r--source/slang/slang-emit-spirv.cpp325
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp7
-rw-r--r--source/slang/slang-ir-util.cpp23
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--source/slang/slang-ir.cpp21
-rw-r--r--source/slang/slang-ir.h6
-rw-r--r--source/slang/slang-lower-to-ir.cpp54
-rw-r--r--source/slang/slang-mangle.cpp2
-rw-r--r--source/slang/slang-reflection-api.cpp2
-rw-r--r--source/slang/slang-type-layout.cpp6
-rw-r--r--source/slang/slang.cpp2
23 files changed, 363 insertions, 154 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 5abef94b3..893d5e6d7 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -347,6 +347,7 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element
{
if (!elementCount)
elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength);
+
if (elementCount->getType() != getIntType())
{
// Canonicalize constant elementCount to int.
diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp
index 1881f1b3c..89fa52b09 100644
--- a/source/slang/slang-ast-decl-ref.cpp
+++ b/source/slang/slang-ast-decl-ref.cpp
@@ -41,7 +41,7 @@ DeclRefBase* _getDeclRefFromVal(Val* val)
{
if (auto declRefType = as<DeclRefType>(val))
return declRefType->getDeclRef();
- else if (auto genParamIntVal = as<GenericParamIntVal>(val))
+ else if (auto genParamIntVal = as<DeclRefIntVal>(val))
return genParamIntVal->getDeclRef();
else if (auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(val))
return declaredSubtypeWitness->getDeclRef();
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index efb87b831..1cdca0440 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -176,9 +176,9 @@ void ConstantIntVal::_toTextOverride(StringBuilder& out)
out << getValue();
}
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-void GenericParamIntVal::_toTextOverride(StringBuilder& out)
+void DeclRefIntVal::_toTextOverride(StringBuilder& out)
{
Name* name = getDeclRef().getName();
if (name)
@@ -248,7 +248,7 @@ Val* maybeSubstituteGenericParam(Val* paramVal, Decl* paramDecl, SubstitutionSet
return paramVal;
}
-Val* GenericParamIntVal::_substituteImplOverride(
+Val* DeclRefIntVal::_substituteImplOverride(
ASTBuilder* /* astBuilder */,
SubstitutionSet subst,
int* ioDiff)
@@ -259,12 +259,12 @@ Val* GenericParamIntVal::_substituteImplOverride(
return this;
}
-bool GenericParamIntVal::_isLinkTimeValOverride()
+bool DeclRefIntVal::_isLinkTimeValOverride()
{
return getDeclRef().getDecl()->hasModifier<ExternModifier>();
}
-Val* GenericParamIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map)
+Val* DeclRefIntVal::_linkTimeResolveOverride(Dictionary<String, IntVal*>& map)
{
auto name = getMangledName(getCurrentASTBuilder(), getDeclRef().declRefBase);
IntVal* v;
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index cdfb0b51f..2b4c7ed22 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -168,7 +168,7 @@ class ConstantIntVal : public IntVal
// The logical "value" of a reference to a generic value parameter
FIDDLE()
-class GenericParamIntVal : public IntVal
+class DeclRefIntVal : public IntVal
{
FIDDLE(...)
DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(1)); }
@@ -177,10 +177,7 @@ class GenericParamIntVal : public IntVal
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
- GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef)
- {
- setOperands(inType, inDeclRef);
- }
+ DeclRefIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef) { setOperands(inType, inDeclRef); }
bool _isLinkTimeValOverride();
Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map);
@@ -319,9 +316,9 @@ public:
// for sorting only.
bool operator<(const PolynomialIntValFactor& other) const
{
- if (auto thisGenParam = as<GenericParamIntVal>(getParam()))
+ if (auto thisGenParam = as<DeclRefIntVal>(getParam()))
{
- if (auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
+ if (auto thatGenParam = as<DeclRefIntVal>(other.getParam()))
{
if (thisGenParam->equals(thatGenParam))
return getPower() < other.getPower();
@@ -336,7 +333,7 @@ public:
}
else
{
- if (const auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
+ if (const auto thatGenParam = as<DeclRefIntVal>(other.getParam()))
{
return false;
}
@@ -347,9 +344,9 @@ public:
// for sorting only.
bool operator==(const PolynomialIntValFactor& other) const
{
- if (auto thisGenParam = as<GenericParamIntVal>(getParam()))
+ if (auto thisGenParam = as<DeclRefIntVal>(getParam()))
{
- if (auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
+ if (auto thatGenParam = as<DeclRefIntVal>(other.getParam()))
{
if (thisGenParam->equals(thatGenParam) && getPower() == other.getPower())
return true;
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 642a4bf6a..6f9191135 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -819,7 +819,7 @@ bool SemanticsVisitor::TryUnifyVals(
{
if (const auto c = as<TypeCastIntVal>(i))
i = as<IntVal>(c->getBase());
- return as<GenericParamIntVal>(i);
+ return as<DeclRefIntVal>(i);
};
auto fstParam = paramUnderCast(fstInt);
auto sndParam = paramUnderCast(sndInt);
@@ -1196,7 +1196,7 @@ void SemanticsVisitor::maybeUnifyUnconstraintIntParam(
{
param = as<IntVal>(typeCastParam->getBase());
}
- auto intParam = as<GenericParamIntVal>(param);
+ auto intParam = as<DeclRefIntVal>(param);
if (!intParam)
return;
for (auto c : constraints.constraints)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1e524e27f..dbd52ebea 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3921,7 +3921,7 @@ bool SemanticsVisitor::doesGenericSignatureMatchRequirement(
auto satisfyingValueParamDeclRef = satisfyingMemberDeclRef.as<GenericValueParamDecl>();
SLANG_ASSERT(satisfyingValueParamDeclRef);
- auto satisfyingVal = m_astBuilder->getOrCreate<GenericParamIntVal>(
+ auto satisfyingVal = m_astBuilder->getOrCreate<DeclRefIntVal>(
requiredValueParamDeclRef.getDecl()->getType(),
satisfyingValueParamDeclRef);
satisfyingVal->getDeclRef() = satisfyingValueParamDeclRef;
@@ -8513,7 +8513,7 @@ List<Val*> getDefaultSubstitutionArgs(
if (semantics)
semantics->ensureDecl(genericValueParamDecl, DeclCheckState::ReadyForLookup);
- args.add(astBuilder->getOrCreate<GenericParamIntVal>(
+ args.add(astBuilder->getOrCreate<DeclRefIntVal>(
genericValueParamDecl->getType(),
astBuilder->getDirectDeclRef(genericValueParamDecl)));
}
@@ -11769,7 +11769,7 @@ void checkDerivativeAttributeImpl(
appExpr->arguments.add(baseTypeExpr);
}
- else if (auto genericValParam = as<GenericParamIntVal>(arg))
+ else if (auto genericValParam = as<DeclRefIntVal>(arg))
{
auto declRef = genericValParam->getDeclRef();
appExpr->arguments.add(
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 48f32952b..d151d37be 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1972,9 +1972,14 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef(
// The values of specialization constants aren't known at compile time even
// if they're marked `const`.
- if (decl->hasModifier<SpecializationConstantAttribute>() ||
- decl->hasModifier<VkConstantIdAttribute>())
- return nullptr;
+ if ((decl->hasModifier<SpecializationConstantAttribute>() ||
+ decl->hasModifier<VkConstantIdAttribute>()) &&
+ kind == ConstantFoldingKind::SpecializationConstant)
+ {
+ return m_astBuilder->getOrCreate<DeclRefIntVal>(
+ declRef.substitute(m_astBuilder, declRef.getDecl()->getType()),
+ declRef);
+ }
if (decl->hasModifier<ExternModifier>())
{
@@ -1982,7 +1987,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef(
if (kind == ConstantFoldingKind::CompileTime)
return nullptr;
// But if we are OK with link-time constants, we can still fold it into a val.
- auto rs = m_astBuilder->getOrCreate<GenericParamIntVal>(
+ auto rs = m_astBuilder->getOrCreate<DeclRefIntVal>(
declRef.substitute(m_astBuilder, declRef.getDecl()->getType()),
declRef);
return rs;
@@ -2067,7 +2072,7 @@ IntVal* SemanticsVisitor::tryConstantFoldExpr(
if (auto genericValParamRef = declRef.as<GenericValueParamDecl>())
{
- Val* valResult = m_astBuilder->getOrCreate<GenericParamIntVal>(
+ Val* valResult = m_astBuilder->getOrCreate<DeclRefIntVal>(
declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()),
genericValParamRef);
valResult = valResult->substitute(m_astBuilder, expr.getSubsts());
@@ -2383,7 +2388,7 @@ Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr)
subscriptExpr->indexExprs[0],
IntegerConstantExpressionCoercionType::AnyInteger,
nullptr,
- ConstantFoldingKind::LinkTime);
+ ConstantFoldingKind::SpecializationConstant);
// Validate that array size is greater than zero
if (auto constElementCount = as<ConstantIntVal>(elementCount))
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index a910a3722..6c9a0409d 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2110,6 +2110,7 @@ public:
{
CompileTime,
LinkTime,
+ SpecializationConstant
};
Expr* checkExpressionAndExpectIntegerConstant(
Expr* expr,
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index db753713b..172d09ac2 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -453,9 +453,9 @@ bool SemanticsVisitor::ValuesAreEqual(IntVal* left, IntVal* right)
}
}
- if (auto leftVar = as<GenericParamIntVal>(left))
+ if (auto leftVar = as<DeclRefIntVal>(left))
{
- if (auto rightVar = as<GenericParamIntVal>(right))
+ if (auto rightVar = as<DeclRefIntVal>(right))
{
return leftVar->getDeclRef().equals(rightVar->getDeclRef());
}
diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp
index d2e68ccc8..50fd739cb 100644
--- a/source/slang/slang-doc-markdown-writer.cpp
+++ b/source/slang/slang-doc-markdown-writer.cpp
@@ -782,7 +782,7 @@ void DocMarkdownWriter::writeExtensionConditions(
{
genericParamDecl = extTypeParamDecl.getDecl();
}
- else if (auto extValueParamVal = as<GenericParamIntVal>(arg))
+ else if (auto extValueParamVal = as<DeclRefIntVal>(arg))
{
genericParamDecl = extValueParamVal->getDeclRef().getDecl();
}
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 096e7d8bc..32d3ba7c3 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -775,6 +775,147 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
m_operandStack.setCount(operandsStartIndex);
}
+ SpvOp _specConstantOpcodeConvert(IROp irOpCode, IRBasicType* basicType)
+ {
+ SpvOp opCode = SpvOpUndef;
+ opCode = _arithmeticOpCodeConvert(irOpCode, basicType);
+ if (opCode == SpvOpUndef)
+ {
+ switch (irOpCode)
+ {
+ case kIROp_IntCast:
+ {
+ auto typeStyle = getTypeStyle(basicType->getBaseType());
+ if (typeStyle == kIROp_FloatType)
+ {
+ return SpvOpConvertFToU;
+ }
+ else if (typeStyle == kIROp_IntType)
+ {
+ return SpvOpUConvert;
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return opCode;
+ }
+ return opCode;
+ }
+
+ SpvOp _arithmeticOpCodeConvert(IROp irOpCode, IRBasicType* basicType)
+ {
+ bool isFloatingPoint = false;
+ bool isBool = false;
+ switch (basicType->getBaseType())
+ {
+ case BaseType::Float:
+ case BaseType::Double:
+ case BaseType::Half:
+ isFloatingPoint = true;
+ break;
+ case BaseType::Bool:
+ isBool = true;
+ break;
+ default:
+ break;
+ }
+ bool isSigned = isSignedType(basicType);
+ SpvOp opCode = SpvOpUndef;
+ switch (irOpCode)
+ {
+ case kIROp_Add:
+ opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
+ break;
+ case kIROp_Sub:
+ opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
+ break;
+ case kIROp_Mul:
+ opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
+ break;
+ case kIROp_Div:
+ opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
+ break;
+ case kIROp_IRem:
+ opCode = isSigned ? SpvOpSRem : SpvOpUMod;
+ break;
+ case kIROp_FRem:
+ opCode = SpvOpFRem;
+ break;
+ case kIROp_Less:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThan
+ : isSigned ? SpvOpSLessThan
+ : SpvOpULessThan;
+ break;
+ case kIROp_Leq:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
+ : isSigned ? SpvOpSLessThanEqual
+ : SpvOpULessThanEqual;
+ break;
+ case kIROp_Eql:
+ opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
+ break;
+ case kIROp_Neq:
+ opCode = isFloatingPoint ? SpvOpFUnordNotEqual
+ : isBool ? SpvOpLogicalNotEqual
+ : SpvOpINotEqual;
+ break;
+ case kIROp_Geq:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
+ : isSigned ? SpvOpSGreaterThanEqual
+ : SpvOpUGreaterThanEqual;
+ break;
+ case kIROp_Greater:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
+ : isSigned ? SpvOpSGreaterThan
+ : SpvOpUGreaterThan;
+ break;
+ case kIROp_Neg:
+ opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
+ break;
+ case kIROp_And:
+ opCode = SpvOpLogicalAnd;
+ break;
+ case kIROp_Or:
+ opCode = SpvOpLogicalOr;
+ break;
+ case kIROp_Not:
+ opCode = SpvOpLogicalNot;
+ break;
+ case kIROp_BitAnd:
+ if (isBool)
+ opCode = SpvOpLogicalAnd;
+ else
+ opCode = SpvOpBitwiseAnd;
+ break;
+ case kIROp_BitOr:
+ if (isBool)
+ opCode = SpvOpLogicalOr;
+ else
+ opCode = SpvOpBitwiseOr;
+ break;
+ case kIROp_BitXor:
+ if (isBool)
+ opCode = SpvOpLogicalNotEqual;
+ else
+ opCode = SpvOpBitwiseXor;
+ break;
+ case kIROp_BitNot:
+ if (isBool)
+ opCode = SpvOpLogicalNot;
+ else
+ opCode = SpvOpNot;
+ break;
+ case kIROp_Rsh:
+ opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
+ break;
+ case kIROp_Lsh:
+ opCode = SpvOpShiftLeftLogical;
+ break;
+ }
+ return opCode;
+ }
/// Ensure that an instruction has been emitted
SpvInst* ensureInst(IRInst* irInst)
{
@@ -1972,8 +2113,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
as<IRDebugInlinedAt>(inst));
default:
{
- if (as<IRSPIRVAsmOperand>(inst))
+ if (isSpecConstRateType(inst->getFullType()))
+ return emitSpecializationConstantOp(inst);
+
+ else if (as<IRSPIRVAsmOperand>(inst))
return nullptr;
+
String e = "Unhandled global inst in spirv-emit:\n" +
dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0});
SLANG_UNIMPLEMENTED_X(e.begin());
@@ -2756,6 +2901,66 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ SpvInst* emitSpecializationConstantOp(IRInst* inst)
+ {
+ SpvInst* spv = nullptr;
+ if (m_mapIRInstToSpvInst.tryGetValue(inst, spv))
+ return spv;
+
+ // For each OpSpecConstantOp, the operand must be:
+ // 1. A specialization constant
+ // 2. A literal constant
+ // 3. Another OpSpecConstantOp
+
+ // For 1 and 2, we can just emit the specialization constant or literal constant.
+ if (auto param = as<IRGlobalParam>(inst))
+ {
+ auto layout = getVarLayout(param);
+ if (layout)
+ {
+ if (auto offset =
+ layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
+ {
+ return emitSpecializationConstant(param, offset);
+ }
+ }
+ SLANG_UNREACHABLE("Non specialization constant used in OpSpecConstantOp\n");
+ }
+ else if (as<IRConstant>(inst))
+ {
+ // We need to emit the constant as a specialization constant
+ return emitLit(inst);
+ }
+
+ IRType* type = inst->getOperand(0)->getDataType();
+ IRBasicType* basicType = as<IRBasicType>(type);
+ SpvOp opCode = _specConstantOpcodeConvert(inst->getOp(), basicType);
+ if (opCode == SpvOpUndef)
+ {
+ String e = "Unhandled inst in spirv-emit:\n" +
+ dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0});
+ SLANG_UNIMPLEMENTED_X(e.getBuffer());
+ }
+
+ Array<SpvInst*, 3> operands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ SpvInst* spvInst = emitSpecializationConstantOp(operand);
+ operands.add(spvInst);
+ }
+
+ auto resultType = inst->getFullType();
+ return emitInst(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpSpecConstantOp,
+ resultType,
+ kResultID,
+ opCode,
+ operands);
+ }
+
/// Emit a global parameter definition.
SpvInst* emitGlobalParam(IRGlobalParam* param)
{
@@ -7197,117 +7402,13 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
- bool isFloatingPoint = false;
- bool isBool = false;
- switch (basicType->getBaseType())
- {
- case BaseType::Float:
- case BaseType::Double:
- case BaseType::Half:
- isFloatingPoint = true;
- break;
- case BaseType::Bool:
- isBool = true;
- break;
- default:
- break;
- }
- SpvOp opCode = SpvOpUndef;
- bool isSigned = isSignedType(basicType);
- switch (op)
- {
- case kIROp_Add:
- opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
- break;
- case kIROp_Sub:
- opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
- break;
- case kIROp_Mul:
- opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
- break;
- case kIROp_Div:
- opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
- break;
- case kIROp_IRem:
- opCode = isSigned ? SpvOpSRem : SpvOpUMod;
- break;
- case kIROp_FRem:
- opCode = SpvOpFRem;
- break;
- case kIROp_Less:
- opCode = isFloatingPoint ? SpvOpFOrdLessThan
- : isSigned ? SpvOpSLessThan
- : SpvOpULessThan;
- break;
- case kIROp_Leq:
- opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
- : isSigned ? SpvOpSLessThanEqual
- : SpvOpULessThanEqual;
- break;
- case kIROp_Eql:
- opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
- break;
- case kIROp_Neq:
- opCode = isFloatingPoint ? SpvOpFUnordNotEqual
- : isBool ? SpvOpLogicalNotEqual
- : SpvOpINotEqual;
- break;
- case kIROp_Geq:
- opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
- : isSigned ? SpvOpSGreaterThanEqual
- : SpvOpUGreaterThanEqual;
- break;
- case kIROp_Greater:
- opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
- : isSigned ? SpvOpSGreaterThan
- : SpvOpUGreaterThan;
- break;
- case kIROp_Neg:
- opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
- break;
- case kIROp_And:
- opCode = SpvOpLogicalAnd;
- break;
- case kIROp_Or:
- opCode = SpvOpLogicalOr;
- break;
- case kIROp_Not:
- opCode = SpvOpLogicalNot;
- break;
- case kIROp_BitAnd:
- if (isBool)
- opCode = SpvOpLogicalAnd;
- else
- opCode = SpvOpBitwiseAnd;
- break;
- case kIROp_BitOr:
- if (isBool)
- opCode = SpvOpLogicalOr;
- else
- opCode = SpvOpBitwiseOr;
- break;
- case kIROp_BitXor:
- if (isBool)
- opCode = SpvOpLogicalNotEqual;
- else
- opCode = SpvOpBitwiseXor;
- break;
- case kIROp_BitNot:
- if (isBool)
- opCode = SpvOpLogicalNot;
- else
- opCode = SpvOpNot;
- break;
- case kIROp_Rsh:
- opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
- break;
- case kIROp_Lsh:
- opCode = SpvOpShiftLeftLogical;
- break;
- default:
+
+ SpvOp opCode = _arithmeticOpCodeConvert(op, basicType);
+ if (opCode == SpvOpUndef)
SLANG_ASSERT(!"unknown arithmetic opcode");
- break;
- }
+
+ bool isFloatingPoint = (getTypeStyle(basicType->getBaseType()) == kIROp_FloatType);
+
if (operandCount == 1)
{
return emitInst(parent, instToRegister, opCode, type, kResultID, operands);
@@ -7846,7 +7947,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
emitDebugType(arrayType->getElementType()),
sizedArrayType ? builder.getIntValue(
builder.getUIntType(),
- getIntVal(sizedArrayType->getElementCount()))
+ getArraySizeVal(sizedArrayType->getElementCount()))
: builder.getIntValue(builder.getUIntType(), 0));
}
else if (auto vectorType = as<IRVectorType>(type))
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 5a62c8063..f863858e4 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -95,6 +95,7 @@ INST(Nop, nop, 0, 0)
/* Rate */
INST(ConstExprRate, ConstExpr, 0, HOISTABLE)
+ INST(SpecConstRate, SpecConst, 0, HOISTABLE)
INST(GroupSharedRate, GroupShared, 0, HOISTABLE)
INST(ActualGlobalRate, ActualGlobalRate, 0, HOISTABLE)
INST_RANGE(Rate, ConstExprRate, GroupSharedRate)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 268929fb9..3280dc35c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3938,6 +3938,7 @@ public:
IRConstExprRate* getConstExprRate();
IRGroupSharedRate* getGroupSharedRate();
IRActualGlobalRate* getActualGlobalRate();
+ IRSpecConstRate* getSpecConstRate();
IRRateQualifiedType* getRateQualifiedType(IRRate* rate, IRType* dataType);
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 6f0e22a57..1294b400d 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -309,7 +309,7 @@ struct LoweredElementTypeContext
builder.emitBlock();
auto packedParam = builder.emitParam(refStructType);
auto packedArray = builder.emitFieldAddress(packedParam, dataKey);
- auto count = getIntVal(arrayType->getElementCount());
+ auto count = getArraySizeVal(arrayType->getElementCount());
IRInst* result = nullptr;
if (count <= kMaxArraySizeToUnroll)
{
@@ -374,7 +374,7 @@ struct LoweredElementTypeContext
builder.emitBlock();
auto outParam = builder.emitParam(outLoweredType);
auto originalParam = builder.emitParam(arrayType);
- auto count = getIntVal(arrayType->getElementCount());
+ auto count = getArraySizeVal(arrayType->getElementCount());
auto destArray = builder.emitFieldAddress(outParam, arrayStructKey);
if (count <= kMaxArraySizeToUnroll)
{
@@ -602,7 +602,8 @@ struct LoweredElementTypeContext
StringBuilder nameSB;
nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
getTypeNameHint(nameSB, arrayType->getElementType());
- nameSB << getIntVal(arrayType->getElementCount());
+ nameSB << getArraySizeVal(arrayType->getElementCount());
+
builder.addNameHintDecoration(
loweredType,
nameSB.produceString().getUnownedSlice());
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 5aae53747..9d8773237 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2250,4 +2250,27 @@ bool isFirstBlock(IRInst* inst)
return block->getParent()->getFirstBlock() == block;
}
+bool isSpecConstRateType(IRType* type)
+{
+ if (auto rateQualifiedType = as<IRRateQualifiedType>(type))
+ {
+ if (as<IRSpecConstRate>(rateQualifiedType->getRate()))
+ {
+ return true;
+ }
+ }
+ return false;
+}
+void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst)
+{
+ IRInst* moduleInst = builder->getModule()->getModuleInst();
+ UInt operandCount = inst->getOperandCount();
+ for (UInt ii = 0; ii < operandCount; ++ii)
+ {
+ auto operand = inst->getOperand(ii);
+ if (operand->parent != moduleInst)
+ hoistInstAndOperandsToGlobal(builder, operand);
+ }
+ inst->insertAt(IRInsertLoc::atStart(moduleInst));
+}
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index b111f8abf..900e22c76 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -388,6 +388,9 @@ void legalizeDefUse(IRGlobalValueWithCode* func);
UnownedStringSlice getMangledName(IRInst* inst);
bool isFirstBlock(IRInst* inst);
+
+bool isSpecConstRateType(IRType* type);
+void hoistInstAndOperandsToGlobal(IRBuilder* builder, IRInst* inst);
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index fb7d752d5..9c4cb98c0 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -314,6 +314,23 @@ IRIntegerValue getIntVal(IRInst* inst)
}
}
+IRIntegerValue getArraySizeVal(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_IntLit:
+ return static_cast<IRConstant*>(inst)->value.intVal;
+ break;
+ default:
+ // Treat specialization constant array as the unsized array here.
+ if (isSpecConstRateType(inst->getFullType()))
+ return kUnsizedArrayMagicLength;
+
+ SLANG_UNEXPECTED("needed a known integer value");
+ UNREACHABLE_RETURN(0);
+ }
+}
+
// IRCapabilitySet
CapabilitySet IRCapabilitySet::getCaps()
@@ -3194,6 +3211,10 @@ IRActualGlobalRate* IRBuilder::getActualGlobalRate()
{
return (IRActualGlobalRate*)getType(kIROp_ActualGlobalRate);
}
+IRSpecConstRate* IRBuilder::getSpecConstRate()
+{
+ return (IRSpecConstRate*)getType(kIROp_SpecConstRate);
+}
IRRateQualifiedType* IRBuilder::getRateQualifiedType(IRRate* rate, IRType* dataType)
{
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 91c2f018a..461ed567a 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1162,6 +1162,11 @@ struct IRBoolLit : IRConstant
// if it has one, and assert-fail otherwise.
IRIntegerValue getIntVal(IRInst* inst);
+// If it's a specialization constant sized array or unsized array, returns
+// kUnsizedArrayMagicLength if it's an unsized array. Otherwise just returns
+// the actual size.
+IRIntegerValue getArraySizeVal(IRInst* inst);
+
struct IRStringLit : IRConstant
{
@@ -1644,6 +1649,7 @@ struct IRAtomicType : IRType
SIMPLE_IR_PARENT_TYPE(Rate, Type)
SIMPLE_IR_TYPE(ConstExprRate, Rate)
+SIMPLE_IR_TYPE(SpecConstRate, Rate)
SIMPLE_IR_TYPE(GroupSharedRate, Rate)
SIMPLE_IR_TYPE(ActualGlobalRate, Rate)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index a21c93f06..af285f221 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -492,6 +492,8 @@ struct SharedIRGenContext
Dictionary<SourceFile*, IRInst*> mapSourceFileToDebugSourceInst;
Dictionary<String, IRInst*> mapSourcePathToDebugSourceInst;
+ Dictionary<IntVal*, IRInst*> mapSpecConstValToIRInst;
+
void setGlobalValue(Decl* decl, LoweredValInfo value)
{
globalEnv.mapDeclToValue[decl] = value;
@@ -1552,6 +1554,14 @@ static bool _isTrivialLookupFromInterfaceThis(IRGenContext* context, DeclRefBase
//
+static void maybePropagateRate(IRBuilder* builder, IRType* rateQulifiedType, IRInst* inst)
+{
+ if (isSpecConstRateType(rateQulifiedType))
+ {
+ inst->setFullType(
+ builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType()));
+ }
+}
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredValInfo>
{
@@ -1565,7 +1575,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
UNREACHABLE_RETURN(LoweredValInfo());
}
- LoweredValInfo visitGenericParamIntVal(GenericParamIntVal* val)
+ LoweredValInfo visitDeclRefIntVal(DeclRefIntVal* val)
{
return emitDeclRef(
context,
@@ -1577,27 +1587,35 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
{
TryClauseEnvironment tryEnv;
List<IRInst*> args;
+ IRType* specConstRateType = nullptr;
for (auto arg : val->getArgs())
{
auto loweredArg = lowerVal(context, arg);
args.add(loweredArg.val);
+ if (!specConstRateType && isSpecConstRateType(loweredArg.val->getFullType()))
+ specConstRateType = loweredArg.val->getFullType();
}
auto funcType = lowerType(context, val->getFuncType());
- return emitCallToDeclRef(
+ auto resVal = emitCallToDeclRef(
context,
as<IRFuncType>(funcType)->getResultType(),
val->getFuncDeclRef(),
funcType,
args,
tryEnv);
+ maybePropagateRate(getBuilder(), specConstRateType, resVal.val);
+ return resVal;
}
LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val)
{
auto baseVal = lowerVal(context, val->getBase());
+
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
auto type = lowerType(context, val->getType());
- return LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
+ auto resVal = LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
+ maybePropagateRate(getBuilder(), baseVal.val->getFullType(), resVal.val);
+ return resVal;
}
LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val)
@@ -1625,8 +1643,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
{
termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal);
}
+ maybePropagateRate(getBuilder(), factorVal->getFullType(), termVal);
}
resultVal = irBuilder->emitAdd(termVal->getDataType(), resultVal, termVal);
+ maybePropagateRate(getBuilder(), termVal->getFullType(), resultVal);
}
return LoweredValInfo::simple(resultVal);
}
@@ -2056,7 +2076,18 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto elementType = lowerType(context, type->getElementType());
if (!type->isUnsized())
{
- auto elementCount = lowerSimpleVal(context, type->getElementCount());
+ IRInst* elementCount = nullptr;
+ auto sizeVal = type->getElementCount();
+ auto sharedContext = context->shared;
+ if (!sharedContext->mapSpecConstValToIRInst.tryGetValue(sizeVal, elementCount))
+ {
+ elementCount = lowerSimpleVal(context, sizeVal);
+ if (isSpecConstRateType(elementCount->getFullType()))
+ {
+ sharedContext->mapSpecConstValToIRInst.add(sizeVal, elementCount);
+ hoistInstAndOperandsToGlobal(getBuilder(), elementCount);
+ }
+ }
return getBuilder()->getArrayType(elementType, elementCount);
}
else
@@ -2446,6 +2477,13 @@ void maybeSetRate(IRGenContext* context, IRInst* inst, Decl* decl)
inst->setFullType(
builder->getRateQualifiedType(builder->getActualGlobalRate(), inst->getFullType()));
}
+ else if (
+ decl->hasModifier<SpecializationConstantAttribute>() ||
+ decl->hasModifier<VkConstantIdAttribute>())
+ {
+ inst->setFullType(
+ builder->getRateQualifiedType(builder->getSpecConstRate(), inst->getFullType()));
+ }
}
static String getNameForNameHint(IRGenContext* context, Decl* decl)
@@ -11846,9 +11884,15 @@ RefPtr<IRModule> generateIRForTranslationUnit(
}
#if 0
+ if (compileRequest->optionSet.shouldDumpIR())
{
DiagnosticSinkWriter writer(compileRequest->getSink());
- dumpIR(module, &writer, "GENERATED");
+ dumpIR(
+ module,
+ compileRequest->m_irDumpOptions,
+ "GENERATED",
+ compileRequest->getSourceManager(),
+ &writer);
}
#endif
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index f08ffd75d..056c7accb 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -323,7 +323,7 @@ void emitVal(ManglingContext* context, Val* val)
// to mangle in the constraints even when
// the whole thing is specialized...
}
- else if (auto genericParamIntVal = dynamicCast<GenericParamIntVal>(val))
+ else if (auto genericParamIntVal = dynamicCast<DeclRefIntVal>(val))
{
// TODO: we shouldn't be including the names of generic parameters
// anywhere in mangled names, since changing parameter names
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 065c2c3f6..258266da5 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -3906,7 +3906,7 @@ SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(
auto valueParamDeclRef = convert(valueParam);
- Val* valResult = astBuilder->getOrCreate<GenericParamIntVal>(
+ Val* valResult = astBuilder->getOrCreate<DeclRefIntVal>(
valueParamDeclRef.substitute(
astBuilder,
as<GenericValueParamDecl>(valueParamDeclRef.getDecl())->getType()),
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 9dea0f167..5f2cf4ddf 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -2336,7 +2336,7 @@ static LayoutSize GetElementCount(IntVal* val)
return LayoutSize::infinite();
return LayoutSize(LayoutSize::RawValue(constantVal->getValue()));
}
- else if (const auto varRefVal = as<GenericParamIntVal>(val))
+ else if (const auto varRefVal = as<DeclRefIntVal>(val))
{
// TODO: We want to treat the case where the number of
// elements in an array depends on a generic parameter
@@ -2352,6 +2352,10 @@ static LayoutSize GetElementCount(IntVal* val)
{
return 0;
}
+ else if (as<FuncCallIntVal>(val))
+ {
+ return 0;
+ }
SLANG_UNEXPECTED("unhandled integer literal kind");
UNREACHABLE_RETURN(LayoutSize(0));
}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 602446cda..67d13c34b 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -6079,7 +6079,7 @@ struct SpecializationArgModuleCollector : ComponentTypeVisitor
{
collectReferencedModules(type);
}
- else if (auto declRefVal = as<GenericParamIntVal>(val))
+ else if (auto declRefVal = as<DeclRefIntVal>(val))
{
collectReferencedModules(declRefVal->getDeclRef());
}