diff options
| author | Yong He <yonghe@outlook.com> | 2020-11-06 10:26:27 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-11-06 10:26:27 -0800 |
| commit | 444ff4d8fdeb721b94a9424d03c162f43fb217c9 (patch) | |
| tree | 7896e00e223d9b1a66a8479f510e60136c5713c3 /source | |
| parent | 94861d5d8afdf216c0a507af24fdbe9fda4b66d7 (diff) | |
Specialize witness table lookups. (#1596)
* Specialize witness table lookups.
* Remove generated files from vcxproj
* Fix call to generic interface methods.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-generics-lowering-context.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-generics-lowering-context.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-call.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-dispatch.cpp | 22 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp | 228 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 9 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 6 |
13 files changed, 326 insertions, 23 deletions
diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp index 9c4eb7856..bef4dbb0a 100644 --- a/source/slang/slang-ir-generics-lowering-context.cpp +++ b/source/slang/slang-ir-generics-lowering-context.cpp @@ -219,6 +219,21 @@ namespace Slang } } + List<IRWitnessTable*> SharedGenericsLoweringContext::getWitnessTablesFromInterfaceType(IRInst* interfaceType) + { + List<IRWitnessTable*> witnessTables; + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->op == kIROp_WitnessTable && + cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() == + interfaceType) + { + witnessTables.add(cast<IRWitnessTable>(globalInst)); + } + } + return witnessTables; + } + IRIntegerValue SharedGenericsLoweringContext::getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLocation) { if (auto decor = type->findDecoration<IRAnyValueSizeDecoration>()) diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h index 3bd86e068..2e5366561 100644 --- a/source/slang/slang-ir-generics-lowering-context.h +++ b/source/slang/slang-ir-generics-lowering-context.h @@ -76,6 +76,19 @@ namespace Slang { return lowerType(builder, paramType, Dictionary<IRInst*, IRInst*>()); } + + // Get a list of all witness tables whose conformance type is `interfaceType`. + List<IRWitnessTable*> getWitnessTablesFromInterfaceType(IRInst* interfaceType); + + IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) + { + for (auto entry : table->getEntries()) + { + if (entry->getRequirementKey() == key) + return entry->getSatisfyingVal(); + } + return nullptr; + } }; bool isPolymorphicType(IRInst* typeInst); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 3681077cb..004975252 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -187,7 +187,12 @@ INST(TypeType, type_t, 0, 0) // An `IRWitnessTable` has type `WitnessTableType`. INST(WitnessTableType, witness_table_t, 1, 0) -INST_RANGE(Type, VoidType, WitnessTableType) +// An integer type representing a witness table for targets where +// witness tables are represented as integer IDs. This type is used +// during the lower-generics pass while generating dynamic dispatch +// code and will eventually lower into an uint type. +INST(WitnessTableIDType, witness_table_id_t, 1, 0) +INST_RANGE(Type, VoidType, WitnessTableIDType) /*IRGlobalValueWithCode*/ /* IRGlobalValueWithParams*/ diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index d3f94abed..335ab4827 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1817,6 +1817,7 @@ struct IRBuilder IRBasicBlockType* getBasicBlockType(); IRWitnessTableType* getWitnessTableType(IRType* baseType); + IRWitnessTableIDType* getWitnessTableIDType(IRType* baseType); IRType* getTypeType() { return getType(IROp::kIROp_TypeType); } IRType* getKeyType() { return nullptr; } diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index bd01a78fb..369bc712f 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -266,11 +266,17 @@ namespace Slang return; SLANG_UNEXPECTED("Nested generics specialization."); } + else if (loweredFunc->op == kIROp_lookup_interface_method) + { + lowerCallToInterfaceMethod( + callInst, cast<IRLookupWitnessMethod>(loweredFunc), specializeInst); + return; + } IRFuncType* funcType = cast<IRFuncType>(loweredFunc->getDataType()); translateCallInst(callInst, funcType, loweredFunc, specializeInst); } - void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst) + void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst, IRSpecialize* specializeInst) { // If we see a call(lookup_interface_method(...), ...), we need to translate // all occurences of associatedtypes. @@ -312,7 +318,10 @@ namespace Slang // Translate the new call inst as normal, taking care of packing/unpacking inputs // and outputs. translateCallInst( - newCall, cast<IRFuncType>(dispatchFunc->getFullType()), dispatchFunc, nullptr); + newCall, + cast<IRFuncType>(dispatchFunc->getFullType()), + dispatchFunc, + specializeInst); } void lowerCall(IRCall* callInst) @@ -320,7 +329,7 @@ namespace Slang if (auto specializeInst = as<IRSpecialize>(callInst->getCallee())) lowerCallToSpecializedFunc(callInst, specializeInst); else if (auto lookupInst = as<IRLookupWitnessMethod>(callInst->getCallee())) - lowerCallToInterfaceMethod(callInst, lookupInst); + lowerCallToInterfaceMethod(callInst, lookupInst, nullptr); } void processInst(IRInst* inst) diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 11bb400b0..89194e594 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -9,6 +9,7 @@ #include "slang-ir-lower-generic-call.h" #include "slang-ir-lower-generic-type.h" #include "slang-ir-specialize-dispatch.h" +#include "slang-ir-specialize-dynamic-associatedtype-lookup.h" #include "slang-ir-witness-table-wrapper.h" #include "slang-ir-ssa.h" #include "slang-ir-dce.h" @@ -63,6 +64,10 @@ namespace Slang if (sink->getErrorCount() != 0) return; + specializeDynamicAssociatedTypeLookup(&sharedContext); + if (sink->getErrorCount() != 0) + return; + // We might have generated new temporary variables during lowering. // An SSA pass can clean up unnecessary load/stores. constructSSA(module); diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index ddbb743a8..98a0ba3b7 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -6,29 +6,13 @@ namespace Slang { -IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) -{ - for (auto entry : table->getEntries()) - { - if (entry->getRequirementKey() == key) - return entry->getSatisfyingVal(); - } - return nullptr; -} - IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc) { auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0); // Collect all witness tables of `witnessTableType` in current module. - List<IRWitnessTable*> witnessTables; - for (auto globalInst : sharedContext->module->getGlobalInsts()) - { - if (globalInst->op == kIROp_WitnessTable && globalInst->getDataType() == witnessTableType) - { - witnessTables.add(cast<IRWitnessTable>(globalInst)); - } - } + List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType( + cast<IRWitnessTableType>(witnessTableType)->getConformanceType()); SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock()); auto block = dispatchFunc->getFirstBlock(); @@ -119,7 +103,7 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, builder->setInsertInto(defaultBlock); } - auto callee = findWitnessTableEntry(witnessTable, requirementKey); + auto callee = sharedContext->findWitnessTableEntry(witnessTable, requirementKey); SLANG_ASSERT(callee); auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params); if (callInst->getDataType()->op == kIROp_VoidType) diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp new file mode 100644 index 000000000..c3095eb98 --- /dev/null +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -0,0 +1,228 @@ +#include "slang-ir-specialize-dispatch.h" + +#include "slang-ir-generics-lowering-context.h" +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ + +struct AssociatedTypeLookupSpecializationContext +{ + SharedGenericsLoweringContext* sharedContext; + + IRFunc* createWitnessTableLookupFunc(IRInterfaceType* interfaceType, IRInst* key) + { + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(interfaceType); + + auto inputWitnessTableIDType = builder.getWitnessTableIDType(interfaceType); + auto requirementEntry = sharedContext->findInterfaceRequirementVal(interfaceType, key); + + auto resultWitnessTableType = cast<IRWitnessTableType>(requirementEntry); + auto resultWitnessTableIDType = + builder.getWitnessTableIDType((IRType*)resultWitnessTableType->getConformanceType()); + + auto funcType = + builder.getFuncType(1, (IRType**)&inputWitnessTableIDType, resultWitnessTableIDType); + auto func = builder.createFunc(); + func->setFullType(funcType); + + if (auto linkage = key->findDecoration<IRLinkageDecoration>()) + builder.addNameHintDecoration(func, linkage->getMangledName()); + + builder.setInsertInto(func); + + auto block = builder.emitBlock(); + auto witnessTableParam = builder.emitParam(inputWitnessTableIDType); + + // Collect all witness tables of `witnessTableType` in current module. + List<IRWitnessTable*> witnessTables = + sharedContext->getWitnessTablesFromInterfaceType(interfaceType); + + // Generate case blocks for each possible witness table. + IRBlock* defaultBlock = nullptr; + List<IRInst*> caseBlocks; + for (Index i = 0; i < witnessTables.getCount(); i++) + { + auto witnessTable = witnessTables[i]; + auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>(); + SLANG_ASSERT(seqIdDecoration); + + if (i != witnessTables.getCount() - 1) + { + // Create a case block if we are not the last case. + caseBlocks.add(seqIdDecoration->getSequentialIDOperand()); + builder.setInsertInto(func); + auto caseBlock = builder.emitBlock(); + caseBlocks.add(caseBlock); + } + else + { + // Generate code for the last possible value in the `default` block. + builder.setInsertInto(func); + defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + } + + auto resultWitnessTable = sharedContext->findWitnessTableEntry(witnessTable, key); + auto resultWitnessTableIDDecoration = + resultWitnessTable->findDecoration<IRSequentialIDDecoration>(); + SLANG_ASSERT(resultWitnessTableIDDecoration); + builder.emitReturn(resultWitnessTableIDDecoration->getSequentialIDOperand()); + } + + // Emit a switch statement to return the correct witness table ID based on + // the witness table ID passed in. + builder.setInsertInto(func); + auto breakBlock = builder.emitBlock(); + builder.setInsertInto(breakBlock); + builder.emitUnreachable(); + + builder.setInsertInto(block); + builder.emitSwitch( + witnessTableParam, + breakBlock, + defaultBlock, + caseBlocks.getCount(), + caseBlocks.getBuffer()); + + return func; + } + + // Retrieves the conformance type from a WitnessTableType or a WitnessTableIDType. + IRInterfaceType* getInterfaceTypeFromWitnessTableTypes(IRInst* witnessTableType) + { + switch (witnessTableType->op) + { + case kIROp_WitnessTableType: + return cast<IRInterfaceType>( + cast<IRWitnessTableType>(witnessTableType)->getConformanceType()); + case kIROp_WitnessTableIDType: + return cast<IRInterfaceType>( + cast<IRWitnessTableIDType>(witnessTableType)->getConformanceType()); + default: + return nullptr; + } + } + + void processLookupInterfaceMethodInst(IRLookupWitnessMethod* inst) + { + // Ignore lookups for RTTI objects for now, since they are not used anywhere. + if (!as<IRWitnessTableType>(inst->getDataType())) + return; + + // Replace all witness table lookups with calls to specialized functions that directly + // returns the sequential ID of the resulting witness table, effectively getting rid + // of actual witness table objects in the target code (they all become IDs). + auto witnessTableType = inst->getWitnessTable()->getDataType(); + IRInterfaceType* interfaceType = getInterfaceTypeFromWitnessTableTypes(witnessTableType); + if (!interfaceType) + return; + auto key = inst->getRequirementKey(); + IRFunc* func = nullptr; + if (!sharedContext->mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(key, func)) + { + func = createWitnessTableLookupFunc(interfaceType, key); + sharedContext->mapInterfaceRequirementKeyToDispatchMethods[key] = func; + } + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(inst); + auto witnessTableArg = inst->getWitnessTable(); + if (witnessTableArg->getDataType()->op == kIROp_WitnessTableType) + { + witnessTableArg = builder.emitGetSequentialIDInst(witnessTableArg); + } + auto callInst = builder.emitCallInst( + builder.getWitnessTableIDType(interfaceType), func, witnessTableArg); + inst->replaceUsesWith(callInst); + inst->removeAndDeallocate(); + } + + void cleanUpWitnessTableIDType() + { + List<IRInst*> instsToRemove; + for (auto inst : sharedContext->module->getGlobalInsts()) + { + if (inst->op == kIROp_WitnessTableIDType) + { + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(inst); + inst->replaceUsesWith(builder.getUIntType()); + instsToRemove.add(inst); + } + } + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); + } + + void processGetSequentialIDInst(IRGetSequentialID* inst) + { + if (inst->getRTTIOperand()->getDataType()->op == kIROp_WitnessTableIDType) + { + inst->replaceUsesWith(inst->getRTTIOperand()); + inst->removeAndDeallocate(); + } + } + + void processModule() + { + SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; + sharedBuilder->module = sharedContext->module; + sharedBuilder->session = sharedContext->module->session; + + sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + + while (sharedContext->workList.getCount() != 0) + { + IRInst* inst = sharedContext->workList.getLast(); + + sharedContext->workList.removeLast(); + sharedContext->workListSet.Remove(inst); + + if (inst->op == kIROp_lookup_interface_method) + { + processLookupInterfaceMethodInst(cast<IRLookupWitnessMethod>(inst)); + } + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + sharedContext->addToWorkList(child); + } + } + + // `GetSequentialID(WitnessTableIDOperand)` becomes just `WitnessTableIDOperand`. + sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + while (sharedContext->workList.getCount() != 0) + { + IRInst* inst = sharedContext->workList.getLast(); + + sharedContext->workList.removeLast(); + sharedContext->workListSet.Remove(inst); + + if (inst->op == kIROp_GetSequentialID) + { + processGetSequentialIDInst(cast<IRGetSequentialID>(inst)); + } + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + sharedContext->addToWorkList(child); + } + } + + cleanUpWitnessTableIDType(); + } +}; + +void specializeDynamicAssociatedTypeLookup(SharedGenericsLoweringContext* sharedContext) +{ + AssociatedTypeLookupSpecializationContext context; + context.sharedContext = sharedContext; + context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h new file mode 100644 index 000000000..83039eca5 --- /dev/null +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h @@ -0,0 +1,15 @@ +// slang-ir-specialize-dynamic-associatedtype-lookup.h +#pragma once + +namespace Slang +{ +struct SharedGenericsLoweringContext; + +/// Modifies the lookup of associatedtype entries from witness tables into +/// calls to a specialized "lookup" function that takes a witness table id +/// and returns a witness table id. +/// This is used on GPU targets where all witness tables are replaced as +/// integral IDs instead of a real pointer table. +void specializeDynamicAssociatedTypeLookup(SharedGenericsLoweringContext* sharedContext); + +} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 30790af84..5bc289dab 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2596,6 +2596,16 @@ namespace Slang (IRInst* const*)&baseType); } + IRWitnessTableIDType* IRBuilder::getWitnessTableIDType( + IRType* baseType) + { + return (IRWitnessTableIDType*)findOrEmitHoistableInst( + nullptr, + kIROp_WitnessTableIDType, + 1, + (IRInst* const*)&baseType); + } + IRConstantBufferType* IRBuilder::getConstantBufferType(IRType* elementType) { IRInst* operands[] = { elementType }; @@ -5496,6 +5506,7 @@ namespace Slang case kIROp_DefaultConstruct: case kIROp_Specialize: case kIROp_lookup_interface_method: + case kIROp_GetSequentialID: case kIROp_getAddr: case kIROp_GetValueFromExistentialBox: case kIROp_Construct: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 3b79d9d14..4e78cbb78 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1289,6 +1289,15 @@ struct IRWitnessTableType : IRType IR_LEAF_ISA(WitnessTableType); }; +struct IRWitnessTableIDType : IRType +{ + IRInst* getConformanceType() + { + return getOperand(0); + } + IR_LEAF_ISA(WitnessTableIDType); +}; + struct IRBindExistentialsType : IRType { IR_LEAF_ISA(BindExistentialsType) diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 17590a21a..fd5986481 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -259,6 +259,7 @@ <ClInclude Include="slang-ir-sccp.h" /> <ClInclude Include="slang-ir-specialize-arrays.h" /> <ClInclude Include="slang-ir-specialize-dispatch.h" /> + <ClInclude Include="slang-ir-specialize-dynamic-associatedtype-lookup.h" /> <ClInclude Include="slang-ir-specialize-function-call.h" /> <ClInclude Include="slang-ir-specialize-resources.h" /> <ClInclude Include="slang-ir-specialize.h" /> @@ -383,6 +384,7 @@ <ClCompile Include="slang-ir-sccp.cpp" /> <ClCompile Include="slang-ir-specialize-arrays.cpp" /> <ClCompile Include="slang-ir-specialize-dispatch.cpp" /> + <ClCompile Include="slang-ir-specialize-dynamic-associatedtype-lookup.cpp" /> <ClCompile Include="slang-ir-specialize-function-call.cpp" /> <ClCompile Include="slang-ir-specialize-resources.cpp" /> <ClCompile Include="slang-ir-specialize.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 5664b1f66..9320987cf 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -228,6 +228,9 @@ <ClInclude Include="slang-ir-specialize-dispatch.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="slang-ir-specialize-dynamic-associatedtype-lookup.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="slang-ir-specialize-function-call.h"> <Filter>Header Files</Filter> </ClInclude> @@ -596,6 +599,9 @@ <ClCompile Include="slang-ir-specialize-dispatch.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="slang-ir-specialize-dynamic-associatedtype-lookup.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="slang-ir-specialize-function-call.cpp"> <Filter>Source Files</Filter> </ClCompile> |
