summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-com-methods.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-07-25 10:08:28 -0700
committerGitHub <noreply@github.com>2022-07-25 10:08:28 -0700
commit9566e8af25f87ad034a984db9d847942e454a180 (patch)
tree2f295bf2bf60c39fd35b6b634b903d574b4ca99e /source/slang/slang-ir-lower-com-methods.cpp
parent70147fc7ba6abe0b669363ed5adfd8d4d9545c3f (diff)
Allow `class` to implement COM interface, [DLLExport] (#2338)
* Allow `class` to implement COM interface, [DLLExport] * Fix [COM] usage in tests and examples with UUIDs. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-lower-com-methods.cpp')
-rw-r--r--source/slang/slang-ir-lower-com-methods.cpp35
1 files changed, 35 insertions, 0 deletions
diff --git a/source/slang/slang-ir-lower-com-methods.cpp b/source/slang/slang-ir-lower-com-methods.cpp
index 6c3a3f289..6d5ddb261 100644
--- a/source/slang/slang-ir-lower-com-methods.cpp
+++ b/source/slang/slang-ir-lower-com-methods.cpp
@@ -6,6 +6,7 @@
#include "slang-ir-insts.h"
#include "slang-ir-marshal-native-call.h"
#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -102,6 +103,37 @@ struct ComMethodLoweringContext : public InstPassBase
}
}
+ void processWitnessTable(IRWitnessTable* witnessTable)
+ {
+ auto interfaceType = as<IRInterfaceType>(witnessTable->getConformanceType());
+ if (!interfaceType) return;
+ if (!interfaceType->findDecoration<IRComInterfaceDecoration>())
+ return;
+ auto interfaceReqDict = buildInterfaceRequirementDict(interfaceType);
+
+ IRBuilder builder(&sharedBuilderStorage);
+ NativeCallMarshallingContext marshalContext;
+ marshalContext.diagnosticSink = diagnosticSink;
+ for (auto entry : witnessTable->getEntries())
+ {
+ IRInst* interfaceRequirement = nullptr;
+ if (!interfaceReqDict.TryGetValue(entry->getRequirementKey(), interfaceRequirement))
+ continue;
+ auto implFunc = as<IRFunc>(entry->getSatisfyingVal());
+ if (!implFunc) continue;
+ // If the function already has the same signature as the lowered COM interface method,
+ // we don't need to do anything.
+ if (isTypeEqual(entry->getSatisfyingVal()->getDataType(), (IRType*)interfaceRequirement))
+ continue;
+ // Now we need to generate a wrapper function that calls into the original one.
+ auto nativeFunc = marshalContext.generateDLLExportWrapperFunc(builder, implFunc);
+ entry->setOperand(1, nativeFunc);
+ }
+
+ auto classType = witnessTable->getConcreteType();
+ builder.addCOMWitnessDecoration(classType, witnessTable);
+ }
+
void processModule()
{
sharedBuilderStorage.init(module);
@@ -124,6 +156,9 @@ struct ComMethodLoweringContext : public InstPassBase
// Update func types of COM interfaces.
processInstsOfType<IRInterfaceType>(kIROp_InterfaceType, [this](IRInterfaceType* inst) { processInterfaceType(inst); });
+ // Update witness tables of classes that implement COM interfaces.
+ // Generate native-to-managed wrappers for each witness table entry.
+ processInstsOfType<IRWitnessTable>(kIROp_WitnessTable, [this](IRWitnessTable* table) { processWitnessTable(table); });
}
};