diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-19 15:03:56 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-19 15:03:56 -0700 |
| commit | 453683bf44f2112719802eaac2b332d49eebd640 (patch) | |
| tree | d399db4c9cba90c11980186d3df1ffcc4d423b5a /source/slang/slang-ir-specialize.cpp | |
| parent | ecf85df6eee3da76ef54b14e4ab083f22da89e46 (diff) | |
Tuple swizzling, concat, comparison and `countof`. (#4856)
* Tuple swizzling and element access.
* Update proposal status.
* Cleanup.
* Fix merrge error.
* Address review.
Diffstat (limited to 'source/slang/slang-ir-specialize.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 210 |
1 files changed, 181 insertions, 29 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 2eb16112f..c9e94352e 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -594,7 +594,165 @@ struct SpecializationContext case kIROp_GetTupleElement: return maybeSpecializeFoldableInst(inst); + + case kIROp_TypePack: + case kIROp_TupleType: + return maybeSpecializeTypePackOrTupleType(inst); + + case kIROp_MakeValuePack: + case kIROp_MakeTuple: + return maybeSpecializeMakeValuePackOrTuple(inst); + + case kIROp_CountOf: + return maybeSpecializeCountOf(inst); + } + } + + + void flattenPackOperand(ShortList<IRInst*>& flattenedList, IRInst* inst) + { + if (auto makeValuePack = as<IRMakeValuePack>(inst)) + { + for (UInt i = 0; i < makeValuePack->getOperandCount(); i++) + { + flattenPackOperand(flattenedList, makeValuePack->getOperand(i)); + } + } + else if (auto typePack = as<IRTypePack>(inst)) + { + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + flattenPackOperand(flattenedList, typePack->getOperand(i)); + } + } + else + { + SLANG_ASSERT(inst); + flattenedList.add(inst); + } + } + + bool maybeSpecializeTypePackOrTupleType(IRInst* inst) + { + // If any element of the type pack or tuple is a TypePack, we want to + // flatten that type pack into the current type pack or tuple. + + bool needProcess = false; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (as<IRTypePack>(inst->getOperand(i))) + { + needProcess = true; + break; + } + } + // If none of the operands are MakeValuePack, there is no need to flatten anything. + if (!needProcess) + return false; + + // We will recursively flatten all MakeValuePack operands. + ShortList<IRInst*> flattendOperands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + flattenPackOperand(flattendOperands, operand); + } + + IRBuilder builder(module); + builder.setInsertBefore(inst); + IRInst* newInst; + if (inst->getOp() == kIROp_TypePack) + newInst = builder.getTypePack(flattendOperands.getCount(), (IRType* const*)flattendOperands.getArrayView().getBuffer()); + else + newInst = builder.getTupleType(flattendOperands.getCount(), (IRType* const*)flattendOperands.getArrayView().getBuffer()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + addUsersToWorkList(newInst); + return true; + } + + bool maybeSpecializeMakeValuePackOrTuple(IRInst* inst) + { + // If any element of the value pack or tuple is a ValuePack, we want to + // flatten that value pack into the current value pack or tuple. + + bool needProcess = false; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (as<IRMakeValuePack>(inst->getOperand(i))) + { + needProcess = true; + break; + } } + // If none of the operands are MakeValuePack, there is no need to flatten anything. + if (!needProcess) + return false; + + // We will recursively flatten all MakeValuePack operands. + ShortList<IRInst*> flattendOperands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + flattenPackOperand(flattendOperands, operand); + } + + IRBuilder builder(module); + builder.setInsertBefore(inst); + IRInst* newInst = nullptr; + if (inst->getOp() == kIROp_MakeValuePack) + newInst = builder.emitMakeValuePack(inst->getFullType(), flattendOperands.getCount(), flattendOperands.getArrayView().getBuffer()); + else + newInst = builder.emitMakeTuple(inst->getFullType(), flattendOperands.getCount(), flattendOperands.getArrayView().getBuffer()); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + addUsersToWorkList(newInst); + return true; + } + + bool maybeSpecializeCountOf(IRInst* inst) + { + auto operand = inst->getOperand(0); + + // If operand is a value, make sure we are working on its type. + + switch (operand->getOp()) + { + case kIROp_MakeValuePack: + case kIROp_MakeTuple: + operand = operand->getDataType(); + break; + } + + // We can only figure out the count of a type pack or tuple type. + switch (operand->getOp()) + { + case kIROp_TypePack: + case kIROp_TupleType: + break; + default: + return false; + } + + // If none of the element type is a TypePack, we can just return the count. + for (UInt i = 0; i < operand->getOperandCount(); i++) + { + switch (operand->getOperand(i)->getOp()) + { + case kIROp_Param: + case kIROp_TypePack: + case kIROp_ExpandTypeOrVal: + return false; + } + } + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto newInst = builder.getIntValue(inst->getDataType(), operand->getOperandCount()); + addUsersToWorkList(inst); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; } // Specializing lookup on witness tables is a general @@ -606,7 +764,7 @@ struct SpecializationContext { // Note: While we currently have named the instruction // `lookup_witness_method`, the `method` part is a misnomer - // and the same instruction can look up *any* interfacemay + // and the same instruction can look up *any* interface // requirement based on the witness table that provides // a conformance, and the "key" that indicates the interface // requirement. @@ -2268,7 +2426,7 @@ struct SpecializationContext for (UInt i = 0; i < expandInst->getCaptureCount(); i++) { - if (!as<IRTupleType>(expandInst->getCapture(i))) + if (!as<IRTypePack>(expandInst->getCapture(i))) return false; } @@ -2276,16 +2434,16 @@ struct SpecializationContext builder.setInsertBefore(expandInst); List<IRInst*> elements; UInt elementCount = 0; - if (auto firstTupleType = as<IRTupleType>(expandInst->getCapture(0))) + if (auto firstTypePack = as<IRTypePack>(expandInst->getCapture(0))) { - elementCount = firstTupleType->getOperandCount(); + elementCount = firstTypePack->getOperandCount(); } if (elementCount == 0) { - auto resultTuple = builder.emitMakeTuple(0, (IRInst*const*)nullptr); - expandInst->replaceUsesWith(resultTuple); + auto resultValuePack = builder.emitMakeValuePack(0, (IRInst*const*)nullptr); + expandInst->replaceUsesWith(resultValuePack); expandInst->removeAndDeallocate(); - addUsersToWorkList(resultTuple); + addUsersToWorkList(resultValuePack); return true; } @@ -2328,7 +2486,7 @@ struct SpecializationContext } } - auto resultTuple = builder.emitMakeTuple(elements); + auto resultValuePack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer()); auto currentBlock = builder.getBlock(); for (auto nextInst = expandInst->next; nextInst;) { @@ -2337,7 +2495,7 @@ struct SpecializationContext nextInst = next; } addUsersToWorkList(expandInst); - expandInst->replaceUsesWith(resultTuple); + expandInst->replaceUsesWith(resultValuePack); expandInst->removeAndDeallocate(); return true; } @@ -2355,15 +2513,15 @@ struct SpecializationContext { auto eachInst = as<IREach>(val); auto packInst = eachInst->getElement(); - if (auto tuple = as<IRTupleType>(packInst)) + if (auto typePack = as<IRTypePack>(packInst)) { - SLANG_RELEASE_ASSERT(indexInPack < tuple->getOperandCount()); - return tuple->getOperand(indexInPack); + SLANG_RELEASE_ASSERT(indexInPack < typePack->getOperandCount()); + return typePack->getOperand(indexInPack); } - else if (auto makeTuple = as<IRMakeTuple>(packInst)) + else if (auto makeValuePack = as<IRMakeValuePack>(packInst)) { - SLANG_RELEASE_ASSERT(indexInPack < makeTuple->getOperandCount()); - return makeTuple->getOperand(indexInPack); + SLANG_RELEASE_ASSERT(indexInPack < makeValuePack->getOperandCount()); + return makeValuePack->getOperand(indexInPack); } else if (!as<IRTypeKind>(packInst->getDataType())) { @@ -2413,24 +2571,18 @@ struct SpecializationContext if (expandInst->getCaptureCount() == 0) return false; - bool anyAbstractPack = false; for (UInt i = 0; i < expandInst->getCaptureCount(); i++) { - if (!as<IRTupleType>(expandInst->getCaptureType(i))) - { - anyAbstractPack = true; - break; - } + if (!as<IRTypePack>(expandInst->getCaptureType(i))) + return false; } - if (anyAbstractPack) - return false; IRBuilder builder(expandInst); builder.setInsertBefore(expandInst); List<IRInst*> elements; UInt elementCount = 0; - if (auto firstTupleType = as<IRTupleType>(expandInst->getCaptureType(0))) + if (auto firstTypePack = as<IRTypePack>(expandInst->getCaptureType(0))) { - elementCount = firstTupleType->getOperandCount(); + elementCount = firstTypePack->getOperandCount(); } for (UInt i = 0; i < elementCount; i++) { @@ -2444,16 +2596,16 @@ struct SpecializationContext List<IRType*> types; for (auto element : elements) types.add(element->getDataType()); - auto newTupleType = builder.getTupleType(types); - auto result = builder.emitMakeWitnessPack(newTupleType, elements.getArrayView()); + auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer()); + auto result = builder.emitMakeWitnessPack(newTypePack, elements.getArrayView()); expandInst->replaceUsesWith(result); expandInst->removeAndDeallocate(); return true; } else { - auto newTupleType = builder.getTupleType(elements.getCount(), (IRType*const*)elements.getBuffer()); - expandInst->replaceUsesWith(newTupleType); + auto newTypePack = builder.getTypePack(elements.getCount(), (IRType*const*)elements.getBuffer()); + expandInst->replaceUsesWith(newTypePack); expandInst->removeAndDeallocate(); return true; } |
