summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-01-08 00:25:32 -0600
committerGitHub <noreply@github.com>2025-01-07 22:25:32 -0800
commit1a56f58fdd0c704e6dc0fad0f0ec33a25a35e60b (patch)
treea44cdcd19379df09bdd4a8e585652e718f402ac6
parent7e278c3ad6eaedbce1d6b6babecbe32f1764b269 (diff)
Check whether array element is fully specialized (#6000)
* Check whether array element is fully specialized close #5776 When we start specialize a "specialize" IR, we should make sure all the elements are fully specialized, but we miss checking the elements of an array. This change will check the it. * add test * add all wrapper types into the check * add utility function to check if the type is wrapper type --------- Co-authored-by: zhangkai <zhangkai@zhangkais-MacBook-Pro.local> Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-ir-specialize.cpp6
-rw-r--r--source/slang/slang-ir-util.cpp25
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--tests/bugs/gh-5776.slang86
-rw-r--r--tests/bugs/gh-5776.slang.expected.txt7
5 files changed, 126 insertions, 0 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 2757538a6..50dfa2c6a 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -125,6 +125,12 @@ struct SpecializationContext
}
}
+ if (isWrapperType(inst))
+ {
+ // For all the wrapper type, we need to make sure the operands are fully specialized.
+ return areAllOperandsFullySpecialized(inst);
+ }
+
// The default case is that a global value is always specialized.
if (inst->getParent() == module->getModuleInst())
{
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 7788a50d5..c753600a7 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -277,6 +277,31 @@ bool isSimpleHLSLDataType(IRInst* inst)
return true;
}
+bool isWrapperType(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_ArrayType:
+ case kIROp_TextureType:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ case kIROp_PtrType:
+ case kIROp_RefType:
+ case kIROp_ConstRefType:
+ case kIROp_HLSLStructuredBufferType:
+ case kIROp_HLSLRWStructuredBufferType:
+ case kIROp_HLSLRasterizerOrderedStructuredBufferType:
+ case kIROp_HLSLAppendStructuredBufferType:
+ case kIROp_HLSLConsumeStructuredBufferType:
+ case kIROp_TupleType:
+ case kIROp_OptionalType:
+ case kIROp_TypePack:
+ return true;
+ default:
+ return false;
+ }
+}
+
SourceLoc findFirstUseLoc(IRInst* inst)
{
for (auto use = inst->firstUse; use; use = use->nextUse)
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 9a712ba96..e23aeb618 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -104,6 +104,8 @@ bool isSimpleDataType(IRType* type);
bool isSimpleHLSLDataType(IRInst* inst);
+bool isWrapperType(IRInst* inst);
+
SourceLoc findFirstUseLoc(IRInst* inst);
inline bool isChildInstOf(IRInst* inst, IRInst* parent)
diff --git a/tests/bugs/gh-5776.slang b/tests/bugs/gh-5776.slang
new file mode 100644
index 000000000..625a7b5cc
--- /dev/null
+++ b/tests/bugs/gh-5776.slang
@@ -0,0 +1,86 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -profile sm_6_0 -use-dxil -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cuda -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -wgpu -output-using-type
+
+
+interface IFoo
+{
+ associatedtype FooType : IFoo;
+}
+
+extension float : IFoo
+{
+ typedef float FooType;
+}
+
+__generic<T:IFoo, let N:int>
+extension Array<T, N> : IFoo
+{
+ typedef Array<T.FooType, N> FooType;
+}
+
+__generic<T:IFoo, let N:int>
+extension vector<T, N> : IFoo
+{
+ typedef vector<T.FooType, N> FooType;
+}
+
+__generic<T:IFoo, let N:int, let M:int>
+extension matrix<T, N, M> : IFoo
+{
+ typedef matrix<T.FooType, N, M> FooType;
+}
+
+struct WrappedBuffer<T : IFoo>
+{
+ StructuredBuffer<T> buffer;
+ int shape;
+
+ T get(int idx) { return buffer[idx]; }
+}
+
+
+struct GradInBuffer<T : IFoo>
+{
+ WrappedBuffer<T.FooType> wrapBuffer;
+}
+
+struct CallData
+{
+ GradInBuffer<float[2]> grad_in1;
+ GradInBuffer<vector<float, 2>> grad_in2;
+ GradInBuffer<float2x2> grad_in3;
+}
+
+
+//TEST_INPUT: set call_data.grad_in1.wrapBuffer.buffer = ubuffer(data=[1.0 2.0 3.0 4.0], stride=4);
+//TEST_INPUT: set call_data.grad_in2.wrapBuffer.buffer = ubuffer(data=[5.0 6.0 7.0 8.0], stride=4);
+//TEST_INPUT: set call_data.grad_in3.wrapBuffer.buffer = ubuffer(data=[1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0], stride=4);
+ParameterBlock<CallData> call_data;
+
+
+//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+ float[2] data1 = call_data.grad_in1.wrapBuffer.buffer[0];
+ float[2] data2 = call_data.grad_in1.wrapBuffer.get(1);
+ outputBuffer[0] = data1[0];
+ outputBuffer[1] = data2[0];
+
+ vector<float, 2> data3 = call_data.grad_in2.wrapBuffer.buffer[0];
+ vector<float, 2> data4 = call_data.grad_in2.wrapBuffer.get(1);
+ outputBuffer[2] = data3[0];
+ outputBuffer[3] = data4[0];
+
+ float2x2 data5 = call_data.grad_in3.wrapBuffer.buffer[0];
+ float2x2 data6 = call_data.grad_in3.wrapBuffer.get(1);
+ outputBuffer[4] = data5[0][0];
+ outputBuffer[5] = data6[0][0];
+}
diff --git a/tests/bugs/gh-5776.slang.expected.txt b/tests/bugs/gh-5776.slang.expected.txt
new file mode 100644
index 000000000..ffde6889e
--- /dev/null
+++ b/tests/bugs/gh-5776.slang.expected.txt
@@ -0,0 +1,7 @@
+type: float
+1.000000
+3.000000
+5.000000
+7.000000
+1.000000
+5.000000