From e036c7f38e3b8b1f0b0e0ac1b4ef22fc6f16963b Mon Sep 17 00:00:00 2001 From: ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> Date: Tue, 8 Jul 2025 17:18:02 -0700 Subject: Fix `extension` incorrectly interacting with `equality` and `type-coercion` constraints (#7578) * fix problem * cleanup comment * format code * make change more restrictive * format code * push logic update * format code * push test fix * make test more general --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- source/slang/slang-check-constraint.cpp | 4 ++ source/slang/slang-check-decl.cpp | 4 +- source/slang/slang-check-impl.h | 6 +++ .../extensions/extension-with-where-clause-1.slang | 49 +++++++++++++++++++++ .../extensions/extension-with-where-clause-2.slang | 50 ++++++++++++++++++++++ .../extensions/extension-with-where-clause-3.slang | 41 ++++++++++++++++++ 6 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 tests/language-feature/extensions/extension-with-where-clause-1.slang create mode 100644 tests/language-feature/extensions/extension-with-where-clause-2.slang create mode 100644 tests/language-feature/extensions/extension-with-where-clause-3.slang diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 7c55c440c..7c6f8929a 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -363,6 +363,7 @@ DeclRef SemanticsVisitor::trySolveConstraintSystem( ValUnificationContext unificationContext; unificationContext.optionalConstraint = constraintDeclRef.getDecl()->hasModifier(); + unificationContext.equalityConstraint = constraintDeclRef.getDecl()->isEqualityConstraint; if (!TryUnifyTypes( *system, unificationContext, @@ -492,6 +493,8 @@ DeclRef SemanticsVisitor::trySolveConstraintSystem( { if (c.isOptional) joinType = type; + else if (c.isEquality) + joinType = type; else // failure! return DeclRef(); @@ -970,6 +973,7 @@ bool SemanticsVisitor::TryUnifyTypeParam( constraint.val = type; constraint.isUsedAsLValue = type.isLeftValue; constraint.isOptional = unificationContext.optionalConstraint; + constraint.isEquality = unificationContext.equalityConstraint; constraints.constraints.add(constraint); return true; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 558834c34..5de83e24d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3398,9 +3398,9 @@ void SemanticsDeclHeaderVisitor::visitGenericDecl(GenericDecl* genericDecl) ensureDecl(valParam, DeclCheckState::ReadyForReference); valParam->parameterIndex = parameterIndex++; } - else if (auto constraint = as(m)) + else if (as(m) || as(m)) { - ensureDecl(constraint, DeclCheckState::ReadyForReference); + ensureDecl(m, DeclCheckState::ReadyForReference); } } } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 30e317401..4a6ccfe17 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2305,6 +2305,11 @@ public: // if it is otherwise unconstrained, but doesn't take precedence over a constraint that is // not optional. bool isOptional = false; + + // Is this constraint an equality? This tells us that "joining" types is meaningless, we + // know the result will be the sub type. If it is not, we will error once we start + // substituting types. + bool isEquality = false; }; // A collection of constraints that will need to be satisfied (solved) @@ -2658,6 +2663,7 @@ public: { Index indexInTypePack = 0; bool optionalConstraint = false; + bool equalityConstraint = false; }; // Try to find a unification for two values diff --git a/tests/language-feature/extensions/extension-with-where-clause-1.slang b/tests/language-feature/extensions/extension-with-where-clause-1.slang new file mode 100644 index 000000000..9facb8aa3 --- /dev/null +++ b/tests/language-feature/extensions/extension-with-where-clause-1.slang @@ -0,0 +1,49 @@ +//TEST:SIMPLE(filecheck=CHECK_FAIL): -target spirv -entry computeMain -stage compute -DFAIL +//TEST:SIMPLE(filecheck=CHECK_PASS): -target spirv -entry computeMain -stage compute +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type + +interface ITwoParamGeneric +{ + int getVal(); +} + +struct Foo : ITwoParamGeneric +{ + int val = 0; + int getVal() + { + return val; + } +} + +extension Foo where A == B +{ + [mutating] + void setVal(int dataIn) + { + val = dataIn; + } +} + +void test(Foo x) where A == B +{ +} + +//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer outBuffer; + +void computeMain() +{ +//CHECK_FAIL-DAG: error 30027:{{.*}}'setVal'{{.*}}'Foo'. +//CHECK_FAIL-DAG: error 39999: could not specialize generic for arguments of type +//CHECK_PASS: OpEntryPoint +//CHECK: 3 +#ifdef FAIL + Foo x = Foo(); +#else + Foo x = Foo(); +#endif + x.setVal(3); + test(x); + outBuffer[0] = x.getVal(); +} \ No newline at end of file diff --git a/tests/language-feature/extensions/extension-with-where-clause-2.slang b/tests/language-feature/extensions/extension-with-where-clause-2.slang new file mode 100644 index 000000000..bbdfff630 --- /dev/null +++ b/tests/language-feature/extensions/extension-with-where-clause-2.slang @@ -0,0 +1,50 @@ +//TEST:SIMPLE(filecheck=CHECK_FAIL): -target spirv -entry computeMain -stage compute -DFAIL +//TEST:SIMPLE(filecheck=CHECK_PASS): -target spirv -entry computeMain -stage compute +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type + +interface ITwoParamGeneric +{ + int getVal(); +} + +struct Foo : ITwoParamGeneric +{ + int val = 0; + int getVal() + { + return val; + } +} + +struct NotPrimitiveCastable +{ + double data; +} + +extension Foo + where int(A) +#ifdef FAIL + where NotPrimitiveCastable(B) +#else + where float(B) +#endif +{ + [mutating] + void setVal(int dataIn) + { + val = dataIn; + } +} + +//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer outBuffer; + +void computeMain() +{ +//CHECK_FAIL: error 30027:{{.*}}'setVal'{{.*}}'Foo'. +//CHECK_PASS: OpEntryPoint +//CHECK: 3 + Foo x = Foo(); + x.setVal(3); + outBuffer[0] = x.getVal(); +} \ No newline at end of file diff --git a/tests/language-feature/extensions/extension-with-where-clause-3.slang b/tests/language-feature/extensions/extension-with-where-clause-3.slang new file mode 100644 index 000000000..6a3574b2c --- /dev/null +++ b/tests/language-feature/extensions/extension-with-where-clause-3.slang @@ -0,0 +1,41 @@ +//TEST:SIMPLE(filecheck=CHECK_FAIL): -target spirv -entry computeMain -stage compute -DFAIL +//TEST:SIMPLE(filecheck=CHECK_PASS): -target spirv -entry computeMain -stage compute +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type + +struct Foo +{ + int val = 0; + int getVal() + { + return val; + } +} + +extension Foo + where A == int +{ + [mutating] + void setVal1(int dataIn) + { + val += dataIn; + } +} + +//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer outBuffer; + +void computeMain() +{ +//CHECK_FAIL: error 30027: 'setVal1'{{.*}}'Foo' +//CHECK_PASS: OpEntryPoint +//CHECK: 3 +#ifdef FAIL + // fails since while expanding A and applying `where`, + // we will find a `float`, not a `int` + Foo x = Foo(); +#else + Foo x = Foo(); +#endif + x.setVal1(3); + outBuffer[0] = x.getVal(); +} \ No newline at end of file -- cgit v1.2.3