summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-07-25 10:08:28 -0700
committerGitHub <noreply@github.com>2022-07-25 10:08:28 -0700
commit9566e8af25f87ad034a984db9d847942e454a180 (patch)
tree2f295bf2bf60c39fd35b6b634b903d574b4ca99e /source
parent70147fc7ba6abe0b669363ed5adfd8d4d9545c3f (diff)
Allow `class` to implement COM interface, [DLLExport] (#2338)
* Allow `class` to implement COM interface, [DLLExport] * Fix [COM] usage in tests and examples with UUIDs. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang-rt/slang-rt.h10
-rw-r--r--source/slang/core.meta.slang5
-rw-r--r--source/slang/slang-ast-modifier.h9
-rw-r--r--source/slang/slang-check-modifier.cpp44
-rw-r--r--source/slang/slang-check-overload.cpp2
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-emit-c-like.cpp156
-rw-r--r--source/slang/slang-emit-c-like.h4
-rw-r--r--source/slang/slang-emit-cpp.cpp100
-rw-r--r--source/slang/slang-emit-cpp.h1
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-dll-export.cpp72
-rw-r--r--source/slang/slang-ir-dll-export.h10
-rw-r--r--source/slang/slang-ir-dll-import.cpp16
-rw-r--r--source/slang/slang-ir-dll-import.h1
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h39
-rw-r--r--source/slang/slang-ir-lower-com-methods.cpp35
-rw-r--r--source/slang/slang-ir-lower-existential.cpp7
-rw-r--r--source/slang/slang-ir-marshal-native-call.cpp117
-rw-r--r--source/slang/slang-ir-marshal-native-call.h9
-rw-r--r--source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp8
-rw-r--r--source/slang/slang-ir-strip-witness-tables.cpp5
-rw-r--r--source/slang/slang-ir-util.cpp12
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--source/slang/slang-ir.cpp4
-rw-r--r--source/slang/slang-lower-to-ir.cpp24
27 files changed, 657 insertions, 45 deletions
diff --git a/source/slang-rt/slang-rt.h b/source/slang-rt/slang-rt.h
index b6f397c72..3941b3acc 100644
--- a/source/slang-rt/slang-rt.h
+++ b/source/slang-rt/slang-rt.h
@@ -4,6 +4,7 @@
#include "../core/slang-string.h"
#include "../core/slang-smart-pointer.h"
+#include "../core/slang-com-object.h"
#ifdef SLANG_RT_DYNAMIC_EXPORT
# define SLANG_RT_API SLANG_DLL_EXPORT
@@ -11,6 +12,15 @@
# define SLANG_RT_API
#endif
+#if defined(_MSC_VER)
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
+#else
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
+//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
+#endif
+
+#define SLANG_PRELUDE_EXPORT extern "C" SLANG_PRELUDE_SHARED_LIB_EXPORT
+
extern "C"
{
SLANG_RT_API void SLANG_MCALL _slang_rt_abort(Slang::String errorMessage);
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index e1f7503a8..fb39b43f2 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2187,8 +2187,11 @@ attribute_syntax [__unsafeForceInlineEarly] : UnsafeForceInlineEarlyAttribute;
__attributeTarget(FuncDecl)
attribute_syntax [DllImport(modulePath: String)] : DllImportAttribute;
+__attributeTarget(FuncDecl)
+attribute_syntax [DllExport] : DllExportAttribute;
+
__attributeTarget(InterfaceDecl)
-attribute_syntax [COM] : ComInterfaceAttribute;
+attribute_syntax [COM(guid: String)] : ComInterfaceAttribute;
// Inheritance Control
__attributeTarget(AggTypeDecl)
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 62a32045b..f66867542 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -940,12 +940,21 @@ class DllImportAttribute : public Attribute
SLANG_AST_CLASS(DllImportAttribute)
String modulePath;
+
+ String functionName;
+};
+
+class DllExportAttribute : public Attribute
+{
+ SLANG_AST_CLASS(DllExportAttribute)
};
/// An attribute that marks an interface type as a COM interface declaration.
class ComInterfaceAttribute : public Attribute
{
SLANG_AST_CLASS(ComInterfaceAttribute)
+
+ String guid;
};
/// A `[__requiresNVAPI]` attribute indicates that the declaration being modifed
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 28164c126..16a1cae26 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -1,5 +1,6 @@
// slang-check-modifier.cpp
#include "slang-check-impl.h"
+#include "../core/slang-char-util.h"
// This file implements semantic checking behavior for
// modifiers.
@@ -568,7 +569,7 @@ namespace Slang
}
else if (auto dllImportAttr = as<DllImportAttribute>(attr))
{
- SLANG_ASSERT(attr->args.getCount() == 1);
+ SLANG_ASSERT(attr->args.getCount() == 1 || attr->args.getCount() == 2);
String libraryName;
if (!checkLiteralStringVal(dllImportAttr->args[0], &libraryName))
@@ -576,6 +577,13 @@ namespace Slang
return false;
}
dllImportAttr->modulePath = libraryName;
+
+ String functionName;
+ if (dllImportAttr->args.getCount() == 2 && !checkLiteralStringVal(dllImportAttr->args[1], &functionName))
+ {
+ return false;
+ }
+ dllImportAttr->functionName = functionName;
}
else if (auto rayPayloadAttr = as<VulkanRayPayloadAttribute>(attr))
{
@@ -606,6 +614,38 @@ namespace Slang
customJVPAttr->funcDeclRef = funcExpr;
}
+ else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ String guid;
+ if (!checkLiteralStringVal(comInterfaceAttr->args[0], &guid))
+ {
+ return false;
+ }
+ StringBuilder resultGUID;
+ for (auto ch : guid)
+ {
+ if (CharUtil::isHexDigit(ch))
+ {
+ resultGUID.appendChar(ch);
+ }
+ else if (ch == '-')
+ {
+ continue;
+ }
+ else
+ {
+ getSink()->diagnose(attr, Diagnostics::invalidGUID, guid);
+ return false;
+ }
+ }
+ comInterfaceAttr->guid = resultGUID.ToString();
+ if (comInterfaceAttr->guid.getLength() != 32)
+ {
+ getSink()->diagnose(attr, Diagnostics::invalidGUID, guid);
+ return false;
+ }
+ }
else
{
if(attr->args.getCount() == 0)
@@ -617,7 +657,7 @@ namespace Slang
{
// We should be special-casing the checking of any attribute
// with a non-zero number of arguments.
- SLANG_DIAGNOSE_UNEXPECTED(getSink(), attr, "unhandled attribute");
+ getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0);
return false;
}
}
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index be351a78f..7f2386eb4 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -86,7 +86,7 @@ namespace Slang
getSink()->diagnose(context.originalExpr, Diagnostics::newCanOnlyBeUsedToInitializeAClass);
return false;
}
- if (!isNewExpr && isClassType)
+ if (!isNewExpr && isClassType && context.originalExpr)
{
getSink()->diagnose(context.originalExpr, Diagnostics::classCanOnlyBeInitializedWithNew);
return false;
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 85f5f5fb0..2b12c3de4 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -322,6 +322,7 @@ DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user
DIAGNOSTIC(31121, Error, anyValueSizeExceedsLimit, "'anyValueSize' cannot exceed $0")
DIAGNOSTIC(31122, Error, associatedTypeNotAllowInComInterface, "associatedtype not allowed in a [COM] interface")
+DIAGNOSTIC(31123, Error, invalidGUID, "'$0' is not a valid GUID")
// Enums
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 1e0eb614d..c3ae31894 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -344,6 +344,50 @@ void CLikeSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
SLANG_UNUSED(witnessTable);
}
+void CLikeSourceEmitter::emitComWitnessTable(IRWitnessTable* witnessTable)
+{
+ auto classType = witnessTable->getConcreteType();
+ for (auto ent : witnessTable->getEntries())
+ {
+ auto req = ent->getRequirementKey();
+ auto func = as<IRFunc>(ent->getSatisfyingVal());
+ if (!func)
+ continue;
+
+ auto resultType = func->getResultType();
+
+ auto name = getName(classType) + "::" + getName(req);
+
+ emitFuncDecorations(func);
+
+ emitType(resultType, name);
+ m_writer->emit("(");
+ // Skip declaration of `this` parameter.
+ auto firstParam = func->getFirstParam()->getNextParam();
+ for (auto pp = firstParam; pp; pp = pp->getNextParam())
+ {
+ if (pp != firstParam)
+ m_writer->emit(", ");
+
+ emitSimpleFuncParamImpl(pp);
+ }
+ m_writer->emit(")");
+ m_writer->emit("\n{\n");
+ m_writer->indent();
+
+ // emit definition for `this` param.
+ m_writer->emit("auto ");
+ m_writer->emit(getName(func->getFirstParam()));
+ m_writer->emit(" = this;\n");
+
+ // Need to emit the operations in the blocks of the function
+ emitFunctionBody(func);
+
+ m_writer->dedent();
+ m_writer->emit("}\n\n");
+ }
+}
+
void CLikeSourceEmitter::emitInterface(IRInterfaceType* interfaceType)
{
SLANG_UNUSED(interfaceType);
@@ -1869,7 +1913,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
{
auto prec = getInfo(EmitOp::Postfix);
needClose = maybeEmitParens(outerPrec, prec);
- emitDereferenceOperand(inst->getOperand(0), leftSide(outerPrec, prec));
+ emitOperand(inst->getOperand(0), leftSide(outerPrec, prec));
m_writer->emit(".detach()");
break;
}
@@ -2630,6 +2674,12 @@ void CLikeSourceEmitter::emitParamTypeImpl(IRType* type, String const& name)
void CLikeSourceEmitter::emitFuncDecl(IRFunc* func)
{
+ auto name = getName(func);
+ emitFuncDecl(func, name);
+}
+
+void CLikeSourceEmitter::emitFuncDecl(IRFunc* func, const String& name)
+{
// We don't want to emit declarations for operations
// that only appear in the IR as stand-ins for built-in
// operations on that target.
@@ -2652,8 +2702,6 @@ void CLikeSourceEmitter::emitFuncDecl(IRFunc* func)
auto funcType = func->getDataType();
auto resultType = func->getResultType();
- auto name = getName(func);
-
emitFuncDecorations(func);
emitType(resultType, name);
@@ -2790,17 +2838,59 @@ void CLikeSourceEmitter::emitClass(IRClassType* classType)
{
return;
}
-
+ List<IRWitnessTable*> comWitnessTables;
+ for (auto child : classType->getDecorations())
+ {
+ if (auto decoration = as<IRCOMWitnessDecoration>(child))
+ {
+ comWitnessTables.add(cast<IRWitnessTable>(decoration->getWitnessTable()));
+ }
+ }
m_writer->emit("class ");
emitPostKeywordTypeAttributes(classType);
m_writer->emit(getName(classType));
- m_writer->emit(" : public RefObject");
+ if (comWitnessTables.getCount() == 0)
+ {
+ m_writer->emit(" : public RefObject");
+ }
+ else
+ {
+ m_writer->emit(" : public ComObject");
+ for (auto wt : comWitnessTables)
+ {
+ m_writer->emit(", public ");
+ m_writer->emit(getName(wt->getConformanceType()));
+ }
+ }
m_writer->emit("\n{\n");
m_writer->emit("public:\n");
m_writer->indent();
+ if (comWitnessTables.getCount())
+ {
+ m_writer->emit("SLANG_COM_OBJECT_IUNKNOWN_ALL\n");
+ m_writer->emit("void* getInterface(const Guid & uuid)\n{\n");
+ m_writer->indent();
+ m_writer->emit("if (uuid == ISlangUnknown::getTypeGuid()) return static_cast<ISlangUnknown*>(this);\n");
+ for (auto wt : comWitnessTables)
+ {
+ auto interfaceName = getName(wt->getConformanceType());
+ m_writer->emit("if (uuid == ");
+ m_writer->emit(interfaceName);
+ m_writer->emit("::getTypeGuid())\n");
+ m_writer->indent();
+ m_writer->emit("return static_cast<");
+ m_writer->emit(interfaceName);
+ m_writer->emit("*>(this);\n");
+ m_writer->dedent();
+ }
+ m_writer->emit("return nullptr;\n");
+ m_writer->dedent();
+ m_writer->emit("}\n");
+ }
+
for (auto ff : classType->getFields())
{
auto fieldKey = ff->getKey();
@@ -2818,6 +2908,28 @@ void CLikeSourceEmitter::emitClass(IRClassType* classType)
m_writer->emit(";\n");
}
+ // Emit COM method declarations.
+ for (auto wt : comWitnessTables)
+ {
+ for (auto wtEntry : wt->getChildren())
+ {
+ auto req = as<IRWitnessTableEntry>(wtEntry);
+ if (!req) continue;
+ auto func = as<IRFunc>(req->getSatisfyingVal());
+ if (!func) continue;
+ m_writer->emit("virtual SLANG_NO_THROW ");
+ emitType(func->getResultType(), "SLANG_MCALL " + getName(req->getRequirementKey()));
+ m_writer->emit("(");
+ auto param = func->getFirstParam();
+ param = param->getNextParam();
+ for (; param; param = param->getNextParam())
+ {
+ emitParamType(param->getFullType(), getName(param));
+ }
+ m_writer->emit(") override;\n");
+ }
+ }
+
m_writer->dedent();
m_writer->emit("};\n\n");
}
@@ -3135,6 +3247,32 @@ void CLikeSourceEmitter::emitGlobalInst(IRInst* inst)
emitGlobalInstImpl(inst);
}
+static bool _shouldSkipFuncEmit(IRInst* func)
+{
+ // Skip emitting a func if it is a COM interface wrapper implementation and used
+ // only by the witness table. We will emit this func differently than normal funcs
+ // and this is handled by `emitComWitnessTable`.
+
+ if (func->hasMoreThanOneUse()) return false;
+ if (func->firstUse)
+ {
+ if (auto entry = as<IRWitnessTableEntry>(func->firstUse->getUser()))
+ {
+ if (auto table = as<IRWitnessTable>(entry->getParent()))
+ {
+ if (auto interfaceType = table->getConformanceType())
+ {
+ if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ return false;
+}
+
void CLikeSourceEmitter::emitGlobalInstImpl(IRInst* inst)
{
m_writer->advanceToSourceLocation(inst->sourceLoc);
@@ -3152,7 +3290,10 @@ void CLikeSourceEmitter::emitGlobalInstImpl(IRInst* inst)
break;
case kIROp_Func:
- emitFunc((IRFunc*) inst);
+ if (!_shouldSkipFuncEmit(inst))
+ {
+ emitFunc((IRFunc*) inst);
+ }
break;
case kIROp_GlobalVar:
@@ -3219,6 +3360,9 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I
case kIROp_InterfaceType:
requiredLevel = EmitAction::ForwardDeclaration;
break;
+ case kIROp_COMWitnessDecoration:
+ requiredLevel = EmitAction::ForwardDeclaration;
+ break;
default:
break;
}
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index 825223278..d85e336e0 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -339,6 +339,8 @@ public:
void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); }
void emitFuncDecl(IRFunc* func);
+ void emitFuncDecl(IRFunc* func, const String& name);
+
IREntryPointLayout* getEntryPointLayout(IRFunc* func);
@@ -472,6 +474,8 @@ public:
virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) = 0;
virtual void emitWitnessTable(IRWitnessTable* witnessTable);
+ void emitComWitnessTable(IRWitnessTable* witnessTable);
+
virtual void emitInterface(IRInterfaceType* interfaceType);
virtual void emitRTTIObject(IRRTTIObject* rttiObject);
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 96b42af43..5df13cbde 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1677,7 +1677,11 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
if (isBuiltin(interfaceType))
return;
- auto witnessTableItems = witnessTable->getChildren();
+ if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ {
+ pendingWitnessTableDefinitions.add(witnessTable);
+ return;
+ }
// Declare a global variable for the witness table.
m_writer->emit("extern \"C\" { ");
@@ -1701,6 +1705,11 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
for (auto witnessTable : pendingWitnessTableDefinitions)
{
auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
+ if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ {
+ emitComWitnessTable(witnessTable);
+ continue;
+ }
List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable);
m_writer->emit("extern \"C\"\n{\n");
m_writer->indent();
@@ -1750,6 +1759,12 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
void CPPSourceEmitter::emitComInterface(IRInterfaceType* interfaceType)
{
+ auto comDecoration = interfaceType->findDecoration<IRComInterfaceDecoration>();
+ auto guidInst = as<IRStringLit>(comDecoration->getOperand(0));
+ SLANG_RELEASE_ASSERT(guidInst);
+ auto guid = guidInst->getStringSlice();
+ SLANG_RELEASE_ASSERT(guid.getLength() == 32);
+
m_writer->emit("struct ");
emitSimpleType(interfaceType);
m_writer->emit(" : ");
@@ -1779,6 +1794,23 @@ void CPPSourceEmitter::emitComInterface(IRInterfaceType* interfaceType)
// Emit methods.
m_writer->emit("\n{\n");
m_writer->indent();
+ // Emit GUID.
+ m_writer->emit("SLANG_COM_INTERFACE(0x");
+ m_writer->emit(guid.subString(0, 8));
+ m_writer->emit(", 0x");
+ m_writer->emit(guid.subString(8, 4));
+ m_writer->emit(", 0x");
+ m_writer->emit(guid.subString(12, 4));
+ m_writer->emit(", { ");
+ for (UInt i = 0; i < 8; i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ m_writer->emit("0x");
+ m_writer->emit(guid.subString(16 + i * 2, 2));
+ }
+ m_writer->emit(" })\n");
+
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
{
auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
@@ -2476,6 +2508,33 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
}
return true;
}
+ case kIROp_MakeExistential:
+ case kIROp_MakeExistentialWithRTTI:
+ {
+ auto rsType = cast<IRComPtrType>(inst->getDataType());
+ m_writer->emit("ComPtr<");
+ m_writer->emit(getName(rsType->getOperand(0)));
+ m_writer->emit(">(");
+ m_writer->emit("static_cast<");
+ m_writer->emit(getName(rsType->getOperand(0)));
+ m_writer->emit("*>(");
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), prec));
+ m_writer->emit(".Ptr()");
+ m_writer->emit("))");
+ return true;
+ }
+ case kIROp_GetValueFromBoundInterface:
+ {
+ m_writer->emit("static_cast<");
+ m_writer->emit(getName(inst->getFullType()));
+ m_writer->emit("*>(");
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), prec));
+ m_writer->emit(".get()");
+ m_writer->emit(")");
+ return true;
+ }
}
}
@@ -2604,9 +2663,10 @@ void CPPSourceEmitter::emitVarDecorationsImpl(IRInst* inst)
Super::emitVarDecorationsImpl(inst);
}
-
-void CPPSourceEmitter::_maybeEmitExportLike(IRInst* inst)
+void CPPSourceEmitter::_getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC)
{
+ outIsExport = false;
+ outIsExternC = false;
// Specially handle export, as we don't want to emit it multiple times
if (getTargetReq()->isWholeProgramRequest())
{
@@ -2619,34 +2679,38 @@ void CPPSourceEmitter::_maybeEmitExportLike(IRInst* inst)
}
}
- bool isExternC = false;
- bool isExported = false;
-
// If public/export made it externally visible
for (auto decoration : inst->getDecorations())
{
const auto op = decoration->getOp();
if (op == kIROp_ExternCppDecoration)
{
- isExternC = true;
+ outIsExternC = true;
}
else if (op == kIROp_PublicDecoration ||
op == kIROp_HLSLExportDecoration)
{
- isExported = true;
+ outIsExport = true;
}
}
+ }
+}
- // TODO(JS): Currently export *also* implies it's extern "C" and we can't list twice
- if (isExported)
- {
- m_writer->emit("SLANG_PRELUDE_EXPORT\n");
- }
- else if (isExternC)
- {
- // It's name is not manged.
- m_writer->emit("extern \"C\"\n");
- }
+void CPPSourceEmitter::_maybeEmitExportLike(IRInst* inst)
+{
+ bool isExternC = false;
+ bool isExported = false;
+ _getExportStyle(inst, isExternC, isExported);
+
+ // TODO(JS): Currently export *also* implies it's extern "C" and we can't list twice
+ if (isExported)
+ {
+ m_writer->emit("SLANG_PRELUDE_EXPORT\n");
+ }
+ else if (isExternC)
+ {
+ // It's name is not manged.
+ m_writer->emit("extern \"C\"\n");
}
}
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index 6199c33f2..c5b9f3d9c 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -133,6 +133,7 @@ protected:
void _emitWitnessTableDefinitions();
/// Maybe emits 'export' (such that visible outside binary/dll) and `extern "C"` naming
+ void _getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC);
void _maybeEmitExportLike(IRInst* inst);
HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 9b7fcfa38..1ded8668f 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -9,6 +9,7 @@
#include "slang-ir-byte-address-legalize.h"
#include "slang-ir-collect-global-uniforms.h"
#include "slang-ir-dce.h"
+#include "slang-ir-dll-export.h"
#include "slang-ir-dll-import.h"
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-entry-point-uniforms.h"
@@ -211,6 +212,7 @@ Result linkAndOptimizeIR(
{
lowerComInterfaces(irModule, artifactDesc.style, sink);
generateDllImportFuncs(irModule, sink);
+ generateDllExportFuncs(irModule, sink);
break;
}
default: break;
diff --git a/source/slang/slang-ir-dll-export.cpp b/source/slang/slang-ir-dll-export.cpp
new file mode 100644
index 000000000..a8b464b43
--- /dev/null
+++ b/source/slang/slang-ir-dll-export.cpp
@@ -0,0 +1,72 @@
+// slang-ir-dll-export.cpp
+#include "slang-ir-dll-export.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-marshal-native-call.h"
+
+namespace Slang
+{
+
+struct DllExportContext
+{
+ IRModule* module;
+ DiagnosticSink* diagnosticSink;
+
+ SharedIRBuilder sharedBuilder;
+
+ void processFunc(IRFunc* func, IRDllExportDecoration* dllExportDecoration)
+ {
+ NativeCallMarshallingContext marshalContext;
+ marshalContext.diagnosticSink = diagnosticSink;
+ IRBuilder builder(sharedBuilder);
+ auto wrapper = marshalContext.generateDLLExportWrapperFunc(builder, func);
+ dllExportDecoration->removeFromParent();
+ dllExportDecoration->insertAtStart(wrapper);
+ builder.addNameHintDecoration(wrapper, dllExportDecoration->getFunctionName());
+ builder.addExternCppDecoration(wrapper, dllExportDecoration->getFunctionName());
+ builder.addPublicDecoration(wrapper);
+ builder.addKeepAliveDecoration(wrapper);
+ builder.addHLSLExportDecoration(wrapper);
+ if (auto oldPublicDecoration = func->findDecoration<IRPublicDecoration>())
+ {
+ oldPublicDecoration->removeFromParent();
+ }
+ }
+
+ void processModule()
+ {
+ struct Candidate { IRFunc* func; IRDllExportDecoration* exportDecoration; };
+ List<Candidate> candidates;
+ for (auto childFunc : module->getGlobalInsts())
+ {
+ switch(childFunc->getOp())
+ {
+ case kIROp_Func:
+ if (auto dllExportDecoration = childFunc->findDecoration<IRDllExportDecoration>())
+ {
+ candidates.add(Candidate{ as<IRFunc>(childFunc), dllExportDecoration });
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ for (auto candidate : candidates)
+ {
+ processFunc(candidate.func, candidate.exportDecoration);
+ }
+ }
+};
+
+void generateDllExportFuncs(IRModule* module, DiagnosticSink* sink)
+{
+ DllExportContext context;
+ context.module = module;
+ context.diagnosticSink = sink;
+ context.sharedBuilder.init(module);
+ return context.processModule();
+}
+
+}
diff --git a/source/slang/slang-ir-dll-export.h b/source/slang/slang-ir-dll-export.h
new file mode 100644
index 000000000..eb7c2792e
--- /dev/null
+++ b/source/slang/slang-ir-dll-export.h
@@ -0,0 +1,10 @@
+// slang-ir-dll-export.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ class DiagnosticSink;
+ /// Generate wrappers for functions marked as [DllExport].
+ void generateDllExportFuncs(IRModule* module, DiagnosticSink* sink);
+}
diff --git a/source/slang/slang-ir-dll-import.cpp b/source/slang/slang-ir-dll-import.cpp
index b123dfb03..43ed9a102 100644
--- a/source/slang/slang-ir-dll-import.cpp
+++ b/source/slang/slang-ir-dll-import.cpp
@@ -123,11 +123,19 @@ struct DllImportContext
builder.emitIf(isUninitialized, trueBlock, afterBlock);
builder.setInsertInto(trueBlock);
- auto modulePtr = builder.emitCallInst(
- builder.getPtrType(builder.getVoidType()),
- getLoadDllFunc(),
- builder.getStringValue(dllImportDecoration->getLibraryName()));
+ IRInst* modulePtr;
+ if (dllImportDecoration->getLibraryName() == "")
+ {
+ modulePtr = builder.getIntValue(builder.getIntType(), 0);
+ }
+ else
+ {
+ modulePtr = builder.emitCallInst(
+ builder.getPtrType(builder.getVoidType()),
+ getLoadDllFunc(),
+ builder.getStringValue(dllImportDecoration->getLibraryName()));
+ }
IRInst* loadDllFuncArgs[] = {
modulePtr, builder.getStringValue(dllImportDecoration->getFunctionName())};
auto loadedNativeFuncPtr = builder.emitCallInst(
diff --git a/source/slang/slang-ir-dll-import.h b/source/slang/slang-ir-dll-import.h
index c330f803f..d2dc1a9a3 100644
--- a/source/slang/slang-ir-dll-import.h
+++ b/source/slang/slang-ir-dll-import.h
@@ -7,5 +7,4 @@ namespace Slang
class DiagnosticSink;
/// Generate implementations for functions marked as [DllImport].
void generateDllImportFuncs(IRModule* module, DiagnosticSink* sink);
-
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index a91c21434..8a62e34d4 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -645,6 +645,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// An dllImport decoration marks a function as imported from a DLL. Slang will generate dynamic function loading logic to use this function at runtime.
INST(DllImportDecoration, dllImport, 2, 0)
+ /// An dllExport decoration marks a function as an export symbol. Slang will generate a native wrapper function that is exported to DLL.
+ INST(DllExportDecoration, dllExport, 1, 0)
/// Marks an interface as a COM interface declaration.
INST(ComInterfaceDecoration, COMInterface, 0, 0)
@@ -690,6 +692,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// generated derivative function.
INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0)
+ /// Marks a class type as a COM interface implementation, which enables
+ /// the witness table to be easily picked up by emit.
+ INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0)
+
/// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index b8cce4c17..f5d5a86ac 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -186,6 +186,17 @@ struct IRComInterfaceDecoration : IRDecoration
IR_LEAF_ISA(ComInterfaceDecoration)
};
+struct IRCOMWitnessDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_COMWitnessDecoration
+ };
+ IR_LEAF_ISA(COMWitnessDecoration)
+
+ IRInst* getWitnessTable() { return getOperand(0); }
+};
+
/// A decoration on `IRParam`s that represent generic parameters,
/// marking the interface type that the generic parameter conforms to.
/// A generic parameter can have more than one `IRTypeConstraintDecoration`s
@@ -457,6 +468,18 @@ struct IRDllImportDecoration : IRDecoration
UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); }
};
+struct IRDllExportDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_DllExportDecoration
+ };
+ IR_LEAF_ISA(DllExportDecoration)
+
+ IRStringLit* getFunctionNameOperand() { return cast<IRStringLit>(getOperand(0)); }
+ UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); }
+};
+
struct IRFormatDecoration : IRDecoration
{
enum { kOp = kIROp_FormatDecoration };
@@ -2521,7 +2544,7 @@ public:
IRInst* emitManagedPtrAttach(IRInst* managedPtrVar, IRInst* value);
- IRInst* emitManagedPtrDetach(IRInst* managedPtrVar);
+ IRInst* emitManagedPtrDetach(IRType* type, IRInst* managedPtrVal);
IRInst* emitGetNativePtr(IRInst* value);
@@ -3052,11 +3075,21 @@ public:
addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn);
}
+ void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
+ {
+ addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1);
+ }
+
void addDllImportDecoration(IRInst* value, UnownedStringSlice const& libraryName, UnownedStringSlice const& functionName)
{
addDecoration(value, kIROp_DllImportDecoration, getStringValue(libraryName), getStringValue(functionName));
}
+ void addDllExportDecoration(IRInst* value, UnownedStringSlice const& functionName)
+ {
+ addDecoration(value, kIROp_DllExportDecoration, getStringValue(functionName));
+ }
+
void addEntryPointDecoration(IRInst* value, Profile profile, UnownedStringSlice const& name, UnownedStringSlice const& moduleName)
{
IRInst* operands[] = { getIntValue(getIntType(), profile.raw), getStringValue(name), getStringValue(moduleName) };
@@ -3117,9 +3150,9 @@ public:
addDecoration(inst, kIROp_AnyValueSizeDecoration, getIntValue(getIntType(), value));
}
- void addComInterfaceDecoration(IRInst* inst)
+ void addComInterfaceDecoration(IRInst* inst, UnownedStringSlice guid)
{
- addDecoration(inst, kIROp_ComInterfaceDecoration);
+ addDecoration(inst, kIROp_ComInterfaceDecoration, getStringValue(guid));
}
void addTypeConstraintDecoration(IRInst* inst, IRInst* constraintType)
diff --git a/source/slang/slang-ir-lower-com-methods.cpp b/source/slang/slang-ir-lower-com-methods.cpp
index 6c3a3f289..6d5ddb261 100644
--- a/source/slang/slang-ir-lower-com-methods.cpp
+++ b/source/slang/slang-ir-lower-com-methods.cpp
@@ -6,6 +6,7 @@
#include "slang-ir-insts.h"
#include "slang-ir-marshal-native-call.h"
#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -102,6 +103,37 @@ struct ComMethodLoweringContext : public InstPassBase
}
}
+ void processWitnessTable(IRWitnessTable* witnessTable)
+ {
+ auto interfaceType = as<IRInterfaceType>(witnessTable->getConformanceType());
+ if (!interfaceType) return;
+ if (!interfaceType->findDecoration<IRComInterfaceDecoration>())
+ return;
+ auto interfaceReqDict = buildInterfaceRequirementDict(interfaceType);
+
+ IRBuilder builder(&sharedBuilderStorage);
+ NativeCallMarshallingContext marshalContext;
+ marshalContext.diagnosticSink = diagnosticSink;
+ for (auto entry : witnessTable->getEntries())
+ {
+ IRInst* interfaceRequirement = nullptr;
+ if (!interfaceReqDict.TryGetValue(entry->getRequirementKey(), interfaceRequirement))
+ continue;
+ auto implFunc = as<IRFunc>(entry->getSatisfyingVal());
+ if (!implFunc) continue;
+ // If the function already has the same signature as the lowered COM interface method,
+ // we don't need to do anything.
+ if (isTypeEqual(entry->getSatisfyingVal()->getDataType(), (IRType*)interfaceRequirement))
+ continue;
+ // Now we need to generate a wrapper function that calls into the original one.
+ auto nativeFunc = marshalContext.generateDLLExportWrapperFunc(builder, implFunc);
+ entry->setOperand(1, nativeFunc);
+ }
+
+ auto classType = witnessTable->getConcreteType();
+ builder.addCOMWitnessDecoration(classType, witnessTable);
+ }
+
void processModule()
{
sharedBuilderStorage.init(module);
@@ -124,6 +156,9 @@ struct ComMethodLoweringContext : public InstPassBase
// Update func types of COM interfaces.
processInstsOfType<IRInterfaceType>(kIROp_InterfaceType, [this](IRInterfaceType* inst) { processInterfaceType(inst); });
+ // Update witness tables of classes that implement COM interfaces.
+ // Generate native-to-managed wrappers for each witness table entry.
+ processInstsOfType<IRWitnessTable>(kIROp_WitnessTable, [this](IRWitnessTable* table) { processWitnessTable(table); });
}
};
diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp
index b0d9e6f2f..cfdacc7ac 100644
--- a/source/slang/slang-ir-lower-existential.cpp
+++ b/source/slang/slang-ir-lower-existential.cpp
@@ -24,6 +24,8 @@ namespace Slang
auto valueType = sharedContext->lowerType(builder, value->getDataType());
auto witnessTableType = cast<IRWitnessTableTypeBase>(inst->getWitnessTable()->getDataType());
auto interfaceType = witnessTableType->getConformanceType();
+ if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ return;
auto witnessTableIdType = builder->getWitnessTableIDType((IRType*)interfaceType);
auto anyValueSize = sharedContext->getInterfaceAnyValueSize(interfaceType, inst->sourceLoc);
auto anyValueType = builder->getAnyValueType(anyValueSize);
@@ -139,7 +141,10 @@ namespace Slang
IRBuilder builderStorage(sharedContext->sharedBuilderStorage);
auto builder = &builderStorage;
builder->setInsertBefore(inst);
-
+ if (inst->getDataType()->getOp() == kIROp_ClassType)
+ {
+ return;
+ }
// A value of interface will lower as a tuple, and
// the third element of that tuple represents the
// concrete value that was put into the existential.
diff --git a/source/slang/slang-ir-marshal-native-call.cpp b/source/slang/slang-ir-marshal-native-call.cpp
index 8e342cb26..8a97bfd3e 100644
--- a/source/slang/slang-ir-marshal-native-call.cpp
+++ b/source/slang/slang-ir-marshal-native-call.cpp
@@ -123,6 +123,123 @@ namespace Slang
}
}
+ void NativeCallMarshallingContext::marshalManagedValueToNativeResultValue(
+ IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args)
+ {
+ switch (originalArg->getDataType()->getOp())
+ {
+ case kIROp_InOutType:
+ case kIROp_RefType:
+ SLANG_UNREACHABLE("out and ref types should be handled before reaching here.");
+ break;
+ case kIROp_StringType:
+ {
+ diagnosticSink->diagnose(originalArg, Diagnostics::unimplemented, "marshal string to native return value");
+ }
+ break;
+ case kIROp_ClassType:
+ {
+ diagnosticSink->diagnose(originalArg, Diagnostics::unimplemented, "marshal class to native return value");
+ }
+ break;
+ case kIROp_InterfaceType:
+ {
+ auto nativePtr = builder.emitManagedPtrDetach(builder.getNativePtrType(originalArg->getDataType()), originalArg);
+ args.add(nativePtr);
+ }
+ break;
+ case kIROp_ComPtrType:
+ {
+ auto nativePtr = builder.emitManagedPtrDetach(
+ builder.getNativePtrType(
+ (IRType*)cast<IRComPtrType>(originalArg->getDataType())->getOperand(0)),
+ originalArg);
+ args.add(nativePtr);
+ }
+ break;
+ default:
+ args.add(originalArg);
+ break;
+ }
+ }
+
+ IRInst* NativeCallMarshallingContext::marshalNativeArgToManagedArg(
+ IRBuilder& builder, const List<IRInst*>& args, Index& consumeIndex, IRType* expectedManagedType)
+ {
+ // For now, all managed values maps to one native value, so we just call `marshalNativeValueToManagedValue`.
+ // This function can be extended in the future to support things like `List` that maps to more than one
+ // native args.
+ SLANG_UNUSED(expectedManagedType);
+ auto managedVal = marshalNativeValueToManagedValue(builder, args[consumeIndex]);
+ consumeIndex++;
+ return managedVal;
+ }
+
+ IRFunc* NativeCallMarshallingContext::generateDLLExportWrapperFunc(IRBuilder& builder, IRFunc* originalFunc)
+ {
+ builder.setInsertBefore(originalFunc);
+ auto funcType = getNativeFuncType(builder, originalFunc->getDataType());
+ auto newFunc = builder.createFunc();
+ newFunc->setFullType(funcType);
+ builder.setInsertInto(newFunc);
+ builder.emitBlock();
+ List<IRInst*> params;
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto paramType = funcType->getParamType(i);
+ params.add(builder.emitParam(paramType));
+ }
+ List<IRInst*> args;
+ Index nativeParamConsumeIndex = 0;
+ for (UInt i = 0; i < originalFunc->getParamCount(); i++)
+ {
+ auto managedParamType = originalFunc->getParamType(i);
+ auto managedArg = marshalNativeArgToManagedArg(builder, params, nativeParamConsumeIndex, managedParamType);
+ args.add(managedArg);
+ }
+ auto originalReturnType = originalFunc->getResultType();
+ auto callInst = builder.emitCallInst(originalReturnType, originalFunc, args);
+ if (auto resultType = as<IRResultType>(originalReturnType))
+ {
+ auto isResultError = builder.emitIsResultError(callInst);
+ IRBlock* trueBlock = nullptr;
+ IRBlock* falseBlock = nullptr;
+ IRBlock* afterBlock = nullptr;
+ builder.emitIfElseWithBlocks(isResultError, trueBlock, falseBlock, afterBlock);
+
+ builder.setInsertInto(trueBlock);
+ builder.emitReturn(builder.emitGetResultError(callInst));
+
+ builder.setInsertInto(falseBlock);
+ auto resultVal = builder.emitGetResultValue(callInst);
+ List<IRInst*> nativeVals;
+ marshalManagedValueToNativeResultValue(builder, resultVal, nativeVals);
+ for (Index i = 0; i < nativeVals.getCount(); i++)
+ {
+ SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount());
+ builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]);
+ nativeParamConsumeIndex++;
+ }
+ builder.emitReturn(builder.getIntValue(builder.getIntType(), 0));
+
+ builder.setInsertInto(afterBlock);
+ builder.emitUnreachable();
+ }
+ else
+ {
+ List<IRInst*> nativeVals;
+ marshalManagedValueToNativeResultValue(builder, callInst, nativeVals);
+ for (Index i = 1; i < nativeVals.getCount(); i++)
+ {
+ SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount());
+ builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]);
+ nativeParamConsumeIndex++;
+ }
+ builder.emitReturn(nativeVals[0]);
+ }
+ return newFunc;
+ }
+
IRInst* NativeCallMarshallingContext::marshalNativeCall(
IRBuilder& builder,
IRFuncType* originalFuncType,
diff --git a/source/slang/slang-ir-marshal-native-call.h b/source/slang/slang-ir-marshal-native-call.h
index bbc2078ea..e70d177ac 100644
--- a/source/slang/slang-ir-marshal-native-call.h
+++ b/source/slang/slang-ir-marshal-native-call.h
@@ -41,10 +41,19 @@ namespace Slang
void marshalRefManagedValueToNativeValue(
IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args);
+ // Marshal a managed value to a native value for input into a native functions.
void marshalManagedValueToNativeValue(
IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args);
+ // Marshal a managed value to a native value for the return value of a native function.
+ void marshalManagedValueToNativeResultValue(
+ IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args);
+
IRInst* marshalNativeValueToManagedValue(
IRBuilder& builder, IRInst* nativeValue);
+
+ IRInst* marshalNativeArgToManagedArg(IRBuilder& builder, const List<IRInst*>& args, Index& consumeIndex, IRType* expectedManagedType);
+
+ IRFunc* generateDLLExportWrapperFunc(IRBuilder& builder, IRFunc* originalFunc);
};
}
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
index 7464a1c35..52c8edca6 100644
--- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
+++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
@@ -208,9 +208,12 @@ struct AssociatedTypeLookupSpecializationContext
auto seqId = inst->findDecoration<IRSequentialIDDecoration>();
SLANG_ASSERT(seqId);
// Insert code to pack sequential ID into an uint2 at all use sites.
- for (auto use = inst->firstUse; use; )
+ IRUse* nextUse = nullptr;
+ for (auto use = inst->firstUse; use; use = nextUse)
{
- auto nextUse = use->nextUse;
+ nextUse = use->nextUse;
+ if (as<IRCOMWitnessDecoration>(use->getUser()))
+ continue;
IRBuilder builder(sharedContext->sharedBuilderStorage);
builder.setInsertBefore(use->getUser());
auto uint2Type = builder.getVectorType(
@@ -222,7 +225,6 @@ struct AssociatedTypeLookupSpecializationContext
use->set(uint2seqID);
use = nextUse;
}
- inst->replaceUsesWith(seqId->getSequentialIDOperand());
}
});
diff --git a/source/slang/slang-ir-strip-witness-tables.cpp b/source/slang/slang-ir-strip-witness-tables.cpp
index 8536508ba..4c8901c52 100644
--- a/source/slang/slang-ir-strip-witness-tables.cpp
+++ b/source/slang/slang-ir-strip-witness-tables.cpp
@@ -25,9 +25,12 @@ void stripWitnessTables(IRModule* module)
auto witnessTable = as<IRWitnessTable>(inst);
if(!witnessTable)
continue;
+ auto conformanceType = witnessTable->getConformanceType();
+ if (conformanceType && conformanceType->findDecoration<IRComInterfaceDecoration>())
+ continue;
witnessTable->removeAndDeallocateAllDecorationsAndChildren();
}
}
-} \ No newline at end of file
+}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index fda7b25cc..a515217d9 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -12,6 +12,18 @@ bool isPointerOfType(IRInst* type, IROp opCode)
return false;
}
+Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* interfaceType)
+{
+ Dictionary<IRInst*, IRInst*> result;
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
+ {
+ auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
+ if (!entry) continue;
+ result[entry->getRequirementKey()] = entry->getRequirementVal();
+ }
+ return result;
+}
+
bool isPointerOfType(IRInst* type, IRInst* elementType)
{
if (auto ptrType = as<IRPtrTypeBase>(type))
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 69a9c4d16..2c11374e3 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -17,6 +17,9 @@ bool isPointerOfType(IRInst* ptrType, IRInst* elementType);
// True if ptrType is a pointer type to a type of opCode
bool isPointerOfType(IRInst* ptrType, IROp opCode);
+// Builds a dictionary that maps from requirement key to requirement value for `interfaceType`.
+Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* interfaceType);
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index ca6435d8e..86586c2e8 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4487,9 +4487,9 @@ namespace Slang
return emitIntrinsicInst(getVoidType(), kIROp_ManagedPtrAttach, 2, args);
}
- IRInst* IRBuilder::emitManagedPtrDetach(IRInst* managedPtrVar)
+ IRInst* IRBuilder::emitManagedPtrDetach(IRType* type, IRInst* managedPtrVal)
{
- return emitIntrinsicInst(getVoidType(), kIROp_ManagedPtrDetach, 1, &managedPtrVar);
+ return emitIntrinsicInst(type, kIROp_ManagedPtrDetach, 1, &managedPtrVal);
}
IRInst* IRBuilder::emitGetManagedPtrWriteRef(IRInst* ptrToManagedPtr)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 81201f5f8..66f314d06 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1153,7 +1153,15 @@ static void addLinkageDecoration(
if (auto dllImportModifier = decl->findModifier<DllImportAttribute>())
{
auto libraryName = dllImportModifier->modulePath;
- builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), decl->getName()->text.getUnownedSlice());
+ auto functionName = dllImportModifier->functionName.getLength()
+ ? dllImportModifier->functionName.getUnownedSlice()
+ : decl->getName()->text.getUnownedSlice();
+ builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), functionName);
+ }
+ if (decl->findModifier<DllExportAttribute>())
+ {
+ builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice());
+ builder->addPublicDecoration(inst);
}
}
@@ -2409,6 +2417,12 @@ ParameterDirection getParameterDirection(VarDeclBase* paramDecl)
///
ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection defaultDirection)
{
+ // The `this` parameter for a `class` is always `in`.
+ if (as<ClassDecl>(parentDecl->parentDecl))
+ {
+ return kParameterDirection_In;
+ }
+
// Applications can opt in to a mutable `this` parameter,
// by applying the `[mutating]` attribute to their
// declaration.
@@ -5917,6 +5931,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addLinkageDecoration(context, irWitnessTable, inheritanceDecl, mangledName.getUnownedSlice());
+ // If the witness table is for a COM interface, always keep it alive.
+ if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
+ {
+ subBuilder->addPublicDecoration(irWitnessTable);
+ }
+
// TODO(JS):
// Not clear what to do here around HLSLExportModifier.
// In HLSL it only (currently) applies to functions, so perhaps do nothing is reasonable.
@@ -6596,7 +6616,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
if (auto comInterfaceAttr = decl->findModifier<ComInterfaceAttribute>())
{
- subBuilder->addComInterfaceDecoration(irInterface);
+ subBuilder->addComInterfaceDecoration(irInterface, comInterfaceAttr->guid.getUnownedSlice());
}
if (auto builtinAttr = decl->findModifier<BuiltinAttribute>())
{