diff options
| author | Yong He <yonghe@outlook.com> | 2025-10-16 11:23:13 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-16 18:23:13 +0000 |
| commit | bedc3421c9e1e0837fa69e30396a27a60f0fee53 (patch) | |
| tree | 4ac8a8f5a4fbfc5f601cd60f82e73e29664123ba | |
| parent | d8b732d7ba6d31a724cb18dc93f60d8bcc522c19 (diff) | |
Fix use of variadic generics with [Differentiable]. (#8736)
There was a bug that causes the compiler failing to treat a `no_diff
TypePack` as a type pack, and thus diagnose an error when resolving the
following call.
The fix is to unwrap any ModifiedType wrappers in `IsTypePack()` check.
| -rw-r--r-- | docs/user-guide/06-interfaces-generics.md | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 4 | ||||
| -rw-r--r-- | tests/language-feature/generics/variadic-generic-differentiable.slang | 59 | ||||
| -rw-r--r-- | tests/language-feature/generics/variadic-user-guide.slang | 33 |
5 files changed, 102 insertions, 3 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md index 30efbada1..b7fca069e 100644 --- a/docs/user-guide/06-interfaces-generics.md +++ b/docs/user-guide/06-interfaces-generics.md @@ -1026,12 +1026,12 @@ void printNumbers<each T>(expand each T args) where T == int void compute<each T>(expand each T args) where T == int { // Maps every element in `args` to `elementValue + 1`, and forwards the - // new values as arguments to `printNumber`. - printNumber(expand (each args) + 1); + // new values as arguments to `printNumbers`. + printNumbers(expand (each args) + 1); // The above statement is equivalent to: // ``` - // printNumber(args[0] + 1, args[1] + 1, ..., args[n-1] + 1); + // printNumbers(args[0] + 1, args[1] + 1, ..., args[n-1] + 1); // ``` } void test() diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index e3f75cb2f..bfa347df3 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -4,6 +4,7 @@ #include "slang-ast-builder.h" #include "slang-ast-dispatch.h" #include "slang-ast-modifier.h" +#include "slang-check.h" #include "slang-syntax.h" #include <assert.h> @@ -13,6 +14,7 @@ namespace Slang bool isAbstractTypePack(Type* type) { + type = unwrapModifiedType(type); if (as<ExpandType>(type)) return true; if (isDeclRefTypeOf<GenericTypePackParamDecl>(type)) @@ -22,6 +24,7 @@ bool isAbstractTypePack(Type* type) bool isTypePack(Type* type) { + type = unwrapModifiedType(type); if (as<ConcreteTypePack>(type)) return true; return isAbstractTypePack(type); diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 7b9462f2f..d4edaae08 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -1014,6 +1014,10 @@ inline bool isTypeEqualityWitness(Val* witness) } return true; } + else if (auto expandWitness = as<ExpandSubtypeWitness>(witness)) + { + return isTypeEqualityWitness(expandWitness->getPatternTypeWitness()); + } return false; } diff --git a/tests/language-feature/generics/variadic-generic-differentiable.slang b/tests/language-feature/generics/variadic-generic-differentiable.slang new file mode 100644 index 000000000..55af3424c --- /dev/null +++ b/tests/language-feature/generics/variadic-generic-differentiable.slang @@ -0,0 +1,59 @@ +//TEST:INTERPRET(filecheck=CHECK): + +// Test that we can call variadic generic [Differentiable] methods. + +interface IParameterExtractor +{ + [Differentiable] + void extract(no_diff uint x); +} + +struct FloatParameterExtractor : IParameterExtractor +{ + [Differentiable] + void extract(no_diff uint x) + { + printf("fff\n"); + } +} + +[Differentiable] +void extract_parameters_helper<T : IParameterExtractor>( + no_diff uint x, + T arg, + ) +{ + arg.extract(x); +} + +[Differentiable] +void wrapper1<each T>( + uint x, + expand each T args, // compiler will add no_diff modifier here. + ) + where T : IParameterExtractor +{ + expand extract_parameters_helper(x, each args); +} + +[Differentiable] +void wrapper2<each T>( + uint x, + expand each T args, // compiler will add no_diff modifier here. + ) + where T : IParameterExtractor +{ + wrapper1(x, args); +} + + +void main() +{ + // There was a bug that causes the compiler failing to treat a `no_diff TypePack` as + // a type pack, and thus diagnose an error when resolving the following call. + // + wrapper2(1, FloatParameterExtractor(), FloatParameterExtractor()); +} + +// CHECK: fff +// CHECK: fff
\ No newline at end of file diff --git a/tests/language-feature/generics/variadic-user-guide.slang b/tests/language-feature/generics/variadic-user-guide.slang new file mode 100644 index 000000000..1d1cfd790 --- /dev/null +++ b/tests/language-feature/generics/variadic-user-guide.slang @@ -0,0 +1,33 @@ +//TEST:INTERPRET(filecheck=CHECK): + +void printNumbers<each T>(expand each T args) where T == int +{ + // An single expression statement whose type will be `(void, void, ...)`. + // where each `void` is the result of evaluating expression `printf(...)` with + // each corresponding element in `args` passed as print operand. + // + expand printf("%d\n", each args); + + // The above statement is equivalent to: + // ``` + // (printf("%d\n", args[0]), printf("%d\n", args[1]), ..., printf("%d\n", args[n-1])); + // ``` +} +void compute<each T>(expand each T args) where T == int +{ + // Maps every element in `args` to `elementValue + 1`, and forwards the + // new values as arguments to `printNumber`. + printNumbers(expand ((each args) + 1)); + + // The above statement is equivalent to: + // ``` + // printNumber(args[0] + 1, args[1] + 1, ..., args[n-1] + 1); + // ``` +} +void main() +{ + compute(1,2,3); + // CHECK: 2 + // CHECK: 3 + // CHECK: 4 +}
\ No newline at end of file |
