summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-10-16 11:23:13 -0700
committerGitHub <noreply@github.com>2025-10-16 18:23:13 +0000
commitbedc3421c9e1e0837fa69e30396a27a60f0fee53 (patch)
tree4ac8a8f5a4fbfc5f601cd60f82e73e29664123ba
parentd8b732d7ba6d31a724cb18dc93f60d8bcc522c19 (diff)
Fix use of variadic generics with [Differentiable]. (#8736)
There was a bug that causes the compiler failing to treat a `no_diff TypePack` as a type pack, and thus diagnose an error when resolving the following call. The fix is to unwrap any ModifiedType wrappers in `IsTypePack()` check.
-rw-r--r--docs/user-guide/06-interfaces-generics.md6
-rw-r--r--source/slang/slang-ast-type.cpp3
-rw-r--r--source/slang/slang-ast-val.h4
-rw-r--r--tests/language-feature/generics/variadic-generic-differentiable.slang59
-rw-r--r--tests/language-feature/generics/variadic-user-guide.slang33
5 files changed, 102 insertions, 3 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md
index 30efbada1..b7fca069e 100644
--- a/docs/user-guide/06-interfaces-generics.md
+++ b/docs/user-guide/06-interfaces-generics.md
@@ -1026,12 +1026,12 @@ void printNumbers<each T>(expand each T args) where T == int
void compute<each T>(expand each T args) where T == int
{
// Maps every element in `args` to `elementValue + 1`, and forwards the
- // new values as arguments to `printNumber`.
- printNumber(expand (each args) + 1);
+ // new values as arguments to `printNumbers`.
+ printNumbers(expand (each args) + 1);
// The above statement is equivalent to:
// ```
- // printNumber(args[0] + 1, args[1] + 1, ..., args[n-1] + 1);
+ // printNumbers(args[0] + 1, args[1] + 1, ..., args[n-1] + 1);
// ```
}
void test()
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index e3f75cb2f..bfa347df3 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -4,6 +4,7 @@
#include "slang-ast-builder.h"
#include "slang-ast-dispatch.h"
#include "slang-ast-modifier.h"
+#include "slang-check.h"
#include "slang-syntax.h"
#include <assert.h>
@@ -13,6 +14,7 @@ namespace Slang
bool isAbstractTypePack(Type* type)
{
+ type = unwrapModifiedType(type);
if (as<ExpandType>(type))
return true;
if (isDeclRefTypeOf<GenericTypePackParamDecl>(type))
@@ -22,6 +24,7 @@ bool isAbstractTypePack(Type* type)
bool isTypePack(Type* type)
{
+ type = unwrapModifiedType(type);
if (as<ConcreteTypePack>(type))
return true;
return isAbstractTypePack(type);
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 7b9462f2f..d4edaae08 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -1014,6 +1014,10 @@ inline bool isTypeEqualityWitness(Val* witness)
}
return true;
}
+ else if (auto expandWitness = as<ExpandSubtypeWitness>(witness))
+ {
+ return isTypeEqualityWitness(expandWitness->getPatternTypeWitness());
+ }
return false;
}
diff --git a/tests/language-feature/generics/variadic-generic-differentiable.slang b/tests/language-feature/generics/variadic-generic-differentiable.slang
new file mode 100644
index 000000000..55af3424c
--- /dev/null
+++ b/tests/language-feature/generics/variadic-generic-differentiable.slang
@@ -0,0 +1,59 @@
+//TEST:INTERPRET(filecheck=CHECK):
+
+// Test that we can call variadic generic [Differentiable] methods.
+
+interface IParameterExtractor
+{
+ [Differentiable]
+ void extract(no_diff uint x);
+}
+
+struct FloatParameterExtractor : IParameterExtractor
+{
+ [Differentiable]
+ void extract(no_diff uint x)
+ {
+ printf("fff\n");
+ }
+}
+
+[Differentiable]
+void extract_parameters_helper<T : IParameterExtractor>(
+ no_diff uint x,
+ T arg,
+ )
+{
+ arg.extract(x);
+}
+
+[Differentiable]
+void wrapper1<each T>(
+ uint x,
+ expand each T args, // compiler will add no_diff modifier here.
+ )
+ where T : IParameterExtractor
+{
+ expand extract_parameters_helper(x, each args);
+}
+
+[Differentiable]
+void wrapper2<each T>(
+ uint x,
+ expand each T args, // compiler will add no_diff modifier here.
+ )
+ where T : IParameterExtractor
+{
+ wrapper1(x, args);
+}
+
+
+void main()
+{
+ // There was a bug that causes the compiler failing to treat a `no_diff TypePack` as
+ // a type pack, and thus diagnose an error when resolving the following call.
+ //
+ wrapper2(1, FloatParameterExtractor(), FloatParameterExtractor());
+}
+
+// CHECK: fff
+// CHECK: fff \ No newline at end of file
diff --git a/tests/language-feature/generics/variadic-user-guide.slang b/tests/language-feature/generics/variadic-user-guide.slang
new file mode 100644
index 000000000..1d1cfd790
--- /dev/null
+++ b/tests/language-feature/generics/variadic-user-guide.slang
@@ -0,0 +1,33 @@
+//TEST:INTERPRET(filecheck=CHECK):
+
+void printNumbers<each T>(expand each T args) where T == int
+{
+ // An single expression statement whose type will be `(void, void, ...)`.
+ // where each `void` is the result of evaluating expression `printf(...)` with
+ // each corresponding element in `args` passed as print operand.
+ //
+ expand printf("%d\n", each args);
+
+ // The above statement is equivalent to:
+ // ```
+ // (printf("%d\n", args[0]), printf("%d\n", args[1]), ..., printf("%d\n", args[n-1]));
+ // ```
+}
+void compute<each T>(expand each T args) where T == int
+{
+ // Maps every element in `args` to `elementValue + 1`, and forwards the
+ // new values as arguments to `printNumber`.
+ printNumbers(expand ((each args) + 1));
+
+ // The above statement is equivalent to:
+ // ```
+ // printNumber(args[0] + 1, args[1] + 1, ..., args[n-1] + 1);
+ // ```
+}
+void main()
+{
+ compute(1,2,3);
+ // CHECK: 2
+ // CHECK: 3
+ // CHECK: 4
+} \ No newline at end of file