diff options
| author | Yong He <yonghe@outlook.com> | 2020-07-15 12:48:56 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-07-15 12:48:56 -0700 |
| commit | 5758d16612eda0f902d7d4c02535afe44dec2ac2 (patch) | |
| tree | 51e0ccbb46a68a0d6686ec2c40588efa9895e0cf /source | |
| parent | e9d5ecbf19147af6e1473020b64ced4286b79079 (diff) | |
IR pass to generate witness table wrappers. (#1443)
* Refactor lower-generics pass into separate subpasses.
* IR pass to generate witness table wrappers.
* Re-generate vs project files.
* Fix x86 build error.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 167 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-function.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-witness-table-wrapper.cpp | 211 | ||||
| -rw-r--r-- | source/slang/slang-ir-witness-table-wrapper.h | 23 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 6 |
8 files changed, 274 insertions, 152 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index b71feafc1..c949075fb 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1467,24 +1467,6 @@ UnownedStringSlice CPPSourceEmitter::_getFuncName(const HLSLIntrinsic* specOp) return m_slicePool.getSlice(handle); } -UnownedStringSlice CPPSourceEmitter::_getWitnessTableWrapperFuncName(IRFunc* func) -{ - StringSlicePool::Handle handle = StringSlicePool::kNullHandle; - if (m_witnessTableWrapperFuncNameMap.TryGetValue(func, handle)) - { - return m_slicePool.getSlice(handle); - } - - StringBuilder builder; - builder << getName(func) << "_wtwrapper"; - - handle = m_slicePool.add(builder); - m_witnessTableWrapperFuncNameMap.Add(func, handle); - - SLANG_ASSERT(handle != StringSlicePool::kNullHandle); - return m_slicePool.getSlice(handle); -} - SlangResult CPPSourceEmitter::calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& outBuilder) { typedef HLSLIntrinsic::Op Op; @@ -1629,122 +1611,6 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable) pendingWitnessTableDefinitions.add(witnessTable); } -void CPPSourceEmitter::_emitWitnessTableWrappers() -{ - for (auto witnessTable : pendingWitnessTableDefinitions) - { - auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0)); - for (auto child : witnessTable->getChildren()) - { - if (auto entry = as<IRWitnessTableEntry>(child)) - { - if (auto funcVal = as<IRFunc>(entry->getSatisfyingVal())) - { - IRInst* requirementVal = nullptr; - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - if (auto reqEntry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i))) - { - if (reqEntry->getRequirementKey() == entry->getRequirementKey()) - { - requirementVal = reqEntry->getRequirementVal(); - break; - } - } - } - SLANG_ASSERT(requirementVal != nullptr); - IRFuncType* requirementFuncType = cast<IRFuncType>(requirementVal); - emitType(funcVal->getResultType()); - m_writer->emit(" "); - m_writer->emit(_getWitnessTableWrapperFuncName(funcVal)); - m_writer->emit("("); - // Emit parameter list. - { - bool isFirst = true; - SLANG_ASSERT(funcVal->getParamCount() == requirementFuncType->getParamCount()); - auto pp = funcVal->getParams().begin(); - for (UInt i = 0; i < requirementFuncType->getParamCount(); ++i, ++pp) - { - auto paramType = requirementFuncType->getParamType(i); - - if (as<IRTypeType>(paramType)) - continue; - - if (isFirst) - isFirst = false; - else - m_writer->emit(","); - emitParamType(paramType, getName(*pp)); - } - } - m_writer->emit(")\n{\n"); - m_writer->indent(); - m_writer->emit("return "); - m_writer->emit(getName(funcVal)); - m_writer->emit("("); - // Emit argument list. - { - bool isFirst = true; - UInt paramIndex = 0; - for (auto defParamIter = funcVal->getParams().begin(); - defParamIter!=funcVal->getParams().end(); - ++defParamIter, ++paramIndex) - { - auto param = *defParamIter; - auto reqParamType = requirementFuncType->getParamType(paramIndex); - if (as<IRTypeType>(param->getFullType())) - continue; - - if (isFirst) - isFirst = false; - else - m_writer->emit(", "); - - // If the implementation expects a concrete type - // (either in the form of a pointer for `out`/`inout` parameters, - // or in the form a a value for `in` parameters, while - // the interface exposes a raw pointer type (void*), - // we need to cast the raw pointer type to the appropriate - // concerete type. (void*->Concrete* / void*->Concrete&). - if (reqParamType->op == kIROp_RawPointerType && - param->getDataType()->op != kIROp_RawPointerType) - { - if (as<IRPtrTypeBase>(param->getFullType())) - { - // The implementation function expects a pointer to the - // concrete type. This is the case for inout/out parameters. - m_writer->emit("static_cast<"); - emitType(param->getFullType()); - m_writer->emit(">("); - m_writer->emit(getName(param)); - m_writer->emit(")"); - } - else - { - // The implementation function expects just a value of the - // concrete type. We need to insert a dereference in this case. - m_writer->emit("*static_cast<"); - emitType(param->getFullType()); - m_writer->emit("*>("); - m_writer->emit(getName(param)); - m_writer->emit(")"); - } - } - else - { - m_writer->emit(getName(param)); - } - } - } - m_writer->emit(");\n"); - m_writer->dedent(); - m_writer->emit("}\n"); - } - } - } - } -} - void CPPSourceEmitter::_emitWitnessTableDefinitions() { for (auto witnessTable : pendingWitnessTableDefinitions) @@ -1767,7 +1633,7 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() else isFirstEntry = false; - m_writer->emit(_getWitnessTableWrapperFuncName(funcVal)); + m_writer->emit(getName(funcVal)); } else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal())) { @@ -1778,9 +1644,18 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() m_writer->emit("&"); m_writer->emit(getName(witnessTableVal)); } + else if (entry->getSatisfyingVal() && + isPointerOfType(entry->getSatisfyingVal()->getDataType(), kIROp_RTTIType)) + { + if (!isFirstEntry) + m_writer->emit(",\n"); + else + isFirstEntry = false; + emitInstExpr(entry->getSatisfyingVal(), getInfo(EmitOp::General)); + } else { - // TODO: handle other witness table entry types. + SLANG_UNEXPECTED("unknown witnesstable entry type"); } } m_writer->dedent(); @@ -1857,6 +1732,12 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( m_writer->emit(getName(entry->getRequirementKey())); m_writer->emit(";\n"); } + else if (isPointerOfType(entry->getRequirementVal(), kIROp_RTTIType)) + { + m_writer->emit("TypeInfo* "); + m_writer->emit(getName(entry->getRequirementKey())); + m_writer->emit(";\n"); + } } m_writer->dedent(); m_writer->emit("};\n"); @@ -2336,6 +2217,15 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut m_writer->emit("->typeSize)"); return true; } + case kIROp_BitCast: + { + m_writer->emit("(("); + emitType(inst->getDataType()); + m_writer->emit(")("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("))"); + return true; + } } } @@ -2668,11 +2558,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) emitGlobalInst(action.inst); } } - - // Emit wrapper functions for each witness table entry. - // These wrapper functions takes an abstract type parameter (void*) - // in the place of `this` parameter. - _emitWitnessTableWrappers(); } // Emit all witness table definitions. diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index 29d6e215e..e12493b5a 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -106,10 +106,6 @@ protected: UnownedStringSlice _getFuncName(const HLSLIntrinsic* specOp); - // Returns a StringSlice representing the mangled name of a witness table - // wrapper function. - UnownedStringSlice _getWitnessTableWrapperFuncName(IRFunc* func); - UnownedStringSlice _getTypeName(IRType* type); SlangResult _calcCPPTextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName); @@ -127,19 +123,12 @@ protected: // of all the witness table objects in `pendingWitnessTableDefinitions`. void _emitWitnessTableDefinitions(); - // Emit wrapper functions that are referenced in witness tables. - // Wrapper functions wraps the actual member function, and takes a `void*` - // as the `this` parameter instead of the actual object type, so that - // their signature is agnostic to the object type. - void _emitWitnessTableWrappers(); - HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount); static bool _isVariable(IROp op); Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap; Dictionary<const HLSLIntrinsic*, StringSlicePool::Handle> m_intrinsicNameMap; - Dictionary<IRFunc*, StringSlicePool::Handle> m_witnessTableWrapperFuncNameMap; IRTypeSet m_typeSet; diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index 1e725cfae..e930c6cc8 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -212,6 +212,10 @@ namespace Slang { entry->setRequirementVal(lowerGenericFuncType(&builder, genericFuncType)); } + else if (entry->getRequirementVal()->op == kIROp_AssociatedType) + { + entry->setRequirementVal(builder.getPtrType(builder.getRTTIType())); + } } } diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 7876cc7d8..61fa8ad17 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -5,6 +5,7 @@ #include "slang-ir-lower-generic-function.h" #include "slang-ir-lower-generic-call.h" #include "slang-ir-lower-generic-var.h" +#include "slang-ir-witness-table-wrapper.h" namespace Slang { @@ -16,5 +17,6 @@ namespace Slang lowerGenericFunctions(&sharedContext); lowerGenericCalls(&sharedContext); lowerGenericVar(&sharedContext); + generateWitnessTableWrapperFunctions(&sharedContext); } } // namespace Slang diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp new file mode 100644 index 000000000..8a30ed148 --- /dev/null +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -0,0 +1,211 @@ +// slang-ir-witness-table-wrapper.cpp +#include "slang-ir-witness-table-wrapper.h" + +#include "slang-ir-generics-lowering-context.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + struct GenericsLoweringContext; + + struct GenerateWitnessTableWrapperContext + { + SharedGenericsLoweringContext* sharedContext; + + IRStringLit* _getWitnessTableWrapperFuncName(IRFunc* func) + { + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(func); + if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) + { + return builder->getStringValue((String(linkageDecoration->getMangledName()) + "_wtwrapper").getUnownedSlice()); + } + if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) + { + return builder->getStringValue((String(namehintDecoration->getName()) + "_wtwrapper").getUnownedSlice()); + } + return nullptr; + } + + IRFunc* emitWitnessTableWrapper(IRFunc* func, IRInst* interfaceRequirementVal) + { + auto funcTypeInInterface = cast<IRFuncType>(interfaceRequirementVal); + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(func); + + auto wrapperFunc = builder->createFunc(); + wrapperFunc->setFullType((IRType*)interfaceRequirementVal); + if (auto name = _getWitnessTableWrapperFuncName(func)) + builder->addNameHintDecoration(wrapperFunc, name); + + builder->setInsertInto(wrapperFunc); + auto block = builder->emitBlock(); + builder->setInsertInto(block); + + ShortList<IRParam*> params; + for (UInt i = 0; i < funcTypeInInterface->getParamCount(); i++) + { + params.add(builder->emitParam(funcTypeInInterface->getParamType(i))); + } + + List<IRInst*> args; + bool callerAllocatesReturnVal = funcTypeInInterface->getResultType()->op == kIROp_VoidType + && func->getResultType()->op != kIROp_VoidType; + IRVar* retVar = nullptr; + if (callerAllocatesReturnVal) + { + // If return value is allocated by caller, we need to write the result + // of the call into a local variable, and copy from that local variable + // to the address passed in by the caller. + retVar = builder->emitVar(func->getResultType()); + SLANG_ASSERT(params.getCount() == (Index)(func->getParamCount() + 1)); + } + else + { + SLANG_ASSERT(params.getCount() == (Index)func->getParamCount()); + } + for (UInt i = 0; i < func->getParamCount(); i++) + { + auto wrapperParam = params[i + (callerAllocatesReturnVal ? 1 : 0)]; + // Type of the parameter in interface requirement. + auto reqParamType = wrapperParam->getDataType(); + // Type of the parameter in the callee. + auto funcParamType = func->getParamType(i); + + // If the implementation expects a concrete type + // (either in the form of a pointer for `out`/`inout` parameters, + // or in the form a a value for `in` parameters, while + // the interface exposes a raw pointer type (void*), + // we need to cast the raw pointer type to the appropriate + // concerete type. (void*->Concrete* / void*->Concrete&). + if (as<IRRawPointerTypeBase>(reqParamType) && + !as<IRRawPointerTypeBase>(funcParamType)) + { + if (as<IRPtrTypeBase>(funcParamType)) + { + // The implementation function expects a pointer to the + // concrete type. This is the case for inout/out parameters. + auto bitCast = builder->emitBitCast(funcParamType, wrapperParam); + args.add(bitCast); + } + else + { + // The implementation function expects just a value of the + // concrete type. We need to insert a load in this case. + auto bitCast = builder->emitBitCast( + builder->getPtrType(funcParamType), + wrapperParam); + auto load = builder->emitLoad(bitCast); + args.add(load); + } + } + else + { + args.add(wrapperParam); + } + } + auto call = builder->emitCallInst(func->getResultType(), func, args); + if (retVar) + { + // If the caller of the wrapper function allocates space, + // we need to store the result of the call into a local varaible, + // and then copy the local variable into the caller-provided + // buffer (params[0]). + builder->emitStore(retVar, call); + // The result type of the inner function can only be a concrete type + // if we reach here. If it is a generic type or generic associated type, + // it would have already been lowered out during interface lowering and + // lowerGenericFunction. + // This means that we can just grab the rtti object from the type directly. + auto rttiObject = sharedContext->maybeEmitRTTIObject(func->getResultType()); + auto rttiPtr = builder->emitGetAddress( + builder->getPtrType(builder->getRTTIType()), + rttiObject); + builder->emitCopy(params[0], retVar, rttiPtr); + builder->emitReturn(); + } + else + { + if (call->getDataType()->op == kIROp_VoidType) + builder->emitReturn(); + else + builder->emitReturn(call); + } + return wrapperFunc; + } + + void lowerWitnessTable(IRWitnessTable* witnessTable) + { + auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType()); + for (auto child : witnessTable->getChildren()) + { + auto entry = as<IRWitnessTableEntry>(child); + if (!entry) + continue; + auto interfaceRequirementVal = sharedContext->findInterfaceRequirementVal(interfaceType, entry->getRequirementKey()); + if (auto ordinaryFunc = as<IRFunc>(entry->getSatisfyingVal())) + { + auto wrapper = emitWitnessTableWrapper(ordinaryFunc, interfaceRequirementVal); + entry->satisfyingVal.set(wrapper); + sharedContext->addToWorkList(wrapper); + } + } + } + + void processInst(IRInst* inst) + { + if (auto witnessTable = as<IRWitnessTable>(inst)) + { + lowerWitnessTable(witnessTable); + } + } + + void processModule() + { + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; + sharedBuilder->module = sharedContext->module; + sharedBuilder->session = sharedContext->module->session; + + sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + + while (sharedContext->workList.getCount() != 0) + { + // We will then iterate until our work list goes dry. + // + while (sharedContext->workList.getCount() != 0) + { + IRInst* inst = sharedContext->workList.getLast(); + + sharedContext->workList.removeLast(); + sharedContext->workListSet.Remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + sharedContext->addToWorkList(child); + } + } + } + } + }; + + void generateWitnessTableWrapperFunctions(SharedGenericsLoweringContext* sharedContext) + { + GenerateWitnessTableWrapperContext context; + context.sharedContext = sharedContext; + context.processModule(); + } + +} diff --git a/source/slang/slang-ir-witness-table-wrapper.h b/source/slang/slang-ir-witness-table-wrapper.h new file mode 100644 index 000000000..62b8ffa2c --- /dev/null +++ b/source/slang/slang-ir-witness-table-wrapper.h @@ -0,0 +1,23 @@ +// slang-ir-witness-table-wrapper.h +#pragma once + +namespace Slang +{ + struct SharedGenericsLoweringContext; + + /// This pass generates wrapper functions for witness table function entries. + /// + /// Enabled for generation of dynamic dispatch code only. + /// + /// Functions that are used to satisfy interface requirement have concrete + /// type signatures for `this` and `associatedtype` parameters/return values. + /// However, when they are called from a witness table, the callee only have a + /// raw pointer for this arguments, since the conrete type is not known to the + /// callee. Therefore, we need to generate wrappers for each member function + /// callable through a witness table, so that the wrapper functions take general void* + /// pointer for arguments whose type is unknown at call sites, and convert them + /// to concrete types and calls the actual implementation. + void generateWitnessTableWrapperFunctions( + SharedGenericsLoweringContext* sharedContext); + +} diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 53d4681b3..b92fb5967 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -259,6 +259,7 @@ <ClInclude Include="slang-ir-type-set.h" /> <ClInclude Include="slang-ir-union.h" /> <ClInclude Include="slang-ir-validate.h" /> + <ClInclude Include="slang-ir-witness-table-wrapper.h" /> <ClInclude Include="slang-ir-wrap-structured-buffers.h" /> <ClInclude Include="slang-ir.h" /> <ClInclude Include="slang-legalize-types.h" /> @@ -357,6 +358,7 @@ <ClCompile Include="slang-ir-type-set.cpp" /> <ClCompile Include="slang-ir-union.cpp" /> <ClCompile Include="slang-ir-validate.cpp" /> + <ClCompile Include="slang-ir-witness-table-wrapper.cpp" /> <ClCompile Include="slang-ir-wrap-structured-buffers.cpp" /> <ClCompile Include="slang-ir.cpp" /> <ClCompile Include="slang-legalize-types.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 5ae8d77ff..855b3fb7b 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -228,6 +228,9 @@ <ClInclude Include="slang-ir-validate.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="slang-ir-witness-table-wrapper.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="slang-ir-wrap-structured-buffers.h"> <Filter>Header Files</Filter> </ClInclude> @@ -518,6 +521,9 @@ <ClCompile Include="slang-ir-validate.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="slang-ir-witness-table-wrapper.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="slang-ir-wrap-structured-buffers.cpp"> <Filter>Source Files</Filter> </ClCompile> |
