summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/ir-legalize-types.cpp123
-rw-r--r--source/slang/legalize-types.h6
-rw-r--r--tests/bugs/gh-566.slang34
-rw-r--r--tests/bugs/gh-566.slang.expected.txt4
4 files changed, 164 insertions, 3 deletions
diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp
index 05b8ca647..969e3eb96 100644
--- a/source/slang/ir-legalize-types.cpp
+++ b/source/slang/ir-legalize-types.cpp
@@ -669,6 +669,122 @@ static LegalVal legalizeGetElementPtr(
indexOperand);
}
+static LegalVal legalizeMakeStruct(
+ IRTypeLegalizationContext* context,
+ LegalType legalType,
+ LegalVal const* legalArgs,
+ UInt argCount)
+{
+ auto builder = context->builder;
+
+ switch(legalType.flavor)
+ {
+ case LegalType::Flavor::simple:
+ {
+ List<IRInst*> args;
+ for(UInt aa = 0; aa < argCount; ++aa)
+ {
+ // Note: we assume that all the arguments
+ // must be simple here, because otherwise
+ // the `struct` type with them as fields
+ // would not be simple...
+ //
+ args.Add(legalArgs[aa].getSimple());
+ }
+ return LegalVal::simple(
+ builder->emitMakeStruct(
+ legalType.getSimple(),
+ argCount,
+ args.Buffer()));
+ }
+
+ case LegalType::Flavor::pair:
+ {
+ // There are two sides, the ordinary and the special,
+ // and we basically just dispatch to both of them.
+ auto pairType = legalType.getPair();
+ auto pairInfo = pairType->pairInfo;
+ LegalType ordinaryType = pairType->ordinaryType;
+ LegalType specialType = pairType->specialType;
+
+ List<LegalVal> ordinaryArgs;
+ List<LegalVal> specialArgs;
+ UInt argCounter = 0;
+ for(auto ee : pairInfo->elements)
+ {
+ UInt argIndex = argCounter++;
+ LegalVal arg = legalArgs[argIndex];
+
+ if((ee.flags & Slang::PairInfo::kFlag_hasOrdinaryAndSpecial) == Slang::PairInfo::kFlag_hasOrdinaryAndSpecial)
+ {
+ // The field is itself a pair type, so we expect
+ // the argument value to be one too...
+ auto argPair = arg.getPair();
+ ordinaryArgs.Add(argPair->ordinaryVal);
+ specialArgs.Add(argPair->specialVal);
+ }
+ else if(ee.flags & Slang::PairInfo::kFlag_hasOrdinary)
+ {
+ ordinaryArgs.Add(arg);
+ }
+ else if(ee.flags & Slang::PairInfo::kFlag_hasSpecial)
+ {
+ specialArgs.Add(arg);
+ }
+ }
+
+ LegalVal ordinaryVal = legalizeMakeStruct(
+ context,
+ ordinaryType,
+ ordinaryArgs.Buffer(),
+ ordinaryArgs.Count());
+
+ LegalVal specialVal = legalizeMakeStruct(
+ context,
+ specialType,
+ specialArgs.Buffer(),
+ specialArgs.Count());
+
+ return LegalVal::pair(ordinaryVal, specialVal, pairInfo);
+ }
+ break;
+
+ case LegalType::Flavor::tuple:
+ {
+ // We are constructing a tuple of values from
+ // the individual fields. We need to identify
+ // for each tuple element what field it uses,
+ // and then extract that field's value.
+
+ auto tupleType = legalType.getTuple();
+
+ RefPtr<TuplePseudoVal> resTupleInfo = new TuplePseudoVal();
+ UInt argCounter = 0;
+ for(auto typeElem : tupleType->elements)
+ {
+ auto elemKey = typeElem.key;
+ UInt argIndex = argCounter++;
+ SLANG_ASSERT(argIndex < argCount);
+
+ LegalVal argVal = legalArgs[argIndex];
+
+ TuplePseudoVal::Element resElem;
+ resElem.key = elemKey;
+ resElem.val = argVal;
+
+ resTupleInfo->elements.Add(resElem);
+ }
+ return LegalVal::tuple(resTupleInfo);
+ }
+
+ default:
+ SLANG_UNEXPECTED("unhandled");
+ UNREACHABLE_RETURN(LegalVal());
+ }
+}
+
+
+
static LegalVal legalizeInst(
IRTypeLegalizationContext* context,
IRInst* inst,
@@ -695,6 +811,13 @@ static LegalVal legalizeInst(
case kIROp_Call:
return legalizeCall(context, (IRCall*)inst);
+ case kIROp_makeStruct:
+ return legalizeMakeStruct(
+ context,
+ type,
+ args,
+ inst->getOperandCount());
+
default:
// TODO: produce a user-visible diagnostic here
SLANG_UNEXPECTED("non-simple operand(s)!");
diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h
index c4cafe157..014df123f 100644
--- a/source/slang/legalize-types.h
+++ b/source/slang/legalize-types.h
@@ -290,7 +290,7 @@ struct LegalVal
return result;
}
- IRInst* getSimple()
+ IRInst* getSimple() const
{
SLANG_ASSERT(flavor == Flavor::simple);
return irValue;
@@ -298,7 +298,7 @@ struct LegalVal
static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal);
- RefPtr<TuplePseudoVal> getTuple()
+ RefPtr<TuplePseudoVal> getTuple() const
{
SLANG_ASSERT(flavor == Flavor::tuple);
return obj.As<TuplePseudoVal>();
@@ -313,7 +313,7 @@ struct LegalVal
LegalVal const& specialVal,
RefPtr<PairInfo> pairInfo);
- RefPtr<PairPseudoVal> getPair()
+ RefPtr<PairPseudoVal> getPair() const
{
SLANG_ASSERT(flavor == Flavor::pair);
return obj.As<PairPseudoVal>();
diff --git a/tests/bugs/gh-566.slang b/tests/bugs/gh-566.slang
new file mode 100644
index 000000000..eeb8c1639
--- /dev/null
+++ b/tests/bugs/gh-566.slang
@@ -0,0 +1,34 @@
+// legalize-struct-init.slang
+
+//TEST(compute):COMPARE_COMPUTE:
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+//TEST_INPUT:ubuffer(data=[4 3 2 1], stride=4):dxbinding(1),glbinding(1)
+
+
+RWStructuredBuffer<uint> outputBuffer;
+RWStructuredBuffer<uint> inputBuffer;
+
+struct Helper
+{
+ RWStructuredBuffer<uint> o;
+ RWStructuredBuffer<uint> i;
+ uint t;
+};
+
+void test(Helper h)
+{
+ h.o[h.t] = h.i[h.t] * 16 + h.t;
+}
+
+void test(uint t)
+{
+ Helper h = { outputBuffer, inputBuffer, t };
+ test(h);
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ test(tid);
+}
diff --git a/tests/bugs/gh-566.slang.expected.txt b/tests/bugs/gh-566.slang.expected.txt
new file mode 100644
index 000000000..309b89a9a
--- /dev/null
+++ b/tests/bugs/gh-566.slang.expected.txt
@@ -0,0 +1,4 @@
+40
+31
+22
+13