diff options
Diffstat (limited to 'source/slang/slang-ir-dll-import.cpp')
| -rw-r--r-- | source/slang/slang-ir-dll-import.cpp | 249 |
1 files changed, 249 insertions, 0 deletions
diff --git a/source/slang/slang-ir-dll-import.cpp b/source/slang/slang-ir-dll-import.cpp new file mode 100644 index 000000000..ed0a3e007 --- /dev/null +++ b/source/slang/slang-ir-dll-import.cpp @@ -0,0 +1,249 @@ +// slang-ir-dll-import.cpp +#include "slang-ir-dll-import.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +struct DllImportContext +{ + IRModule* module; + DiagnosticSink* diagnosticSink; + + SharedIRBuilder sharedBuilder; + + IRFunc* loadDllFunc = nullptr; + IRFunc* loadFuncPtrFunc = nullptr; + IRFunc* stringGetBufferFunc = nullptr; + + IRFunc* createBuiltinIntrinsicFunc(UInt paramCount, IRType** paramTypes, IRType* resultType, UnownedStringSlice targetIntrinsic) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(module->getModuleInst()); + IRFunc* result = builder.createFunc(); + builder.setInsertInto(result); + auto funcType = builder.getFuncType(paramCount, paramTypes, resultType); + builder.setDataType(result, funcType); + builder.addTargetIntrinsicDecoration( + result, CapabilitySet(CapabilityAtom::CPP), targetIntrinsic); + return result; + } + + IRFunc* getLoadDllFunc() + { + if (!loadDllFunc) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(module->getModuleInst()); + IRType* stringType = builder.getStringType(); + loadDllFunc = createBuiltinIntrinsicFunc( + 1, + &stringType, + builder.getPtrType(builder.getVoidType()), + UnownedStringSlice("_slang_rt_load_dll($0)")); + } + return loadDllFunc; + } + + IRFunc* getLoadFuncPtrFunc() + { + if (!loadFuncPtrFunc) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(module->getModuleInst()); + + IRType* stringType = builder.getStringType(); + IRType* paramTypes[] = {builder.getPtrType(builder.getVoidType()), stringType}; + loadFuncPtrFunc = createBuiltinIntrinsicFunc( + 2, + paramTypes, + builder.getPtrType(builder.getVoidType()), + UnownedStringSlice("_slang_rt_load_dll_func($0, $1)")); + } + return loadFuncPtrFunc; + } + + IRFunc* getStringGetBufferFunc() + { + if (!stringGetBufferFunc) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(module->getModuleInst()); + + IRType* stringType = builder.getStringType(); + IRType* paramTypes[] = {stringType}; + stringGetBufferFunc = createBuiltinIntrinsicFunc( + 1, + paramTypes, + builder.getPtrType(builder.getCharType()), + UnownedStringSlice("const_cast<char*>($0.getBuffer())")); + } + return stringGetBufferFunc; + } + + IRType* getNativeType(IRBuilder& builder, IRType* type) + { + switch (type->getOp()) + { + case kIROp_StringType: + return builder.getPtrType(builder.getCharType()); + default: + return type; + } + } + + IRType* getNativeFuncType(IRBuilder& builder, IRFunc* func) + { + List<IRInst*> nativeParamTypes; + auto declaredFuncType = func->getDataType(); + assert(declaredFuncType->getOp() == kIROp_FuncType); + for (UInt i = 0; i < declaredFuncType->getParamCount(); ++i) + { + auto paramType = declaredFuncType->getParamType(i); + nativeParamTypes.add(getNativeType(builder, as<IRType>(paramType))); + } + IRType* returnType = getNativeType(builder, func->getResultType()); + auto funcType = builder.getFuncType( + nativeParamTypes.getCount(), (IRType**)nativeParamTypes.getBuffer(), returnType); + + return funcType; + } + + void marshalImportRefParameter(IRBuilder& builder, IRParam* param, List<IRInst*>& args) + { + SLANG_UNUSED(builder); + + auto innerType = as<IRPtrTypeBase>(param->getDataType())->getValueType(); + switch (innerType->getOp()) + { + case kIROp_StringType: + { + diagnosticSink->diagnose( + param->sourceLoc, + Diagnostics::invalidTypeMarshallingForImportedDLLSymbol, + param->getParent()->getParent()); + } + break; + default: + args.add(param); + break; + } + } + + void marshalImportParameter(IRBuilder& builder, IRParam* param, List<IRInst*>& args) + { + switch (param->getDataType()->getOp()) + { + case kIROp_InOutType: + case kIROp_RefType: + return marshalImportRefParameter(builder, param, args); + case kIROp_StringType: + { + auto getStringBufferFunc = getStringGetBufferFunc(); + args.add(builder.emitCallInst( + builder.getPtrType(builder.getCharType()), getStringBufferFunc, 1, (IRInst**)¶m)); + } + break; + default: + args.add(param); + break; + } + } + + void processFunc(IRFunc* func, IRDllImportDecoration* dllImportDecoration) + { + assert(func->getFirstBlock() == nullptr); + + IRBuilder builder(sharedBuilder); + + auto nativeType = getNativeFuncType(builder, func); + builder.setInsertInto(module->getModuleInst()); + auto funcPtr = builder.createGlobalVar(nativeType); + builder.setInsertInto(funcPtr); + builder.emitBlock(); + builder.emitReturn(builder.getPtrValue(nullptr)); + + builder.setInsertInto(func); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + + // Emit parameters. + auto declaredFuncType = func->getDataType(); + List<IRParam*> params; + for (UInt i = 0; i < declaredFuncType->getParamCount(); ++i) + { + auto paramType = declaredFuncType->getParamType(i); + params.add(builder.emitParam((IRType*)paramType)); + } + + // Marshal parameters to arguments into native func. + List<IRInst*> args; + for (auto param : params) + { + marshalImportParameter(builder, param, args); + } + IRInst* cmpArgs[] = {builder.emitLoad(nativeType, funcPtr), builder.getPtrValue(nullptr)}; + auto isUninitialized = + builder.emitIntrinsicInst(builder.getBoolType(), kIROp_Eql, 2, cmpArgs); + + auto trueBlock = builder.emitBlock(); + auto afterBlock = builder.emitBlock(); + + builder.setInsertInto(block); + builder.emitIf(isUninitialized, trueBlock, afterBlock); + + builder.setInsertInto(trueBlock); + auto modulePtr = builder.emitCallInst( + builder.getPtrType(builder.getVoidType()), + getLoadDllFunc(), + builder.getStringValue(dllImportDecoration->getLibraryName())); + + IRInst* loadDllFuncArgs[] = { + modulePtr, builder.getStringValue(dllImportDecoration->getFunctionName())}; + auto loadedNativeFuncPtr = builder.emitCallInst( + builder.getPtrType(builder.getVoidType()), getLoadFuncPtrFunc(), 2, loadDllFuncArgs); + builder.emitStore( + funcPtr, builder.emitBitCast(nativeType, loadedNativeFuncPtr)); + builder.emitBranch(afterBlock); + builder.setInsertInto(afterBlock); + + IRType* nativeReturnType = getNativeType(builder, func->getResultType()); + auto nativeFunc = builder.emitLoad(funcPtr); + auto call = builder.emitCallInst(nativeReturnType, nativeFunc, args); + if (declaredFuncType->getResultType()->getOp() != kIROp_VoidType) + { + builder.emitReturn(call); + } + } + + void processModule() + { + for (auto childFunc : module->getGlobalInsts()) + { + switch(childFunc->getOp()) + { + case kIROp_Func: + if (auto dllImportDecoration = childFunc->findDecoration<IRDllImportDecoration>()) + { + processFunc(as<IRFunc>(childFunc), dllImportDecoration); + } + break; + default: + break; + } + } + } +}; + +void generateDllImportFuncs(IRModule* module, DiagnosticSink* sink) +{ + DllImportContext context; + context.module = module; + context.diagnosticSink = sink; + context.sharedBuilder.init(module); + return context.processModule(); +} + +} |
