summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp')
-rw-r--r--source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp73
1 files changed, 58 insertions, 15 deletions
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
index a8d2902f6..eb77f651e 100644
--- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
+++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
@@ -37,6 +37,15 @@ struct AssociatedTypeLookupSpecializationContext
auto block = builder.emitBlock();
auto witnessTableParam = builder.emitParam(inputWitnessTableIDType);
+ // `witnessTableParam` is expected to have `IRWitnessTableID` type, which
+ // will later lower into a `uint2`. We only use the first element of the uint2
+ // to store the sequential ID and reserve the second 32-bit value for future
+ // pointer-compatibility. We insert a member extract inst right now
+ // to obtain the first element and use it in our switch statement.
+ UInt elemIdx = 0;
+ auto witnessTableSequentialID =
+ builder.emitSwizzle(builder.getUIntType(), witnessTableParam, 1, &elemIdx);
+
// Collect all witness tables of `witnessTableType` in current module.
List<IRWitnessTable*> witnessTables =
sharedContext->getWitnessTablesFromInterfaceType(interfaceType);
@@ -70,23 +79,41 @@ struct AssociatedTypeLookupSpecializationContext
auto resultWitnessTableIDDecoration =
resultWitnessTable->findDecoration<IRSequentialIDDecoration>();
SLANG_ASSERT(resultWitnessTableIDDecoration);
- builder.emitReturn(resultWitnessTableIDDecoration->getSequentialIDOperand());
+ // Pack the resulting witness table ID into a `uint2`.
+ auto uint2Type = builder.getVectorType(
+ builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2));
+ IRInst* uint2Args[] = {
+ resultWitnessTableIDDecoration->getSequentialIDOperand(),
+ builder.getIntValue(builder.getUIntType(), 0)};
+ auto resultID = builder.emitMakeVector(uint2Type, 2, uint2Args);
+ builder.emitReturn(resultID);
}
- // Emit a switch statement to return the correct witness table ID based on
- // the witness table ID passed in.
builder.setInsertInto(func);
- auto breakBlock = builder.emitBlock();
- builder.setInsertInto(breakBlock);
- builder.emitUnreachable();
-
- builder.setInsertInto(block);
- builder.emitSwitch(
- witnessTableParam,
- breakBlock,
- defaultBlock,
- caseBlocks.getCount(),
- caseBlocks.getBuffer());
+
+ if (witnessTables.getCount() == 1)
+ {
+ // If there is only 1 case, no switch statement is necessary.
+ builder.setInsertInto(block);
+ builder.emitBranch(defaultBlock);
+ }
+ else
+ {
+ // If there are more than 1 cases,
+ // emit a switch statement to return the correct witness table ID based on
+ // the witness table ID passed in.
+ auto breakBlock = builder.emitBlock();
+ builder.setInsertInto(breakBlock);
+ builder.emitUnreachable();
+
+ builder.setInsertInto(block);
+ builder.emitSwitch(
+ witnessTableSequentialID,
+ breakBlock,
+ defaultBlock,
+ caseBlocks.getCount(),
+ caseBlocks.getBuffer());
+ }
return func;
}
@@ -176,12 +203,28 @@ struct AssociatedTypeLookupSpecializationContext
});
// Replace all direct uses of IRWitnessTables with its sequential ID.
- workOnModule([](IRInst* inst)
+ workOnModule([this](IRInst* inst)
{
if (inst->op == kIROp_WitnessTable)
{
auto seqId = inst->findDecoration<IRSequentialIDDecoration>();
SLANG_ASSERT(seqId);
+ // Insert code to pack sequential ID into an uint2 at all use sites.
+ for (auto use = inst->firstUse; use; )
+ {
+ auto nextUse = use->nextUse;
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder.setInsertBefore(use->getUser());
+ auto uint2Type = builder.getVectorType(
+ builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2));
+ IRInst* uint2Args[] = {
+ seqId->getSequentialIDOperand(),
+ builder.getIntValue(builder.getUIntType(), 0)};
+ auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args);
+ use->set(uint2seqID);
+ use = nextUse;
+ }
inst->replaceUsesWith(seqId->getSequentialIDOperand());
}
});