summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-04 15:21:17 -0700
committerGitHub <noreply@github.com>2022-10-04 15:21:17 -0700
commit364e43264b9f69957ddaed8890392d82fb25c822 (patch)
treeaa4d9b6a90ddf398c12f7cce6499e3946d8ffeb1
parent8b1daa68a5ff1398cdf130aacad32d2e5646d1eb (diff)
Fix `ApplyExtensionToType` on own type being extended. (#2430)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-check-decl.cpp10
-rw-r--r--source/slang/slang.natvis7
-rw-r--r--tests/bugs/generic-extension.slang28
-rw-r--r--tests/bugs/generic-extension.slang.expected.txt4
4 files changed, 45 insertions, 4 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index a7cd768b1..b18e1c4da 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5226,9 +5226,17 @@ namespace Slang
constraints.loc = extDecl->loc;
constraints.genericDecl = extGenericDecl;
+ // Inside the body of an extension declaration, we may end up trying to apply that
+ // extension to its own target type.
+ // If we see that we are in that case, we can apply the extension declaration as - is,
+ // without any additional substitutions.
+ if (extDecl->targetType->equals(type))
+ {
+ return extDeclRef;
+ }
+
if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type))
return DeclRef<ExtensionDecl>();
-
auto constraintSubst = trySolveConstraintSystem(&constraints, DeclRef<Decl>(extGenericDecl, nullptr).as<GenericDecl>());
if (!constraintSubst)
{
diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis
index 6c8f15ec9..ee868be62 100644
--- a/source/slang/slang.natvis
+++ b/source/slang/slang.natvis
@@ -282,6 +282,8 @@
<DisplayString Condition="nameAndLoc.name!=0">{nameAndLoc.name->text}: {astNodeType}</DisplayString>
<DisplayString Condition="nameAndLoc.name==0">{astNodeType}</DisplayString>
<Expand>
+ <Item Name="[Name]" Condition="nameAndLoc.name!=0">nameAndLoc.name->text</Item>
+ <Item Name="[Parent]">parentDecl</Item>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ContainerDecl">(Slang::ContainerDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtensionDecl">(Slang::ExtensionDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::StructDecl">(Slang::StructDecl*)&amp;astNodeType</ExpandedItem>
@@ -322,6 +324,7 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::EmptyDecl">(Slang::EmptyDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::SyntaxDecl">(Slang::SyntaxDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclGroup">(Slang::DeclGroup*)&amp;astNodeType</ExpandedItem>
+
<Item Name="Decl">(Slang::DeclBase*)this,nd</Item>
</Expand>
</Type>
@@ -369,7 +372,7 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::EmptyDecl">(Slang::EmptyDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::SyntaxDecl">(Slang::SyntaxDecl*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclGroup">(Slang::DeclGroup*)&amp;astNodeType</ExpandedItem>
- <Item Name="Decl">(Slang::DeclBase*)this,nd</Item>
+ <Item Name="Decl">(Slang::Decl*)this,nd</Item>
</Expand>
</Type>
@@ -473,8 +476,6 @@
<Type Name="Slang::AggTypeDecl">
<DisplayString>{nameAndLoc.name}: {astNodeType}</DisplayString>
<Expand>
- <Item Name="[Name]">nameAndLoc.name</Item>
- <Item Name="[Parent]">parentDecl</Item>
<Item Name="[Members]">members</Item>
</Expand>
</Type>
diff --git a/tests/bugs/generic-extension.slang b/tests/bugs/generic-extension.slang
new file mode 100644
index 000000000..a3e039872
--- /dev/null
+++ b/tests/bugs/generic-extension.slang
@@ -0,0 +1,28 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+interface IFoo
+{
+ static This myAdd(This v, float val);
+}
+
+__generic<let N : int>
+extension vector<float, N> : IFoo
+{
+ static vector<float, N> myAdd(vector<float, N> v, float val)
+ {
+ return v + vector<float, N>(val);
+ }
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int index = int(dispatchThreadID.x);
+ float2 v = 1.0;
+ v = float2.myAdd(v, 2.0);
+ outputBuffer[index] = int(v.x);
+}
+
diff --git a/tests/bugs/generic-extension.slang.expected.txt b/tests/bugs/generic-extension.slang.expected.txt
new file mode 100644
index 000000000..463fa2702
--- /dev/null
+++ b/tests/bugs/generic-extension.slang.expected.txt
@@ -0,0 +1,4 @@
+3
+3
+3
+3