summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-31 17:35:08 -0700
committerGitHub <noreply@github.com>2024-07-31 17:35:08 -0700
commit4c6b0a2831a7edd1419bd0b2e6edd089080e07be (patch)
tree54557efa8bfc9e316014fe1555f07c94afa93cd4
parentbab4b821dc6bcee4ff86751743762584c17e9103 (diff)
Allow generic type deduction from ParameterBlock arguments. (#4766)
* Allow generic type deduction from ParameterBlock arguments. * Fix test. * Update expected failure list. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-check-constraint.cpp5
-rw-r--r--tests/expected-failure-github.txt2
-rw-r--r--tests/language-feature/generics/parameter-block-unify.slang33
-rw-r--r--tools/render-test/shader-input-layout.cpp5
4 files changed, 42 insertions, 3 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 23f7354d9..1195ed1f9 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -949,6 +949,11 @@ namespace Slang
QualType(sndVectorType->getElementType(), snd.isLeftValue));
}
}
+
+ if (auto fstUniformParamGroupType = as<UniformParameterGroupType>(fst))
+ return TryUnifyTypes(constraints, QualType(fstUniformParamGroupType->getElementType(), fst.isLeftValue), snd);
+ if (auto sndUniformParamGroupType = as<UniformParameterGroupType>(snd))
+ return TryUnifyTypes(constraints, fst, QualType(sndUniformParamGroupType->getElementType(), snd.isLeftValue));
return false;
}
diff --git a/tests/expected-failure-github.txt b/tests/expected-failure-github.txt
index f157d2b9a..524930f62 100644
--- a/tests/expected-failure-github.txt
+++ b/tests/expected-failure-github.txt
@@ -5,5 +5,3 @@ tests/language-feature/saturated-cooperation/fuse-product.slang (vk)
tests/language-feature/saturated-cooperation/fuse.slang (vk)
tests/bugs/byte-address-buffer-interlocked-add-f32.slang (vk)
tests/serialization/obfuscated-serialized-module-test.slang.2 syn (mtl)
-tests/render/cross-compile-entry-point.slang.2 syn (mtl)
-tests/bindings/nested-parameter-block-2.slang.3 syn (mtl)
diff --git a/tests/language-feature/generics/parameter-block-unify.slang b/tests/language-feature/generics/parameter-block-unify.slang
new file mode 100644
index 000000000..b549f555b
--- /dev/null
+++ b/tests/language-feature/generics/parameter-block-unify.slang
@@ -0,0 +1,33 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-mtl -compute -shaderobj -output-using-type -render-features argument-buffer-tier-2
+
+struct TestStruct<Format:__BuiltinIntegerType, let count : int>
+{
+ Format f;
+};
+
+Format testFunction<Format : __BuiltinIntegerType, let count : int>(TestStruct<Format, count> data)
+{
+ return data.f + __int_cast<Format>(count);
+}
+
+//TEST_INPUT: set testBlock = new TestStruct<int, 12>{1}
+ParameterBlock<TestStruct<int, 12>> testBlock;
+
+//TEST_INPUT: set testBlock2 = new TestStruct<int, 12>{2}
+ConstantBuffer<TestStruct<int, 2>> testBlock2;
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // CHECK: 13
+ outputBuffer[0] = testFunction(testBlock);
+ // CHECK: 13
+ outputBuffer[1] = testFunction<int, 12>(testBlock);
+ // CHECK: 4
+ outputBuffer[2] = testFunction(testBlock2);
+} \ No newline at end of file
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index 3012d45a4..ac99d5cd8 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -462,7 +462,10 @@ namespace renderer_test
sb << typeName << "<";
for (;;)
{
- sb << parseTypeName(parser);
+ if (parser.LookAhead(Misc::TokenType::IntLiteral))
+ sb << parser.ReadInt();
+ else
+ sb << parseTypeName(parser);
if (!parser.AdvanceIf(","))
break;
sb << ",";