summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2025-07-08 17:18:02 -0700
committerGitHub <noreply@github.com>2025-07-09 00:18:02 +0000
commite036c7f38e3b8b1f0b0e0ac1b4ef22fc6f16963b (patch)
tree75805f1a7deb3df8f2ee6fbc7eb47fc405a24e93
parent2c4bfce49d9af2414f6a3f70f7221d6890a017e7 (diff)
Fix `extension` incorrectly interacting with `equality` and `type-coercion` constraints (#7578)
* fix problem * cleanup comment * format code * make change more restrictive * format code * push logic update * format code * push test fix * make test more general --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--source/slang/slang-check-constraint.cpp4
-rw-r--r--source/slang/slang-check-decl.cpp4
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--tests/language-feature/extensions/extension-with-where-clause-1.slang49
-rw-r--r--tests/language-feature/extensions/extension-with-where-clause-2.slang50
-rw-r--r--tests/language-feature/extensions/extension-with-where-clause-3.slang41
6 files changed, 152 insertions, 2 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 7c55c440c..7c6f8929a 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -363,6 +363,7 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
ValUnificationContext unificationContext;
unificationContext.optionalConstraint =
constraintDeclRef.getDecl()->hasModifier<OptionalConstraintModifier>();
+ unificationContext.equalityConstraint = constraintDeclRef.getDecl()->isEqualityConstraint;
if (!TryUnifyTypes(
*system,
unificationContext,
@@ -492,6 +493,8 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
{
if (c.isOptional)
joinType = type;
+ else if (c.isEquality)
+ joinType = type;
else
// failure!
return DeclRef<Decl>();
@@ -970,6 +973,7 @@ bool SemanticsVisitor::TryUnifyTypeParam(
constraint.val = type;
constraint.isUsedAsLValue = type.isLeftValue;
constraint.isOptional = unificationContext.optionalConstraint;
+ constraint.isEquality = unificationContext.equalityConstraint;
constraints.constraints.add(constraint);
return true;
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 558834c34..5de83e24d 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3398,9 +3398,9 @@ void SemanticsDeclHeaderVisitor::visitGenericDecl(GenericDecl* genericDecl)
ensureDecl(valParam, DeclCheckState::ReadyForReference);
valParam->parameterIndex = parameterIndex++;
}
- else if (auto constraint = as<GenericTypeConstraintDecl>(m))
+ else if (as<GenericTypeConstraintDecl>(m) || as<TypeCoercionConstraintDecl>(m))
{
- ensureDecl(constraint, DeclCheckState::ReadyForReference);
+ ensureDecl(m, DeclCheckState::ReadyForReference);
}
}
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 30e317401..4a6ccfe17 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2305,6 +2305,11 @@ public:
// if it is otherwise unconstrained, but doesn't take precedence over a constraint that is
// not optional.
bool isOptional = false;
+
+ // Is this constraint an equality? This tells us that "joining" types is meaningless, we
+ // know the result will be the sub type. If it is not, we will error once we start
+ // substituting types.
+ bool isEquality = false;
};
// A collection of constraints that will need to be satisfied (solved)
@@ -2658,6 +2663,7 @@ public:
{
Index indexInTypePack = 0;
bool optionalConstraint = false;
+ bool equalityConstraint = false;
};
// Try to find a unification for two values
diff --git a/tests/language-feature/extensions/extension-with-where-clause-1.slang b/tests/language-feature/extensions/extension-with-where-clause-1.slang
new file mode 100644
index 000000000..9facb8aa3
--- /dev/null
+++ b/tests/language-feature/extensions/extension-with-where-clause-1.slang
@@ -0,0 +1,49 @@
+//TEST:SIMPLE(filecheck=CHECK_FAIL): -target spirv -entry computeMain -stage compute -DFAIL
+//TEST:SIMPLE(filecheck=CHECK_PASS): -target spirv -entry computeMain -stage compute
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type
+
+interface ITwoParamGeneric<A, B>
+{
+ int getVal();
+}
+
+struct Foo<A, B> : ITwoParamGeneric<A, B>
+{
+ int val = 0;
+ int getVal()
+ {
+ return val;
+ }
+}
+
+extension<A, B> Foo<A,B> where A == B
+{
+ [mutating]
+ void setVal(int dataIn)
+ {
+ val = dataIn;
+ }
+}
+
+void test<A, B>(Foo<A,B> x) where A == B
+{
+}
+
+//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float4> outBuffer;
+
+void computeMain()
+{
+//CHECK_FAIL-DAG: error 30027:{{.*}}'setVal'{{.*}}'Foo<int, float>'.
+//CHECK_FAIL-DAG: error 39999: could not specialize generic for arguments of type
+//CHECK_PASS: OpEntryPoint
+//CHECK: 3
+#ifdef FAIL
+ Foo<int, float> x = Foo<int, float>();
+#else
+ Foo<int, int> x = Foo<int, int>();
+#endif
+ x.setVal(3);
+ test(x);
+ outBuffer[0] = x.getVal();
+} \ No newline at end of file
diff --git a/tests/language-feature/extensions/extension-with-where-clause-2.slang b/tests/language-feature/extensions/extension-with-where-clause-2.slang
new file mode 100644
index 000000000..bbdfff630
--- /dev/null
+++ b/tests/language-feature/extensions/extension-with-where-clause-2.slang
@@ -0,0 +1,50 @@
+//TEST:SIMPLE(filecheck=CHECK_FAIL): -target spirv -entry computeMain -stage compute -DFAIL
+//TEST:SIMPLE(filecheck=CHECK_PASS): -target spirv -entry computeMain -stage compute
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type
+
+interface ITwoParamGeneric<A, B>
+{
+ int getVal();
+}
+
+struct Foo<A, B> : ITwoParamGeneric<A, B>
+{
+ int val = 0;
+ int getVal()
+ {
+ return val;
+ }
+}
+
+struct NotPrimitiveCastable
+{
+ double data;
+}
+
+extension<A, B> Foo<A,B>
+ where int(A)
+#ifdef FAIL
+ where NotPrimitiveCastable(B)
+#else
+ where float(B)
+#endif
+{
+ [mutating]
+ void setVal(int dataIn)
+ {
+ val = dataIn;
+ }
+}
+
+//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float4> outBuffer;
+
+void computeMain()
+{
+//CHECK_FAIL: error 30027:{{.*}}'setVal'{{.*}}'Foo<float, int>'.
+//CHECK_PASS: OpEntryPoint
+//CHECK: 3
+ Foo<float, int> x = Foo<float, int>();
+ x.setVal(3);
+ outBuffer[0] = x.getVal();
+} \ No newline at end of file
diff --git a/tests/language-feature/extensions/extension-with-where-clause-3.slang b/tests/language-feature/extensions/extension-with-where-clause-3.slang
new file mode 100644
index 000000000..6a3574b2c
--- /dev/null
+++ b/tests/language-feature/extensions/extension-with-where-clause-3.slang
@@ -0,0 +1,41 @@
+//TEST:SIMPLE(filecheck=CHECK_FAIL): -target spirv -entry computeMain -stage compute -DFAIL
+//TEST:SIMPLE(filecheck=CHECK_PASS): -target spirv -entry computeMain -stage compute
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type
+
+struct Foo<each A>
+{
+ int val = 0;
+ int getVal()
+ {
+ return val;
+ }
+}
+
+extension<each A> Foo<A>
+ where A == int
+{
+ [mutating]
+ void setVal1(int dataIn)
+ {
+ val += dataIn;
+ }
+}
+
+//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float4> outBuffer;
+
+void computeMain()
+{
+//CHECK_FAIL: error 30027: 'setVal1'{{.*}}'Foo<float>'
+//CHECK_PASS: OpEntryPoint
+//CHECK: 3
+#ifdef FAIL
+ // fails since while expanding A and applying `where`,
+ // we will find a `float`, not a `int`
+ Foo<float> x = Foo<float>();
+#else
+ Foo<int> x = Foo<int>();
+#endif
+ x.setVal1(3);
+ outBuffer[0] = x.getVal();
+} \ No newline at end of file