From e5a75563a1ba2e378353af8b937b8b7bb0fe2c2b Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 21 Jun 2022 14:55:59 -0700 Subject: Lower throwing COM interface method. (#2282) * Lower throwing COM interface method. * Fix. * Fix warnings. Co-authored-by: Yong He --- source/slang/slang-ir-lower-com-methods.cpp | 138 ++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 source/slang/slang-ir-lower-com-methods.cpp (limited to 'source/slang/slang-ir-lower-com-methods.cpp') diff --git a/source/slang/slang-ir-lower-com-methods.cpp b/source/slang/slang-ir-lower-com-methods.cpp new file mode 100644 index 000000000..6c3a3f289 --- /dev/null +++ b/source/slang/slang-ir-lower-com-methods.cpp @@ -0,0 +1,138 @@ +// slang-ir-lower-com-methods.cpp + +#include "slang-ir-lower-com-methods.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-marshal-native-call.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ + +struct ComMethodLoweringContext : public InstPassBase +{ + DiagnosticSink* diagnosticSink = nullptr; + + NativeCallMarshallingContext marshal; + + OrderedHashSet comCallees; + + ComMethodLoweringContext(IRModule* inModule) + : InstPassBase(inModule) + {} + + void processComCall(IRCall* comCall) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(comCall); + auto callee = as(comCall->getCallee()); + SLANG_ASSERT(callee); + + comCallees.Add(callee); + + auto calleeType = as(comCall->getCallee()->getDataType()); + SLANG_ASSERT(calleeType); + + auto nativeFuncType = marshal.getNativeFuncType(builder, calleeType); + ShortList args; + for (UInt i = 0; i < comCall->getArgCount(); i++) + args.add(comCall->getArg(i)); + auto currentBlock = builder.getBlock(); + auto nextInst = comCall->getNextInst(); + auto newResult = marshal.marshalNativeCall( + builder, + calleeType, + nativeFuncType, + comCall->getCallee(), + args.getCount(), + args.getArrayView().getBuffer()); + + comCall->replaceUsesWith(newResult); + if (builder.getBlock() != currentBlock) + { + // `marshalNativeCall` may have replaced the original call with branch insts. + // If this is the case, we need to move all insts after the original call in the original + // basic block to the new basic block. + while (nextInst) + { + auto next = nextInst->getNextInst(); + nextInst->removeFromParent(); + nextInst->insertAtEnd(builder.getBlock()); + nextInst = next; + } + } + comCall->removeAndDeallocate(); + } + + void processCall(IRCall* inst) + { + auto funcValue = inst->getOperand(0); + + // Detect if this is a call into a COM interface method. + if (funcValue->getOp() == kIROp_lookup_interface_method) + { + const auto operand0TypeOp = funcValue->getOperand(0)->getDataType(); + if (auto tableType = as(operand0TypeOp)) + { + if (tableType->getConformanceType()->findDecoration()) + { + processComCall(inst); + return; + } + } + } + } + + void processInterfaceType(IRInterfaceType* interfaceType) + { + if (!interfaceType->findDecoration()) + return; + IRBuilder builder(&sharedBuilderStorage); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + auto entry = as(interfaceType->getOperand(i)); + if (!entry) + continue; + if (auto funcType = as(entry->getRequirementVal())) + { + builder.setInsertBefore(funcType); + entry->setRequirementVal(marshal.getNativeFuncType(builder, funcType)); + } + } + } + + void processModule() + { + sharedBuilderStorage.init(module); + + // Deduplicate equivalent types. + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + + // Translate all Calls to interface methods. + processInstsOfType(kIROp_Call, [this](IRCall* inst) { processCall(inst); }); + + // Update functypes of com callees. + for (auto callee : comCallees) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(callee); + auto nativeType = marshal.getNativeFuncType(builder, as(callee->getDataType())); + callee->setFullType(nativeType); + } + + // Update func types of COM interfaces. + processInstsOfType(kIROp_InterfaceType, [this](IRInterfaceType* inst) { processInterfaceType(inst); }); + + } +}; + +void lowerComMethods(IRModule* module, DiagnosticSink* sink) +{ + ComMethodLoweringContext context(module); + context.diagnosticSink = sink; + context.marshal.diagnosticSink = sink; + + return context.processModule(); +} +} -- cgit v1.2.3