summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-30 12:03:23 -0700
committerGitHub <noreply@github.com>2024-08-30 12:03:23 -0700
commitde83628070614ec37349c9f334ed72a54a6889da (patch)
treebc97f74013a073fc958b75e68089696e14d71412 /source
parentf428a058ea48535a323c32d206ebc7e551c3c3e9 (diff)
Support specialization constants. (#4963)
* Support specialization constants. * Fix. * Fix. * Fix. * Fix. * Make sure specialization constants have names. * Clean up and support the dxc [vk::constant_id] syntax. * Fix. * Fix. * Fix.
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang11
-rw-r--r--source/slang/slang-ast-modifier.h18
-rw-r--r--source/slang/slang-check-impl.h4
-rw-r--r--source/slang/slang-check-modifier.cpp123
-rw-r--r--source/slang/slang-emit-c-like.cpp3
-rw-r--r--source/slang/slang-emit-c-like.h2
-rw-r--r--source/slang/slang-emit-glsl.cpp11
-rw-r--r--source/slang/slang-emit-glsl.h1
-rw-r--r--source/slang/slang-emit-spirv-ops.h24
-rw-r--r--source/slang/slang-emit-spirv.cpp141
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h6
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp9
-rw-r--r--source/slang/slang-lower-to-ir.cpp23
-rw-r--r--source/slang/slang-parameter-binding.cpp35
-rw-r--r--source/slang/slang-parser.cpp7
-rw-r--r--source/slang/slang-type-system-shared.h2
17 files changed, 325 insertions, 97 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 7d5f4087c..4e8529666 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2699,12 +2699,19 @@ attribute_syntax [vk_shader_record] : ShaderRecordAttr
__attributeTarget(VarDeclBase)
attribute_syntax [shader_record] : ShaderRecordAttribute;
-__attributeTarget(DeclBase)
+__attributeTarget(VarDeclBase)
attribute_syntax [vk_push_constant] : PushConstantAttribute;
-__attributeTarget(DeclBase)
+__attributeTarget(VarDeclBase)
attribute_syntax [push_constant] : PushConstantAttribute;
__attributeTarget(VarDeclBase)
+attribute_syntax[vk_specialization_constant] : SpecializationConstantAttribute;
+__attributeTarget(VarDeclBase)
+attribute_syntax[SpecializationConstant] : SpecializationConstantAttribute;
+__attributeTarget(VarDeclBase)
+attribute_syntax[vk_constant_id(location: int)] : VkConstantIdAttribute;
+
+__attributeTarget(VarDeclBase)
attribute_syntax [vk_location(location : int)] : GLSLLocationAttribute;
__attributeTarget(VarDeclBase)
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index a7af4a249..69f86a43d 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -235,10 +235,6 @@ class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier
// Specific cases for known GLSL `layout` modifiers that we need to work with
-class GLSLConstantIDLayoutModifier : public GLSLParsedLayoutModifier
-{
- SLANG_AST_CLASS(GLSLConstantIDLayoutModifier)
-};
class GLSLLocationLayoutModifier : public GLSLParsedLayoutModifier
{
@@ -727,11 +723,23 @@ class FlagsAttribute : public Attribute
};
// [[vk_push_constant]] [[push_constant]]
-class PushConstantAttribute : public Attribute
+class PushConstantAttribute : public Attribute
{
SLANG_AST_CLASS(PushConstantAttribute)
};
+// [[vk_specialization_constant]] [[specialization_constant]]
+class SpecializationConstantAttribute : public Attribute
+{
+ SLANG_AST_CLASS(SpecializationConstantAttribute)
+};
+
+// [[vk_constant_id]]
+class VkConstantIdAttribute : public Attribute
+{
+ SLANG_AST_CLASS(VkConstantIdAttribute)
+ int location;
+};
// [[vk_shader_record]] [[shader_record]]
class ShaderRecordAttribute : public Attribute
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 96b467b0c..caec9dcee 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1615,7 +1615,9 @@ namespace Slang
bool getAttributeTargetSyntaxClasses(SyntaxClass<NodeBase> & cls, uint32_t typeFlags);
- bool validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget);
+ // Check an attribute, and return a checked modifier that represents it.
+ //
+ Modifier* validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget);
AttributeBase* checkAttribute(
UncheckedAttribute* uncheckedAttr,
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index d7f879c51..f05b58c34 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -330,7 +330,7 @@ namespace Slang
return false;
}
- bool SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget)
+ Modifier* SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget)
{
if (auto numThreadsAttr = as<NumThreadsAttribute>(attr))
{
@@ -348,14 +348,14 @@ namespace Slang
auto intValue = checkLinkTimeConstantIntVal(arg);
if (!intValue)
{
- return false;
+ return nullptr;
}
if (auto constIntVal = as<ConstantIntVal>(intValue))
{
if (constIntVal->getValue() < 1)
{
getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, constIntVal->getValue());
- return false;
+ return nullptr;
}
if (intValue->getType() != m_astBuilder->getIntType())
{
@@ -390,7 +390,7 @@ namespace Slang
auto intValue = checkLinkTimeConstantIntVal(arg);
if (!intValue)
{
- return false;
+ return nullptr;
}
if (auto constIntVal = as<ConstantIntVal>(intValue))
{
@@ -407,7 +407,7 @@ namespace Slang
if (!isValidWaveSize)
{
getSink()->diagnose(attr, Diagnostics::invalidWaveSize, constIntVal->getValue());
- return false;
+ return nullptr;
}
}
value = intValue;
@@ -426,20 +426,20 @@ namespace Slang
if (attr->args.getCount() != 1)
{
- return false;
+ return nullptr;
}
auto value = checkConstantIntVal(attr->args[0]);
if (value == nullptr)
{
- return false;
+ return nullptr;
}
const IRIntegerValue kMaxAnyValueSize = 0x7FFF;
if (value->getValue() > kMaxAnyValueSize)
{
getSink()->diagnose(anyValueSizeAttr->loc, Diagnostics::anyValueSizeExceedsLimit, kMaxAnyValueSize);
- return false;
+ return nullptr;
}
anyValueSizeAttr->size = int32_t(value->getValue());
@@ -448,16 +448,16 @@ namespace Slang
{
if (attr->args.getCount() != 1)
{
- return false;
+ return nullptr;
}
auto value = checkConstantIntVal(attr->args[0]);
if (value == nullptr)
{
- return false;
+ return nullptr;
}
if (value->getValue() < 0)
{
- return false;
+ return nullptr;
}
glslRequireShaderInputParameter->parameterNumber = int32_t(value->getValue());
}
@@ -465,23 +465,23 @@ namespace Slang
{
if (attr->args.getCount() != 1)
{
- return false;
+ return nullptr;
}
auto rank = checkConstantIntVal(attr->args[0]);
if (rank == nullptr)
{
- return false;
+ return nullptr;
}
overloadRankAttr->rank = int32_t(rank->getValue());
}
else if (auto inputAttachmentIndexLayoutAttribute = as<GLSLInputAttachmentIndexLayoutAttribute>(attr))
{
if (attr->args.getCount() != 1)
- return false;
+ return nullptr;
auto location = checkConstantIntVal(attr->args[0]);
if(!location)
- return false;
+ return nullptr;
inputAttachmentIndexLayoutAttribute->location = location->getValue();
}
@@ -492,7 +492,7 @@ namespace Slang
// in core.meta.slang, but that's not completely implemented. So for now we check here.
if (attr->args.getCount() != 2)
{
- return false;
+ return nullptr;
}
// TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here
@@ -502,7 +502,7 @@ namespace Slang
if (binding == nullptr || set == nullptr)
{
- return false;
+ return nullptr;
}
bindingAttr->binding = int32_t(binding->getValue());
@@ -515,13 +515,13 @@ namespace Slang
if (attr->args.getCount() != 1)
{
- return false;
+ return nullptr;
}
auto value = checkConstantIntVal(attr->args[0]);
if (value == nullptr)
{
- return false;
+ return nullptr;
}
simpleLayoutAttr->value = int32_t(value->getValue());
@@ -531,7 +531,7 @@ namespace Slang
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if (!val) return false;
+ if (!val) return nullptr;
maxVertexCountAttr->value = (int32_t)val->getValue();
}
@@ -540,7 +540,7 @@ namespace Slang
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if (!val) return false;
+ if (!val) return nullptr;
instanceAttr->value = (int32_t)val->getValue();
}
@@ -551,7 +551,7 @@ namespace Slang
String capNameString;
if (!checkLiteralStringVal(attr->args[0], &capNameString))
{
- return false;
+ return nullptr;
}
CapabilityName capName = findCapabilityName(capNameString.getUnownedSlice());
@@ -587,7 +587,7 @@ namespace Slang
{
// always diagnose this error since nothing can compile with an invalid capability
getSink()->diagnose(attr, Diagnostics::unknownCapability, capNameString);
- return false;
+ return nullptr;
}
}
else if ((as<DomainAttribute>(attr)) ||
@@ -627,21 +627,6 @@ namespace Slang
getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName);
}
}
- else if (as<PushConstantAttribute>(attr))
- {
- // Has no args
- SLANG_ASSERT(attr->args.getCount() == 0);
- }
- else if (as<ShaderRecordAttribute>(attr))
- {
- // Has no args
- SLANG_ASSERT(attr->args.getCount() == 0);
- }
- else if (as<EarlyDepthStencilAttribute>(attr))
- {
- // Has no args
- SLANG_ASSERT(attr->args.getCount() == 0);
- }
else if (auto attrUsageAttr = as<AttributeUsageAttribute>(attr))
{
uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None;
@@ -655,13 +640,13 @@ namespace Slang
else
{
getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName);
- return false;
+ return nullptr;
}
}
if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId))
{
getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget);
- return false;
+ return nullptr;
}
}
else if (const auto unrollAttr = as<UnrollAttribute>(attr))
@@ -753,7 +738,7 @@ namespace Slang
String formatName;
if(!checkLiteralStringVal(attr->args[0], &formatName))
{
- return false;
+ return nullptr;
}
ImageFormat format = ImageFormat::unknown;
@@ -782,7 +767,7 @@ namespace Slang
String diagnosticName;
if(!checkLiteralStringVal(attr->args[0], &diagnosticName))
{
- return false;
+ return nullptr;
}
auto diagnosticInfo = findDiagnosticByName(diagnosticName.getUnownedSlice());
@@ -800,14 +785,14 @@ namespace Slang
String libraryName;
if (!checkLiteralStringVal(dllImportAttr->args[0], &libraryName))
{
- return false;
+ return nullptr;
}
dllImportAttr->modulePath = libraryName;
String functionName;
if (dllImportAttr->args.getCount() == 2 && !checkLiteralStringVal(dllImportAttr->args[1], &functionName))
{
- return false;
+ return nullptr;
}
dllImportAttr->functionName = functionName;
}
@@ -816,7 +801,7 @@ namespace Slang
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if (!val) return false;
+ if (!val) return nullptr;
rayPayloadAttr->location = (int32_t)val->getValue();
}
@@ -824,7 +809,7 @@ namespace Slang
{
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if (!val) return false;
+ if (!val) return nullptr;
rayPayloadInAttr->location = (int32_t)val->getValue();
}
else if (auto callablePayloadAttr = as<VulkanCallablePayloadAttribute>(attr))
@@ -832,7 +817,7 @@ namespace Slang
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if (!val) return false;
+ if (!val) return nullptr;
callablePayloadAttr->location = (int32_t)val->getValue();
}
@@ -840,7 +825,7 @@ namespace Slang
{
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if (!val) return false;
+ if (!val) return nullptr;
callablePayloadInAttr->location = (int32_t)val->getValue();
}
else if (auto hitObjectAttributesAttr = as<VulkanHitObjectAttributesAttribute>(attr))
@@ -848,10 +833,18 @@ namespace Slang
SLANG_ASSERT(attr->args.getCount() == 1);
auto val = checkConstantIntVal(attr->args[0]);
- if (!val) return false;
+ if (!val) return nullptr;
hitObjectAttributesAttr->location = (int32_t)val->getValue();
}
+ else if (auto constantIdAttr = as<VkConstantIdAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ auto val = checkConstantIntVal(attr->args[0]);
+
+ if (!val) return nullptr;
+ constantIdAttr->location = (int32_t)val->getValue();
+ }
else if (as<UserDefinedDerivativeAttribute>(attr) || as<PrimalSubstituteAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
@@ -876,7 +869,7 @@ namespace Slang
SLANG_ASSERT(as<Decl>(attrTarget));
auto val = checkConstantIntVal(attr->args[0]);
- if (!val) return false;
+ if (!val) return nullptr;
preferRecomputeAttr->sideEffectBehavior = (PreferRecomputeAttribute::SideEffectBehavior) val->getValue();
}
@@ -886,7 +879,7 @@ namespace Slang
String guid;
if (!checkLiteralStringVal(comInterfaceAttr->args[0], &guid))
{
- return false;
+ return nullptr;
}
StringBuilder resultGUID;
for (auto ch : guid)
@@ -902,14 +895,14 @@ namespace Slang
else
{
getSink()->diagnose(attr, Diagnostics::invalidGUID, guid);
- return false;
+ return nullptr;
}
}
comInterfaceAttr->guid = resultGUID.toString();
if (comInterfaceAttr->guid.getLength() != 32)
{
getSink()->diagnose(attr, Diagnostics::invalidGUID, guid);
- return false;
+ return nullptr;
}
}
else if (const auto derivativeMemberAttr = as<DerivativeMemberAttribute>(attr))
@@ -918,7 +911,7 @@ namespace Slang
if (!varDecl)
{
getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attr->getKeywordName());
- return false;
+ return nullptr;
}
}
else if (auto deprecatedAttr = as<DeprecatedAttribute>(attr))
@@ -928,7 +921,7 @@ namespace Slang
String message;
if(!checkLiteralStringVal(attr->args[0], &message))
{
- return false;
+ return nullptr;
}
deprecatedAttr->message = message;
@@ -940,7 +933,7 @@ namespace Slang
String name;
if(!checkLiteralStringVal(attr->args[0], &name))
{
- return false;
+ return nullptr;
}
knownBuiltinAttr->name = name;
@@ -953,7 +946,7 @@ namespace Slang
String name;
if(!checkLiteralStringVal(attr->args[0], &name))
{
- return false;
+ return nullptr;
}
pyExportAttr->name = name;
@@ -980,17 +973,17 @@ namespace Slang
if (attr->args.getCount() > 2)
{
getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0);
- return false;
+ return nullptr;
}
else if (attr->args.getCount() < 2)
{
getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 2);
- return false;
+ return nullptr;
}
CapabilityName capName;
if (!checkCapabilityName(attr->args[0], capName))
{
- return false;
+ return nullptr;
}
requirePreludeAttr->capabilitySet = CapabilitySet(capName);
if (auto stringLitExpr = as<StringLiteralExpr>(attr->args[1]))
@@ -1000,9 +993,9 @@ namespace Slang
else
{
getSink()->diagnose(attr->args[1], Diagnostics::expectedAStringLiteral);
- return false;
+ return nullptr;
}
- return true;
+ return attr;
}
else
{
@@ -1016,11 +1009,11 @@ namespace Slang
// We should be special-casing the checking of any attribute
// with a non-zero number of arguments.
getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0);
- return false;
+ return nullptr;
}
}
- return true;
+ return attr;
}
AttributeBase* SemanticsVisitor::checkAttribute(
@@ -1156,7 +1149,6 @@ namespace Slang
// Modifiers that are their own exclusive group.
case ASTNodeType::GLSLLayoutModifier:
case ASTNodeType::GLSLParsedLayoutModifier:
- case ASTNodeType::GLSLConstantIDLayoutModifier:
case ASTNodeType::GLSLLocationLayoutModifier:
case ASTNodeType::GLSLInputAttachmentIndexLayoutAttribute:
case ASTNodeType::GLSLOffsetLayoutAttribute:
@@ -1236,7 +1228,6 @@ namespace Slang
case ASTNodeType::OutModifier:
case ASTNodeType::GLSLLayoutModifier:
case ASTNodeType::GLSLParsedLayoutModifier:
- case ASTNodeType::GLSLConstantIDLayoutModifier:
case ASTNodeType::GLSLLocationLayoutModifier:
case ASTNodeType::GLSLInputAttachmentIndexLayoutAttribute:
case ASTNodeType::GLSLOffsetLayoutAttribute:
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index fadf76dd2..ad365581a 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -4178,6 +4178,9 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl)
emitLayoutSemantics(varDecl, "register");
+ // If the parameter has a default value, we may need to emit it.
+ emitGlobalParamDefaultVal(varDecl);
+
// A shader parameter cannot have an initializer,
// so we do need to consider emitting one here.
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index c866456a2..756b59913 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -468,7 +468,7 @@ public:
protected:
-
+ virtual void emitGlobalParamDefaultVal(IRGlobalParam* inst) { SLANG_UNUSED(inst); }
virtual void emitPostDeclarationAttributesForType(IRInst* type) { SLANG_UNUSED(type); }
virtual bool doesTargetSupportPtrTypes() { return false; }
virtual void emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling, EmitLayoutSemanticOption layoutSemanticOption) { SLANG_UNUSED(inst); SLANG_UNUSED(uniformSemanticSpelling); SLANG_UNUSED(layoutSemanticOption); }
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 902030791..fff1cddbe 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -385,6 +385,15 @@ void GLSLSourceEmitter::_emitGLSLSSBO(IRGlobalParam* varDecl, IRGLSLShaderStorag
m_writer->emit(";\n");
}
+void GLSLSourceEmitter::emitGlobalParamDefaultVal(IRGlobalParam* param)
+{
+ if (auto defaultValDecor = param->findDecoration<IRDefaultValueDecoration>())
+ {
+ m_writer->emit(" = ");
+ emitInstExpr(defaultValDecor->getOperand(0), EmitOpInfo());
+ }
+}
+
void GLSLSourceEmitter::_emitGLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type)
{
auto varLayout = getVarLayout(varDecl);
@@ -418,6 +427,8 @@ void GLSLSourceEmitter::_emitGLSLParameterGroup(IRGlobalParam* varDecl, IRUnifor
}
_emitGLSLLayoutQualifier(LayoutResourceKind::PushConstantBuffer, &containerChain);
+ _emitGLSLLayoutQualifier(LayoutResourceKind::SpecializationConstant, &containerChain);
+
bool isShaderRecord = _emitGLSLLayoutQualifier(LayoutResourceKind::ShaderRecord, &containerChain);
if (isShaderRecord)
diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h
index 1911749ec..cc8a4b02d 100644
--- a/source/slang/slang-emit-glsl.h
+++ b/source/slang/slang-emit-glsl.h
@@ -46,6 +46,7 @@ protected:
virtual void emitTypeImpl(IRType* type, const StringSliceLoc* nameAndLoc) SLANG_OVERRIDE;
virtual void emitParamTypeImpl(IRType* type, String const& name) SLANG_OVERRIDE;
virtual void emitFuncDecorationImpl(IRDecoration* decoration) SLANG_OVERRIDE;
+ virtual void emitGlobalParamDefaultVal(IRGlobalParam* decl) SLANG_OVERRIDE;
virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index fa5fc7cd1..3d6bf846f 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -455,6 +455,30 @@ SpvInst* emitOpVariable(
);
}
+template<typename T, typename TOperand>
+SpvInst* emitOpSpecConstant(SpvInstParent* parent, IRInst* inst, const T& idResultType, TOperand operand)
+{
+ return emitInst(parent, inst, SpvOpSpecConstant, idResultType, kResultID, operand);
+}
+
+template<typename T, typename Ts>
+SpvInst* emitOpSpecConstantComposite(SpvInstParent* parent, IRInst* inst, const T& idResultType, const Ts& constituents)
+{
+ return emitInst(parent, inst, SpvOpSpecConstantComposite, idResultType, kResultID, constituents);
+}
+
+template<typename T>
+SpvInst* emitOpSpecConstantTrue(SpvInstParent* parent, IRInst* inst, const T& idResultType)
+{
+ return emitInst(parent, inst, SpvOpSpecConstantTrue, idResultType, kResultID);
+}
+
+template<typename T>
+SpvInst* emitOpSpecConstantFalse(SpvInstParent* parent, IRInst* inst, const T& idResultType)
+{
+ return emitInst(parent, inst, SpvOpSpecConstantFalse, idResultType, kResultID);
+}
+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpLoad
template<typename T1, typename T2>
SpvInst* emitOpLoad(
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 8f4501c2b..bb1f1378f 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -918,6 +918,51 @@ struct SPIRVEmitContext
};
Dictionary<ConstantValueKey<IRIntegerValue>, SpvInst*> m_spvIntConstants;
Dictionary<ConstantValueKey<IRFloatingPointValue>, SpvInst*> m_spvFloatConstants;
+
+ // Get an SpvLiteralBits from an IRConstant.
+ SpvLiteralBits getLiteralBits(IRInst* type, IRInst* inst)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_DoubleType:
+ {
+ if (auto fval = as<IRFloatLit>(inst))
+ return SpvLiteralBits::from64(DoubleAsInt64(fval->getValue()));
+ break;
+ }
+ case kIROp_HalfType:
+ {
+ if (auto fval = as<IRFloatLit>(inst))
+ return SpvLiteralBits::from32(uint32_t(FloatToHalf((float)fval->getValue())));
+ break;
+ }
+ case kIROp_FloatType:
+ {
+ if (auto fval = as<IRFloatLit>(inst))
+ return SpvLiteralBits::from32(FloatAsInt((float)fval->getValue()));
+ break;
+ }
+ case kIROp_Int64Type:
+ case kIROp_UInt64Type:
+#if SLANG_PTR_IS_64
+ case kIROp_PtrType:
+ case kIROp_UIntPtrType:
+#endif
+ {
+ if (auto val = as<IRIntLit>(inst))
+ return SpvLiteralBits::from64(uint64_t(val->getValue()));
+ break;
+ }
+ default:
+ {
+ if (auto val = as<IRIntLit>(inst))
+ return SpvLiteralBits::from32(uint32_t(val->getValue()));
+ break;
+ }
+ }
+ return SpvLiteralBits::from32(0);
+ }
+
SpvInst* emitIntConstant(IRIntegerValue val, IRType* type, IRInst* inst = nullptr)
{
ConstantValueKey<IRIntegerValue> key;
@@ -1269,6 +1314,7 @@ struct SPIRVEmitContext
return SpvStorageClassPhysicalStorageBuffer;
case AddressSpace::Global:
case AddressSpace::MetalObjectData:
+ case AddressSpace::SpecializationConstant:
// msvc is limiting us from putting the UNEXPECTED macro here, so
// just fall out
;
@@ -2403,9 +2449,88 @@ struct SPIRVEmitContext
}
}
+ /// Emit a specialization constant.
+ SpvInst* emitSpecializationConstant(IRGlobalParam* param, IRVarOffsetAttr* offset)
+ {
+ IRInst* defaultVal = nullptr;
+ if (auto defaultValDecor = param->findDecoration<IRDefaultValueDecoration>())
+ {
+ defaultVal = defaultValDecor->getOperand(0);
+ }
+ else
+ {
+ IRBuilder builder(param);
+ builder.setInsertBefore(param);
+ defaultVal = builder.emitDefaultConstruct(param->getDataType());
+ }
+
+ SpvInst* result = nullptr;
+ if (as<IRBoolType>(defaultVal->getDataType()))
+ {
+ bool value = false;
+ if (auto boolLit = as<IRBoolLit>(defaultVal))
+ {
+ value = boolLit->getValue();
+ }
+ if (value)
+ {
+ result = emitOpSpecConstantTrue(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ param,
+ param->getDataType());
+ }
+ else
+ {
+ result = emitOpSpecConstantFalse(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ param,
+ param->getDataType());
+ }
+ }
+ else if (auto type = as<IRBasicType>(defaultVal->getDataType()))
+ {
+ result = emitOpSpecConstant(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ param,
+ param->getDataType(),
+ getLiteralBits(type, defaultVal));
+ }
+ else if (as<IRVectorType>(defaultVal->getDataType()))
+ {
+ List<IRInst*> operands;
+ auto makeVector = as<IRMakeVector>(defaultVal);
+ for (UInt i = 0; i < makeVector->getOperandCount(); i++)
+ {
+ operands.add(makeVector->getOperand(i));
+ }
+ result = emitOpSpecConstantComposite(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ param,
+ param->getDataType(),
+ operands);
+ }
+
+ emitOpDecorateSpecId(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ result,
+ SpvLiteralInteger::from32((uint32_t)offset->getOffset()));
+
+ maybeEmitName(result, param);
+ return result;
+ }
+
/// Emit a global parameter definition.
SpvInst* emitGlobalParam(IRGlobalParam* param)
{
+ auto layout = getVarLayout(param);
+ if (layout)
+ {
+ if (auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
+ {
+ return emitSpecializationConstant(param, offset);
+ }
+ }
auto storageClass = SpvStorageClassUniform;
if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
{
@@ -2425,7 +2550,7 @@ struct SPIRVEmitContext
storageClass
);
maybeEmitPointerDecoration(varInst, param);
- if (auto layout = getVarLayout(param))
+ if (layout)
emitVarLayout(param, varInst, layout);
emitDecorations(param, getID(varInst));
return varInst;
@@ -3419,8 +3544,20 @@ struct SPIRVEmitContext
}
}
paramsSet.add(spvGlobalInst);
- params.add(spvGlobalInst);
referencedBuiltinIRVars.add(globalInst);
+
+ // Don't add a global param to the interface if it is a specialization constant.
+ switch (spvGlobalInst->opcode)
+ {
+ case SpvOpSpecConstant:
+ case SpvOpSpecConstantFalse:
+ case SpvOpSpecConstantTrue:
+ case SpvOpSpecConstantComposite:
+ break;
+ default:
+ params.add(spvGlobalInst);
+ break;
+ }
}
}
break;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 09cb7952c..f308f340d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -783,6 +783,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(HasExplicitHLSLBindingDecoration, HasExplicitHLSLBinding, 0, 0)
+ INST(DefaultValueDecoration, DefaultValue, 1, 0)
INST(ReadNoneDecoration, readNone, 0, 0)
INST(VulkanCallablePayloadDecoration, vulkanCallablePayload, 0, 0)
INST(VulkanCallablePayloadInDecoration, vulkanCallablePayloadIn, 0, 0)
@@ -956,6 +957,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(ConstructorDecoration, constructor, 1, 0)
INST(MethodDecoration, method, 0, 0)
INST(PackOffsetDecoration, packoffset, 2, 0)
+ INST(SpecializationConstantDecoration, SpecializationConstantDecoration, 1, 0)
// Reflection metadata for a shader parameter that provides the original type name.
INST(UserTypeNameDecoration, UserTypeName, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9b94005d9..8f648aabd 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -330,6 +330,8 @@ IR_SIMPLE_DECORATION(PerVertexDecoration)
IR_SIMPLE_DECORATION(SPIRVBlockDecoration)
+IR_SIMPLE_DECORATION(DefaultValueDecoration)
+
struct IRRequireGLSLVersionDecoration : IRDecoration
{
enum { kOp = kIROp_RequireGLSLVersionDecoration };
@@ -4974,6 +4976,10 @@ public:
{
addDecoration(value, kIROp_HasExplicitHLSLBindingDecoration);
}
+ void addDefaultValueDecoration(IRInst* value, IRInst* defaultValue)
+ {
+ addDecoration(value, kIROp_DefaultValueDecoration, defaultValue);
+ }
void addNVAPIMagicDecoration(IRInst* value, UnownedStringSlice const& name)
{
addDecoration(value, kIROp_NVAPIMagicDecoration, getStringValue(name));
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 11fbb5b4b..c3d0874e7 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -568,6 +568,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
+ // Don't do any processing for specialization constants.
+ if (addressSpace == AddressSpace::SpecializationConstant)
+ {
+ return;
+ }
+
// Opaque resource handles can't be in Uniform for Vulkan, if they are
// placed here then put them in UniformConstant instead
if (isSpirvUniformConstantType(inst->getDataType()))
@@ -728,6 +734,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case LayoutResourceKind::PushConstantBuffer:
addressSpace = AddressSpace::PushConstant;
break;
+ case LayoutResourceKind::SpecializationConstant:
+ addressSpace = AddressSpace::SpecializationConstant;
+ break;
case LayoutResourceKind::RayPayload:
addressSpace = AddressSpace::IncomingRayPayload;
break;
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ab9e2a540..f62bb631e 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8021,14 +8021,33 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
addTargetIntrinsicDecorations(nullptr, irParam, decl);
- if (decl->findModifier<HLSLLayoutSemantic>())
+
+ bool hasLayoutSemantic = false;
+ bool isSpecializationConstant = false;
+ for (auto modifier : decl->modifiers)
{
- builder->addHasExplicitHLSLBindingDecoration(irParam);
+ if (as<HLSLLayoutSemantic>(modifier))
+ {
+ hasLayoutSemantic = true;
+ }
+ else if (as<SpecializationConstantAttribute>(modifier) || as<VkConstantIdAttribute>(modifier))
+ {
+ isSpecializationConstant = true;
+ }
}
+ if (hasLayoutSemantic)
+ builder->addHasExplicitHLSLBindingDecoration(irParam);
+
// A global variable's SSA value is a *pointer* to
// the underlying storage.
context->setGlobalValue(decl, paramVal);
+ if (isSpecializationConstant && decl->initExpr)
+ {
+ auto initVal = getSimpleVal(context, lowerRValueExpr(context, decl->initExpr));
+ builder->addDefaultValueDecoration(irParam, initVal);
+ }
+
irParam->moveToEnd();
return paramVal;
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 6f8504a31..115ccc55e 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -703,20 +703,21 @@ RefPtr<TypeLayout> getTypeLayoutForGlobalShaderParameter(
type);
}
- // TODO: The cases below for detecting globals that aren't actually
- // shader parameters should be redundant now that the semantic
- // checking logic is responsible for populating the list of
- // parameters on a `Program`. We should be able to clean up
- // the code by removing these two cases, and the related null
- // pointer checks in the code that calls this.
-
- // HLSL `static` modifier indicates "thread local"
- if(varDecl->hasModifier<HLSLStaticModifier>())
- return nullptr;
-
- // HLSL `groupshared` modifier indicates "thread-group local"
- if(varDecl->hasModifier<HLSLGroupSharedModifier>())
- return nullptr;
+ if (varDecl->hasModifier<SpecializationConstantAttribute>() ||
+ varDecl->hasModifier<VkConstantIdAttribute>())
+ {
+ auto specializationConstantRule = rules->getSpecializationConstantRules();
+ if (!specializationConstantRule)
+ {
+ // If the target doesn't support specialization constants, then we will
+ // layout them as ordinary uniform data.
+ specializationConstantRule = rules->getConstantBufferRules(context->getTargetRequest()->getOptionSet());
+ }
+ return createTypeLayoutWith(
+ layoutContext,
+ specializationConstantRule,
+ type);
+ }
// TODO(tfoley): there may be other cases that we need to handle here
@@ -1143,10 +1144,10 @@ static void addExplicitParameterBindings_GLSL(
else if(auto foundSpecializationConstant = typeLayout->FindResourceInfo(LayoutResourceKind::SpecializationConstant))
{
info[kResInfo].resInfo = foundSpecializationConstant;
- DeclRef<Decl> varDecl2(varDecl);
- // Try to find `constant_id` binding
- if(!findLayoutArg<GLSLConstantIDLayoutModifier>(varDecl2, &info[kResInfo].semanticInfo.index))
+ if (auto layoutAttr = varDecl.getDecl()->findModifier<VkConstantIdAttribute>())
+ info[kResInfo].semanticInfo.index = layoutAttr->location;
+ else
return;
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index b3a18f8a8..04ada006c 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -8190,7 +8190,7 @@ namespace Slang
CASE(push_constant, PushConstantAttribute)
CASE(shaderRecordNV, ShaderRecordAttribute)
CASE(shaderRecordEXT, ShaderRecordAttribute)
- CASE(constant_id, GLSLConstantIDLayoutModifier)
+ CASE(constant_id, VkConstantIdAttribute)
CASE(std140, GLSLStd140Modifier)
CASE(std430, GLSLStd430Modifier)
CASE(scalar, GLSLScalarModifier)
@@ -8228,6 +8228,11 @@ namespace Slang
parser->diagnose(modifier->loc, Diagnostics::missingLayoutBindingModifier);
}
}
+ else if (auto specConstAttr = as<VkConstantIdAttribute>(modifier))
+ {
+ parser->ReadToken(TokenType::OpAssign);
+ specConstAttr->location = (int)getIntegerLiteralValue(parser->ReadToken(TokenType::IntegerLiteral));
+ }
listBuilder.add(modifier);
}
diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h
index adf6e26f8..7fc9eeb6b 100644
--- a/source/slang/slang-type-system-shared.h
+++ b/source/slang/slang-type-system-shared.h
@@ -101,6 +101,8 @@ FOREACH_BASE_TYPE(DEFINE_BASE_TYPE)
UniformConstant,
// Corresponds to SPIR-V's SpvStorageClassImage
Image,
+ // Represents a SPIR-V specialization constant
+ SpecializationConstant,
// Default address space for a user-defined pointer
UserPointer = 0x100000001ULL,