summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2025-08-14 12:27:55 -0700
committerGitHub <noreply@github.com>2025-08-14 19:27:55 +0000
commitdd06524f523cdac9c753801ce9c3992f66ae5576 (patch)
treea73cad28075d1beaa04ba50dc4b668f3097c87b2 /source/slang
parentceb2e8d04885d55dd9685a38977a55c4f53f202f (diff)
[Capability System] Fix bug where capabilities do not correctly propegate if AST-parent has target+set the AST-child does not (#8175)
Fixes: #8174 Changes: * To determine if we propagate capabilities, we need to ensure that a `join` will do nothing (optimization since `join` is expensive + caching data for the `join` adds up to be expensive). This logic was changed in `slang-check-decl.cpp` since the current logic was incorrect. * A parent could have the set `metal+glsl` and the use-site could have `glsl`. In this case, we will not remove `metal` from the parent since `{metal+glsl}.implies({glsl})` is true. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/glsl.meta.slang28
-rw-r--r--source/slang/hlsl.meta.slang55
-rw-r--r--source/slang/slang-capability.cpp54
-rw-r--r--source/slang/slang-capability.h24
-rw-r--r--source/slang/slang-check-decl.cpp6
5 files changed, 119 insertions, 48 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index eeaf2a58c..5a78a9960 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -301,7 +301,7 @@ public extension vector<T, 3>
[ForceInline]
[OverloadRank(15)]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
public bool operator==<T:__BuiltinArithmeticType, let N:int>(vector<T, N> left, vector<T, N> right)
{
return all(equal(left, right));
@@ -309,7 +309,7 @@ public bool operator==<T:__BuiltinArithmeticType, let N:int>(vector<T, N> left,
[ForceInline]
[OverloadRank(15)]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
public bool operator!=<T:__BuiltinArithmeticType, let N:int>(vector<T, N> left, vector<T, N> right)
{
return any(notEqual(left, right));
@@ -317,7 +317,7 @@ public bool operator!=<T:__BuiltinArithmeticType, let N:int>(vector<T, N> left,
[ForceInline]
[OverloadRank(14)]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
public bool operator==<T:__BuiltinFloatingPointType, let N:int>(vector<T, N> left, vector<T, N> right)
{
return all(equal(left, right));
@@ -325,7 +325,7 @@ public bool operator==<T:__BuiltinFloatingPointType, let N:int>(vector<T, N> lef
[ForceInline]
[OverloadRank(14)]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
public bool operator!=<T:__BuiltinFloatingPointType, let N:int>(vector<T, N> left, vector<T, N> right)
{
return any(notEqual(left, right));
@@ -333,7 +333,7 @@ public bool operator!=<T:__BuiltinFloatingPointType, let N:int>(vector<T, N> lef
[ForceInline]
[OverloadRank(14)]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
public bool operator==<T:__BuiltinLogicalType, let N:int>(vector<T, N> left, vector<T, N> right)
{
return all(equal(left, right));
@@ -341,7 +341,7 @@ public bool operator==<T:__BuiltinLogicalType, let N:int>(vector<T, N> left, vec
[ForceInline]
[OverloadRank(14)]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
public bool operator!=<T:__BuiltinLogicalType, let N:int>(vector<T, N> left, vector<T, N> right)
{
return any(notEqual(left, right));
@@ -354,7 +354,7 @@ for (auto type : kBaseTypes) {
}}}}
[ForceInline]
[OverloadRank(15)]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
public bool operator==<let N:int>(vector<$(typeName), N> left, vector<$(typeName), N> right)
{
return all(equal(left, right));
@@ -362,7 +362,7 @@ public bool operator==<let N:int>(vector<$(typeName), N> left, vector<$(typeName
[ForceInline]
[OverloadRank(15)]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
public bool operator!=<let N:int>(vector<$(typeName), N> left, vector<$(typeName), N> right)
{
return any(notEqual(left, right));
@@ -6264,7 +6264,7 @@ public void traceRayMotionNV(
// GL_KHR_shader_subgroup_basic Built-in Variables
-[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
void requireGLSLExtForSubgroupBasicBuiltin() {
__target_switch
{
@@ -6288,7 +6288,7 @@ void setupExtForSubgroupBasicBuiltIn() {
}
__spirv_version(1.3)
-[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
void requireGLSLExtForSubgroupBallotBuiltin() {
__target_switch
{
@@ -6301,7 +6301,7 @@ void requireGLSLExtForSubgroupBallotBuiltin() {
}
__spirv_version(1.3)
-[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
void setupExtForSubgroupBallotBuiltIn() {
__target_switch
{
@@ -8204,7 +8204,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv, subgroup_quad)]
public T subgroupQuadBroadcast(T value, uint id)
{
shader_subgroup_preamble<T>();
@@ -8247,9 +8247,9 @@ public T subgroupQuadSwapDiagonal(T value)
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
-__glsl_extension(GL_KHR_shader_subgroup_quad)
+ __glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+ [require(glsl_hlsl_metal_spirv, subgroup_quad)]
public vector<T,N> subgroupQuadBroadcast(vector<T,N> value, uint id)
{
shader_subgroup_preamble<T>();
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index b8243d6e4..740f48320 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -4804,11 +4804,11 @@ __intrinsic_op($(kIROp_ByteAddressBufferLoad))
T __byteAddressBufferLoad<T>(RasterizerOrderedByteAddressBuffer buffer, uint offset, uint alignment);
__intrinsic_op($(kIROp_ByteAddressBufferStore))
-[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)]
void __byteAddressBufferStore<T>(RWByteAddressBuffer buffer, uint offset, uint alignment, T value);
__intrinsic_op($(kIROp_ByteAddressBufferStore))
-[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, byteaddressbuffer_rw)]
void __byteAddressBufferStore<T>(RasterizerOrderedByteAddressBuffer buffer, uint offset, uint alignment, T value);
__intrinsic_op($(kIROp_GetUntypedBufferPtr))
@@ -5095,7 +5095,7 @@ struct $(item.name)
/// When targeting non-HLSL, the status is always 0.
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint2 Load2(uint location)
{
__target_switch
@@ -5108,7 +5108,7 @@ struct $(item.name)
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint2 Load2Aligned(uint location, uint alignment)
{
__target_switch
@@ -5125,7 +5125,7 @@ struct $(item.name)
///@return `uint2` Two 32-bit unsigned integers loaded from the buffer.
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint2 Load2Aligned(uint location)
{
__target_switch
@@ -5163,7 +5163,7 @@ struct $(item.name)
/// When targeting non-HLSL, the status is always 0.
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint3 Load3(uint location)
{
__target_switch
@@ -5176,7 +5176,7 @@ struct $(item.name)
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint3 Load3Aligned(uint location, uint alignment)
{
__target_switch
@@ -5193,7 +5193,7 @@ struct $(item.name)
///@return `uint3` Three 32-bit unsigned integer value loaded from the buffer.
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint3 Load3Aligned(uint location)
{
__target_switch
@@ -5230,7 +5230,7 @@ struct $(item.name)
/// If any values were taken from an unmapped tile, `CheckAccessFullyMapped` returns FALSE.
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint4 Load4(uint location)
{
__target_switch
@@ -5243,7 +5243,7 @@ struct $(item.name)
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint4 Load4Aligned(uint location, uint alignment)
{
__target_switch
@@ -5260,7 +5260,7 @@ struct $(item.name)
///@return `uint4` Four 32-bit unsigned integer value loaded from the buffer.
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
uint4 Load4Aligned(uint location)
{
__target_switch
@@ -5284,7 +5284,7 @@ struct $(item.name)
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
T Load<T>(uint location)
{
return __byteAddressBufferLoad<T>(this, location, 0);
@@ -5292,7 +5292,7 @@ struct $(item.name)
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
T LoadAligned<T>(uint location, uint alignment)
{
return __byteAddressBufferLoad<T>(this, location, alignment);
@@ -5305,7 +5305,7 @@ struct $(item.name)
///Currently, this function only supports when `T` is scalar, vector, or matrix type.
[__NoSideEffect]
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
T LoadAligned<T>(uint location)
{
return __byteAddressBufferLoad<T>(this, location, __naturalStrideOf<T>());
@@ -5828,7 +5828,7 @@ ${{{{
///@param address The input address in bytes, which must be a multiple of 4.
///@param alignment Specifies the alignment of the location, which must be a multiple of 4.
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store(uint address, uint value)
{
__target_switch
@@ -5845,7 +5845,7 @@ ${{{{
///@param value Two input values.
///@param alignment Specifies the alignment of the location, which must be a multiple of 4.
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store2(uint address, uint2 value)
{
__target_switch
@@ -5858,7 +5858,7 @@ ${{{{
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store2(uint address, uint2 value, uint alignment)
{
__target_switch
@@ -5874,7 +5874,7 @@ ${{{{
///@param address The input address in bytes, which must be a multiple of 8.
///@param value Two input values.
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store2Aligned(uint address, uint2 value)
{
__target_switch
@@ -5890,7 +5890,7 @@ ${{{{
///@param value Three input values.
///@param alignment Specifies the alignment of the location, which must be a multiple of 4.
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store3(uint address, uint3 value)
{
__target_switch
@@ -5902,7 +5902,7 @@ ${{{{
}
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store3(uint address, uint3 value, uint alignment)
{
__target_switch
@@ -5918,7 +5918,7 @@ ${{{{
///@param address The input address in bytes, which must be a multiple of 12.
///@param value Three input values.
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store3Aligned(uint address, uint3 value)
{
__target_switch
@@ -5934,7 +5934,7 @@ ${{{{
///@param value Four input values.
///@param alignment Specifies the alignment of the location, which must be a multiple of 4.
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store4(uint address, uint4 value)
{
__target_switch
@@ -5947,7 +5947,7 @@ ${{{{
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store4(uint address, uint4 value, uint alignment)
{
__target_switch
@@ -5963,7 +5963,7 @@ ${{{{
///@param address The input address in bytes, which must be a multiple of 16.
///@param value Four input values.
[ForceInline]
- [require(cpp_cuda_glsl_hlsl_metal_spirv)]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store4Aligned(uint address, uint4 value)
{
__target_switch
@@ -5975,12 +5975,14 @@ ${{{{
}
[ForceInline]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store<T>(uint address, T value)
{
__byteAddressBufferStore(this, address, 0, value);
}
[ForceInline]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void Store<T>(uint address, T value, uint alignment)
{
__byteAddressBufferStore(this, address, alignment, value);
@@ -5992,6 +5994,7 @@ ${{{{
///@param address The input address in bytes, which must be a multiple of size of `T`.
///@param value The input value.
[ForceInline]
+ [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
void StoreAligned<T>(uint address, T value)
{
__byteAddressBufferStore(this, address, __naturalStrideOf<T>(), value);
@@ -8806,7 +8809,7 @@ vector<T,N> fdim(vector<T,N> x, vector<T,N> y)
/// @category math
__generic<T : __BuiltinFloatingPointType>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
T divide(T x, T y)
{
__target_switch
@@ -8819,7 +8822,7 @@ T divide(T x, T y)
__generic<T : __BuiltinFloatingPointType, let N: int>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
vector<T,N> divide(vector<T,N> x, vector<T,N> y)
{
__target_switch
diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp
index a965ecd93..0c883b3b8 100644
--- a/source/slang/slang-capability.cpp
+++ b/source/slang/slang-capability.cpp
@@ -529,36 +529,68 @@ bool CapabilitySet::implies(CapabilityAtom atom) const
return this->implies(tmpSet);
}
+// Implication depends heavily on context as per the `ImpliesFlags`.
CapabilitySet::ImpliesReturnFlags CapabilitySet::_implies(
CapabilitySet const& otherSet,
ImpliesFlags flags) const
{
- // x implies (c | d) only if (x implies c) and (x implies d).
+ // By default (`ImpliesFlags::None`): x implies (c | d) only if (x implies c) and (x implies d).
bool onlyRequireSingleImply = ((int)flags & (int)ImpliesFlags::OnlyRequireASingleValidImply);
+ bool cannotHaveMoreTargetAndStageSets =
+ ((int)flags & (int)ImpliesFlags::CannotHaveMoreTargetAndStageSets);
+ bool canHaveSubsetOfTargetAndStageSets =
+ ((int)flags & (int)ImpliesFlags::CanHaveSubsetOfTargetAndStageSets);
+
int flagsCollected = (int)CapabilitySet::ImpliesReturnFlags::NotImplied;
if (otherSet.isEmpty())
return CapabilitySet::ImpliesReturnFlags::Implied;
- for (const auto& otherTarget : otherSet.m_targetSets)
+ // If empty, and the other is not empty, it does not matter what flags are used,
+ // `this` is considered to "not imply" another set. This is important since
+ // `T.join(U)` causes `T == U`.
+ if (this->isEmpty())
+ return CapabilitySet::ImpliesReturnFlags::NotImplied;
+
+ if (cannotHaveMoreTargetAndStageSets &&
+ this->getCapabilityTargetSets().getCount() > otherSet.getCapabilityTargetSets().getCount())
{
- auto thisTarget = this->m_targetSets.tryGetValue(otherTarget.first);
+ return CapabilitySet::ImpliesReturnFlags::NotImplied;
+ }
+
+ for (const auto& otherTargetPair : otherSet.m_targetSets)
+ {
+ auto thisTarget = this->m_targetSets.tryGetValue(otherTargetPair.first);
+ const auto& otherTarget = otherTargetPair.second;
if (!thisTarget)
{
if (onlyRequireSingleImply)
continue;
+
+ if (canHaveSubsetOfTargetAndStageSets)
+ continue;
// 'this' lacks a target 'other' has.
return CapabilitySet::ImpliesReturnFlags::NotImplied;
}
- for (const auto& otherStage : otherTarget.second.shaderStageSets)
+ if (cannotHaveMoreTargetAndStageSets && thisTarget->getShaderStageSets().getCount() >
+ otherTarget.getShaderStageSets().getCount())
+ {
+ return CapabilitySet::ImpliesReturnFlags::NotImplied;
+ }
+
+ for (const auto& otherStagePair : otherTarget.getShaderStageSets())
{
- auto thisStage = thisTarget->shaderStageSets.tryGetValue(otherStage.first);
+ auto thisStage = thisTarget->shaderStageSets.tryGetValue(otherStagePair.first);
+ const auto& otherStage = otherStagePair.second;
if (!thisStage)
{
if (onlyRequireSingleImply)
continue;
+
+ if (canHaveSubsetOfTargetAndStageSets)
+ continue;
// 'this' lacks a stage 'other' has.
return CapabilitySet::ImpliesReturnFlags::NotImplied;
}
@@ -567,9 +599,9 @@ CapabilitySet::ImpliesReturnFlags CapabilitySet::_implies(
if (thisStage->atomSet)
{
auto& thisStageSet = thisStage->atomSet.value();
- if (otherStage.second.atomSet)
+ if (otherStage.atomSet)
{
- auto contained = thisStageSet.contains(otherStage.second.atomSet.value());
+ auto contained = thisStageSet.contains(otherStage.atomSet.value());
if (!onlyRequireSingleImply && !contained)
{
return CapabilitySet::ImpliesReturnFlags::NotImplied;
@@ -593,12 +625,20 @@ bool CapabilitySet::implies(CapabilitySet const& other) const
return (int)_implies(other, ImpliesFlags::None) &
(int)CapabilitySet::ImpliesReturnFlags::Implied;
}
+
CapabilitySet::ImpliesReturnFlags CapabilitySet::atLeastOneSetImpliedInOther(
CapabilitySet const& other) const
{
return _implies(other, ImpliesFlags::OnlyRequireASingleValidImply);
}
+bool CapabilitySet::joinWithOtherWillChangeThis(CapabilitySet const& other) const
+{
+ return !(
+ (int)_implies(other, ImpliesFlags::CannotHaveMoreTargetAndStageSets) &
+ (int)CapabilitySet::ImpliesReturnFlags::Implied);
+}
+
void CapabilityTargetSet::unionWith(const CapabilityTargetSet& other)
{
for (auto otherStageSet : other.shaderStageSets)
diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h
index 43f933620..c72d1d7d6 100644
--- a/source/slang/slang-capability.h
+++ b/source/slang/slang-capability.h
@@ -168,9 +168,12 @@ public:
Implied = 1 << 0,
};
/// Does this capability set imply all the capabilities in `other`?
+ /// `this` can have excess target+stage sets.
bool implies(CapabilitySet const& other) const;
/// Does this capability set imply at least 1 set in other.
ImpliesReturnFlags atLeastOneSetImpliedInOther(CapabilitySet const& other) const;
+ /// Will a `join` with `other` change `this`?
+ bool joinWithOtherWillChangeThis(CapabilitySet const& other) const;
/// Does this capability set imply the atomic capability `other`?
bool implies(CapabilityAtom other) const;
@@ -353,8 +356,29 @@ private:
enum class ImpliesFlags
{
+ // All permutations of target+stage from `other` must be implied by a target+stage
+ // in `this`.
None = 0,
+ // Given a single target+stage permutation, if 1 permutation is implied in `other`,
+ // return true.
OnlyRequireASingleValidImply = 1 << 0,
+ // The target+stage permuations in `this` cannot have extra permutations
+ // relative to `other`.
+ // Ex: `{metal|glsl}.implies({glsl})` is false
+ // `{glsl}.implies({glsl|metal})` is false
+ // `{glsl}.implies({glsl|glsl})` is true
+ CannotHaveMoreTargetAndStageSets = 1 << 1,
+ // The target+stage permuations in `this` can have less permutations
+ // than `other`. This means, only for the shared permutations of `this`
+ // and `other` does `thisSet[target][stage].imply(otherSet)` have to be
+ // true.
+ // If `this` is empty, `this` is not able to imply `other` unless `other`
+ // is empty.
+ // Ex: `{glsl}.implies({glsl|metal})` is true since we only compare shared-permutations.
+ CanHaveSubsetOfTargetAndStageSets = 1 << 2,
+
+ WillAJoinWithOtherModifyThis =
+ CannotHaveMoreTargetAndStageSets | CanHaveSubsetOfTargetAndStageSets
};
ImpliesReturnFlags _implies(CapabilitySet const& other, ImpliesFlags flags) const;
};
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 0a9853012..507e12fa6 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -14268,7 +14268,11 @@ static void _propagateRequirement(
ensureDecl(visitor, referencedDecl, DeclCheckState::CapabilityChecked);
}
- if (resultCaps.implies(nodeCaps))
+ // If we do not have the same target+stage, we need to `join` to remove excess target+stage.
+ //
+ // If we have the same target+stage but current capabilities do not imply incoming capabilities,
+ // we need to `join`.
+ if (!resultCaps.joinWithOtherWillChangeThis(nodeCaps))
return;
auto oldCaps = resultCaps;