summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-specialize-dispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-specialize-dispatch.cpp')
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp127
1 files changed, 127 insertions, 0 deletions
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
new file mode 100644
index 000000000..0c519427d
--- /dev/null
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -0,0 +1,127 @@
+#include "slang-ir-specialize-dispatch.h"
+
+#include "slang-ir-generics-lowering-context.h"
+#include "slang-ir-insts.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
+{
+ for (auto entry : table->getEntries())
+ {
+ if (entry->getRequirementKey() == key)
+ return entry->getSatisfyingVal();
+ }
+ return nullptr;
+}
+
+void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
+{
+ auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
+
+ // Collect all witness tables of `witnessTableType` in current module.
+ List<IRWitnessTable*> witnessTables;
+ for (auto globalInst : sharedContext->module->getGlobalInsts())
+ {
+ if (globalInst->op == kIROp_WitnessTable && globalInst->getDataType() == witnessTableType)
+ {
+ witnessTables.add(cast<IRWitnessTable>(globalInst));
+ }
+ }
+
+ SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock());
+ auto block = dispatchFunc->getFirstBlock();
+
+ // The dispatch function before modification must be in the form of
+ // call(lookup_interface_method(witnessTableParam, interfaceReqKey), args)
+ // We now find the relavent instructions.
+ IRCall* callInst = nullptr;
+ IRLookupWitnessMethod* lookupInst = nullptr;
+ IRReturn* returnInst = nullptr;
+ for (auto inst : block->getOrdinaryInsts())
+ {
+ switch (inst->op)
+ {
+ case kIROp_Call:
+ callInst = cast<IRCall>(inst);
+ break;
+ case kIROp_lookup_interface_method:
+ lookupInst = cast<IRLookupWitnessMethod>(inst);
+ break;
+ case kIROp_ReturnVal:
+ case kIROp_ReturnVoid:
+ returnInst = cast<IRReturn>(inst);
+ break;
+ default:
+ break;
+ }
+ }
+ SLANG_ASSERT(callInst && lookupInst && returnInst);
+
+ IRBuilder builderStorage;
+ auto builder = &builderStorage;
+ builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder->setInsertBefore(callInst);
+
+ auto witnessTableParam = block->getFirstParam();
+ auto requirementKey = lookupInst->getRequirementKey();
+ List<IRInst*> params;
+ for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam())
+ {
+ params.add(param);
+ }
+
+ // Emit cascaded if statements to call the correct concrete function based on
+ // the witness table pointer passed in.
+ auto ifBlock = block;
+ for (Index i = 0; i < witnessTables.getCount(); i++)
+ {
+ auto witnessTable = witnessTables[i];
+ bool isLast = (i == witnessTables.getCount() - 1);
+ IRInst* cmpArgs[] =
+ {
+ builder->emitBitCast(builder->getUInt64Type(), witnessTableParam),
+ builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable)
+ };
+ IRInst* condition = nullptr;
+ IRBlock* trueBlock = nullptr;
+ if (!isLast)
+ {
+ condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs);
+ trueBlock = builder->emitBlock();
+ }
+ auto callee = findWitnessTableEntry(witnessTable, requirementKey);
+ SLANG_ASSERT(callee);
+ auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
+ if (callInst->getDataType()->op == kIROp_VoidType)
+ builder->emitReturn();
+ else
+ builder->emitReturn(specializedCallInst);
+ if (!isLast)
+ {
+ auto falseBlock = builder->emitBlock();
+ builder->setInsertInto(ifBlock);
+ builder->emitIf(condition, trueBlock, falseBlock);
+ builder->setInsertInto(falseBlock);
+ ifBlock = falseBlock;
+ }
+ }
+
+ // Remove old implementation.
+ lookupInst->removeAndDeallocate();
+ callInst->removeAndDeallocate();
+ returnInst->removeAndDeallocate();
+}
+
+void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
+{
+ sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+
+ for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods)
+ {
+ auto dispatchFunc = kv.Value;
+ specializeDispatchFunction(sharedContext, dispatchFunc);
+ }
+}
+} // namespace Slang