// slang-ir-marshal-native-call.h #include "slang-ir-marshal-native-call.h" #include "slang-ir-insts.h" #include "slang-ir.h" namespace Slang { IRType* NativeCallMarshallingContext::getNativeType(IRBuilder& builder, IRType* type) { switch (type->getOp()) { case kIROp_StringType: return builder.getNativeStringType(); case kIROp_InterfaceType: return builder.getNativePtrType(type); case kIROp_ComPtrType: return builder.getNativePtrType((IRType*)as(type)->getOperand(0)); case kIROp_ArrayType: case kIROp_UnsizedArrayType: { auto arrayType = as(type); auto elementType = arrayType->getElementType(); auto nativeElementType = getNativeType(builder, elementType); return builder.getArrayTypeBase( elementType->getOp(), nativeElementType, arrayType->getElementCount()); } case kIROp_BorrowInOutParamType: case kIROp_RefParamType: case kIROp_BorrowInParamType: case kIROp_OutParamType: return builder.getPtrType(getNativeType(builder, (IRType*)type->getOperand(0))); default: return type; } } IRFuncType* NativeCallMarshallingContext::getNativeFuncType( IRBuilder& builder, IRFuncType* declaredFuncType) { List nativeParamTypes; assert(declaredFuncType->getOp() == kIROp_FuncType); for (UInt i = 0; i < declaredFuncType->getParamCount(); ++i) { auto paramType = declaredFuncType->getParamType(i); nativeParamTypes.add(getNativeType(builder, (IRType*)(paramType))); } IRType* returnType = declaredFuncType->getResultType(); if (auto resultType = as(declaredFuncType->getResultType())) { auto nativeResultType = getNativeType(builder, resultType->getValueType()); auto nativeErrorType = getNativeType(builder, resultType->getErrorType()); nativeParamTypes.add(builder.getPtrType(nativeResultType)); nativeParamTypes.add(builder.getPtrType(nativeErrorType)); returnType = builder.getIntType(); } else { returnType = getNativeType(builder, returnType); } auto funcType = builder.getFuncType( nativeParamTypes.getCount(), (IRType**)nativeParamTypes.getBuffer(), returnType); return funcType; } void NativeCallMarshallingContext::marshalRefManagedValueToNativeValue( IRBuilder& builder, IRInst* originalArg, List& args) { auto ptrTypeBase = as(originalArg->getDataType()); SLANG_RELEASE_ASSERT(ptrTypeBase); switch (ptrTypeBase->getValueType()->getOp()) { case kIROp_InterfaceType: case kIROp_ComPtrType: args.add(builder.emitGetManagedPtrWriteRef(originalArg)); break; default: args.add(originalArg); break; } } void NativeCallMarshallingContext::marshalManagedValueToNativeValue( IRBuilder& builder, IRType* originalParamType, IRInst* originalArg, List& args) { switch (originalParamType->getOp()) { case kIROp_BorrowInOutParamType: case kIROp_RefParamType: case kIROp_BorrowInParamType: case kIROp_OutParamType: return marshalRefManagedValueToNativeValue(builder, originalArg, args); case kIROp_StringType: { auto nativeStr = builder.emitGetNativeString(originalArg); args.add(nativeStr); } break; case kIROp_InterfaceType: { auto nativePtr = builder.emitGetNativePtr(originalArg); args.add(nativePtr); } break; default: args.add(originalArg); break; } } IRInst* NativeCallMarshallingContext::marshalNativeValueToManagedValue( IRBuilder& builder, IRInst* nativeVal) { switch (nativeVal->getDataType()->getOp()) { case kIROp_NativeStringType: { return builder.emitMakeString(nativeVal); } break; case kIROp_NativePtrType: { SLANG_RELEASE_ASSERT( nativeVal->getDataType()->getOperand(0)->getOp() == kIROp_InterfaceType); auto comPtrVar = builder.emitVar( builder.getComPtrType((IRType*)nativeVal->getDataType()->getOperand(0))); builder.emitManagedPtrAttach(comPtrVar, nativeVal); return builder.emitLoad(comPtrVar); } break; case kIROp_InterfaceType: { auto comPtrVar = builder.emitVar(nativeVal->getDataType()); builder.emitManagedPtrAttach(comPtrVar, nativeVal); return builder.emitLoad(comPtrVar); } break; default: return nativeVal; break; } } void NativeCallMarshallingContext::marshalManagedValueToNativeResultValue( IRBuilder& builder, IRInst* originalArg, List& args) { switch (originalArg->getDataType()->getOp()) { case kIROp_BorrowInOutParamType: case kIROp_RefParamType: case kIROp_BorrowInParamType: SLANG_UNREACHABLE("out and ref types should be handled before reaching here."); break; case kIROp_StringType: { diagnosticSink->diagnose( originalArg, Diagnostics::unimplemented, "marshal string to native return value"); } break; case kIROp_ClassType: { diagnosticSink->diagnose( originalArg, Diagnostics::unimplemented, "marshal class to native return value"); } break; case kIROp_InterfaceType: { auto nativePtr = builder.emitManagedPtrDetach( builder.getNativePtrType(originalArg->getDataType()), originalArg); args.add(nativePtr); } break; case kIROp_ComPtrType: { auto nativePtr = builder.emitManagedPtrDetach( builder.getNativePtrType( (IRType*)cast(originalArg->getDataType())->getOperand(0)), originalArg); args.add(nativePtr); } break; default: args.add(originalArg); break; } } IRInst* NativeCallMarshallingContext::marshalNativeArgToManagedArg( IRBuilder& builder, const List& args, Index& consumeIndex, IRType* expectedManagedType) { // For now, all managed values maps to one native value, so we just call // `marshalNativeValueToManagedValue`. This function can be extended in the future to support // things like `List` that maps to more than one native args. SLANG_UNUSED(expectedManagedType); auto managedVal = marshalNativeValueToManagedValue(builder, args[consumeIndex]); consumeIndex++; return managedVal; } IRFunc* NativeCallMarshallingContext::generateDLLExportWrapperFunc( IRBuilder& builder, IRFunc* originalFunc) { builder.setInsertBefore(originalFunc); auto funcType = getNativeFuncType(builder, originalFunc->getDataType()); auto newFunc = builder.createFunc(); newFunc->setFullType(funcType); builder.setInsertInto(newFunc); builder.emitBlock(); List params; for (UInt i = 0; i < funcType->getParamCount(); i++) { auto paramType = funcType->getParamType(i); params.add(builder.emitParam(paramType)); } List args; Index nativeParamConsumeIndex = 0; for (UInt i = 0; i < originalFunc->getParamCount(); i++) { auto managedParamType = originalFunc->getParamType(i); auto managedArg = marshalNativeArgToManagedArg( builder, params, nativeParamConsumeIndex, managedParamType); args.add(managedArg); } auto originalReturnType = originalFunc->getResultType(); auto callInst = builder.emitCallInst(originalReturnType, originalFunc, args); if (const auto resultType = as(originalReturnType)) { auto isResultError = builder.emitIsResultError(callInst); IRBlock* trueBlock = nullptr; IRBlock* falseBlock = nullptr; IRBlock* afterBlock = nullptr; builder.emitIfElseWithBlocks(isResultError, trueBlock, falseBlock, afterBlock); builder.setInsertInto(falseBlock); auto resultVal = builder.emitGetResultValue(callInst); List nativeVals; marshalManagedValueToNativeResultValue(builder, resultVal, nativeVals); for (Index i = 0; i < nativeVals.getCount(); i++) { SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount()); builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]); nativeParamConsumeIndex++; } // S_OK builder.emitReturn(builder.getIntValue(builder.getIntType(), 0)); builder.setInsertInto(trueBlock); nativeVals.clear(); auto errorVal = builder.emitGetResultError(callInst); marshalManagedValueToNativeResultValue(builder, errorVal, nativeVals); for (Index i = 0; i < nativeVals.getCount(); i++) { SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount()); builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]); nativeParamConsumeIndex++; } // E_FAIL builder.emitReturn(builder.getIntValue(builder.getIntType(), 0x80004005)); builder.setInsertInto(afterBlock); builder.emitUnreachable(); } else { List nativeVals; marshalManagedValueToNativeResultValue(builder, callInst, nativeVals); for (Index i = 1; i < nativeVals.getCount(); i++) { SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount()); builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]); nativeParamConsumeIndex++; } builder.emitReturn(nativeVals[0]); } return newFunc; } IRInst* NativeCallMarshallingContext::marshalNativeCall( IRBuilder& builder, IRFuncType* originalFuncType, IRFuncType* nativeFuncType, IRInst* nativeFunc, Int argCount, IRInst* const* originalArgs) { // Marshal parameters to arguments into native func. List args; for (Int i = 0; i < argCount; i++) { auto paramType = originalFuncType->getParamType(i); marshalManagedValueToNativeValue(builder, paramType, originalArgs[i], args); } IRType* originalReturnType = originalFuncType->getResultType(); IRVar* resultVar = nullptr; IRVar* errorVar = nullptr; if (auto resultType = as(originalReturnType)) { // Declare a local variable to receive result. resultVar = builder.emitVar(getNativeType(builder, resultType->getValueType())); errorVar = builder.emitVar(getNativeType(builder, resultType->getErrorType())); args.add(resultVar); args.add(resultVar); } // Insert call. IRInst* call = builder.emitCallInst(nativeFuncType->getResultType(), nativeFunc, args); // TODO: marshal output/ref args back to original args. IRInst* returnValue = call; // Marshal result and out arguments back to managed values. if (auto resultType = as(originalReturnType)) { auto val = builder.emitLoad(resultVar); auto err = builder.emitLoad(errorVar); auto tag = call; val = marshalNativeValueToManagedValue(builder, val); err = marshalNativeValueToManagedValue(builder, err); auto intTag = tag; if (tag->getDataType()->getOp() != kIROp_IntType) { intTag = builder.emitCast(builder.getIntType(), tag); } auto tagIsError = builder.emitLess(intTag, builder.getIntValue(builder.getIntType(), 0)); IRBlock *trueBlock, *falseBlock, *afterBlock; builder.emitIfElseWithBlocks(tagIsError, trueBlock, falseBlock, afterBlock); builder.setInsertInto(trueBlock); returnValue = builder.emitMakeResultError(resultType, err); builder.emitBranch(afterBlock, 1, &returnValue); builder.setInsertInto(falseBlock); returnValue = builder.emitMakeResultValue(resultType, val); builder.emitBranch(afterBlock, 1, &returnValue); builder.setInsertInto(afterBlock); returnValue = builder.emitParam(resultType); } else { returnValue = marshalNativeValueToManagedValue(builder, call); } return returnValue; } } // namespace Slang