diff options
| -rw-r--r-- | source/slang/ir-legalize-types.cpp | 123 | ||||
| -rw-r--r-- | source/slang/legalize-types.h | 6 | ||||
| -rw-r--r-- | tests/bugs/gh-566.slang | 34 | ||||
| -rw-r--r-- | tests/bugs/gh-566.slang.expected.txt | 4 |
4 files changed, 164 insertions, 3 deletions
diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 05b8ca647..969e3eb96 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -669,6 +669,122 @@ static LegalVal legalizeGetElementPtr( indexOperand); } +static LegalVal legalizeMakeStruct( + IRTypeLegalizationContext* context, + LegalType legalType, + LegalVal const* legalArgs, + UInt argCount) +{ + auto builder = context->builder; + + switch(legalType.flavor) + { + case LegalType::Flavor::simple: + { + List<IRInst*> args; + for(UInt aa = 0; aa < argCount; ++aa) + { + // Note: we assume that all the arguments + // must be simple here, because otherwise + // the `struct` type with them as fields + // would not be simple... + // + args.Add(legalArgs[aa].getSimple()); + } + return LegalVal::simple( + builder->emitMakeStruct( + legalType.getSimple(), + argCount, + args.Buffer())); + } + + case LegalType::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairType = legalType.getPair(); + auto pairInfo = pairType->pairInfo; + LegalType ordinaryType = pairType->ordinaryType; + LegalType specialType = pairType->specialType; + + List<LegalVal> ordinaryArgs; + List<LegalVal> specialArgs; + UInt argCounter = 0; + for(auto ee : pairInfo->elements) + { + UInt argIndex = argCounter++; + LegalVal arg = legalArgs[argIndex]; + + if((ee.flags & Slang::PairInfo::kFlag_hasOrdinaryAndSpecial) == Slang::PairInfo::kFlag_hasOrdinaryAndSpecial) + { + // The field is itself a pair type, so we expect + // the argument value to be one too... + auto argPair = arg.getPair(); + ordinaryArgs.Add(argPair->ordinaryVal); + specialArgs.Add(argPair->specialVal); + } + else if(ee.flags & Slang::PairInfo::kFlag_hasOrdinary) + { + ordinaryArgs.Add(arg); + } + else if(ee.flags & Slang::PairInfo::kFlag_hasSpecial) + { + specialArgs.Add(arg); + } + } + + LegalVal ordinaryVal = legalizeMakeStruct( + context, + ordinaryType, + ordinaryArgs.Buffer(), + ordinaryArgs.Count()); + + LegalVal specialVal = legalizeMakeStruct( + context, + specialType, + specialArgs.Buffer(), + specialArgs.Count()); + + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + break; + + case LegalType::Flavor::tuple: + { + // We are constructing a tuple of values from + // the individual fields. We need to identify + // for each tuple element what field it uses, + // and then extract that field's value. + + auto tupleType = legalType.getTuple(); + + RefPtr<TuplePseudoVal> resTupleInfo = new TuplePseudoVal(); + UInt argCounter = 0; + for(auto typeElem : tupleType->elements) + { + auto elemKey = typeElem.key; + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < argCount); + + LegalVal argVal = legalArgs[argIndex]; + + TuplePseudoVal::Element resElem; + resElem.key = elemKey; + resElem.val = argVal; + + resTupleInfo->elements.Add(resElem); + } + return LegalVal::tuple(resTupleInfo); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + + + static LegalVal legalizeInst( IRTypeLegalizationContext* context, IRInst* inst, @@ -695,6 +811,13 @@ static LegalVal legalizeInst( case kIROp_Call: return legalizeCall(context, (IRCall*)inst); + case kIROp_makeStruct: + return legalizeMakeStruct( + context, + type, + args, + inst->getOperandCount()); + default: // TODO: produce a user-visible diagnostic here SLANG_UNEXPECTED("non-simple operand(s)!"); diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h index c4cafe157..014df123f 100644 --- a/source/slang/legalize-types.h +++ b/source/slang/legalize-types.h @@ -290,7 +290,7 @@ struct LegalVal return result; } - IRInst* getSimple() + IRInst* getSimple() const { SLANG_ASSERT(flavor == Flavor::simple); return irValue; @@ -298,7 +298,7 @@ struct LegalVal static LegalVal tuple(RefPtr<TuplePseudoVal> tupleVal); - RefPtr<TuplePseudoVal> getTuple() + RefPtr<TuplePseudoVal> getTuple() const { SLANG_ASSERT(flavor == Flavor::tuple); return obj.As<TuplePseudoVal>(); @@ -313,7 +313,7 @@ struct LegalVal LegalVal const& specialVal, RefPtr<PairInfo> pairInfo); - RefPtr<PairPseudoVal> getPair() + RefPtr<PairPseudoVal> getPair() const { SLANG_ASSERT(flavor == Flavor::pair); return obj.As<PairPseudoVal>(); diff --git a/tests/bugs/gh-566.slang b/tests/bugs/gh-566.slang new file mode 100644 index 000000000..eeb8c1639 --- /dev/null +++ b/tests/bugs/gh-566.slang @@ -0,0 +1,34 @@ +// legalize-struct-init.slang + +//TEST(compute):COMPARE_COMPUTE: +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out +//TEST_INPUT:ubuffer(data=[4 3 2 1], stride=4):dxbinding(1),glbinding(1) + + +RWStructuredBuffer<uint> outputBuffer; +RWStructuredBuffer<uint> inputBuffer; + +struct Helper +{ + RWStructuredBuffer<uint> o; + RWStructuredBuffer<uint> i; + uint t; +}; + +void test(Helper h) +{ + h.o[h.t] = h.i[h.t] * 16 + h.t; +} + +void test(uint t) +{ + Helper h = { outputBuffer, inputBuffer, t }; + test(h); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + test(tid); +} diff --git a/tests/bugs/gh-566.slang.expected.txt b/tests/bugs/gh-566.slang.expected.txt new file mode 100644 index 000000000..309b89a9a --- /dev/null +++ b/tests/bugs/gh-566.slang.expected.txt @@ -0,0 +1,4 @@ +40 +31 +22 +13 |
