summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-specialize.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-19 15:03:56 -0700
committerGitHub <noreply@github.com>2024-08-19 15:03:56 -0700
commit453683bf44f2112719802eaac2b332d49eebd640 (patch)
treed399db4c9cba90c11980186d3df1ffcc4d423b5a /source/slang/slang-ir-specialize.cpp
parentecf85df6eee3da76ef54b14e4ab083f22da89e46 (diff)
Tuple swizzling, concat, comparison and `countof`. (#4856)
* Tuple swizzling and element access. * Update proposal status. * Cleanup. * Fix merrge error. * Address review.
Diffstat (limited to 'source/slang/slang-ir-specialize.cpp')
-rw-r--r--source/slang/slang-ir-specialize.cpp210
1 files changed, 181 insertions, 29 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 2eb16112f..c9e94352e 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -594,7 +594,165 @@ struct SpecializationContext
case kIROp_GetTupleElement:
return maybeSpecializeFoldableInst(inst);
+
+ case kIROp_TypePack:
+ case kIROp_TupleType:
+ return maybeSpecializeTypePackOrTupleType(inst);
+
+ case kIROp_MakeValuePack:
+ case kIROp_MakeTuple:
+ return maybeSpecializeMakeValuePackOrTuple(inst);
+
+ case kIROp_CountOf:
+ return maybeSpecializeCountOf(inst);
+ }
+ }
+
+
+ void flattenPackOperand(ShortList<IRInst*>& flattenedList, IRInst* inst)
+ {
+ if (auto makeValuePack = as<IRMakeValuePack>(inst))
+ {
+ for (UInt i = 0; i < makeValuePack->getOperandCount(); i++)
+ {
+ flattenPackOperand(flattenedList, makeValuePack->getOperand(i));
+ }
+ }
+ else if (auto typePack = as<IRTypePack>(inst))
+ {
+ for (UInt i = 0; i < typePack->getOperandCount(); i++)
+ {
+ flattenPackOperand(flattenedList, typePack->getOperand(i));
+ }
+ }
+ else
+ {
+ SLANG_ASSERT(inst);
+ flattenedList.add(inst);
+ }
+ }
+
+ bool maybeSpecializeTypePackOrTupleType(IRInst* inst)
+ {
+ // If any element of the type pack or tuple is a TypePack, we want to
+ // flatten that type pack into the current type pack or tuple.
+
+ bool needProcess = false;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (as<IRTypePack>(inst->getOperand(i)))
+ {
+ needProcess = true;
+ break;
+ }
+ }
+ // If none of the operands are MakeValuePack, there is no need to flatten anything.
+ if (!needProcess)
+ return false;
+
+ // We will recursively flatten all MakeValuePack operands.
+ ShortList<IRInst*> flattendOperands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ flattenPackOperand(flattendOperands, operand);
+ }
+
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ IRInst* newInst;
+ if (inst->getOp() == kIROp_TypePack)
+ newInst = builder.getTypePack(flattendOperands.getCount(), (IRType* const*)flattendOperands.getArrayView().getBuffer());
+ else
+ newInst = builder.getTupleType(flattendOperands.getCount(), (IRType* const*)flattendOperands.getArrayView().getBuffer());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newInst);
+ return true;
+ }
+
+ bool maybeSpecializeMakeValuePackOrTuple(IRInst* inst)
+ {
+ // If any element of the value pack or tuple is a ValuePack, we want to
+ // flatten that value pack into the current value pack or tuple.
+
+ bool needProcess = false;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (as<IRMakeValuePack>(inst->getOperand(i)))
+ {
+ needProcess = true;
+ break;
+ }
}
+ // If none of the operands are MakeValuePack, there is no need to flatten anything.
+ if (!needProcess)
+ return false;
+
+ // We will recursively flatten all MakeValuePack operands.
+ ShortList<IRInst*> flattendOperands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ flattenPackOperand(flattendOperands, operand);
+ }
+
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ IRInst* newInst = nullptr;
+ if (inst->getOp() == kIROp_MakeValuePack)
+ newInst = builder.emitMakeValuePack(inst->getFullType(), flattendOperands.getCount(), flattendOperands.getArrayView().getBuffer());
+ else
+ newInst = builder.emitMakeTuple(inst->getFullType(), flattendOperands.getCount(), flattendOperands.getArrayView().getBuffer());
+
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newInst);
+ return true;
+ }
+
+ bool maybeSpecializeCountOf(IRInst* inst)
+ {
+ auto operand = inst->getOperand(0);
+
+ // If operand is a value, make sure we are working on its type.
+
+ switch (operand->getOp())
+ {
+ case kIROp_MakeValuePack:
+ case kIROp_MakeTuple:
+ operand = operand->getDataType();
+ break;
+ }
+
+ // We can only figure out the count of a type pack or tuple type.
+ switch (operand->getOp())
+ {
+ case kIROp_TypePack:
+ case kIROp_TupleType:
+ break;
+ default:
+ return false;
+ }
+
+ // If none of the element type is a TypePack, we can just return the count.
+ for (UInt i = 0; i < operand->getOperandCount(); i++)
+ {
+ switch (operand->getOperand(i)->getOp())
+ {
+ case kIROp_Param:
+ case kIROp_TypePack:
+ case kIROp_ExpandTypeOrVal:
+ return false;
+ }
+ }
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+ auto newInst = builder.getIntValue(inst->getDataType(), operand->getOperandCount());
+ addUsersToWorkList(inst);
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ return true;
}
// Specializing lookup on witness tables is a general
@@ -606,7 +764,7 @@ struct SpecializationContext
{
// Note: While we currently have named the instruction
// `lookup_witness_method`, the `method` part is a misnomer
- // and the same instruction can look up *any* interfacemay
+ // and the same instruction can look up *any* interface
// requirement based on the witness table that provides
// a conformance, and the "key" that indicates the interface
// requirement.
@@ -2268,7 +2426,7 @@ struct SpecializationContext
for (UInt i = 0; i < expandInst->getCaptureCount(); i++)
{
- if (!as<IRTupleType>(expandInst->getCapture(i)))
+ if (!as<IRTypePack>(expandInst->getCapture(i)))
return false;
}
@@ -2276,16 +2434,16 @@ struct SpecializationContext
builder.setInsertBefore(expandInst);
List<IRInst*> elements;
UInt elementCount = 0;
- if (auto firstTupleType = as<IRTupleType>(expandInst->getCapture(0)))
+ if (auto firstTypePack = as<IRTypePack>(expandInst->getCapture(0)))
{
- elementCount = firstTupleType->getOperandCount();
+ elementCount = firstTypePack->getOperandCount();
}
if (elementCount == 0)
{
- auto resultTuple = builder.emitMakeTuple(0, (IRInst*const*)nullptr);
- expandInst->replaceUsesWith(resultTuple);
+ auto resultValuePack = builder.emitMakeValuePack(0, (IRInst*const*)nullptr);
+ expandInst->replaceUsesWith(resultValuePack);
expandInst->removeAndDeallocate();
- addUsersToWorkList(resultTuple);
+ addUsersToWorkList(resultValuePack);
return true;
}
@@ -2328,7 +2486,7 @@ struct SpecializationContext
}
}
- auto resultTuple = builder.emitMakeTuple(elements);
+ auto resultValuePack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer());
auto currentBlock = builder.getBlock();
for (auto nextInst = expandInst->next; nextInst;)
{
@@ -2337,7 +2495,7 @@ struct SpecializationContext
nextInst = next;
}
addUsersToWorkList(expandInst);
- expandInst->replaceUsesWith(resultTuple);
+ expandInst->replaceUsesWith(resultValuePack);
expandInst->removeAndDeallocate();
return true;
}
@@ -2355,15 +2513,15 @@ struct SpecializationContext
{
auto eachInst = as<IREach>(val);
auto packInst = eachInst->getElement();
- if (auto tuple = as<IRTupleType>(packInst))
+ if (auto typePack = as<IRTypePack>(packInst))
{
- SLANG_RELEASE_ASSERT(indexInPack < tuple->getOperandCount());
- return tuple->getOperand(indexInPack);
+ SLANG_RELEASE_ASSERT(indexInPack < typePack->getOperandCount());
+ return typePack->getOperand(indexInPack);
}
- else if (auto makeTuple = as<IRMakeTuple>(packInst))
+ else if (auto makeValuePack = as<IRMakeValuePack>(packInst))
{
- SLANG_RELEASE_ASSERT(indexInPack < makeTuple->getOperandCount());
- return makeTuple->getOperand(indexInPack);
+ SLANG_RELEASE_ASSERT(indexInPack < makeValuePack->getOperandCount());
+ return makeValuePack->getOperand(indexInPack);
}
else if (!as<IRTypeKind>(packInst->getDataType()))
{
@@ -2413,24 +2571,18 @@ struct SpecializationContext
if (expandInst->getCaptureCount() == 0)
return false;
- bool anyAbstractPack = false;
for (UInt i = 0; i < expandInst->getCaptureCount(); i++)
{
- if (!as<IRTupleType>(expandInst->getCaptureType(i)))
- {
- anyAbstractPack = true;
- break;
- }
+ if (!as<IRTypePack>(expandInst->getCaptureType(i)))
+ return false;
}
- if (anyAbstractPack)
- return false;
IRBuilder builder(expandInst);
builder.setInsertBefore(expandInst);
List<IRInst*> elements;
UInt elementCount = 0;
- if (auto firstTupleType = as<IRTupleType>(expandInst->getCaptureType(0)))
+ if (auto firstTypePack = as<IRTypePack>(expandInst->getCaptureType(0)))
{
- elementCount = firstTupleType->getOperandCount();
+ elementCount = firstTypePack->getOperandCount();
}
for (UInt i = 0; i < elementCount; i++)
{
@@ -2444,16 +2596,16 @@ struct SpecializationContext
List<IRType*> types;
for (auto element : elements)
types.add(element->getDataType());
- auto newTupleType = builder.getTupleType(types);
- auto result = builder.emitMakeWitnessPack(newTupleType, elements.getArrayView());
+ auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer());
+ auto result = builder.emitMakeWitnessPack(newTypePack, elements.getArrayView());
expandInst->replaceUsesWith(result);
expandInst->removeAndDeallocate();
return true;
}
else
{
- auto newTupleType = builder.getTupleType(elements.getCount(), (IRType*const*)elements.getBuffer());
- expandInst->replaceUsesWith(newTupleType);
+ auto newTypePack = builder.getTypePack(elements.getCount(), (IRType*const*)elements.getBuffer());
+ expandInst->replaceUsesWith(newTypePack);
expandInst->removeAndDeallocate();
return true;
}