diff options
Diffstat (limited to 'source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp | 73 |
1 files changed, 58 insertions, 15 deletions
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index a8d2902f6..eb77f651e 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -37,6 +37,15 @@ struct AssociatedTypeLookupSpecializationContext auto block = builder.emitBlock(); auto witnessTableParam = builder.emitParam(inputWitnessTableIDType); + // `witnessTableParam` is expected to have `IRWitnessTableID` type, which + // will later lower into a `uint2`. We only use the first element of the uint2 + // to store the sequential ID and reserve the second 32-bit value for future + // pointer-compatibility. We insert a member extract inst right now + // to obtain the first element and use it in our switch statement. + UInt elemIdx = 0; + auto witnessTableSequentialID = + builder.emitSwizzle(builder.getUIntType(), witnessTableParam, 1, &elemIdx); + // Collect all witness tables of `witnessTableType` in current module. List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(interfaceType); @@ -70,23 +79,41 @@ struct AssociatedTypeLookupSpecializationContext auto resultWitnessTableIDDecoration = resultWitnessTable->findDecoration<IRSequentialIDDecoration>(); SLANG_ASSERT(resultWitnessTableIDDecoration); - builder.emitReturn(resultWitnessTableIDDecoration->getSequentialIDOperand()); + // Pack the resulting witness table ID into a `uint2`. + auto uint2Type = builder.getVectorType( + builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); + IRInst* uint2Args[] = { + resultWitnessTableIDDecoration->getSequentialIDOperand(), + builder.getIntValue(builder.getUIntType(), 0)}; + auto resultID = builder.emitMakeVector(uint2Type, 2, uint2Args); + builder.emitReturn(resultID); } - // Emit a switch statement to return the correct witness table ID based on - // the witness table ID passed in. builder.setInsertInto(func); - auto breakBlock = builder.emitBlock(); - builder.setInsertInto(breakBlock); - builder.emitUnreachable(); - - builder.setInsertInto(block); - builder.emitSwitch( - witnessTableParam, - breakBlock, - defaultBlock, - caseBlocks.getCount(), - caseBlocks.getBuffer()); + + if (witnessTables.getCount() == 1) + { + // If there is only 1 case, no switch statement is necessary. + builder.setInsertInto(block); + builder.emitBranch(defaultBlock); + } + else + { + // If there are more than 1 cases, + // emit a switch statement to return the correct witness table ID based on + // the witness table ID passed in. + auto breakBlock = builder.emitBlock(); + builder.setInsertInto(breakBlock); + builder.emitUnreachable(); + + builder.setInsertInto(block); + builder.emitSwitch( + witnessTableSequentialID, + breakBlock, + defaultBlock, + caseBlocks.getCount(), + caseBlocks.getBuffer()); + } return func; } @@ -176,12 +203,28 @@ struct AssociatedTypeLookupSpecializationContext }); // Replace all direct uses of IRWitnessTables with its sequential ID. - workOnModule([](IRInst* inst) + workOnModule([this](IRInst* inst) { if (inst->op == kIROp_WitnessTable) { auto seqId = inst->findDecoration<IRSequentialIDDecoration>(); SLANG_ASSERT(seqId); + // Insert code to pack sequential ID into an uint2 at all use sites. + for (auto use = inst->firstUse; use; ) + { + auto nextUse = use->nextUse; + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(use->getUser()); + auto uint2Type = builder.getVectorType( + builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); + IRInst* uint2Args[] = { + seqId->getSequentialIDOperand(), + builder.getIntValue(builder.getUIntType(), 0)}; + auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args); + use->set(uint2seqID); + use = nextUse; + } inst->replaceUsesWith(seqId->getSequentialIDOperand()); } }); |
