summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-tuple-types.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-19 15:03:56 -0700
committerGitHub <noreply@github.com>2024-08-19 15:03:56 -0700
commit453683bf44f2112719802eaac2b332d49eebd640 (patch)
treed399db4c9cba90c11980186d3df1ffcc4d423b5a /source/slang/slang-ir-lower-tuple-types.cpp
parentecf85df6eee3da76ef54b14e4ab083f22da89e46 (diff)
Tuple swizzling, concat, comparison and `countof`. (#4856)
* Tuple swizzling and element access. * Update proposal status. * Cleanup. * Fix merrge error. * Address review.
Diffstat (limited to 'source/slang/slang-ir-lower-tuple-types.cpp')
-rw-r--r--source/slang/slang-ir-lower-tuple-types.cpp174
1 files changed, 173 insertions, 1 deletions
diff --git a/source/slang/slang-ir-lower-tuple-types.cpp b/source/slang/slang-ir-lower-tuple-types.cpp
index caa031d85..6177cfec2 100644
--- a/source/slang/slang-ir-lower-tuple-types.cpp
+++ b/source/slang/slang-ir-lower-tuple-types.cpp
@@ -85,7 +85,7 @@ namespace Slang
workListSet.add(inst);
}
- void processMakeTuple(IRMakeTuple* inst)
+ void processMakeTuple(IRInst* inst)
{
IRBuilder builderStorage(module);
auto builder = &builderStorage;
@@ -121,6 +121,124 @@ namespace Slang
inst->removeAndDeallocate();
}
+ void processGetElementPtr(IRGetElementPtr* inst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+
+ auto base = inst->getBase();
+ auto baseValueType = tryGetPointedToType(&builder, base->getDataType());
+ auto loweredTupleInfo = getLoweredTupleType(&builder, baseValueType);
+ if (!loweredTupleInfo)
+ return;
+
+ auto elementIndex = getIntVal(inst->getIndex());
+ SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount());
+
+ auto field = loweredTupleInfo->fields[(Index)elementIndex];
+ auto getElement = builder.emitFieldAddress(builder.getPtrType(field->getFieldType()), base, field->getKey());
+ inst->replaceUsesWith(getElement);
+ inst->removeAndDeallocate();
+ }
+
+ void processSwizzle(IRSwizzle* inst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+
+ auto base = inst->getBase();
+ auto loweredTupleInfo = getLoweredTupleType(&builder, base->getDataType());
+
+ if (!loweredTupleInfo)
+ return;
+
+ if (inst->getElementCount() == 1)
+ {
+ auto elementIndex = getIntVal(inst->getElementIndex(0));
+ SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount());
+
+ auto field = loweredTupleInfo->fields[(Index)elementIndex];
+ auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey());
+ inst->replaceUsesWith(getElement);
+ inst->removeAndDeallocate();
+ }
+ else
+ {
+ List<IRInst*> elements;
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto elementIndex = getIntVal(inst->getElementIndex(i));
+ SLANG_ASSERT((Index)elementIndex < loweredTupleInfo->fields.getCount());
+
+ auto field = loweredTupleInfo->fields[(Index)elementIndex];
+ auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey());
+ elements.add(getElement);
+ }
+ auto resultTypeInfo = getLoweredTupleType(&builder, inst->getDataType());
+ auto makeStruct = builder.emitMakeStruct(resultTypeInfo->structType, elements);
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ }
+ }
+
+ void processSwizzleSet(IRSwizzleSet* inst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+
+ auto base = inst->getBase();
+ auto loweredTupleInfo = getLoweredTupleType(&builder, base->getDataType());
+ auto sourceTupleInfo = getLoweredTupleType(&builder, inst->getSource()->getDataType());
+ if (!loweredTupleInfo)
+ return;
+
+ List<IRInst*> elements;
+ for (Index i = 0; i < loweredTupleInfo->fields.getCount(); i++)
+ {
+ auto field = loweredTupleInfo->fields[i];
+ auto getElement = builder.emitFieldExtract(field->getFieldType(), base, field->getKey());
+ elements.add(getElement);
+ }
+
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto baseIndex = getIntVal(inst->getElementIndex(i));
+ auto sourceElement = sourceTupleInfo
+ ? builder.emitFieldExtract(sourceTupleInfo->fields[i]->getFieldType(), inst->getSource(), sourceTupleInfo->fields[i]->getKey())
+ : inst->getSource();
+ elements[baseIndex] = sourceElement;
+ }
+ auto resultTypeInfo = getLoweredTupleType(&builder, inst->getDataType());
+ auto makeStruct = builder.emitMakeStruct(resultTypeInfo->structType, elements);
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ }
+
+ void processSwizzledStore(IRSwizzledStore* inst)
+ {
+ IRBuilder builder(module);
+ builder.setInsertBefore(inst);
+
+ auto dest = inst->getDest();
+ auto destValueType = tryGetPointedToType(&builder, dest->getDataType());
+ auto loweredTupleInfo = getLoweredTupleType(&builder, destValueType);
+ auto sourceTupleInfo = getLoweredTupleType(&builder, inst->getSource()->getDataType());
+ if (!loweredTupleInfo)
+ return;
+
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto baseIndex = getIntVal(inst->getElementIndex(i));
+ auto destField = loweredTupleInfo->fields[baseIndex];
+ auto destFieldPtr = builder.emitFieldAddress(builder.getPtrType(destField->getFieldType()), dest, destField->getKey());
+ auto sourceElement = sourceTupleInfo
+ ? builder.emitFieldExtract(sourceTupleInfo->fields[i]->getFieldType(), inst->getSource(), sourceTupleInfo->fields[i]->getKey())
+ : inst->getSource();
+ builder.emitStore(destFieldPtr, sourceElement);
+ }
+ inst->removeAndDeallocate();
+ }
+
void processTupleType(IRTupleType* inst)
{
IRBuilder builderStorage(module);
@@ -132,19 +250,47 @@ namespace Slang
SLANG_UNUSED(loweredTupleInfo);
}
+ void processIndexedFieldKey(IRIndexedFieldKey* inst)
+ {
+ IRBuilder builder(module);
+ auto loweredTupleInfo = getLoweredTupleType(&builder, inst->getBaseType());
+ if (!loweredTupleInfo)
+ return;
+ auto fieldIndex = getIntVal(inst->getIndex());
+ SLANG_ASSERT(fieldIndex >= 0 && (Index)fieldIndex < loweredTupleInfo->fields.getCount());
+ inst->replaceUsesWith(loweredTupleInfo->fields[fieldIndex]->getKey());
+ inst->removeAndDeallocate();
+ }
+
void processInst(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_MakeTuple:
+ case kIROp_MakeValuePack:
processMakeTuple((IRMakeTuple*)inst);
break;
case kIROp_GetTupleElement:
processGetTupleElement((IRGetTupleElement*)inst);
break;
+ case kIROp_GetElementPtr:
+ processGetElementPtr((IRGetElementPtr*)inst);
+ break;
+ case kIROp_swizzle:
+ processSwizzle((IRSwizzle*)inst);
+ break;
+ case kIROp_swizzleSet:
+ processSwizzleSet((IRSwizzleSet*)inst);
+ break;
+ case kIROp_SwizzledStore:
+ processSwizzledStore((IRSwizzledStore*)inst);
+ break;
case kIROp_TupleType:
processTupleType((IRTupleType*)inst);
break;
+ case kIROp_IndexedFieldKey:
+ processIndexedFieldKey((IRIndexedFieldKey*)inst);
+ break;
default:
break;
}
@@ -152,6 +298,32 @@ namespace Slang
void processModule()
{
+ // First, we want to replace all TypePack with TupleType.
+
+ List<IRInst*> typePacks;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (inst->getOp() == kIROp_TypePack)
+ {
+ typePacks.add(inst);
+ }
+ }
+ IRBuilder builder(module);
+ for (auto inst : typePacks)
+ {
+ builder.setInsertBefore(inst);
+ ShortList<IRType*> types;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ types.add((IRType*)inst->getOperand(i));
+ }
+ auto tupleType = builder.getTupleType((UInt)types.getCount(), types.getArrayView().getBuffer());
+ inst->replaceUsesWith(tupleType);
+ inst->removeAndDeallocate();
+ }
+
+ // Next, lower all tuples to structs.
+
addToWorkList(module->getModuleInst());
while (workList.getCount() != 0)