diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-30 12:03:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-30 12:03:23 -0700 |
| commit | de83628070614ec37349c9f334ed72a54a6889da (patch) | |
| tree | bc97f74013a073fc958b75e68089696e14d71412 | |
| parent | f428a058ea48535a323c32d206ebc7e551c3c3e9 (diff) | |
Support specialization constants. (#4963)
* Support specialization constants.
* Fix.
* Fix.
* Fix.
* Fix.
* Make sure specialization constants have names.
* Clean up and support the dxc [vk::constant_id] syntax.
* Fix.
* Fix.
* Fix.
| -rw-r--r-- | source/slang/core.meta.slang | 11 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 18 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 123 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-glsl.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-emit-glsl.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops.h | 24 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 141 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 23 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 35 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-type-system-shared.h | 2 | ||||
| -rw-r--r-- | tests/language-feature/zero-initialize/struct.slang | 1 | ||||
| -rw-r--r-- | tests/spirv/specialization-constant.slang | 48 |
19 files changed, 373 insertions, 98 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 7d5f4087c..4e8529666 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2699,12 +2699,19 @@ attribute_syntax [vk_shader_record] : ShaderRecordAttr __attributeTarget(VarDeclBase) attribute_syntax [shader_record] : ShaderRecordAttribute; -__attributeTarget(DeclBase) +__attributeTarget(VarDeclBase) attribute_syntax [vk_push_constant] : PushConstantAttribute; -__attributeTarget(DeclBase) +__attributeTarget(VarDeclBase) attribute_syntax [push_constant] : PushConstantAttribute; __attributeTarget(VarDeclBase) +attribute_syntax[vk_specialization_constant] : SpecializationConstantAttribute; +__attributeTarget(VarDeclBase) +attribute_syntax[SpecializationConstant] : SpecializationConstantAttribute; +__attributeTarget(VarDeclBase) +attribute_syntax[vk_constant_id(location: int)] : VkConstantIdAttribute; + +__attributeTarget(VarDeclBase) attribute_syntax [vk_location(location : int)] : GLSLLocationAttribute; __attributeTarget(VarDeclBase) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index a7af4a249..69f86a43d 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -235,10 +235,6 @@ class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier // Specific cases for known GLSL `layout` modifiers that we need to work with -class GLSLConstantIDLayoutModifier : public GLSLParsedLayoutModifier -{ - SLANG_AST_CLASS(GLSLConstantIDLayoutModifier) -}; class GLSLLocationLayoutModifier : public GLSLParsedLayoutModifier { @@ -727,11 +723,23 @@ class FlagsAttribute : public Attribute }; // [[vk_push_constant]] [[push_constant]] -class PushConstantAttribute : public Attribute +class PushConstantAttribute : public Attribute { SLANG_AST_CLASS(PushConstantAttribute) }; +// [[vk_specialization_constant]] [[specialization_constant]] +class SpecializationConstantAttribute : public Attribute +{ + SLANG_AST_CLASS(SpecializationConstantAttribute) +}; + +// [[vk_constant_id]] +class VkConstantIdAttribute : public Attribute +{ + SLANG_AST_CLASS(VkConstantIdAttribute) + int location; +}; // [[vk_shader_record]] [[shader_record]] class ShaderRecordAttribute : public Attribute diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 96b467b0c..caec9dcee 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1615,7 +1615,9 @@ namespace Slang bool getAttributeTargetSyntaxClasses(SyntaxClass<NodeBase> & cls, uint32_t typeFlags); - bool validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget); + // Check an attribute, and return a checked modifier that represents it. + // + Modifier* validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget); AttributeBase* checkAttribute( UncheckedAttribute* uncheckedAttr, diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index d7f879c51..f05b58c34 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -330,7 +330,7 @@ namespace Slang return false; } - bool SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget) + Modifier* SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget) { if (auto numThreadsAttr = as<NumThreadsAttribute>(attr)) { @@ -348,14 +348,14 @@ namespace Slang auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { - return false; + return nullptr; } if (auto constIntVal = as<ConstantIntVal>(intValue)) { if (constIntVal->getValue() < 1) { getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, constIntVal->getValue()); - return false; + return nullptr; } if (intValue->getType() != m_astBuilder->getIntType()) { @@ -390,7 +390,7 @@ namespace Slang auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { - return false; + return nullptr; } if (auto constIntVal = as<ConstantIntVal>(intValue)) { @@ -407,7 +407,7 @@ namespace Slang if (!isValidWaveSize) { getSink()->diagnose(attr, Diagnostics::invalidWaveSize, constIntVal->getValue()); - return false; + return nullptr; } } value = intValue; @@ -426,20 +426,20 @@ namespace Slang if (attr->args.getCount() != 1) { - return false; + return nullptr; } auto value = checkConstantIntVal(attr->args[0]); if (value == nullptr) { - return false; + return nullptr; } const IRIntegerValue kMaxAnyValueSize = 0x7FFF; if (value->getValue() > kMaxAnyValueSize) { getSink()->diagnose(anyValueSizeAttr->loc, Diagnostics::anyValueSizeExceedsLimit, kMaxAnyValueSize); - return false; + return nullptr; } anyValueSizeAttr->size = int32_t(value->getValue()); @@ -448,16 +448,16 @@ namespace Slang { if (attr->args.getCount() != 1) { - return false; + return nullptr; } auto value = checkConstantIntVal(attr->args[0]); if (value == nullptr) { - return false; + return nullptr; } if (value->getValue() < 0) { - return false; + return nullptr; } glslRequireShaderInputParameter->parameterNumber = int32_t(value->getValue()); } @@ -465,23 +465,23 @@ namespace Slang { if (attr->args.getCount() != 1) { - return false; + return nullptr; } auto rank = checkConstantIntVal(attr->args[0]); if (rank == nullptr) { - return false; + return nullptr; } overloadRankAttr->rank = int32_t(rank->getValue()); } else if (auto inputAttachmentIndexLayoutAttribute = as<GLSLInputAttachmentIndexLayoutAttribute>(attr)) { if (attr->args.getCount() != 1) - return false; + return nullptr; auto location = checkConstantIntVal(attr->args[0]); if(!location) - return false; + return nullptr; inputAttachmentIndexLayoutAttribute->location = location->getValue(); } @@ -492,7 +492,7 @@ namespace Slang // in core.meta.slang, but that's not completely implemented. So for now we check here. if (attr->args.getCount() != 2) { - return false; + return nullptr; } // TODO(JS): Prior validation currently doesn't ensure both args are ints (as specified in core.meta.slang), so check here @@ -502,7 +502,7 @@ namespace Slang if (binding == nullptr || set == nullptr) { - return false; + return nullptr; } bindingAttr->binding = int32_t(binding->getValue()); @@ -515,13 +515,13 @@ namespace Slang if (attr->args.getCount() != 1) { - return false; + return nullptr; } auto value = checkConstantIntVal(attr->args[0]); if (value == nullptr) { - return false; + return nullptr; } simpleLayoutAttr->value = int32_t(value->getValue()); @@ -531,7 +531,7 @@ namespace Slang SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if (!val) return false; + if (!val) return nullptr; maxVertexCountAttr->value = (int32_t)val->getValue(); } @@ -540,7 +540,7 @@ namespace Slang SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if (!val) return false; + if (!val) return nullptr; instanceAttr->value = (int32_t)val->getValue(); } @@ -551,7 +551,7 @@ namespace Slang String capNameString; if (!checkLiteralStringVal(attr->args[0], &capNameString)) { - return false; + return nullptr; } CapabilityName capName = findCapabilityName(capNameString.getUnownedSlice()); @@ -587,7 +587,7 @@ namespace Slang { // always diagnose this error since nothing can compile with an invalid capability getSink()->diagnose(attr, Diagnostics::unknownCapability, capNameString); - return false; + return nullptr; } } else if ((as<DomainAttribute>(attr)) || @@ -627,21 +627,6 @@ namespace Slang getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); } } - else if (as<PushConstantAttribute>(attr)) - { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); - } - else if (as<ShaderRecordAttribute>(attr)) - { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); - } - else if (as<EarlyDepthStencilAttribute>(attr)) - { - // Has no args - SLANG_ASSERT(attr->args.getCount() == 0); - } else if (auto attrUsageAttr = as<AttributeUsageAttribute>(attr)) { uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; @@ -655,13 +640,13 @@ namespace Slang else { getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); - return false; + return nullptr; } } if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId)) { getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget); - return false; + return nullptr; } } else if (const auto unrollAttr = as<UnrollAttribute>(attr)) @@ -753,7 +738,7 @@ namespace Slang String formatName; if(!checkLiteralStringVal(attr->args[0], &formatName)) { - return false; + return nullptr; } ImageFormat format = ImageFormat::unknown; @@ -782,7 +767,7 @@ namespace Slang String diagnosticName; if(!checkLiteralStringVal(attr->args[0], &diagnosticName)) { - return false; + return nullptr; } auto diagnosticInfo = findDiagnosticByName(diagnosticName.getUnownedSlice()); @@ -800,14 +785,14 @@ namespace Slang String libraryName; if (!checkLiteralStringVal(dllImportAttr->args[0], &libraryName)) { - return false; + return nullptr; } dllImportAttr->modulePath = libraryName; String functionName; if (dllImportAttr->args.getCount() == 2 && !checkLiteralStringVal(dllImportAttr->args[1], &functionName)) { - return false; + return nullptr; } dllImportAttr->functionName = functionName; } @@ -816,7 +801,7 @@ namespace Slang SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if (!val) return false; + if (!val) return nullptr; rayPayloadAttr->location = (int32_t)val->getValue(); } @@ -824,7 +809,7 @@ namespace Slang { SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if (!val) return false; + if (!val) return nullptr; rayPayloadInAttr->location = (int32_t)val->getValue(); } else if (auto callablePayloadAttr = as<VulkanCallablePayloadAttribute>(attr)) @@ -832,7 +817,7 @@ namespace Slang SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if (!val) return false; + if (!val) return nullptr; callablePayloadAttr->location = (int32_t)val->getValue(); } @@ -840,7 +825,7 @@ namespace Slang { SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if (!val) return false; + if (!val) return nullptr; callablePayloadInAttr->location = (int32_t)val->getValue(); } else if (auto hitObjectAttributesAttr = as<VulkanHitObjectAttributesAttribute>(attr)) @@ -848,10 +833,18 @@ namespace Slang SLANG_ASSERT(attr->args.getCount() == 1); auto val = checkConstantIntVal(attr->args[0]); - if (!val) return false; + if (!val) return nullptr; hitObjectAttributesAttr->location = (int32_t)val->getValue(); } + else if (auto constantIdAttr = as<VkConstantIdAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto val = checkConstantIntVal(attr->args[0]); + + if (!val) return nullptr; + constantIdAttr->location = (int32_t)val->getValue(); + } else if (as<UserDefinedDerivativeAttribute>(attr) || as<PrimalSubstituteAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); @@ -876,7 +869,7 @@ namespace Slang SLANG_ASSERT(as<Decl>(attrTarget)); auto val = checkConstantIntVal(attr->args[0]); - if (!val) return false; + if (!val) return nullptr; preferRecomputeAttr->sideEffectBehavior = (PreferRecomputeAttribute::SideEffectBehavior) val->getValue(); } @@ -886,7 +879,7 @@ namespace Slang String guid; if (!checkLiteralStringVal(comInterfaceAttr->args[0], &guid)) { - return false; + return nullptr; } StringBuilder resultGUID; for (auto ch : guid) @@ -902,14 +895,14 @@ namespace Slang else { getSink()->diagnose(attr, Diagnostics::invalidGUID, guid); - return false; + return nullptr; } } comInterfaceAttr->guid = resultGUID.toString(); if (comInterfaceAttr->guid.getLength() != 32) { getSink()->diagnose(attr, Diagnostics::invalidGUID, guid); - return false; + return nullptr; } } else if (const auto derivativeMemberAttr = as<DerivativeMemberAttribute>(attr)) @@ -918,7 +911,7 @@ namespace Slang if (!varDecl) { getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attr->getKeywordName()); - return false; + return nullptr; } } else if (auto deprecatedAttr = as<DeprecatedAttribute>(attr)) @@ -928,7 +921,7 @@ namespace Slang String message; if(!checkLiteralStringVal(attr->args[0], &message)) { - return false; + return nullptr; } deprecatedAttr->message = message; @@ -940,7 +933,7 @@ namespace Slang String name; if(!checkLiteralStringVal(attr->args[0], &name)) { - return false; + return nullptr; } knownBuiltinAttr->name = name; @@ -953,7 +946,7 @@ namespace Slang String name; if(!checkLiteralStringVal(attr->args[0], &name)) { - return false; + return nullptr; } pyExportAttr->name = name; @@ -980,17 +973,17 @@ namespace Slang if (attr->args.getCount() > 2) { getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0); - return false; + return nullptr; } else if (attr->args.getCount() < 2) { getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 2); - return false; + return nullptr; } CapabilityName capName; if (!checkCapabilityName(attr->args[0], capName)) { - return false; + return nullptr; } requirePreludeAttr->capabilitySet = CapabilitySet(capName); if (auto stringLitExpr = as<StringLiteralExpr>(attr->args[1])) @@ -1000,9 +993,9 @@ namespace Slang else { getSink()->diagnose(attr->args[1], Diagnostics::expectedAStringLiteral); - return false; + return nullptr; } - return true; + return attr; } else { @@ -1016,11 +1009,11 @@ namespace Slang // We should be special-casing the checking of any attribute // with a non-zero number of arguments. getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0); - return false; + return nullptr; } } - return true; + return attr; } AttributeBase* SemanticsVisitor::checkAttribute( @@ -1156,7 +1149,6 @@ namespace Slang // Modifiers that are their own exclusive group. case ASTNodeType::GLSLLayoutModifier: case ASTNodeType::GLSLParsedLayoutModifier: - case ASTNodeType::GLSLConstantIDLayoutModifier: case ASTNodeType::GLSLLocationLayoutModifier: case ASTNodeType::GLSLInputAttachmentIndexLayoutAttribute: case ASTNodeType::GLSLOffsetLayoutAttribute: @@ -1236,7 +1228,6 @@ namespace Slang case ASTNodeType::OutModifier: case ASTNodeType::GLSLLayoutModifier: case ASTNodeType::GLSLParsedLayoutModifier: - case ASTNodeType::GLSLConstantIDLayoutModifier: case ASTNodeType::GLSLLocationLayoutModifier: case ASTNodeType::GLSLInputAttachmentIndexLayoutAttribute: case ASTNodeType::GLSLOffsetLayoutAttribute: diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index fadf76dd2..ad365581a 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -4178,6 +4178,9 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) emitLayoutSemantics(varDecl, "register"); + // If the parameter has a default value, we may need to emit it. + emitGlobalParamDefaultVal(varDecl); + // A shader parameter cannot have an initializer, // so we do need to consider emitting one here. diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index c866456a2..756b59913 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -468,7 +468,7 @@ public: protected: - + virtual void emitGlobalParamDefaultVal(IRGlobalParam* inst) { SLANG_UNUSED(inst); } virtual void emitPostDeclarationAttributesForType(IRInst* type) { SLANG_UNUSED(type); } virtual bool doesTargetSupportPtrTypes() { return false; } virtual void emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling, EmitLayoutSemanticOption layoutSemanticOption) { SLANG_UNUSED(inst); SLANG_UNUSED(uniformSemanticSpelling); SLANG_UNUSED(layoutSemanticOption); } diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 902030791..fff1cddbe 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -385,6 +385,15 @@ void GLSLSourceEmitter::_emitGLSLSSBO(IRGlobalParam* varDecl, IRGLSLShaderStorag m_writer->emit(";\n"); } +void GLSLSourceEmitter::emitGlobalParamDefaultVal(IRGlobalParam* param) +{ + if (auto defaultValDecor = param->findDecoration<IRDefaultValueDecoration>()) + { + m_writer->emit(" = "); + emitInstExpr(defaultValDecor->getOperand(0), EmitOpInfo()); + } +} + void GLSLSourceEmitter::_emitGLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) { auto varLayout = getVarLayout(varDecl); @@ -418,6 +427,8 @@ void GLSLSourceEmitter::_emitGLSLParameterGroup(IRGlobalParam* varDecl, IRUnifor } _emitGLSLLayoutQualifier(LayoutResourceKind::PushConstantBuffer, &containerChain); + _emitGLSLLayoutQualifier(LayoutResourceKind::SpecializationConstant, &containerChain); + bool isShaderRecord = _emitGLSLLayoutQualifier(LayoutResourceKind::ShaderRecord, &containerChain); if (isShaderRecord) diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index 1911749ec..cc8a4b02d 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -46,6 +46,7 @@ protected: virtual void emitTypeImpl(IRType* type, const StringSliceLoc* nameAndLoc) SLANG_OVERRIDE; virtual void emitParamTypeImpl(IRType* type, String const& name) SLANG_OVERRIDE; virtual void emitFuncDecorationImpl(IRDecoration* decoration) SLANG_OVERRIDE; + virtual void emitGlobalParamDefaultVal(IRGlobalParam* decl) SLANG_OVERRIDE; virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index fa5fc7cd1..3d6bf846f 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -455,6 +455,30 @@ SpvInst* emitOpVariable( ); } +template<typename T, typename TOperand> +SpvInst* emitOpSpecConstant(SpvInstParent* parent, IRInst* inst, const T& idResultType, TOperand operand) +{ + return emitInst(parent, inst, SpvOpSpecConstant, idResultType, kResultID, operand); +} + +template<typename T, typename Ts> +SpvInst* emitOpSpecConstantComposite(SpvInstParent* parent, IRInst* inst, const T& idResultType, const Ts& constituents) +{ + return emitInst(parent, inst, SpvOpSpecConstantComposite, idResultType, kResultID, constituents); +} + +template<typename T> +SpvInst* emitOpSpecConstantTrue(SpvInstParent* parent, IRInst* inst, const T& idResultType) +{ + return emitInst(parent, inst, SpvOpSpecConstantTrue, idResultType, kResultID); +} + +template<typename T> +SpvInst* emitOpSpecConstantFalse(SpvInstParent* parent, IRInst* inst, const T& idResultType) +{ + return emitInst(parent, inst, SpvOpSpecConstantFalse, idResultType, kResultID); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpLoad template<typename T1, typename T2> SpvInst* emitOpLoad( diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 8f4501c2b..bb1f1378f 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -918,6 +918,51 @@ struct SPIRVEmitContext }; Dictionary<ConstantValueKey<IRIntegerValue>, SpvInst*> m_spvIntConstants; Dictionary<ConstantValueKey<IRFloatingPointValue>, SpvInst*> m_spvFloatConstants; + + // Get an SpvLiteralBits from an IRConstant. + SpvLiteralBits getLiteralBits(IRInst* type, IRInst* inst) + { + switch (type->getOp()) + { + case kIROp_DoubleType: + { + if (auto fval = as<IRFloatLit>(inst)) + return SpvLiteralBits::from64(DoubleAsInt64(fval->getValue())); + break; + } + case kIROp_HalfType: + { + if (auto fval = as<IRFloatLit>(inst)) + return SpvLiteralBits::from32(uint32_t(FloatToHalf((float)fval->getValue()))); + break; + } + case kIROp_FloatType: + { + if (auto fval = as<IRFloatLit>(inst)) + return SpvLiteralBits::from32(FloatAsInt((float)fval->getValue())); + break; + } + case kIROp_Int64Type: + case kIROp_UInt64Type: +#if SLANG_PTR_IS_64 + case kIROp_PtrType: + case kIROp_UIntPtrType: +#endif + { + if (auto val = as<IRIntLit>(inst)) + return SpvLiteralBits::from64(uint64_t(val->getValue())); + break; + } + default: + { + if (auto val = as<IRIntLit>(inst)) + return SpvLiteralBits::from32(uint32_t(val->getValue())); + break; + } + } + return SpvLiteralBits::from32(0); + } + SpvInst* emitIntConstant(IRIntegerValue val, IRType* type, IRInst* inst = nullptr) { ConstantValueKey<IRIntegerValue> key; @@ -1269,6 +1314,7 @@ struct SPIRVEmitContext return SpvStorageClassPhysicalStorageBuffer; case AddressSpace::Global: case AddressSpace::MetalObjectData: + case AddressSpace::SpecializationConstant: // msvc is limiting us from putting the UNEXPECTED macro here, so // just fall out ; @@ -2403,9 +2449,88 @@ struct SPIRVEmitContext } } + /// Emit a specialization constant. + SpvInst* emitSpecializationConstant(IRGlobalParam* param, IRVarOffsetAttr* offset) + { + IRInst* defaultVal = nullptr; + if (auto defaultValDecor = param->findDecoration<IRDefaultValueDecoration>()) + { + defaultVal = defaultValDecor->getOperand(0); + } + else + { + IRBuilder builder(param); + builder.setInsertBefore(param); + defaultVal = builder.emitDefaultConstruct(param->getDataType()); + } + + SpvInst* result = nullptr; + if (as<IRBoolType>(defaultVal->getDataType())) + { + bool value = false; + if (auto boolLit = as<IRBoolLit>(defaultVal)) + { + value = boolLit->getValue(); + } + if (value) + { + result = emitOpSpecConstantTrue( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + param, + param->getDataType()); + } + else + { + result = emitOpSpecConstantFalse( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + param, + param->getDataType()); + } + } + else if (auto type = as<IRBasicType>(defaultVal->getDataType())) + { + result = emitOpSpecConstant( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + param, + param->getDataType(), + getLiteralBits(type, defaultVal)); + } + else if (as<IRVectorType>(defaultVal->getDataType())) + { + List<IRInst*> operands; + auto makeVector = as<IRMakeVector>(defaultVal); + for (UInt i = 0; i < makeVector->getOperandCount(); i++) + { + operands.add(makeVector->getOperand(i)); + } + result = emitOpSpecConstantComposite( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + param, + param->getDataType(), + operands); + } + + emitOpDecorateSpecId( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + result, + SpvLiteralInteger::from32((uint32_t)offset->getOffset())); + + maybeEmitName(result, param); + return result; + } + /// Emit a global parameter definition. SpvInst* emitGlobalParam(IRGlobalParam* param) { + auto layout = getVarLayout(param); + if (layout) + { + if (auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant)) + { + return emitSpecializationConstant(param, offset); + } + } auto storageClass = SpvStorageClassUniform; if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) { @@ -2425,7 +2550,7 @@ struct SPIRVEmitContext storageClass ); maybeEmitPointerDecoration(varInst, param); - if (auto layout = getVarLayout(param)) + if (layout) emitVarLayout(param, varInst, layout); emitDecorations(param, getID(varInst)); return varInst; @@ -3419,8 +3544,20 @@ struct SPIRVEmitContext } } paramsSet.add(spvGlobalInst); - params.add(spvGlobalInst); referencedBuiltinIRVars.add(globalInst); + + // Don't add a global param to the interface if it is a specialization constant. + switch (spvGlobalInst->opcode) + { + case SpvOpSpecConstant: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantComposite: + break; + default: + params.add(spvGlobalInst); + break; + } } } break; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 09cb7952c..f308f340d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -783,6 +783,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) INST(HasExplicitHLSLBindingDecoration, HasExplicitHLSLBinding, 0, 0) + INST(DefaultValueDecoration, DefaultValue, 1, 0) INST(ReadNoneDecoration, readNone, 0, 0) INST(VulkanCallablePayloadDecoration, vulkanCallablePayload, 0, 0) INST(VulkanCallablePayloadInDecoration, vulkanCallablePayloadIn, 0, 0) @@ -956,6 +957,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) INST(ConstructorDecoration, constructor, 1, 0) INST(MethodDecoration, method, 0, 0) INST(PackOffsetDecoration, packoffset, 2, 0) + INST(SpecializationConstantDecoration, SpecializationConstantDecoration, 1, 0) // Reflection metadata for a shader parameter that provides the original type name. INST(UserTypeNameDecoration, UserTypeName, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9b94005d9..8f648aabd 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -330,6 +330,8 @@ IR_SIMPLE_DECORATION(PerVertexDecoration) IR_SIMPLE_DECORATION(SPIRVBlockDecoration) +IR_SIMPLE_DECORATION(DefaultValueDecoration) + struct IRRequireGLSLVersionDecoration : IRDecoration { enum { kOp = kIROp_RequireGLSLVersionDecoration }; @@ -4974,6 +4976,10 @@ public: { addDecoration(value, kIROp_HasExplicitHLSLBindingDecoration); } + void addDefaultValueDecoration(IRInst* value, IRInst* defaultValue) + { + addDecoration(value, kIROp_DefaultValueDecoration, defaultValue); + } void addNVAPIMagicDecoration(IRInst* value, UnownedStringSlice const& name) { addDecoration(value, kIROp_NVAPIMagicDecoration, getStringValue(name)); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 11fbb5b4b..c3d0874e7 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -568,6 +568,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } + // Don't do any processing for specialization constants. + if (addressSpace == AddressSpace::SpecializationConstant) + { + return; + } + // Opaque resource handles can't be in Uniform for Vulkan, if they are // placed here then put them in UniformConstant instead if (isSpirvUniformConstantType(inst->getDataType())) @@ -728,6 +734,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case LayoutResourceKind::PushConstantBuffer: addressSpace = AddressSpace::PushConstant; break; + case LayoutResourceKind::SpecializationConstant: + addressSpace = AddressSpace::SpecializationConstant; + break; case LayoutResourceKind::RayPayload: addressSpace = AddressSpace::IncomingRayPayload; break; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ab9e2a540..f62bb631e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8021,14 +8021,33 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } addTargetIntrinsicDecorations(nullptr, irParam, decl); - if (decl->findModifier<HLSLLayoutSemantic>()) + + bool hasLayoutSemantic = false; + bool isSpecializationConstant = false; + for (auto modifier : decl->modifiers) { - builder->addHasExplicitHLSLBindingDecoration(irParam); + if (as<HLSLLayoutSemantic>(modifier)) + { + hasLayoutSemantic = true; + } + else if (as<SpecializationConstantAttribute>(modifier) || as<VkConstantIdAttribute>(modifier)) + { + isSpecializationConstant = true; + } } + if (hasLayoutSemantic) + builder->addHasExplicitHLSLBindingDecoration(irParam); + // A global variable's SSA value is a *pointer* to // the underlying storage. context->setGlobalValue(decl, paramVal); + if (isSpecializationConstant && decl->initExpr) + { + auto initVal = getSimpleVal(context, lowerRValueExpr(context, decl->initExpr)); + builder->addDefaultValueDecoration(irParam, initVal); + } + irParam->moveToEnd(); return paramVal; diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 6f8504a31..115ccc55e 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -703,20 +703,21 @@ RefPtr<TypeLayout> getTypeLayoutForGlobalShaderParameter( type); } - // TODO: The cases below for detecting globals that aren't actually - // shader parameters should be redundant now that the semantic - // checking logic is responsible for populating the list of - // parameters on a `Program`. We should be able to clean up - // the code by removing these two cases, and the related null - // pointer checks in the code that calls this. - - // HLSL `static` modifier indicates "thread local" - if(varDecl->hasModifier<HLSLStaticModifier>()) - return nullptr; - - // HLSL `groupshared` modifier indicates "thread-group local" - if(varDecl->hasModifier<HLSLGroupSharedModifier>()) - return nullptr; + if (varDecl->hasModifier<SpecializationConstantAttribute>() || + varDecl->hasModifier<VkConstantIdAttribute>()) + { + auto specializationConstantRule = rules->getSpecializationConstantRules(); + if (!specializationConstantRule) + { + // If the target doesn't support specialization constants, then we will + // layout them as ordinary uniform data. + specializationConstantRule = rules->getConstantBufferRules(context->getTargetRequest()->getOptionSet()); + } + return createTypeLayoutWith( + layoutContext, + specializationConstantRule, + type); + } // TODO(tfoley): there may be other cases that we need to handle here @@ -1143,10 +1144,10 @@ static void addExplicitParameterBindings_GLSL( else if(auto foundSpecializationConstant = typeLayout->FindResourceInfo(LayoutResourceKind::SpecializationConstant)) { info[kResInfo].resInfo = foundSpecializationConstant; - DeclRef<Decl> varDecl2(varDecl); - // Try to find `constant_id` binding - if(!findLayoutArg<GLSLConstantIDLayoutModifier>(varDecl2, &info[kResInfo].semanticInfo.index)) + if (auto layoutAttr = varDecl.getDecl()->findModifier<VkConstantIdAttribute>()) + info[kResInfo].semanticInfo.index = layoutAttr->location; + else return; } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index b3a18f8a8..04ada006c 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8190,7 +8190,7 @@ namespace Slang CASE(push_constant, PushConstantAttribute) CASE(shaderRecordNV, ShaderRecordAttribute) CASE(shaderRecordEXT, ShaderRecordAttribute) - CASE(constant_id, GLSLConstantIDLayoutModifier) + CASE(constant_id, VkConstantIdAttribute) CASE(std140, GLSLStd140Modifier) CASE(std430, GLSLStd430Modifier) CASE(scalar, GLSLScalarModifier) @@ -8228,6 +8228,11 @@ namespace Slang parser->diagnose(modifier->loc, Diagnostics::missingLayoutBindingModifier); } } + else if (auto specConstAttr = as<VkConstantIdAttribute>(modifier)) + { + parser->ReadToken(TokenType::OpAssign); + specConstAttr->location = (int)getIntegerLiteralValue(parser->ReadToken(TokenType::IntegerLiteral)); + } listBuilder.add(modifier); } diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h index adf6e26f8..7fc9eeb6b 100644 --- a/source/slang/slang-type-system-shared.h +++ b/source/slang/slang-type-system-shared.h @@ -101,6 +101,8 @@ FOREACH_BASE_TYPE(DEFINE_BASE_TYPE) UniformConstant, // Corresponds to SPIR-V's SpvStorageClassImage Image, + // Represents a SPIR-V specialization constant + SpecializationConstant, // Default address space for a user-defined pointer UserPointer = 0x100000001ULL, diff --git a/tests/language-feature/zero-initialize/struct.slang b/tests/language-feature/zero-initialize/struct.slang index 62f403903..efb1d8fc3 100644 --- a/tests/language-feature/zero-initialize/struct.slang +++ b/tests/language-feature/zero-initialize/struct.slang @@ -2,7 +2,6 @@ // CHECK-COUNT-3: {{.* }}= 0; //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -xslang -zero-initialize -//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -entry computeMain -emit-spirv-directly -allow-glsl -xslang -zero-initialize //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -entry computeMain -allow-glsl -xslang -zero-initialize //TEST(smoke,compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-dx12 -use-dxil -compute -entry computeMain -allow-glsl -xslang -zero-initialize diff --git a/tests/spirv/specialization-constant.slang b/tests/spirv/specialization-constant.slang new file mode 100644 index 000000000..63141d25e --- /dev/null +++ b/tests/spirv/specialization-constant.slang @@ -0,0 +1,48 @@ +//TEST:SIMPLE(filecheck=GLSL): -target glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv + +// CHECK-DAG: OpDecorate %[[C0:[0-9A-Za-z_]+]] SpecId 0 +// CHECK-DAG: %[[C0]] = OpSpecConstant %int 1 + +// CHECK-DAG: OpDecorate %[[C1:[0-9A-Za-z_]+]] SpecId 7 +// CHECK-DAG: %[[C1]] = OpSpecConstant %float 3 + +// CHECK-DAG: OpDecorate %[[C2:[0-9A-Za-z_]+]] SpecId 1 +// CHECK-DAG: %[[C2]] = OpSpecConstantTrue %bool + +// CHECK-DAG: OpDecorate %[[C3:[0-9A-Za-z_]+]] SpecId 9 +// CHECK-DAG: %[[C3]] = OpSpecConstant %int 111 + +// GLSL-DAG: layout(constant_id = 0) +// GLSL-DAG: int constValue0_0 = 1; + +// GLSL-DAG: layout(constant_id = 7) +// GLSL-DAG: float constValue1_0 = 3.0; + +// GLSL-DAG: layout(constant_id = 1) +// GLSL-DAG: bool constValue2_0 = true; + +// GLSL-DAG: layout(constant_id = 9) +// GLSL-DAG: int constValue3_0 = 111; + +[vk::specialization_constant] +const int constValue0 = 1; + +[vk::constant_id(7)] +const float constValue1 = 3.0f; + +[SpecializationConstant] +const bool constValue2 = true; + +layout(constant_id = 9) const int constValue3 = 111; + +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain() +{ + if (constValue2) + outputBuffer[0] = constValue0 + (int)constValue1; + else + outputBuffer[0] = constValue3; +}
\ No newline at end of file |
