diff options
Diffstat (limited to 'source')
27 files changed, 657 insertions, 45 deletions
diff --git a/source/slang-rt/slang-rt.h b/source/slang-rt/slang-rt.h index b6f397c72..3941b3acc 100644 --- a/source/slang-rt/slang-rt.h +++ b/source/slang-rt/slang-rt.h @@ -4,6 +4,7 @@ #include "../core/slang-string.h" #include "../core/slang-smart-pointer.h" +#include "../core/slang-com-object.h" #ifdef SLANG_RT_DYNAMIC_EXPORT # define SLANG_RT_API SLANG_DLL_EXPORT @@ -11,6 +12,15 @@ # define SLANG_RT_API #endif +#if defined(_MSC_VER) +# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport) +#else +# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default"))) +//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default"))) +#endif + +#define SLANG_PRELUDE_EXPORT extern "C" SLANG_PRELUDE_SHARED_LIB_EXPORT + extern "C" { SLANG_RT_API void SLANG_MCALL _slang_rt_abort(Slang::String errorMessage); diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index e1f7503a8..fb39b43f2 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2187,8 +2187,11 @@ attribute_syntax [__unsafeForceInlineEarly] : UnsafeForceInlineEarlyAttribute; __attributeTarget(FuncDecl) attribute_syntax [DllImport(modulePath: String)] : DllImportAttribute; +__attributeTarget(FuncDecl) +attribute_syntax [DllExport] : DllExportAttribute; + __attributeTarget(InterfaceDecl) -attribute_syntax [COM] : ComInterfaceAttribute; +attribute_syntax [COM(guid: String)] : ComInterfaceAttribute; // Inheritance Control __attributeTarget(AggTypeDecl) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 62a32045b..f66867542 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -940,12 +940,21 @@ class DllImportAttribute : public Attribute SLANG_AST_CLASS(DllImportAttribute) String modulePath; + + String functionName; +}; + +class DllExportAttribute : public Attribute +{ + SLANG_AST_CLASS(DllExportAttribute) }; /// An attribute that marks an interface type as a COM interface declaration. class ComInterfaceAttribute : public Attribute { SLANG_AST_CLASS(ComInterfaceAttribute) + + String guid; }; /// A `[__requiresNVAPI]` attribute indicates that the declaration being modifed diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 28164c126..16a1cae26 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1,5 +1,6 @@ // slang-check-modifier.cpp #include "slang-check-impl.h" +#include "../core/slang-char-util.h" // This file implements semantic checking behavior for // modifiers. @@ -568,7 +569,7 @@ namespace Slang } else if (auto dllImportAttr = as<DllImportAttribute>(attr)) { - SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(attr->args.getCount() == 1 || attr->args.getCount() == 2); String libraryName; if (!checkLiteralStringVal(dllImportAttr->args[0], &libraryName)) @@ -576,6 +577,13 @@ namespace Slang return false; } dllImportAttr->modulePath = libraryName; + + String functionName; + if (dllImportAttr->args.getCount() == 2 && !checkLiteralStringVal(dllImportAttr->args[1], &functionName)) + { + return false; + } + dllImportAttr->functionName = functionName; } else if (auto rayPayloadAttr = as<VulkanRayPayloadAttribute>(attr)) { @@ -606,6 +614,38 @@ namespace Slang customJVPAttr->funcDeclRef = funcExpr; } + else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + String guid; + if (!checkLiteralStringVal(comInterfaceAttr->args[0], &guid)) + { + return false; + } + StringBuilder resultGUID; + for (auto ch : guid) + { + if (CharUtil::isHexDigit(ch)) + { + resultGUID.appendChar(ch); + } + else if (ch == '-') + { + continue; + } + else + { + getSink()->diagnose(attr, Diagnostics::invalidGUID, guid); + return false; + } + } + comInterfaceAttr->guid = resultGUID.ToString(); + if (comInterfaceAttr->guid.getLength() != 32) + { + getSink()->diagnose(attr, Diagnostics::invalidGUID, guid); + return false; + } + } else { if(attr->args.getCount() == 0) @@ -617,7 +657,7 @@ namespace Slang { // We should be special-casing the checking of any attribute // with a non-zero number of arguments. - SLANG_DIAGNOSE_UNEXPECTED(getSink(), attr, "unhandled attribute"); + getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.getCount(), 0); return false; } } diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index be351a78f..7f2386eb4 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -86,7 +86,7 @@ namespace Slang getSink()->diagnose(context.originalExpr, Diagnostics::newCanOnlyBeUsedToInitializeAClass); return false; } - if (!isNewExpr && isClassType) + if (!isNewExpr && isClassType && context.originalExpr) { getSink()->diagnose(context.originalExpr, Diagnostics::classCanOnlyBeInitializedWithNew); return false; diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 85f5f5fb0..2b12c3de4 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -322,6 +322,7 @@ DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user DIAGNOSTIC(31121, Error, anyValueSizeExceedsLimit, "'anyValueSize' cannot exceed $0") DIAGNOSTIC(31122, Error, associatedTypeNotAllowInComInterface, "associatedtype not allowed in a [COM] interface") +DIAGNOSTIC(31123, Error, invalidGUID, "'$0' is not a valid GUID") // Enums diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 1e0eb614d..c3ae31894 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -344,6 +344,50 @@ void CLikeSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable) SLANG_UNUSED(witnessTable); } +void CLikeSourceEmitter::emitComWitnessTable(IRWitnessTable* witnessTable) +{ + auto classType = witnessTable->getConcreteType(); + for (auto ent : witnessTable->getEntries()) + { + auto req = ent->getRequirementKey(); + auto func = as<IRFunc>(ent->getSatisfyingVal()); + if (!func) + continue; + + auto resultType = func->getResultType(); + + auto name = getName(classType) + "::" + getName(req); + + emitFuncDecorations(func); + + emitType(resultType, name); + m_writer->emit("("); + // Skip declaration of `this` parameter. + auto firstParam = func->getFirstParam()->getNextParam(); + for (auto pp = firstParam; pp; pp = pp->getNextParam()) + { + if (pp != firstParam) + m_writer->emit(", "); + + emitSimpleFuncParamImpl(pp); + } + m_writer->emit(")"); + m_writer->emit("\n{\n"); + m_writer->indent(); + + // emit definition for `this` param. + m_writer->emit("auto "); + m_writer->emit(getName(func->getFirstParam())); + m_writer->emit(" = this;\n"); + + // Need to emit the operations in the blocks of the function + emitFunctionBody(func); + + m_writer->dedent(); + m_writer->emit("}\n\n"); + } +} + void CLikeSourceEmitter::emitInterface(IRInterfaceType* interfaceType) { SLANG_UNUSED(interfaceType); @@ -1869,7 +1913,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO { auto prec = getInfo(EmitOp::Postfix); needClose = maybeEmitParens(outerPrec, prec); - emitDereferenceOperand(inst->getOperand(0), leftSide(outerPrec, prec)); + emitOperand(inst->getOperand(0), leftSide(outerPrec, prec)); m_writer->emit(".detach()"); break; } @@ -2630,6 +2674,12 @@ void CLikeSourceEmitter::emitParamTypeImpl(IRType* type, String const& name) void CLikeSourceEmitter::emitFuncDecl(IRFunc* func) { + auto name = getName(func); + emitFuncDecl(func, name); +} + +void CLikeSourceEmitter::emitFuncDecl(IRFunc* func, const String& name) +{ // We don't want to emit declarations for operations // that only appear in the IR as stand-ins for built-in // operations on that target. @@ -2652,8 +2702,6 @@ void CLikeSourceEmitter::emitFuncDecl(IRFunc* func) auto funcType = func->getDataType(); auto resultType = func->getResultType(); - auto name = getName(func); - emitFuncDecorations(func); emitType(resultType, name); @@ -2790,17 +2838,59 @@ void CLikeSourceEmitter::emitClass(IRClassType* classType) { return; } - + List<IRWitnessTable*> comWitnessTables; + for (auto child : classType->getDecorations()) + { + if (auto decoration = as<IRCOMWitnessDecoration>(child)) + { + comWitnessTables.add(cast<IRWitnessTable>(decoration->getWitnessTable())); + } + } m_writer->emit("class "); emitPostKeywordTypeAttributes(classType); m_writer->emit(getName(classType)); - m_writer->emit(" : public RefObject"); + if (comWitnessTables.getCount() == 0) + { + m_writer->emit(" : public RefObject"); + } + else + { + m_writer->emit(" : public ComObject"); + for (auto wt : comWitnessTables) + { + m_writer->emit(", public "); + m_writer->emit(getName(wt->getConformanceType())); + } + } m_writer->emit("\n{\n"); m_writer->emit("public:\n"); m_writer->indent(); + if (comWitnessTables.getCount()) + { + m_writer->emit("SLANG_COM_OBJECT_IUNKNOWN_ALL\n"); + m_writer->emit("void* getInterface(const Guid & uuid)\n{\n"); + m_writer->indent(); + m_writer->emit("if (uuid == ISlangUnknown::getTypeGuid()) return static_cast<ISlangUnknown*>(this);\n"); + for (auto wt : comWitnessTables) + { + auto interfaceName = getName(wt->getConformanceType()); + m_writer->emit("if (uuid == "); + m_writer->emit(interfaceName); + m_writer->emit("::getTypeGuid())\n"); + m_writer->indent(); + m_writer->emit("return static_cast<"); + m_writer->emit(interfaceName); + m_writer->emit("*>(this);\n"); + m_writer->dedent(); + } + m_writer->emit("return nullptr;\n"); + m_writer->dedent(); + m_writer->emit("}\n"); + } + for (auto ff : classType->getFields()) { auto fieldKey = ff->getKey(); @@ -2818,6 +2908,28 @@ void CLikeSourceEmitter::emitClass(IRClassType* classType) m_writer->emit(";\n"); } + // Emit COM method declarations. + for (auto wt : comWitnessTables) + { + for (auto wtEntry : wt->getChildren()) + { + auto req = as<IRWitnessTableEntry>(wtEntry); + if (!req) continue; + auto func = as<IRFunc>(req->getSatisfyingVal()); + if (!func) continue; + m_writer->emit("virtual SLANG_NO_THROW "); + emitType(func->getResultType(), "SLANG_MCALL " + getName(req->getRequirementKey())); + m_writer->emit("("); + auto param = func->getFirstParam(); + param = param->getNextParam(); + for (; param; param = param->getNextParam()) + { + emitParamType(param->getFullType(), getName(param)); + } + m_writer->emit(") override;\n"); + } + } + m_writer->dedent(); m_writer->emit("};\n\n"); } @@ -3135,6 +3247,32 @@ void CLikeSourceEmitter::emitGlobalInst(IRInst* inst) emitGlobalInstImpl(inst); } +static bool _shouldSkipFuncEmit(IRInst* func) +{ + // Skip emitting a func if it is a COM interface wrapper implementation and used + // only by the witness table. We will emit this func differently than normal funcs + // and this is handled by `emitComWitnessTable`. + + if (func->hasMoreThanOneUse()) return false; + if (func->firstUse) + { + if (auto entry = as<IRWitnessTableEntry>(func->firstUse->getUser())) + { + if (auto table = as<IRWitnessTable>(entry->getParent())) + { + if (auto interfaceType = table->getConformanceType()) + { + if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + { + return true; + } + } + } + } + } + return false; +} + void CLikeSourceEmitter::emitGlobalInstImpl(IRInst* inst) { m_writer->advanceToSourceLocation(inst->sourceLoc); @@ -3152,7 +3290,10 @@ void CLikeSourceEmitter::emitGlobalInstImpl(IRInst* inst) break; case kIROp_Func: - emitFunc((IRFunc*) inst); + if (!_shouldSkipFuncEmit(inst)) + { + emitFunc((IRFunc*) inst); + } break; case kIROp_GlobalVar: @@ -3219,6 +3360,9 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I case kIROp_InterfaceType: requiredLevel = EmitAction::ForwardDeclaration; break; + case kIROp_COMWitnessDecoration: + requiredLevel = EmitAction::ForwardDeclaration; + break; default: break; } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 825223278..d85e336e0 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -339,6 +339,8 @@ public: void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); } void emitFuncDecl(IRFunc* func); + void emitFuncDecl(IRFunc* func, const String& name); + IREntryPointLayout* getEntryPointLayout(IRFunc* func); @@ -472,6 +474,8 @@ public: virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) = 0; virtual void emitWitnessTable(IRWitnessTable* witnessTable); + void emitComWitnessTable(IRWitnessTable* witnessTable); + virtual void emitInterface(IRInterfaceType* interfaceType); virtual void emitRTTIObject(IRRTTIObject* rttiObject); diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 96b42af43..5df13cbde 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1677,7 +1677,11 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable) if (isBuiltin(interfaceType)) return; - auto witnessTableItems = witnessTable->getChildren(); + if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + { + pendingWitnessTableDefinitions.add(witnessTable); + return; + } // Declare a global variable for the witness table. m_writer->emit("extern \"C\" { "); @@ -1701,6 +1705,11 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() for (auto witnessTable : pendingWitnessTableDefinitions) { auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType()); + if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + { + emitComWitnessTable(witnessTable); + continue; + } List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable); m_writer->emit("extern \"C\"\n{\n"); m_writer->indent(); @@ -1750,6 +1759,12 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() void CPPSourceEmitter::emitComInterface(IRInterfaceType* interfaceType) { + auto comDecoration = interfaceType->findDecoration<IRComInterfaceDecoration>(); + auto guidInst = as<IRStringLit>(comDecoration->getOperand(0)); + SLANG_RELEASE_ASSERT(guidInst); + auto guid = guidInst->getStringSlice(); + SLANG_RELEASE_ASSERT(guid.getLength() == 32); + m_writer->emit("struct "); emitSimpleType(interfaceType); m_writer->emit(" : "); @@ -1779,6 +1794,23 @@ void CPPSourceEmitter::emitComInterface(IRInterfaceType* interfaceType) // Emit methods. m_writer->emit("\n{\n"); m_writer->indent(); + // Emit GUID. + m_writer->emit("SLANG_COM_INTERFACE(0x"); + m_writer->emit(guid.subString(0, 8)); + m_writer->emit(", 0x"); + m_writer->emit(guid.subString(8, 4)); + m_writer->emit(", 0x"); + m_writer->emit(guid.subString(12, 4)); + m_writer->emit(", { "); + for (UInt i = 0; i < 8; i++) + { + if (i > 0) + m_writer->emit(", "); + m_writer->emit("0x"); + m_writer->emit(guid.subString(16 + i * 2, 2)); + } + m_writer->emit(" })\n"); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); @@ -2476,6 +2508,33 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut } return true; } + case kIROp_MakeExistential: + case kIROp_MakeExistentialWithRTTI: + { + auto rsType = cast<IRComPtrType>(inst->getDataType()); + m_writer->emit("ComPtr<"); + m_writer->emit(getName(rsType->getOperand(0))); + m_writer->emit(">("); + m_writer->emit("static_cast<"); + m_writer->emit(getName(rsType->getOperand(0))); + m_writer->emit("*>("); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), prec)); + m_writer->emit(".Ptr()"); + m_writer->emit("))"); + return true; + } + case kIROp_GetValueFromBoundInterface: + { + m_writer->emit("static_cast<"); + m_writer->emit(getName(inst->getFullType())); + m_writer->emit("*>("); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), prec)); + m_writer->emit(".get()"); + m_writer->emit(")"); + return true; + } } } @@ -2604,9 +2663,10 @@ void CPPSourceEmitter::emitVarDecorationsImpl(IRInst* inst) Super::emitVarDecorationsImpl(inst); } - -void CPPSourceEmitter::_maybeEmitExportLike(IRInst* inst) +void CPPSourceEmitter::_getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC) { + outIsExport = false; + outIsExternC = false; // Specially handle export, as we don't want to emit it multiple times if (getTargetReq()->isWholeProgramRequest()) { @@ -2619,34 +2679,38 @@ void CPPSourceEmitter::_maybeEmitExportLike(IRInst* inst) } } - bool isExternC = false; - bool isExported = false; - // If public/export made it externally visible for (auto decoration : inst->getDecorations()) { const auto op = decoration->getOp(); if (op == kIROp_ExternCppDecoration) { - isExternC = true; + outIsExternC = true; } else if (op == kIROp_PublicDecoration || op == kIROp_HLSLExportDecoration) { - isExported = true; + outIsExport = true; } } + } +} - // TODO(JS): Currently export *also* implies it's extern "C" and we can't list twice - if (isExported) - { - m_writer->emit("SLANG_PRELUDE_EXPORT\n"); - } - else if (isExternC) - { - // It's name is not manged. - m_writer->emit("extern \"C\"\n"); - } +void CPPSourceEmitter::_maybeEmitExportLike(IRInst* inst) +{ + bool isExternC = false; + bool isExported = false; + _getExportStyle(inst, isExternC, isExported); + + // TODO(JS): Currently export *also* implies it's extern "C" and we can't list twice + if (isExported) + { + m_writer->emit("SLANG_PRELUDE_EXPORT\n"); + } + else if (isExternC) + { + // It's name is not manged. + m_writer->emit("extern \"C\"\n"); } } diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index 6199c33f2..c5b9f3d9c 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -133,6 +133,7 @@ protected: void _emitWitnessTableDefinitions(); /// Maybe emits 'export' (such that visible outside binary/dll) and `extern "C"` naming + void _getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC); void _maybeEmitExportLike(IRInst* inst); HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 9b7fcfa38..1ded8668f 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -9,6 +9,7 @@ #include "slang-ir-byte-address-legalize.h" #include "slang-ir-collect-global-uniforms.h" #include "slang-ir-dce.h" +#include "slang-ir-dll-export.h" #include "slang-ir-dll-import.h" #include "slang-ir-eliminate-phis.h" #include "slang-ir-entry-point-uniforms.h" @@ -211,6 +212,7 @@ Result linkAndOptimizeIR( { lowerComInterfaces(irModule, artifactDesc.style, sink); generateDllImportFuncs(irModule, sink); + generateDllExportFuncs(irModule, sink); break; } default: break; diff --git a/source/slang/slang-ir-dll-export.cpp b/source/slang/slang-ir-dll-export.cpp new file mode 100644 index 000000000..a8b464b43 --- /dev/null +++ b/source/slang/slang-ir-dll-export.cpp @@ -0,0 +1,72 @@ +// slang-ir-dll-export.cpp +#include "slang-ir-dll-export.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-marshal-native-call.h" + +namespace Slang +{ + +struct DllExportContext +{ + IRModule* module; + DiagnosticSink* diagnosticSink; + + SharedIRBuilder sharedBuilder; + + void processFunc(IRFunc* func, IRDllExportDecoration* dllExportDecoration) + { + NativeCallMarshallingContext marshalContext; + marshalContext.diagnosticSink = diagnosticSink; + IRBuilder builder(sharedBuilder); + auto wrapper = marshalContext.generateDLLExportWrapperFunc(builder, func); + dllExportDecoration->removeFromParent(); + dllExportDecoration->insertAtStart(wrapper); + builder.addNameHintDecoration(wrapper, dllExportDecoration->getFunctionName()); + builder.addExternCppDecoration(wrapper, dllExportDecoration->getFunctionName()); + builder.addPublicDecoration(wrapper); + builder.addKeepAliveDecoration(wrapper); + builder.addHLSLExportDecoration(wrapper); + if (auto oldPublicDecoration = func->findDecoration<IRPublicDecoration>()) + { + oldPublicDecoration->removeFromParent(); + } + } + + void processModule() + { + struct Candidate { IRFunc* func; IRDllExportDecoration* exportDecoration; }; + List<Candidate> candidates; + for (auto childFunc : module->getGlobalInsts()) + { + switch(childFunc->getOp()) + { + case kIROp_Func: + if (auto dllExportDecoration = childFunc->findDecoration<IRDllExportDecoration>()) + { + candidates.add(Candidate{ as<IRFunc>(childFunc), dllExportDecoration }); + } + break; + default: + break; + } + } + + for (auto candidate : candidates) + { + processFunc(candidate.func, candidate.exportDecoration); + } + } +}; + +void generateDllExportFuncs(IRModule* module, DiagnosticSink* sink) +{ + DllExportContext context; + context.module = module; + context.diagnosticSink = sink; + context.sharedBuilder.init(module); + return context.processModule(); +} + +} diff --git a/source/slang/slang-ir-dll-export.h b/source/slang/slang-ir-dll-export.h new file mode 100644 index 000000000..eb7c2792e --- /dev/null +++ b/source/slang/slang-ir-dll-export.h @@ -0,0 +1,10 @@ +// slang-ir-dll-export.h +#pragma once + +namespace Slang +{ + struct IRModule; + class DiagnosticSink; + /// Generate wrappers for functions marked as [DllExport]. + void generateDllExportFuncs(IRModule* module, DiagnosticSink* sink); +} diff --git a/source/slang/slang-ir-dll-import.cpp b/source/slang/slang-ir-dll-import.cpp index b123dfb03..43ed9a102 100644 --- a/source/slang/slang-ir-dll-import.cpp +++ b/source/slang/slang-ir-dll-import.cpp @@ -123,11 +123,19 @@ struct DllImportContext builder.emitIf(isUninitialized, trueBlock, afterBlock); builder.setInsertInto(trueBlock); - auto modulePtr = builder.emitCallInst( - builder.getPtrType(builder.getVoidType()), - getLoadDllFunc(), - builder.getStringValue(dllImportDecoration->getLibraryName())); + IRInst* modulePtr; + if (dllImportDecoration->getLibraryName() == "") + { + modulePtr = builder.getIntValue(builder.getIntType(), 0); + } + else + { + modulePtr = builder.emitCallInst( + builder.getPtrType(builder.getVoidType()), + getLoadDllFunc(), + builder.getStringValue(dllImportDecoration->getLibraryName())); + } IRInst* loadDllFuncArgs[] = { modulePtr, builder.getStringValue(dllImportDecoration->getFunctionName())}; auto loadedNativeFuncPtr = builder.emitCallInst( diff --git a/source/slang/slang-ir-dll-import.h b/source/slang/slang-ir-dll-import.h index c330f803f..d2dc1a9a3 100644 --- a/source/slang/slang-ir-dll-import.h +++ b/source/slang/slang-ir-dll-import.h @@ -7,5 +7,4 @@ namespace Slang class DiagnosticSink; /// Generate implementations for functions marked as [DllImport]. void generateDllImportFuncs(IRModule* module, DiagnosticSink* sink); - } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index a91c21434..8a62e34d4 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -645,6 +645,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// An dllImport decoration marks a function as imported from a DLL. Slang will generate dynamic function loading logic to use this function at runtime. INST(DllImportDecoration, dllImport, 2, 0) + /// An dllExport decoration marks a function as an export symbol. Slang will generate a native wrapper function that is exported to DLL. + INST(DllExportDecoration, dllExport, 1, 0) /// Marks an interface as a COM interface declaration. INST(ComInterfaceDecoration, COMInterface, 0, 0) @@ -690,6 +692,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// generated derivative function. INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0) + /// Marks a class type as a COM interface implementation, which enables + /// the witness table to be easily picked up by emit. + INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0) + /// Marks a struct type as being used as a structured buffer block. /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index b8cce4c17..f5d5a86ac 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -186,6 +186,17 @@ struct IRComInterfaceDecoration : IRDecoration IR_LEAF_ISA(ComInterfaceDecoration) }; +struct IRCOMWitnessDecoration : IRDecoration +{ + enum + { + kOp = kIROp_COMWitnessDecoration + }; + IR_LEAF_ISA(COMWitnessDecoration) + + IRInst* getWitnessTable() { return getOperand(0); } +}; + /// A decoration on `IRParam`s that represent generic parameters, /// marking the interface type that the generic parameter conforms to. /// A generic parameter can have more than one `IRTypeConstraintDecoration`s @@ -457,6 +468,18 @@ struct IRDllImportDecoration : IRDecoration UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } }; +struct IRDllExportDecoration : IRDecoration +{ + enum + { + kOp = kIROp_DllExportDecoration + }; + IR_LEAF_ISA(DllExportDecoration) + + IRStringLit* getFunctionNameOperand() { return cast<IRStringLit>(getOperand(0)); } + UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } +}; + struct IRFormatDecoration : IRDecoration { enum { kOp = kIROp_FormatDecoration }; @@ -2521,7 +2544,7 @@ public: IRInst* emitManagedPtrAttach(IRInst* managedPtrVar, IRInst* value); - IRInst* emitManagedPtrDetach(IRInst* managedPtrVar); + IRInst* emitManagedPtrDetach(IRType* type, IRInst* managedPtrVal); IRInst* emitGetNativePtr(IRInst* value); @@ -3052,11 +3075,21 @@ public: addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn); } + void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) + { + addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1); + } + void addDllImportDecoration(IRInst* value, UnownedStringSlice const& libraryName, UnownedStringSlice const& functionName) { addDecoration(value, kIROp_DllImportDecoration, getStringValue(libraryName), getStringValue(functionName)); } + void addDllExportDecoration(IRInst* value, UnownedStringSlice const& functionName) + { + addDecoration(value, kIROp_DllExportDecoration, getStringValue(functionName)); + } + void addEntryPointDecoration(IRInst* value, Profile profile, UnownedStringSlice const& name, UnownedStringSlice const& moduleName) { IRInst* operands[] = { getIntValue(getIntType(), profile.raw), getStringValue(name), getStringValue(moduleName) }; @@ -3117,9 +3150,9 @@ public: addDecoration(inst, kIROp_AnyValueSizeDecoration, getIntValue(getIntType(), value)); } - void addComInterfaceDecoration(IRInst* inst) + void addComInterfaceDecoration(IRInst* inst, UnownedStringSlice guid) { - addDecoration(inst, kIROp_ComInterfaceDecoration); + addDecoration(inst, kIROp_ComInterfaceDecoration, getStringValue(guid)); } void addTypeConstraintDecoration(IRInst* inst, IRInst* constraintType) diff --git a/source/slang/slang-ir-lower-com-methods.cpp b/source/slang/slang-ir-lower-com-methods.cpp index 6c3a3f289..6d5ddb261 100644 --- a/source/slang/slang-ir-lower-com-methods.cpp +++ b/source/slang/slang-ir-lower-com-methods.cpp @@ -6,6 +6,7 @@ #include "slang-ir-insts.h" #include "slang-ir-marshal-native-call.h" #include "slang-ir-inst-pass-base.h" +#include "slang-ir-util.h" namespace Slang { @@ -102,6 +103,37 @@ struct ComMethodLoweringContext : public InstPassBase } } + void processWitnessTable(IRWitnessTable* witnessTable) + { + auto interfaceType = as<IRInterfaceType>(witnessTable->getConformanceType()); + if (!interfaceType) return; + if (!interfaceType->findDecoration<IRComInterfaceDecoration>()) + return; + auto interfaceReqDict = buildInterfaceRequirementDict(interfaceType); + + IRBuilder builder(&sharedBuilderStorage); + NativeCallMarshallingContext marshalContext; + marshalContext.diagnosticSink = diagnosticSink; + for (auto entry : witnessTable->getEntries()) + { + IRInst* interfaceRequirement = nullptr; + if (!interfaceReqDict.TryGetValue(entry->getRequirementKey(), interfaceRequirement)) + continue; + auto implFunc = as<IRFunc>(entry->getSatisfyingVal()); + if (!implFunc) continue; + // If the function already has the same signature as the lowered COM interface method, + // we don't need to do anything. + if (isTypeEqual(entry->getSatisfyingVal()->getDataType(), (IRType*)interfaceRequirement)) + continue; + // Now we need to generate a wrapper function that calls into the original one. + auto nativeFunc = marshalContext.generateDLLExportWrapperFunc(builder, implFunc); + entry->setOperand(1, nativeFunc); + } + + auto classType = witnessTable->getConcreteType(); + builder.addCOMWitnessDecoration(classType, witnessTable); + } + void processModule() { sharedBuilderStorage.init(module); @@ -124,6 +156,9 @@ struct ComMethodLoweringContext : public InstPassBase // Update func types of COM interfaces. processInstsOfType<IRInterfaceType>(kIROp_InterfaceType, [this](IRInterfaceType* inst) { processInterfaceType(inst); }); + // Update witness tables of classes that implement COM interfaces. + // Generate native-to-managed wrappers for each witness table entry. + processInstsOfType<IRWitnessTable>(kIROp_WitnessTable, [this](IRWitnessTable* table) { processWitnessTable(table); }); } }; diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp index b0d9e6f2f..cfdacc7ac 100644 --- a/source/slang/slang-ir-lower-existential.cpp +++ b/source/slang/slang-ir-lower-existential.cpp @@ -24,6 +24,8 @@ namespace Slang auto valueType = sharedContext->lowerType(builder, value->getDataType()); auto witnessTableType = cast<IRWitnessTableTypeBase>(inst->getWitnessTable()->getDataType()); auto interfaceType = witnessTableType->getConformanceType(); + if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + return; auto witnessTableIdType = builder->getWitnessTableIDType((IRType*)interfaceType); auto anyValueSize = sharedContext->getInterfaceAnyValueSize(interfaceType, inst->sourceLoc); auto anyValueType = builder->getAnyValueType(anyValueSize); @@ -139,7 +141,10 @@ namespace Slang IRBuilder builderStorage(sharedContext->sharedBuilderStorage); auto builder = &builderStorage; builder->setInsertBefore(inst); - + if (inst->getDataType()->getOp() == kIROp_ClassType) + { + return; + } // A value of interface will lower as a tuple, and // the third element of that tuple represents the // concrete value that was put into the existential. diff --git a/source/slang/slang-ir-marshal-native-call.cpp b/source/slang/slang-ir-marshal-native-call.cpp index 8e342cb26..8a97bfd3e 100644 --- a/source/slang/slang-ir-marshal-native-call.cpp +++ b/source/slang/slang-ir-marshal-native-call.cpp @@ -123,6 +123,123 @@ namespace Slang } } + void NativeCallMarshallingContext::marshalManagedValueToNativeResultValue( + IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args) + { + switch (originalArg->getDataType()->getOp()) + { + case kIROp_InOutType: + case kIROp_RefType: + SLANG_UNREACHABLE("out and ref types should be handled before reaching here."); + break; + case kIROp_StringType: + { + diagnosticSink->diagnose(originalArg, Diagnostics::unimplemented, "marshal string to native return value"); + } + break; + case kIROp_ClassType: + { + diagnosticSink->diagnose(originalArg, Diagnostics::unimplemented, "marshal class to native return value"); + } + break; + case kIROp_InterfaceType: + { + auto nativePtr = builder.emitManagedPtrDetach(builder.getNativePtrType(originalArg->getDataType()), originalArg); + args.add(nativePtr); + } + break; + case kIROp_ComPtrType: + { + auto nativePtr = builder.emitManagedPtrDetach( + builder.getNativePtrType( + (IRType*)cast<IRComPtrType>(originalArg->getDataType())->getOperand(0)), + originalArg); + args.add(nativePtr); + } + break; + default: + args.add(originalArg); + break; + } + } + + IRInst* NativeCallMarshallingContext::marshalNativeArgToManagedArg( + IRBuilder& builder, const List<IRInst*>& args, Index& consumeIndex, IRType* expectedManagedType) + { + // For now, all managed values maps to one native value, so we just call `marshalNativeValueToManagedValue`. + // This function can be extended in the future to support things like `List` that maps to more than one + // native args. + SLANG_UNUSED(expectedManagedType); + auto managedVal = marshalNativeValueToManagedValue(builder, args[consumeIndex]); + consumeIndex++; + return managedVal; + } + + IRFunc* NativeCallMarshallingContext::generateDLLExportWrapperFunc(IRBuilder& builder, IRFunc* originalFunc) + { + builder.setInsertBefore(originalFunc); + auto funcType = getNativeFuncType(builder, originalFunc->getDataType()); + auto newFunc = builder.createFunc(); + newFunc->setFullType(funcType); + builder.setInsertInto(newFunc); + builder.emitBlock(); + List<IRInst*> params; + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + auto paramType = funcType->getParamType(i); + params.add(builder.emitParam(paramType)); + } + List<IRInst*> args; + Index nativeParamConsumeIndex = 0; + for (UInt i = 0; i < originalFunc->getParamCount(); i++) + { + auto managedParamType = originalFunc->getParamType(i); + auto managedArg = marshalNativeArgToManagedArg(builder, params, nativeParamConsumeIndex, managedParamType); + args.add(managedArg); + } + auto originalReturnType = originalFunc->getResultType(); + auto callInst = builder.emitCallInst(originalReturnType, originalFunc, args); + if (auto resultType = as<IRResultType>(originalReturnType)) + { + auto isResultError = builder.emitIsResultError(callInst); + IRBlock* trueBlock = nullptr; + IRBlock* falseBlock = nullptr; + IRBlock* afterBlock = nullptr; + builder.emitIfElseWithBlocks(isResultError, trueBlock, falseBlock, afterBlock); + + builder.setInsertInto(trueBlock); + builder.emitReturn(builder.emitGetResultError(callInst)); + + builder.setInsertInto(falseBlock); + auto resultVal = builder.emitGetResultValue(callInst); + List<IRInst*> nativeVals; + marshalManagedValueToNativeResultValue(builder, resultVal, nativeVals); + for (Index i = 0; i < nativeVals.getCount(); i++) + { + SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount()); + builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]); + nativeParamConsumeIndex++; + } + builder.emitReturn(builder.getIntValue(builder.getIntType(), 0)); + + builder.setInsertInto(afterBlock); + builder.emitUnreachable(); + } + else + { + List<IRInst*> nativeVals; + marshalManagedValueToNativeResultValue(builder, callInst, nativeVals); + for (Index i = 1; i < nativeVals.getCount(); i++) + { + SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount()); + builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]); + nativeParamConsumeIndex++; + } + builder.emitReturn(nativeVals[0]); + } + return newFunc; + } + IRInst* NativeCallMarshallingContext::marshalNativeCall( IRBuilder& builder, IRFuncType* originalFuncType, diff --git a/source/slang/slang-ir-marshal-native-call.h b/source/slang/slang-ir-marshal-native-call.h index bbc2078ea..e70d177ac 100644 --- a/source/slang/slang-ir-marshal-native-call.h +++ b/source/slang/slang-ir-marshal-native-call.h @@ -41,10 +41,19 @@ namespace Slang void marshalRefManagedValueToNativeValue( IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args); + // Marshal a managed value to a native value for input into a native functions. void marshalManagedValueToNativeValue( IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args); + // Marshal a managed value to a native value for the return value of a native function. + void marshalManagedValueToNativeResultValue( + IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args); + IRInst* marshalNativeValueToManagedValue( IRBuilder& builder, IRInst* nativeValue); + + IRInst* marshalNativeArgToManagedArg(IRBuilder& builder, const List<IRInst*>& args, Index& consumeIndex, IRType* expectedManagedType); + + IRFunc* generateDLLExportWrapperFunc(IRBuilder& builder, IRFunc* originalFunc); }; } diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index 7464a1c35..52c8edca6 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -208,9 +208,12 @@ struct AssociatedTypeLookupSpecializationContext auto seqId = inst->findDecoration<IRSequentialIDDecoration>(); SLANG_ASSERT(seqId); // Insert code to pack sequential ID into an uint2 at all use sites. - for (auto use = inst->firstUse; use; ) + IRUse* nextUse = nullptr; + for (auto use = inst->firstUse; use; use = nextUse) { - auto nextUse = use->nextUse; + nextUse = use->nextUse; + if (as<IRCOMWitnessDecoration>(use->getUser())) + continue; IRBuilder builder(sharedContext->sharedBuilderStorage); builder.setInsertBefore(use->getUser()); auto uint2Type = builder.getVectorType( @@ -222,7 +225,6 @@ struct AssociatedTypeLookupSpecializationContext use->set(uint2seqID); use = nextUse; } - inst->replaceUsesWith(seqId->getSequentialIDOperand()); } }); diff --git a/source/slang/slang-ir-strip-witness-tables.cpp b/source/slang/slang-ir-strip-witness-tables.cpp index 8536508ba..4c8901c52 100644 --- a/source/slang/slang-ir-strip-witness-tables.cpp +++ b/source/slang/slang-ir-strip-witness-tables.cpp @@ -25,9 +25,12 @@ void stripWitnessTables(IRModule* module) auto witnessTable = as<IRWitnessTable>(inst); if(!witnessTable) continue; + auto conformanceType = witnessTable->getConformanceType(); + if (conformanceType && conformanceType->findDecoration<IRComInterfaceDecoration>()) + continue; witnessTable->removeAndDeallocateAllDecorationsAndChildren(); } } -}
\ No newline at end of file +} diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index fda7b25cc..a515217d9 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -12,6 +12,18 @@ bool isPointerOfType(IRInst* type, IROp opCode) return false; } +Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* interfaceType) +{ + Dictionary<IRInst*, IRInst*> result; + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); + if (!entry) continue; + result[entry->getRequirementKey()] = entry->getRequirementVal(); + } + return result; +} + bool isPointerOfType(IRInst* type, IRInst* elementType) { if (auto ptrType = as<IRPtrTypeBase>(type)) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 69a9c4d16..2c11374e3 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -17,6 +17,9 @@ bool isPointerOfType(IRInst* ptrType, IRInst* elementType); // True if ptrType is a pointer type to a type of opCode bool isPointerOfType(IRInst* ptrType, IROp opCode); +// Builds a dictionary that maps from requirement key to requirement value for `interfaceType`. +Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* interfaceType); + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ca6435d8e..86586c2e8 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4487,9 +4487,9 @@ namespace Slang return emitIntrinsicInst(getVoidType(), kIROp_ManagedPtrAttach, 2, args); } - IRInst* IRBuilder::emitManagedPtrDetach(IRInst* managedPtrVar) + IRInst* IRBuilder::emitManagedPtrDetach(IRType* type, IRInst* managedPtrVal) { - return emitIntrinsicInst(getVoidType(), kIROp_ManagedPtrDetach, 1, &managedPtrVar); + return emitIntrinsicInst(type, kIROp_ManagedPtrDetach, 1, &managedPtrVal); } IRInst* IRBuilder::emitGetManagedPtrWriteRef(IRInst* ptrToManagedPtr) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 81201f5f8..66f314d06 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1153,7 +1153,15 @@ static void addLinkageDecoration( if (auto dllImportModifier = decl->findModifier<DllImportAttribute>()) { auto libraryName = dllImportModifier->modulePath; - builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), decl->getName()->text.getUnownedSlice()); + auto functionName = dllImportModifier->functionName.getLength() + ? dllImportModifier->functionName.getUnownedSlice() + : decl->getName()->text.getUnownedSlice(); + builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), functionName); + } + if (decl->findModifier<DllExportAttribute>()) + { + builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addPublicDecoration(inst); } } @@ -2409,6 +2417,12 @@ ParameterDirection getParameterDirection(VarDeclBase* paramDecl) /// ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection defaultDirection) { + // The `this` parameter for a `class` is always `in`. + if (as<ClassDecl>(parentDecl->parentDecl)) + { + return kParameterDirection_In; + } + // Applications can opt in to a mutable `this` parameter, // by applying the `[mutating]` attribute to their // declaration. @@ -5917,6 +5931,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addLinkageDecoration(context, irWitnessTable, inheritanceDecl, mangledName.getUnownedSlice()); + // If the witness table is for a COM interface, always keep it alive. + if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>()) + { + subBuilder->addPublicDecoration(irWitnessTable); + } + // TODO(JS): // Not clear what to do here around HLSLExportModifier. // In HLSL it only (currently) applies to functions, so perhaps do nothing is reasonable. @@ -6596,7 +6616,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } if (auto comInterfaceAttr = decl->findModifier<ComInterfaceAttribute>()) { - subBuilder->addComInterfaceDecoration(irInterface); + subBuilder->addComInterfaceDecoration(irInterface, comInterfaceAttr->guid.getUnownedSlice()); } if (auto builtinAttr = decl->findModifier<BuiltinAttribute>()) { |
