diff options
| author | ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> | 2025-08-14 12:27:55 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-14 19:27:55 +0000 |
| commit | dd06524f523cdac9c753801ce9c3992f66ae5576 (patch) | |
| tree | a73cad28075d1beaa04ba50dc4b668f3097c87b2 /source/slang | |
| parent | ceb2e8d04885d55dd9685a38977a55c4f53f202f (diff) | |
[Capability System] Fix bug where capabilities do not correctly propegate if AST-parent has target+set the AST-child does not (#8175)
Fixes: #8174
Changes:
* To determine if we propagate capabilities, we need to ensure that a
`join` will do nothing (optimization since `join` is expensive + caching
data for the `join` adds up to be expensive). This logic was changed in
`slang-check-decl.cpp` since the current logic was incorrect.
* A parent could have the set `metal+glsl` and the use-site could have
`glsl`. In this case, we will not remove `metal` from the parent since
`{metal+glsl}.implies({glsl})` is true.
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/glsl.meta.slang | 28 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 55 | ||||
| -rw-r--r-- | source/slang/slang-capability.cpp | 54 | ||||
| -rw-r--r-- | source/slang/slang-capability.h | 24 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 6 |
5 files changed, 119 insertions, 48 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index eeaf2a58c..5a78a9960 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -301,7 +301,7 @@ public extension vector<T, 3> [ForceInline] [OverloadRank(15)] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] public bool operator==<T:__BuiltinArithmeticType, let N:int>(vector<T, N> left, vector<T, N> right) { return all(equal(left, right)); @@ -309,7 +309,7 @@ public bool operator==<T:__BuiltinArithmeticType, let N:int>(vector<T, N> left, [ForceInline] [OverloadRank(15)] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] public bool operator!=<T:__BuiltinArithmeticType, let N:int>(vector<T, N> left, vector<T, N> right) { return any(notEqual(left, right)); @@ -317,7 +317,7 @@ public bool operator!=<T:__BuiltinArithmeticType, let N:int>(vector<T, N> left, [ForceInline] [OverloadRank(14)] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] public bool operator==<T:__BuiltinFloatingPointType, let N:int>(vector<T, N> left, vector<T, N> right) { return all(equal(left, right)); @@ -325,7 +325,7 @@ public bool operator==<T:__BuiltinFloatingPointType, let N:int>(vector<T, N> lef [ForceInline] [OverloadRank(14)] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] public bool operator!=<T:__BuiltinFloatingPointType, let N:int>(vector<T, N> left, vector<T, N> right) { return any(notEqual(left, right)); @@ -333,7 +333,7 @@ public bool operator!=<T:__BuiltinFloatingPointType, let N:int>(vector<T, N> lef [ForceInline] [OverloadRank(14)] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] public bool operator==<T:__BuiltinLogicalType, let N:int>(vector<T, N> left, vector<T, N> right) { return all(equal(left, right)); @@ -341,7 +341,7 @@ public bool operator==<T:__BuiltinLogicalType, let N:int>(vector<T, N> left, vec [ForceInline] [OverloadRank(14)] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] public bool operator!=<T:__BuiltinLogicalType, let N:int>(vector<T, N> left, vector<T, N> right) { return any(notEqual(left, right)); @@ -354,7 +354,7 @@ for (auto type : kBaseTypes) { }}}} [ForceInline] [OverloadRank(15)] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] public bool operator==<let N:int>(vector<$(typeName), N> left, vector<$(typeName), N> right) { return all(equal(left, right)); @@ -362,7 +362,7 @@ public bool operator==<let N:int>(vector<$(typeName), N> left, vector<$(typeName [ForceInline] [OverloadRank(15)] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] public bool operator!=<let N:int>(vector<$(typeName), N> left, vector<$(typeName), N> right) { return any(notEqual(left, right)); @@ -6264,7 +6264,7 @@ public void traceRayMotionNV( // GL_KHR_shader_subgroup_basic Built-in Variables -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] void requireGLSLExtForSubgroupBasicBuiltin() { __target_switch { @@ -6288,7 +6288,7 @@ void setupExtForSubgroupBasicBuiltIn() { } __spirv_version(1.3) -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)] void requireGLSLExtForSubgroupBallotBuiltin() { __target_switch { @@ -6301,7 +6301,7 @@ void requireGLSLExtForSubgroupBallotBuiltin() { } __spirv_version(1.3) -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)] void setupExtForSubgroupBallotBuiltIn() { __target_switch { @@ -8204,7 +8204,7 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv, subgroup_quad)] +[require(glsl_hlsl_metal_spirv, subgroup_quad)] public T subgroupQuadBroadcast(T value, uint id) { shader_subgroup_preamble<T>(); @@ -8247,9 +8247,9 @@ public T subgroupQuadSwapDiagonal(T value) __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) -__glsl_extension(GL_KHR_shader_subgroup_quad) + __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv, subgroup_quad)] + [require(glsl_hlsl_metal_spirv, subgroup_quad)] public vector<T,N> subgroupQuadBroadcast(vector<T,N> value, uint id) { shader_subgroup_preamble<T>(); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index b8243d6e4..740f48320 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -4804,11 +4804,11 @@ __intrinsic_op($(kIROp_ByteAddressBufferLoad)) T __byteAddressBufferLoad<T>(RasterizerOrderedByteAddressBuffer buffer, uint offset, uint alignment); __intrinsic_op($(kIROp_ByteAddressBufferStore)) -[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)] void __byteAddressBufferStore<T>(RWByteAddressBuffer buffer, uint offset, uint alignment, T value); __intrinsic_op($(kIROp_ByteAddressBufferStore)) -[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)] void __byteAddressBufferStore<T>(RasterizerOrderedByteAddressBuffer buffer, uint offset, uint alignment, T value); __intrinsic_op($(kIROp_GetUntypedBufferPtr)) @@ -5095,7 +5095,7 @@ struct $(item.name) /// When targeting non-HLSL, the status is always 0. [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint2 Load2(uint location) { __target_switch @@ -5108,7 +5108,7 @@ struct $(item.name) [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint2 Load2Aligned(uint location, uint alignment) { __target_switch @@ -5125,7 +5125,7 @@ struct $(item.name) ///@return `uint2` Two 32-bit unsigned integers loaded from the buffer. [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint2 Load2Aligned(uint location) { __target_switch @@ -5163,7 +5163,7 @@ struct $(item.name) /// When targeting non-HLSL, the status is always 0. [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint3 Load3(uint location) { __target_switch @@ -5176,7 +5176,7 @@ struct $(item.name) [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint3 Load3Aligned(uint location, uint alignment) { __target_switch @@ -5193,7 +5193,7 @@ struct $(item.name) ///@return `uint3` Three 32-bit unsigned integer value loaded from the buffer. [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint3 Load3Aligned(uint location) { __target_switch @@ -5230,7 +5230,7 @@ struct $(item.name) /// If any values were taken from an unmapped tile, `CheckAccessFullyMapped` returns FALSE. [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint4 Load4(uint location) { __target_switch @@ -5243,7 +5243,7 @@ struct $(item.name) [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint4 Load4Aligned(uint location, uint alignment) { __target_switch @@ -5260,7 +5260,7 @@ struct $(item.name) ///@return `uint4` Four 32-bit unsigned integer value loaded from the buffer. [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] uint4 Load4Aligned(uint location) { __target_switch @@ -5284,7 +5284,7 @@ struct $(item.name) [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] T Load<T>(uint location) { return __byteAddressBufferLoad<T>(this, location, 0); @@ -5292,7 +5292,7 @@ struct $(item.name) [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] T LoadAligned<T>(uint location, uint alignment) { return __byteAddressBufferLoad<T>(this, location, alignment); @@ -5305,7 +5305,7 @@ struct $(item.name) ///Currently, this function only supports when `T` is scalar, vector, or matrix type. [__NoSideEffect] [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] T LoadAligned<T>(uint location) { return __byteAddressBufferLoad<T>(this, location, __naturalStrideOf<T>()); @@ -5828,7 +5828,7 @@ ${{{{ ///@param address The input address in bytes, which must be a multiple of 4. ///@param alignment Specifies the alignment of the location, which must be a multiple of 4. [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store(uint address, uint value) { __target_switch @@ -5845,7 +5845,7 @@ ${{{{ ///@param value Two input values. ///@param alignment Specifies the alignment of the location, which must be a multiple of 4. [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store2(uint address, uint2 value) { __target_switch @@ -5858,7 +5858,7 @@ ${{{{ [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store2(uint address, uint2 value, uint alignment) { __target_switch @@ -5874,7 +5874,7 @@ ${{{{ ///@param address The input address in bytes, which must be a multiple of 8. ///@param value Two input values. [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store2Aligned(uint address, uint2 value) { __target_switch @@ -5890,7 +5890,7 @@ ${{{{ ///@param value Three input values. ///@param alignment Specifies the alignment of the location, which must be a multiple of 4. [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store3(uint address, uint3 value) { __target_switch @@ -5902,7 +5902,7 @@ ${{{{ } [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store3(uint address, uint3 value, uint alignment) { __target_switch @@ -5918,7 +5918,7 @@ ${{{{ ///@param address The input address in bytes, which must be a multiple of 12. ///@param value Three input values. [ForceInline] - [require(cpp_cuda_glsl_hlsl_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store3Aligned(uint address, uint3 value) { __target_switch @@ -5934,7 +5934,7 @@ ${{{{ ///@param value Four input values. ///@param alignment Specifies the alignment of the location, which must be a multiple of 4. [ForceInline] - [require(cpp_cuda_glsl_hlsl_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store4(uint address, uint4 value) { __target_switch @@ -5947,7 +5947,7 @@ ${{{{ [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store4(uint address, uint4 value, uint alignment) { __target_switch @@ -5963,7 +5963,7 @@ ${{{{ ///@param address The input address in bytes, which must be a multiple of 16. ///@param value Four input values. [ForceInline] - [require(cpp_cuda_glsl_hlsl_metal_spirv)] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store4Aligned(uint address, uint4 value) { __target_switch @@ -5975,12 +5975,14 @@ ${{{{ } [ForceInline] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store<T>(uint address, T value) { __byteAddressBufferStore(this, address, 0, value); } [ForceInline] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void Store<T>(uint address, T value, uint alignment) { __byteAddressBufferStore(this, address, alignment, value); @@ -5992,6 +5994,7 @@ ${{{{ ///@param address The input address in bytes, which must be a multiple of size of `T`. ///@param value The input value. [ForceInline] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] void StoreAligned<T>(uint address, T value) { __byteAddressBufferStore(this, address, __naturalStrideOf<T>(), value); @@ -8806,7 +8809,7 @@ vector<T,N> fdim(vector<T,N> x, vector<T,N> y) /// @category math __generic<T : __BuiltinFloatingPointType> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] T divide(T x, T y) { __target_switch @@ -8819,7 +8822,7 @@ T divide(T x, T y) __generic<T : __BuiltinFloatingPointType, let N: int> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] vector<T,N> divide(vector<T,N> x, vector<T,N> y) { __target_switch diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp index a965ecd93..0c883b3b8 100644 --- a/source/slang/slang-capability.cpp +++ b/source/slang/slang-capability.cpp @@ -529,36 +529,68 @@ bool CapabilitySet::implies(CapabilityAtom atom) const return this->implies(tmpSet); } +// Implication depends heavily on context as per the `ImpliesFlags`. CapabilitySet::ImpliesReturnFlags CapabilitySet::_implies( CapabilitySet const& otherSet, ImpliesFlags flags) const { - // x implies (c | d) only if (x implies c) and (x implies d). + // By default (`ImpliesFlags::None`): x implies (c | d) only if (x implies c) and (x implies d). bool onlyRequireSingleImply = ((int)flags & (int)ImpliesFlags::OnlyRequireASingleValidImply); + bool cannotHaveMoreTargetAndStageSets = + ((int)flags & (int)ImpliesFlags::CannotHaveMoreTargetAndStageSets); + bool canHaveSubsetOfTargetAndStageSets = + ((int)flags & (int)ImpliesFlags::CanHaveSubsetOfTargetAndStageSets); + int flagsCollected = (int)CapabilitySet::ImpliesReturnFlags::NotImplied; if (otherSet.isEmpty()) return CapabilitySet::ImpliesReturnFlags::Implied; - for (const auto& otherTarget : otherSet.m_targetSets) + // If empty, and the other is not empty, it does not matter what flags are used, + // `this` is considered to "not imply" another set. This is important since + // `T.join(U)` causes `T == U`. + if (this->isEmpty()) + return CapabilitySet::ImpliesReturnFlags::NotImplied; + + if (cannotHaveMoreTargetAndStageSets && + this->getCapabilityTargetSets().getCount() > otherSet.getCapabilityTargetSets().getCount()) { - auto thisTarget = this->m_targetSets.tryGetValue(otherTarget.first); + return CapabilitySet::ImpliesReturnFlags::NotImplied; + } + + for (const auto& otherTargetPair : otherSet.m_targetSets) + { + auto thisTarget = this->m_targetSets.tryGetValue(otherTargetPair.first); + const auto& otherTarget = otherTargetPair.second; if (!thisTarget) { if (onlyRequireSingleImply) continue; + + if (canHaveSubsetOfTargetAndStageSets) + continue; // 'this' lacks a target 'other' has. return CapabilitySet::ImpliesReturnFlags::NotImplied; } - for (const auto& otherStage : otherTarget.second.shaderStageSets) + if (cannotHaveMoreTargetAndStageSets && thisTarget->getShaderStageSets().getCount() > + otherTarget.getShaderStageSets().getCount()) + { + return CapabilitySet::ImpliesReturnFlags::NotImplied; + } + + for (const auto& otherStagePair : otherTarget.getShaderStageSets()) { - auto thisStage = thisTarget->shaderStageSets.tryGetValue(otherStage.first); + auto thisStage = thisTarget->shaderStageSets.tryGetValue(otherStagePair.first); + const auto& otherStage = otherStagePair.second; if (!thisStage) { if (onlyRequireSingleImply) continue; + + if (canHaveSubsetOfTargetAndStageSets) + continue; // 'this' lacks a stage 'other' has. return CapabilitySet::ImpliesReturnFlags::NotImplied; } @@ -567,9 +599,9 @@ CapabilitySet::ImpliesReturnFlags CapabilitySet::_implies( if (thisStage->atomSet) { auto& thisStageSet = thisStage->atomSet.value(); - if (otherStage.second.atomSet) + if (otherStage.atomSet) { - auto contained = thisStageSet.contains(otherStage.second.atomSet.value()); + auto contained = thisStageSet.contains(otherStage.atomSet.value()); if (!onlyRequireSingleImply && !contained) { return CapabilitySet::ImpliesReturnFlags::NotImplied; @@ -593,12 +625,20 @@ bool CapabilitySet::implies(CapabilitySet const& other) const return (int)_implies(other, ImpliesFlags::None) & (int)CapabilitySet::ImpliesReturnFlags::Implied; } + CapabilitySet::ImpliesReturnFlags CapabilitySet::atLeastOneSetImpliedInOther( CapabilitySet const& other) const { return _implies(other, ImpliesFlags::OnlyRequireASingleValidImply); } +bool CapabilitySet::joinWithOtherWillChangeThis(CapabilitySet const& other) const +{ + return !( + (int)_implies(other, ImpliesFlags::CannotHaveMoreTargetAndStageSets) & + (int)CapabilitySet::ImpliesReturnFlags::Implied); +} + void CapabilityTargetSet::unionWith(const CapabilityTargetSet& other) { for (auto otherStageSet : other.shaderStageSets) diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h index 43f933620..c72d1d7d6 100644 --- a/source/slang/slang-capability.h +++ b/source/slang/slang-capability.h @@ -168,9 +168,12 @@ public: Implied = 1 << 0, }; /// Does this capability set imply all the capabilities in `other`? + /// `this` can have excess target+stage sets. bool implies(CapabilitySet const& other) const; /// Does this capability set imply at least 1 set in other. ImpliesReturnFlags atLeastOneSetImpliedInOther(CapabilitySet const& other) const; + /// Will a `join` with `other` change `this`? + bool joinWithOtherWillChangeThis(CapabilitySet const& other) const; /// Does this capability set imply the atomic capability `other`? bool implies(CapabilityAtom other) const; @@ -353,8 +356,29 @@ private: enum class ImpliesFlags { + // All permutations of target+stage from `other` must be implied by a target+stage + // in `this`. None = 0, + // Given a single target+stage permutation, if 1 permutation is implied in `other`, + // return true. OnlyRequireASingleValidImply = 1 << 0, + // The target+stage permuations in `this` cannot have extra permutations + // relative to `other`. + // Ex: `{metal|glsl}.implies({glsl})` is false + // `{glsl}.implies({glsl|metal})` is false + // `{glsl}.implies({glsl|glsl})` is true + CannotHaveMoreTargetAndStageSets = 1 << 1, + // The target+stage permuations in `this` can have less permutations + // than `other`. This means, only for the shared permutations of `this` + // and `other` does `thisSet[target][stage].imply(otherSet)` have to be + // true. + // If `this` is empty, `this` is not able to imply `other` unless `other` + // is empty. + // Ex: `{glsl}.implies({glsl|metal})` is true since we only compare shared-permutations. + CanHaveSubsetOfTargetAndStageSets = 1 << 2, + + WillAJoinWithOtherModifyThis = + CannotHaveMoreTargetAndStageSets | CanHaveSubsetOfTargetAndStageSets }; ImpliesReturnFlags _implies(CapabilitySet const& other, ImpliesFlags flags) const; }; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 0a9853012..507e12fa6 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -14268,7 +14268,11 @@ static void _propagateRequirement( ensureDecl(visitor, referencedDecl, DeclCheckState::CapabilityChecked); } - if (resultCaps.implies(nodeCaps)) + // If we do not have the same target+stage, we need to `join` to remove excess target+stage. + // + // If we have the same target+stage but current capabilities do not imply incoming capabilities, + // we need to `join`. + if (!resultCaps.joinWithOtherWillChangeThis(nodeCaps)) return; auto oldCaps = resultCaps; |
