// slang-ir-lower-com-methods.cpp #include "slang-ir-lower-com-methods.h" #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-ir-marshal-native-call.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-util.h" namespace Slang { struct ComMethodLoweringContext : public InstPassBase { DiagnosticSink* diagnosticSink = nullptr; NativeCallMarshallingContext marshal; OrderedHashSet comCallees; ComMethodLoweringContext(IRModule* inModule) : InstPassBase(inModule) {} void processComCall(IRCall* comCall) { IRBuilder builder(module); builder.setInsertBefore(comCall); auto callee = as(comCall->getCallee()); SLANG_ASSERT(callee); IRLookupWitnessMethod* innerMostCallee = callee; while (innerMostCallee->getOperand(0)->getOp() == kIROp_LookupWitness) { innerMostCallee = as(innerMostCallee->getOperand(0)); } if (callee != innerMostCallee) { callee = (IRLookupWitnessMethod*)builder.emitLookupInterfaceMethodInst( callee->getDataType(), innerMostCallee->getWitnessTable(), callee->getRequirementKey()); } comCallees.add(callee); auto calleeType = as(callee->getDataType()); SLANG_ASSERT(calleeType); auto nativeFuncType = marshal.getNativeFuncType(builder, calleeType); ShortList args; for (UInt i = 0; i < comCall->getArgCount(); i++) args.add(comCall->getArg(i)); auto currentBlock = builder.getBlock(); auto nextInst = comCall->getNextInst(); auto newResult = marshal.marshalNativeCall( builder, calleeType, nativeFuncType, callee, args.getCount(), args.getArrayView().getBuffer()); comCall->replaceUsesWith(newResult); if (builder.getBlock() != currentBlock) { // `marshalNativeCall` may have replaced the original call with branch insts. // If this is the case, we need to move all insts after the original call in the original // basic block to the new basic block. while (nextInst) { auto next = nextInst->getNextInst(); nextInst->removeFromParent(); nextInst->insertAtEnd(builder.getBlock()); nextInst = next; } } comCall->removeAndDeallocate(); } void processCall(IRCall* inst) { auto funcValue = inst->getOperand(0); // Detect if this is a call into a COM interface method. if (funcValue->getOp() == kIROp_LookupWitness) { const auto operand0TypeOp = funcValue->getOperand(0)->getDataType(); if (auto tableType = as(operand0TypeOp)) { if (tableType->getConformanceType()->findDecoration()) { processComCall(inst); return; } } } } void processInterfaceType(IRInterfaceType* interfaceType) { if (!interfaceType->findDecoration()) return; IRBuilder builder(module); for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { auto entry = as(interfaceType->getOperand(i)); if (!entry) continue; if (auto funcType = as(entry->getRequirementVal())) { builder.setInsertBefore(funcType); entry->setRequirementVal(marshal.getNativeFuncType(builder, funcType)); } } } void processWitnessTable(IRWitnessTable* witnessTable) { auto interfaceType = as(witnessTable->getConformanceType()); if (!interfaceType) return; if (!interfaceType->findDecoration()) return; auto interfaceReqDict = buildInterfaceRequirementDict(interfaceType); IRBuilder builder(module); NativeCallMarshallingContext marshalContext; marshalContext.diagnosticSink = diagnosticSink; for (auto entry : witnessTable->getEntries()) { IRInst* interfaceRequirement = nullptr; if (!interfaceReqDict.tryGetValue(entry->getRequirementKey(), interfaceRequirement)) continue; auto implFunc = as(entry->getSatisfyingVal()); if (!implFunc) continue; // If the function already has the same signature as the lowered COM interface method, // we don't need to do anything. if (isTypeEqual(entry->getSatisfyingVal()->getDataType(), (IRType*)interfaceRequirement)) continue; // Now we need to generate a wrapper function that calls into the original one. auto nativeFunc = marshalContext.generateDLLExportWrapperFunc(builder, implFunc); entry->setOperand(1, nativeFunc); } auto classType = witnessTable->getConcreteType(); builder.addCOMWitnessDecoration(classType, witnessTable); } void processModule() { // Translate all Calls to interface methods. processInstsOfType(kIROp_Call, [this](IRCall* inst) { processCall(inst); }); // Update functypes of com callees. for (auto callee : comCallees) { IRBuilder builder(module); builder.setInsertBefore(callee); auto nativeType = marshal.getNativeFuncType(builder, as(callee->getDataType())); callee->setFullType(nativeType); } // Update func types of COM interfaces. processInstsOfType(kIROp_InterfaceType, [this](IRInterfaceType* inst) { processInterfaceType(inst); }); // Update witness tables of classes that implement COM interfaces. // Generate native-to-managed wrappers for each witness table entry. processInstsOfType(kIROp_WitnessTable, [this](IRWitnessTable* table) { processWitnessTable(table); }); } }; void lowerComMethods(IRModule* module, DiagnosticSink* sink) { ComMethodLoweringContext context(module); context.diagnosticSink = sink; context.marshal.diagnosticSink = sink; return context.processModule(); } }