summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-legalize-types.cpp20
-rw-r--r--tests/bugs/gh-7905.slang63
2 files changed, 80 insertions, 3 deletions
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index dd7107b18..085c3d933 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -179,14 +179,27 @@ static LegalVal maybeMaterializeWrappedValue(IRTypeLegalizationContext* context,
static LegalVal legalizeOperand(IRTypeLegalizationContext* context, IRInst* irValue)
{
LegalVal legalVal;
+
+ // Special handling for type operands
+ if (auto oldType = as<IRType>(irValue))
+ {
+ // e.g. ParameterBlock<Struct>, the inst. ParameterBlockType holds the operand `StructType`,
+ // if we don't legalize it here and the same structType is legalized somewhere else, the
+ // operand of ParameterBlockType might not get updated, and it would result in a type
+ // mismatch.
+ auto legalType = legalizeType(context, oldType);
+ if (legalType.flavor == LegalType::Flavor::simple)
+ return LegalVal::simple(legalType.getSimple());
+ // legalType is not simple, fallback to the original value
+ }
+
if (context->mapValToLegalVal.tryGetValue(irValue, legalVal))
{
return maybeMaterializeWrappedValue(context, legalVal);
}
// For now, assume that anything not covered
- // by the mapping is legal as-is.
-
+ // by the type legalization or val mapping is legal as-is.
return LegalVal::simple(irValue);
}
@@ -2386,7 +2399,8 @@ static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst)
for (UInt aa = 0; aa < argCount; ++aa)
{
auto oldArg = inst->getOperand(aa);
- auto legalArg = legalizeOperand(context, oldArg);
+
+ LegalVal legalArg = legalizeOperand(context, oldArg);
legalArgs.add(legalArg);
if (legalArg.flavor != LegalVal::Flavor::simple)
diff --git a/tests/bugs/gh-7905.slang b/tests/bugs/gh-7905.slang
new file mode 100644
index 000000000..69dd6d3ba
--- /dev/null
+++ b/tests/bugs/gh-7905.slang
@@ -0,0 +1,63 @@
+//TEST:SIMPLE(filecheck=CHECK): -target cuda -entry runPointEstimator
+
+// Test for issue #7905: CUDA Backend failure due to type mismatch
+// This test ensures that struct types in ParameterBlock operands are properly
+// legalized and don't create type mismatches in generated CUDA code.
+
+#define ZOMBIE_PROBLEM_DIMENSION 3
+
+public interface IExample<let DIM : int>
+{
+ // computes the distance to the boundary
+ float compute(vector<float, DIM> x);
+};
+
+public struct Query<Example,let DIM : int>
+ where Example : IExample<DIM>
+{
+ // private AbsorbingBoundaryGeometricQueries absorbingBoundaryGeometricQueries;
+ private Example query;
+ private uint hasNonEmptyAbsorbingBoundary;
+
+ public float compute(vector<float, DIM> x)
+ {
+ return query.compute(x);
+ }
+};
+
+public struct EmptyExample<let DIM : int> : IExample<DIM>
+{
+ // computes the distance to the boundary
+ public float compute(vector<float, DIM> x)
+ {
+ internal static const float FLT_MAX = 3.402823466e+38F;
+ return FLT_MAX;
+ }
+};
+
+
+typedef EmptyExample<ZOMBIE_PROBLEM_DIMENSION> ExampleQuery;
+typedef Query<ExampleQuery, ZOMBIE_PROBLEM_DIMENSION> QueryType;
+
+
+uniform ParameterBlock<QueryType> gQuery;
+uniform RWStructuredBuffer<float3> gInput;
+uniform RWStructuredBuffer<float> gOutput;
+
+
+[shader("compute")]
+[numthreads(256, 1, 1)]
+void runPointEstimator(uint3 threadId: SV_DispatchThreadID,
+ uniform uint n)
+{
+ const uint idx = threadId.x;
+
+ if (idx >= n) {
+ return;
+ }
+
+ float res = gQuery.compute(gInput[idx]);
+ gOutput[idx] = res;
+}
+
+// CHECK-NOT: Query_1