diff options
| author | Yong He <yonghe@outlook.com> | 2022-06-21 14:55:59 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-21 14:55:59 -0700 |
| commit | e5a75563a1ba2e378353af8b937b8b7bb0fe2c2b (patch) | |
| tree | 0f31040b408a66f49dc5cd2354c8424e5ff2e279 /source | |
| parent | ea3800e115d4ad1ce06ec07689152616f47a0e3d (diff) | |
Lower throwing COM interface method. (#2282)
* Lower throwing COM interface method.
* Fix.
* Fix warnings.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-com-interface.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-dll-import.cpp | 98 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-pass-base.h | 83 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-com-methods.cpp | 138 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-com-methods.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-result-type.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-marshal-native-call.cpp | 149 | ||||
| -rw-r--r-- | source/slang/slang-ir-marshal-native-call.h | 50 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 110 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-simplification.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 67 |
16 files changed, 680 insertions, 94 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index e0b1a44f2..10ac2ca84 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1810,6 +1810,22 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO m_writer->emit("GroupMemoryBarrierWithGroupSync()"); break; + case kIROp_getNativeStr: + { + auto prec = getInfo(EmitOp::Postfix); + needClose = maybeEmitParens(outerPrec, prec); + emitOperand(inst->getOperand(0), leftSide(outerPrec, prec)); + m_writer->emit("->getBuffer()"); + break; + } + case kIROp_makeString: + { + m_writer->emit("String("); + emitOperand(inst->getOperand(0), EmitOpInfo()); + m_writer->emit(")"); + break; + } + case kIROp_getElement: case kIROp_getElementPtr: case kIROp_ImageSubscript: diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 0bb1d73e5..fb0f65c5f 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -206,12 +206,6 @@ Result linkAndOptimizeIR( switch (target) { case CodeGenTarget::CPPSource: - { - // TODO(JS): - // We want the interface transformation to take place for 'regular' CPPSource for now too. - lowerComInterfaces(irModule, artifactDesc.style, sink); - break; - } case CodeGenTarget::HostCPPSource: { lowerComInterfaces(irModule, artifactDesc.style, sink); diff --git a/source/slang/slang-ir-com-interface.cpp b/source/slang/slang-ir-com-interface.cpp index 899596209..d9d54d9d2 100644 --- a/source/slang/slang-ir-com-interface.cpp +++ b/source/slang/slang-ir-com-interface.cpp @@ -3,6 +3,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-lower-com-methods.h" namespace Slang { @@ -38,7 +39,8 @@ static bool _canReplace(IRUse* use) void lowerComInterfaces(IRModule* module, ArtifactStyle artifactStyle, DiagnosticSink* sink) { - SLANG_UNUSED(sink); + // First, lower all COM methods and their call sites out of `Result` and other managed types. + lowerComMethods(module, sink); // Find all of the COM interfaces List<IRInterfaceType*> comInterfaces; diff --git a/source/slang/slang-ir-dll-import.cpp b/source/slang/slang-ir-dll-import.cpp index 02743cba4..b123dfb03 100644 --- a/source/slang/slang-ir-dll-import.cpp +++ b/source/slang/slang-ir-dll-import.cpp @@ -3,6 +3,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-marshal-native-call.h" namespace Slang { @@ -83,82 +84,15 @@ struct DllImportContext return stringGetBufferFunc; } - IRType* getNativeType(IRBuilder& builder, IRType* type) - { - switch (type->getOp()) - { - case kIROp_StringType: - return builder.getPtrType(builder.getCharType()); - default: - return type; - } - } - - IRType* getNativeFuncType(IRBuilder& builder, IRFunc* func) - { - List<IRInst*> nativeParamTypes; - auto declaredFuncType = func->getDataType(); - assert(declaredFuncType->getOp() == kIROp_FuncType); - for (UInt i = 0; i < declaredFuncType->getParamCount(); ++i) - { - auto paramType = declaredFuncType->getParamType(i); - nativeParamTypes.add(getNativeType(builder, as<IRType>(paramType))); - } - IRType* returnType = getNativeType(builder, func->getResultType()); - auto funcType = builder.getFuncType( - nativeParamTypes.getCount(), (IRType**)nativeParamTypes.getBuffer(), returnType); - - return funcType; - } - - void marshalImportRefParameter(IRBuilder& builder, IRParam* param, List<IRInst*>& args) - { - SLANG_UNUSED(builder); - - auto innerType = as<IRPtrTypeBase>(param->getDataType())->getValueType(); - switch (innerType->getOp()) - { - case kIROp_StringType: - { - diagnosticSink->diagnose( - param->sourceLoc, - Diagnostics::invalidTypeMarshallingForImportedDLLSymbol, - param->getParent()->getParent()); - } - break; - default: - args.add(param); - break; - } - } - - void marshalImportParameter(IRBuilder& builder, IRParam* param, List<IRInst*>& args) - { - switch (param->getDataType()->getOp()) - { - case kIROp_InOutType: - case kIROp_RefType: - return marshalImportRefParameter(builder, param, args); - case kIROp_StringType: - { - auto getStringBufferFunc = getStringGetBufferFunc(); - args.add(builder.emitCallInst( - builder.getPtrType(builder.getCharType()), getStringBufferFunc, 1, (IRInst**)¶m)); - } - break; - default: - args.add(param); - break; - } - } void processFunc(IRFunc* func, IRDllImportDecoration* dllImportDecoration) { assert(func->getFirstBlock() == nullptr); IRBuilder builder(sharedBuilder); + NativeCallMarshallingContext marshalContext; - auto nativeType = getNativeFuncType(builder, func); + auto nativeType = marshalContext.getNativeFuncType(builder, func->getDataType()); builder.setInsertInto(module->getModuleInst()); auto funcPtr = builder.createGlobalVar(nativeType); builder.setInsertInto(funcPtr); @@ -178,12 +112,6 @@ struct DllImportContext params.add(builder.emitParam((IRType*)paramType)); } - // Marshal parameters to arguments into native func. - List<IRInst*> args; - for (auto param : params) - { - marshalImportParameter(builder, param, args); - } IRInst* cmpArgs[] = {builder.emitLoad(nativeType, funcPtr), builder.getPtrValue(nullptr)}; auto isUninitialized = builder.emitIntrinsicInst(builder.getBoolType(), kIROp_Eql, 2, cmpArgs); @@ -209,17 +137,15 @@ struct DllImportContext builder.emitBranch(afterBlock); builder.setInsertInto(afterBlock); - IRType* nativeReturnType = getNativeType(builder, func->getResultType()); - auto nativeFunc = builder.emitLoad(funcPtr); - auto call = builder.emitCallInst(nativeReturnType, nativeFunc, args); - if (declaredFuncType->getResultType()->getOp() != kIROp_VoidType) - { - builder.emitReturn(call); - } - else - { - builder.emitReturn(); - } + marshalContext.diagnosticSink = diagnosticSink; + auto callResult = marshalContext.marshalNativeCall( + builder, + declaredFuncType, + nativeType, + builder.emitLoad(funcPtr), + params.getCount(), + (IRInst**)params.getBuffer()); + builder.emitReturn(callResult); } void processModule() diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 4d927bdaf..c10ae8639 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -308,6 +308,12 @@ INST(getElement, getElement, 2, 0) INST(getElementPtr, getElementPtr, 2, 0) INST(getAddr, getAddr, 1, 0) +// Get an unowned NativeString from a String. +INST(getNativeStr, getNativeStr, 1, 0) + +// Make String from a NativeString. +INST(makeString, makeString, 1, 0) + // "Subscript" an image at a pixel coordinate to get pointer INST(ImageSubscript, imageSubscript, 2, 0) diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h new file mode 100644 index 000000000..ec4506272 --- /dev/null +++ b/source/slang/slang-ir-inst-pass-base.h @@ -0,0 +1,83 @@ +// slang-ir-inst-pass-base.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + struct IRModule; + + class InstPassBase + { + protected: + IRModule* module; + SharedIRBuilder sharedBuilderStorage; + + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + void addToWorkList(IRInst* inst) + { + if (workListSet.Contains(inst)) + return; + + workList.add(inst); + workListSet.Add(inst); + } + + public: + InstPassBase(IRModule* inModule) + : module(inModule) + {} + + template <typename InstType, typename Func> + void processInstsOfType(IROp instOp, const Func& f) + { + workList.clear(); + workListSet.Clear(); + + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + if (inst->getOp() == instOp) + { + f(as<InstType>(inst)); + } + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + } + + template <typename Func> + void processAllInsts(const Func& f) + { + workList.clear(); + workListSet.Clear(); + + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + f(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + } + }; + +} diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 80438504c..081c67d03 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2401,6 +2401,10 @@ public: return emitMakeTuple(SLANG_COUNT_OF(args), args); } + IRInst* emitMakeString(IRInst* nativeStr); + + IRInst* emitGetNativeString(IRInst* str); + IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element); IRInst* emitMakeResultError(IRType* resultType, IRInst* errorVal); @@ -2687,6 +2691,8 @@ public: IRInst* emitBranch( IRBlock* block); + IRInst* emitBranch(IRBlock* block, Int argCount, IRInst*const* args); + IRInst* emitBreak( IRBlock* target); @@ -2714,6 +2720,16 @@ public: IRBlock* falseBlock, IRBlock* afterBlock); + // Create basic blocks and insert an `IfElse` inst at current position that jumps into the blocks. + // The current insert position is changed to inside `outTrueBlock` after the call. + IRInst* emitIfElseWithBlocks( + IRInst* val, IRBlock*& outTrueBlock, IRBlock*& outFalseBlock, IRBlock*& outAfterBlock); + + // Create basic blocks and insert an `If` inst at current position that jumps into the blocks. + // The current insert position is changed to inside `outTrueBlock` after the call. + IRInst* emitIfWithBlocks( + IRInst* val, IRBlock*& outTrueBlock, IRBlock*& outAfterBlock); + IRInst* emitLoopTest( IRInst* val, IRBlock* bodyBlock, @@ -2778,6 +2794,8 @@ public: IRInst* emitMul(IRType* type, IRInst* left, IRInst* right); IRInst* emitEql(IRInst* left, IRInst* right); IRInst* emitNeq(IRInst* left, IRInst* right); + IRInst* emitLess(IRInst* left, IRInst* right); + IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1); IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1); diff --git a/source/slang/slang-ir-lower-com-methods.cpp b/source/slang/slang-ir-lower-com-methods.cpp new file mode 100644 index 000000000..6c3a3f289 --- /dev/null +++ b/source/slang/slang-ir-lower-com-methods.cpp @@ -0,0 +1,138 @@ +// slang-ir-lower-com-methods.cpp + +#include "slang-ir-lower-com-methods.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-marshal-native-call.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ + +struct ComMethodLoweringContext : public InstPassBase +{ + DiagnosticSink* diagnosticSink = nullptr; + + NativeCallMarshallingContext marshal; + + OrderedHashSet<IRLookupWitnessMethod*> comCallees; + + ComMethodLoweringContext(IRModule* inModule) + : InstPassBase(inModule) + {} + + void processComCall(IRCall* comCall) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(comCall); + auto callee = as<IRLookupWitnessMethod>(comCall->getCallee()); + SLANG_ASSERT(callee); + + comCallees.Add(callee); + + auto calleeType = as<IRFuncType>(comCall->getCallee()->getDataType()); + SLANG_ASSERT(calleeType); + + auto nativeFuncType = marshal.getNativeFuncType(builder, calleeType); + ShortList<IRInst*> args; + for (UInt i = 0; i < comCall->getArgCount(); i++) + args.add(comCall->getArg(i)); + auto currentBlock = builder.getBlock(); + auto nextInst = comCall->getNextInst(); + auto newResult = marshal.marshalNativeCall( + builder, + calleeType, + nativeFuncType, + comCall->getCallee(), + args.getCount(), + args.getArrayView().getBuffer()); + + comCall->replaceUsesWith(newResult); + if (builder.getBlock() != currentBlock) + { + // `marshalNativeCall` may have replaced the original call with branch insts. + // If this is the case, we need to move all insts after the original call in the original + // basic block to the new basic block. + while (nextInst) + { + auto next = nextInst->getNextInst(); + nextInst->removeFromParent(); + nextInst->insertAtEnd(builder.getBlock()); + nextInst = next; + } + } + comCall->removeAndDeallocate(); + } + + void processCall(IRCall* inst) + { + auto funcValue = inst->getOperand(0); + + // Detect if this is a call into a COM interface method. + if (funcValue->getOp() == kIROp_lookup_interface_method) + { + const auto operand0TypeOp = funcValue->getOperand(0)->getDataType(); + if (auto tableType = as<IRWitnessTableTypeBase>(operand0TypeOp)) + { + if (tableType->getConformanceType()->findDecoration<IRComInterfaceDecoration>()) + { + processComCall(inst); + return; + } + } + } + } + + void processInterfaceType(IRInterfaceType* interfaceType) + { + if (!interfaceType->findDecoration<IRComInterfaceDecoration>()) + return; + IRBuilder builder(&sharedBuilderStorage); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); + if (!entry) + continue; + if (auto funcType = as<IRFuncType>(entry->getRequirementVal())) + { + builder.setInsertBefore(funcType); + entry->setRequirementVal(marshal.getNativeFuncType(builder, funcType)); + } + } + } + + void processModule() + { + sharedBuilderStorage.init(module); + + // Deduplicate equivalent types. + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + + // Translate all Calls to interface methods. + processInstsOfType<IRCall>(kIROp_Call, [this](IRCall* inst) { processCall(inst); }); + + // Update functypes of com callees. + for (auto callee : comCallees) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(callee); + auto nativeType = marshal.getNativeFuncType(builder, as<IRFuncType>(callee->getDataType())); + callee->setFullType(nativeType); + } + + // Update func types of COM interfaces. + processInstsOfType<IRInterfaceType>(kIROp_InterfaceType, [this](IRInterfaceType* inst) { processInterfaceType(inst); }); + + } +}; + +void lowerComMethods(IRModule* module, DiagnosticSink* sink) +{ + ComMethodLoweringContext context(module); + context.diagnosticSink = sink; + context.marshal.diagnosticSink = sink; + + return context.processModule(); +} +} diff --git a/source/slang/slang-ir-lower-com-methods.h b/source/slang/slang-ir-lower-com-methods.h new file mode 100644 index 000000000..145d12733 --- /dev/null +++ b/source/slang/slang-ir-lower-com-methods.h @@ -0,0 +1,14 @@ +// slang-ir-lower-com-methods.h +#pragma once + +namespace Slang +{ + +struct IRModule; +class DiagnosticSink; + +/// Lower the signature of COM interface methods out of types that +/// cannot appear in a COM interface. For example, String, List, ComPtr, Result all need to be translated. +void lowerComMethods(IRModule* module, DiagnosticSink* sink); + +} diff --git a/source/slang/slang-ir-lower-result-type.cpp b/source/slang/slang-ir-lower-result-type.cpp index 7e632241b..21e046849 100644 --- a/source/slang/slang-ir-lower-result-type.cpp +++ b/source/slang/slang-ir-lower-result-type.cpp @@ -124,7 +124,7 @@ namespace Slang builder->setInsertBefore(inst); auto info = getLoweredResultType(builder, inst->getDataType()); - if (info->loweredType->getOp() == kIROp_VoidType) + if (info->loweredType->getOp() == kIROp_StructType) { List<IRInst*> operands; operands.add(inst->getOperand(0)); diff --git a/source/slang/slang-ir-marshal-native-call.cpp b/source/slang/slang-ir-marshal-native-call.cpp new file mode 100644 index 000000000..b8e2edb2f --- /dev/null +++ b/source/slang/slang-ir-marshal-native-call.cpp @@ -0,0 +1,149 @@ +// slang-ir-marshal-native-call.h +#include "slang-ir-marshal-native-call.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + + IRType* NativeCallMarshallingContext::getNativeType(IRBuilder& builder, IRType* type) + { + switch (type->getOp()) + { + case kIROp_StringType: + return builder.getNativeStringType(); + default: + return type; + } + } + + IRFuncType* NativeCallMarshallingContext::getNativeFuncType( + IRBuilder& builder, IRFuncType* declaredFuncType) + { + List<IRInst*> nativeParamTypes; + assert(declaredFuncType->getOp() == kIROp_FuncType); + for (UInt i = 0; i < declaredFuncType->getParamCount(); ++i) + { + auto paramType = declaredFuncType->getParamType(i); + nativeParamTypes.add(getNativeType(builder, as<IRType>(paramType))); + } + IRType* returnType = declaredFuncType->getResultType(); + if (auto resultType = as<IRResultType>(declaredFuncType->getResultType())) + { + nativeParamTypes.add(builder.getPtrType(resultType->getValueType())); + returnType = resultType->getErrorType(); + } + getNativeType(builder, declaredFuncType->getResultType()); + auto funcType = builder.getFuncType( + nativeParamTypes.getCount(), (IRType**)nativeParamTypes.getBuffer(), returnType); + + return funcType; + } + + void NativeCallMarshallingContext::marshalRefManagedValueToNativeValue( + IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args) + { + SLANG_UNUSED(builder); + SLANG_UNUSED(originalArg); + args.add(originalArg); + } + + void NativeCallMarshallingContext::marshalManagedValueToNativeValue( + IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args) + { + switch (originalArg->getDataType()->getOp()) + { + case kIROp_InOutType: + case kIROp_RefType: + return marshalRefManagedValueToNativeValue( + builder, originalArg, args); + case kIROp_StringType: + { + auto nativeStr = builder.emitGetNativeString(originalArg); + args.add(nativeStr); + } + break; + default: + args.add(originalArg); + break; + } + } + + IRInst* NativeCallMarshallingContext::marshalNativeValueToManagedValue( + IRBuilder& builder, IRInst* nativeVal) + { + switch (nativeVal->getDataType()->getOp()) + { + case kIROp_NativeStringType: + { + return builder.emitMakeString(nativeVal); + } + break; + default: + return nativeVal; + break; + } + } + + IRInst* NativeCallMarshallingContext::marshalNativeCall( + IRBuilder& builder, + IRFuncType* originalFuncType, + IRFuncType* nativeFuncType, + IRInst* nativeFunc, + Int argCount, + IRInst*const* originalArgs) + { + // Marshal parameters to arguments into native func. + List<IRInst*> args; + for (Int i = 0; i < argCount; i++) + { + marshalManagedValueToNativeValue(builder, originalArgs[i], args); + } + IRType* originalReturnType = originalFuncType->getResultType(); + + IRVar* resultVar = nullptr; + if (auto resultType = as<IRResultType>(originalReturnType)) + { + // Declare a local variable to receive result. + resultVar = builder.emitVar(getNativeType(builder, resultType->getValueType())); + args.add(resultVar); + } + + // Insert call. + IRInst* call = builder.emitCallInst(nativeFuncType->getResultType(), nativeFunc, args); + + // TODO: marshal output/ref args back to original args. + + IRInst* returnValue = call; + + // Marshal result and out arguments back to managed values. + if (auto resultType = as<IRResultType>(originalReturnType)) + { + auto val = builder.emitLoad(resultVar); + auto err = call; + val = marshalNativeValueToManagedValue(builder, val); + auto intErr = err; + if (err->getDataType()->getOp() != kIROp_IntType) + { + intErr = builder.emitConstructorInst(builder.getIntType(), 1, &err); + } + auto errIsError = builder.emitLess(intErr, builder.getIntValue(builder.getIntType(), 0)); + IRBlock *trueBlock, *falseBlock, *afterBlock; + builder.emitIfElseWithBlocks(errIsError, trueBlock, falseBlock, afterBlock); + builder.setInsertInto(trueBlock); + returnValue = builder.emitMakeResultError(resultType, err); + builder.emitBranch(afterBlock, 1, &returnValue); + builder.setInsertInto(falseBlock); + returnValue = builder.emitMakeResultValue(resultType, val); + builder.emitBranch(afterBlock, 1, &returnValue); + builder.setInsertInto(afterBlock); + returnValue = builder.emitParam(resultType); + } + else + { + returnValue = marshalNativeValueToManagedValue(builder, call); + } + return returnValue; + } + +} // namespace Slang diff --git a/source/slang/slang-ir-marshal-native-call.h b/source/slang/slang-ir-marshal-native-call.h new file mode 100644 index 000000000..bbc2078ea --- /dev/null +++ b/source/slang/slang-ir-marshal-native-call.h @@ -0,0 +1,50 @@ +// slang-ir-marshal-native-call.h +#pragma once + +#include "../core/slang-basic.h" + +namespace Slang +{ + class DiagnosticSink; + struct IRModule; + struct IRBuilder; + struct IRType; + struct IRFunc; + struct IRFuncType; + struct IRCall; + struct IRInst; + + class NativeCallMarshallingContext + { + public: + DiagnosticSink* diagnosticSink = nullptr; + public: + // Get a native type for `type` that can be used directly in a native function signature. + IRType* getNativeType(IRBuilder& builder, IRType* type); + + // Get a native function type of `func`. + IRFuncType* getNativeFuncType(IRBuilder& builder, IRFuncType* declaredFuncType); + + // Insert a call at builder's current position into a native func with original arguments. + // `originalArgs` will be marshalled to native args before the actual call. + // returns the managed result value of the call. + // Note: additional insts maybe inserted after the call inst to marshal the native output values back + // to non-native arguments/return values. + IRInst* marshalNativeCall( + IRBuilder& builder, + IRFuncType* originalFuncType, + IRFuncType* nativeFuncType, + IRInst* nativeFunc, + Int argCount, + IRInst* const* originalArgs); + + void marshalRefManagedValueToNativeValue( + IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args); + + void marshalManagedValueToNativeValue( + IRBuilder& builder, IRInst* originalArg, List<IRInst*>& args); + + IRInst* marshalNativeValueToManagedValue( + IRBuilder& builder, IRInst* nativeValue); + }; +} diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp new file mode 100644 index 000000000..6cb5d2971 --- /dev/null +++ b/source/slang/slang-ir-peephole.cpp @@ -0,0 +1,110 @@ +#include "slang-ir-peephole.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ +struct PeepholeContext : InstPassBase +{ + PeepholeContext(IRModule* inModule) + : InstPassBase(inModule) + {} + + bool changed = false; + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_GetResultError: + if (inst->getOperand(0)->getOp() == kIROp_MakeResultError) + { + inst->replaceUsesWith(inst->getOperand(0)->getOperand(0)); + changed = true; + } + break; + case kIROp_GetResultValue: + if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue) + { + inst->replaceUsesWith(inst->getOperand(0)->getOperand(0)); + inst->removeAndDeallocate(); + changed = true; + } + break; + case kIROp_IsResultError: + if (inst->getOperand(0)->getOp() == kIROp_MakeResultError) + { + IRBuilder builder(&sharedBuilderStorage); + inst->replaceUsesWith(builder.getBoolValue(true)); + inst->removeAndDeallocate(); + changed = true; + } + else if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue) + { + IRBuilder builder(&sharedBuilderStorage); + inst->replaceUsesWith(builder.getBoolValue(false)); + inst->removeAndDeallocate(); + changed = true; + } + break; + case kIROp_GetTupleElement: + if (inst->getOperand(0)->getOp() == kIROp_MakeTuple) + { + auto element = inst->getOperand(1); + if (auto intLit = as<IRIntLit>(element)) + { + inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)intLit->value.intVal)); + inst->removeAndDeallocate(); + changed = true; + } + } + break; + case kIROp_FieldExtract: + if (inst->getOperand(0)->getOp() == kIROp_makeStruct) + { + auto field = as<IRFieldExtract>(inst)->field.get(); + Index fieldIndex = -1; + auto structType = as<IRStructType>(inst->getOperand(0)->getDataType()); + if (structType) + { + Index i = 0; + for (auto sfield : structType->getFields()) + { + if (sfield == field) + { + fieldIndex = i; + break; + } + i++; + } + if (fieldIndex != -1 && fieldIndex < (Index)inst->getOperand(0)->getOperandCount()) + { + inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)fieldIndex)); + inst->removeAndDeallocate(); + changed = true; + } + } + } + break; + default: + break; + } + } + + bool processModule() + { + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->init(module); + + changed = false; + processAllInsts([this](IRInst* inst) { processInst(inst); }); + return changed; + } +}; + +bool peepholeOptimize(IRModule* module) +{ + PeepholeContext context = PeepholeContext(module); + return context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-peephole.h b/source/slang/slang-ir-peephole.h new file mode 100644 index 000000000..e05c533eb --- /dev/null +++ b/source/slang/slang-ir-peephole.h @@ -0,0 +1,11 @@ +// slang-ir-peephole.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRCall; + + /// Apply peephole optimizations. + bool peepholeOptimize(IRModule* module); +} diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index fcc6dc4ae..22aea8d36 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -5,6 +5,7 @@ #include "slang-ir-sccp.h" #include "slang-ir-dce.h" #include "slang-ir-simplify-cfg.h" +#include "slang-ir-peephole.h" namespace Slang { @@ -21,6 +22,7 @@ namespace Slang { changed = false; changed |= applySparseConditionalConstantPropagation(module); + changed |= peepholeOptimize(module); changed |= simplifyCFG(module); // Note: we disregard the `changed` state from dead code elimination pass since diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 562d0ea1a..42ff5823b 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3224,6 +3224,16 @@ namespace Slang return emitMakeTuple(type, count, args); } + IRInst* IRBuilder::emitMakeString(IRInst* nativeStr) + { + return emitIntrinsicInst(getStringType(), kIROp_makeString, 1, &nativeStr); + } + + IRInst* IRBuilder::emitGetNativeString(IRInst* str) + { + return emitIntrinsicInst(getNativeStringType(), kIROp_getNativeStr, 1, &str); + } + IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, UInt element) { // As a quick simplification/optimization, if the user requests @@ -4022,6 +4032,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitBranch(IRBlock* block, Int argCount, IRInst* const* args) + { + List<IRInst*> argList; + argList.add(block); + for (Int i = 0; i < argCount; ++i) + argList.add(args[i]); + auto inst = + createInst<IRUnconditionalBranch>(this, kIROp_unconditionalBranch, nullptr, argList.getCount(), argList.getBuffer()); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBreak( IRBlock* target) { @@ -4089,6 +4111,25 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitIfElseWithBlocks( + IRInst* val, IRBlock*& outTrueBlock, IRBlock*& outFalseBlock, IRBlock*& outAfterBlock) + { + outTrueBlock = createBlock(); + outAfterBlock = createBlock(); + outFalseBlock = createBlock(); + auto f = getFunc(); + SLANG_ASSERT(f); + if (f) + { + f->addBlock(outTrueBlock); + f->addBlock(outAfterBlock); + f->addBlock(outFalseBlock); + } + auto result = emitIfElse(val, outTrueBlock, outFalseBlock, outAfterBlock); + setInsertInto(outTrueBlock); + return result; + } + IRInst* IRBuilder::emitIf( IRInst* val, IRBlock* trueBlock, @@ -4097,6 +4138,18 @@ namespace Slang return emitIfElse(val, trueBlock, afterBlock, afterBlock); } + IRInst* IRBuilder::emitIfWithBlocks( + IRInst* val, IRBlock*& outTrueBlock, IRBlock*& outAfterBlock) + { + outTrueBlock = createBlock(); + outAfterBlock = createBlock(); + auto result = emitIf(val, outTrueBlock, outAfterBlock); + insertBlock(outTrueBlock); + insertBlock(outAfterBlock); + setInsertInto(outTrueBlock); + return result; + } + IRInst* IRBuilder::emitLoopTest( IRInst* val, IRBlock* bodyBlock, @@ -4320,6 +4373,13 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitLess(IRInst* left, IRInst* right) + { + auto inst = createInst<IRInst>(this, kIROp_Less, getBoolType(), left, right); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitMul(IRType* type, IRInst* left, IRInst* right) { auto inst = createInst<IRInst>( @@ -5886,6 +5946,13 @@ namespace Slang case kIROp_MakeMatrix: case kIROp_makeArray: case kIROp_makeStruct: + case kIROp_makeString: + case kIROp_getNativeStr: + case kIROp_MakeResultError: + case kIROp_MakeResultValue: + case kIROp_GetResultError: + case kIROp_GetResultValue: + case kIROp_IsResultError: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads case kIROp_ImageSubscript: case kIROp_FieldExtract: |
