summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-cpp.cpp167
-rw-r--r--source/slang/slang-emit-cpp.h11
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp4
-rw-r--r--source/slang/slang-ir-lower-generics.cpp2
-rw-r--r--source/slang/slang-ir-witness-table-wrapper.cpp211
-rw-r--r--source/slang/slang-ir-witness-table-wrapper.h23
-rw-r--r--source/slang/slang.vcxproj2
-rw-r--r--source/slang/slang.vcxproj.filters6
8 files changed, 274 insertions, 152 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index b71feafc1..c949075fb 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1467,24 +1467,6 @@ UnownedStringSlice CPPSourceEmitter::_getFuncName(const HLSLIntrinsic* specOp)
return m_slicePool.getSlice(handle);
}
-UnownedStringSlice CPPSourceEmitter::_getWitnessTableWrapperFuncName(IRFunc* func)
-{
- StringSlicePool::Handle handle = StringSlicePool::kNullHandle;
- if (m_witnessTableWrapperFuncNameMap.TryGetValue(func, handle))
- {
- return m_slicePool.getSlice(handle);
- }
-
- StringBuilder builder;
- builder << getName(func) << "_wtwrapper";
-
- handle = m_slicePool.add(builder);
- m_witnessTableWrapperFuncNameMap.Add(func, handle);
-
- SLANG_ASSERT(handle != StringSlicePool::kNullHandle);
- return m_slicePool.getSlice(handle);
-}
-
SlangResult CPPSourceEmitter::calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& outBuilder)
{
typedef HLSLIntrinsic::Op Op;
@@ -1629,122 +1611,6 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
pendingWitnessTableDefinitions.add(witnessTable);
}
-void CPPSourceEmitter::_emitWitnessTableWrappers()
-{
- for (auto witnessTable : pendingWitnessTableDefinitions)
- {
- auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
- for (auto child : witnessTable->getChildren())
- {
- if (auto entry = as<IRWitnessTableEntry>(child))
- {
- if (auto funcVal = as<IRFunc>(entry->getSatisfyingVal()))
- {
- IRInst* requirementVal = nullptr;
- for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
- {
- if (auto reqEntry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)))
- {
- if (reqEntry->getRequirementKey() == entry->getRequirementKey())
- {
- requirementVal = reqEntry->getRequirementVal();
- break;
- }
- }
- }
- SLANG_ASSERT(requirementVal != nullptr);
- IRFuncType* requirementFuncType = cast<IRFuncType>(requirementVal);
- emitType(funcVal->getResultType());
- m_writer->emit(" ");
- m_writer->emit(_getWitnessTableWrapperFuncName(funcVal));
- m_writer->emit("(");
- // Emit parameter list.
- {
- bool isFirst = true;
- SLANG_ASSERT(funcVal->getParamCount() == requirementFuncType->getParamCount());
- auto pp = funcVal->getParams().begin();
- for (UInt i = 0; i < requirementFuncType->getParamCount(); ++i, ++pp)
- {
- auto paramType = requirementFuncType->getParamType(i);
-
- if (as<IRTypeType>(paramType))
- continue;
-
- if (isFirst)
- isFirst = false;
- else
- m_writer->emit(",");
- emitParamType(paramType, getName(*pp));
- }
- }
- m_writer->emit(")\n{\n");
- m_writer->indent();
- m_writer->emit("return ");
- m_writer->emit(getName(funcVal));
- m_writer->emit("(");
- // Emit argument list.
- {
- bool isFirst = true;
- UInt paramIndex = 0;
- for (auto defParamIter = funcVal->getParams().begin();
- defParamIter!=funcVal->getParams().end();
- ++defParamIter, ++paramIndex)
- {
- auto param = *defParamIter;
- auto reqParamType = requirementFuncType->getParamType(paramIndex);
- if (as<IRTypeType>(param->getFullType()))
- continue;
-
- if (isFirst)
- isFirst = false;
- else
- m_writer->emit(", ");
-
- // If the implementation expects a concrete type
- // (either in the form of a pointer for `out`/`inout` parameters,
- // or in the form a a value for `in` parameters, while
- // the interface exposes a raw pointer type (void*),
- // we need to cast the raw pointer type to the appropriate
- // concerete type. (void*->Concrete* / void*->Concrete&).
- if (reqParamType->op == kIROp_RawPointerType &&
- param->getDataType()->op != kIROp_RawPointerType)
- {
- if (as<IRPtrTypeBase>(param->getFullType()))
- {
- // The implementation function expects a pointer to the
- // concrete type. This is the case for inout/out parameters.
- m_writer->emit("static_cast<");
- emitType(param->getFullType());
- m_writer->emit(">(");
- m_writer->emit(getName(param));
- m_writer->emit(")");
- }
- else
- {
- // The implementation function expects just a value of the
- // concrete type. We need to insert a dereference in this case.
- m_writer->emit("*static_cast<");
- emitType(param->getFullType());
- m_writer->emit("*>(");
- m_writer->emit(getName(param));
- m_writer->emit(")");
- }
- }
- else
- {
- m_writer->emit(getName(param));
- }
- }
- }
- m_writer->emit(");\n");
- m_writer->dedent();
- m_writer->emit("}\n");
- }
- }
- }
- }
-}
-
void CPPSourceEmitter::_emitWitnessTableDefinitions()
{
for (auto witnessTable : pendingWitnessTableDefinitions)
@@ -1767,7 +1633,7 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
else
isFirstEntry = false;
- m_writer->emit(_getWitnessTableWrapperFuncName(funcVal));
+ m_writer->emit(getName(funcVal));
}
else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal()))
{
@@ -1778,9 +1644,18 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
m_writer->emit("&");
m_writer->emit(getName(witnessTableVal));
}
+ else if (entry->getSatisfyingVal() &&
+ isPointerOfType(entry->getSatisfyingVal()->getDataType(), kIROp_RTTIType))
+ {
+ if (!isFirstEntry)
+ m_writer->emit(",\n");
+ else
+ isFirstEntry = false;
+ emitInstExpr(entry->getSatisfyingVal(), getInfo(EmitOp::General));
+ }
else
{
- // TODO: handle other witness table entry types.
+ SLANG_UNEXPECTED("unknown witnesstable entry type");
}
}
m_writer->dedent();
@@ -1857,6 +1732,12 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
m_writer->emit(getName(entry->getRequirementKey()));
m_writer->emit(";\n");
}
+ else if (isPointerOfType(entry->getRequirementVal(), kIROp_RTTIType))
+ {
+ m_writer->emit("TypeInfo* ");
+ m_writer->emit(getName(entry->getRequirementKey()));
+ m_writer->emit(";\n");
+ }
}
m_writer->dedent();
m_writer->emit("};\n");
@@ -2336,6 +2217,15 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
m_writer->emit("->typeSize)");
return true;
}
+ case kIROp_BitCast:
+ {
+ m_writer->emit("((");
+ emitType(inst->getDataType());
+ m_writer->emit(")(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("))");
+ return true;
+ }
}
}
@@ -2668,11 +2558,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module)
emitGlobalInst(action.inst);
}
}
-
- // Emit wrapper functions for each witness table entry.
- // These wrapper functions takes an abstract type parameter (void*)
- // in the place of `this` parameter.
- _emitWitnessTableWrappers();
}
// Emit all witness table definitions.
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index 29d6e215e..e12493b5a 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -106,10 +106,6 @@ protected:
UnownedStringSlice _getFuncName(const HLSLIntrinsic* specOp);
- // Returns a StringSlice representing the mangled name of a witness table
- // wrapper function.
- UnownedStringSlice _getWitnessTableWrapperFuncName(IRFunc* func);
-
UnownedStringSlice _getTypeName(IRType* type);
SlangResult _calcCPPTextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName);
@@ -127,19 +123,12 @@ protected:
// of all the witness table objects in `pendingWitnessTableDefinitions`.
void _emitWitnessTableDefinitions();
- // Emit wrapper functions that are referenced in witness tables.
- // Wrapper functions wraps the actual member function, and takes a `void*`
- // as the `this` parameter instead of the actual object type, so that
- // their signature is agnostic to the object type.
- void _emitWitnessTableWrappers();
-
HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount);
static bool _isVariable(IROp op);
Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap;
Dictionary<const HLSLIntrinsic*, StringSlicePool::Handle> m_intrinsicNameMap;
- Dictionary<IRFunc*, StringSlicePool::Handle> m_witnessTableWrapperFuncNameMap;
IRTypeSet m_typeSet;
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index 1e725cfae..e930c6cc8 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -212,6 +212,10 @@ namespace Slang
{
entry->setRequirementVal(lowerGenericFuncType(&builder, genericFuncType));
}
+ else if (entry->getRequirementVal()->op == kIROp_AssociatedType)
+ {
+ entry->setRequirementVal(builder.getPtrType(builder.getRTTIType()));
+ }
}
}
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 7876cc7d8..61fa8ad17 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -5,6 +5,7 @@
#include "slang-ir-lower-generic-function.h"
#include "slang-ir-lower-generic-call.h"
#include "slang-ir-lower-generic-var.h"
+#include "slang-ir-witness-table-wrapper.h"
namespace Slang
{
@@ -16,5 +17,6 @@ namespace Slang
lowerGenericFunctions(&sharedContext);
lowerGenericCalls(&sharedContext);
lowerGenericVar(&sharedContext);
+ generateWitnessTableWrapperFunctions(&sharedContext);
}
} // namespace Slang
diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp
new file mode 100644
index 000000000..8a30ed148
--- /dev/null
+++ b/source/slang/slang-ir-witness-table-wrapper.cpp
@@ -0,0 +1,211 @@
+// slang-ir-witness-table-wrapper.cpp
+#include "slang-ir-witness-table-wrapper.h"
+
+#include "slang-ir-generics-lowering-context.h"
+#include "slang-ir.h"
+#include "slang-ir-clone.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+ struct GenericsLoweringContext;
+
+ struct GenerateWitnessTableWrapperContext
+ {
+ SharedGenericsLoweringContext* sharedContext;
+
+ IRStringLit* _getWitnessTableWrapperFuncName(IRFunc* func)
+ {
+ IRBuilder builderStorage;
+ auto builder = &builderStorage;
+ builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder->setInsertBefore(func);
+ if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ {
+ return builder->getStringValue((String(linkageDecoration->getMangledName()) + "_wtwrapper").getUnownedSlice());
+ }
+ if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ {
+ return builder->getStringValue((String(namehintDecoration->getName()) + "_wtwrapper").getUnownedSlice());
+ }
+ return nullptr;
+ }
+
+ IRFunc* emitWitnessTableWrapper(IRFunc* func, IRInst* interfaceRequirementVal)
+ {
+ auto funcTypeInInterface = cast<IRFuncType>(interfaceRequirementVal);
+
+ IRBuilder builderStorage;
+ auto builder = &builderStorage;
+ builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder->setInsertBefore(func);
+
+ auto wrapperFunc = builder->createFunc();
+ wrapperFunc->setFullType((IRType*)interfaceRequirementVal);
+ if (auto name = _getWitnessTableWrapperFuncName(func))
+ builder->addNameHintDecoration(wrapperFunc, name);
+
+ builder->setInsertInto(wrapperFunc);
+ auto block = builder->emitBlock();
+ builder->setInsertInto(block);
+
+ ShortList<IRParam*> params;
+ for (UInt i = 0; i < funcTypeInInterface->getParamCount(); i++)
+ {
+ params.add(builder->emitParam(funcTypeInInterface->getParamType(i)));
+ }
+
+ List<IRInst*> args;
+ bool callerAllocatesReturnVal = funcTypeInInterface->getResultType()->op == kIROp_VoidType
+ && func->getResultType()->op != kIROp_VoidType;
+ IRVar* retVar = nullptr;
+ if (callerAllocatesReturnVal)
+ {
+ // If return value is allocated by caller, we need to write the result
+ // of the call into a local variable, and copy from that local variable
+ // to the address passed in by the caller.
+ retVar = builder->emitVar(func->getResultType());
+ SLANG_ASSERT(params.getCount() == (Index)(func->getParamCount() + 1));
+ }
+ else
+ {
+ SLANG_ASSERT(params.getCount() == (Index)func->getParamCount());
+ }
+ for (UInt i = 0; i < func->getParamCount(); i++)
+ {
+ auto wrapperParam = params[i + (callerAllocatesReturnVal ? 1 : 0)];
+ // Type of the parameter in interface requirement.
+ auto reqParamType = wrapperParam->getDataType();
+ // Type of the parameter in the callee.
+ auto funcParamType = func->getParamType(i);
+
+ // If the implementation expects a concrete type
+ // (either in the form of a pointer for `out`/`inout` parameters,
+ // or in the form a a value for `in` parameters, while
+ // the interface exposes a raw pointer type (void*),
+ // we need to cast the raw pointer type to the appropriate
+ // concerete type. (void*->Concrete* / void*->Concrete&).
+ if (as<IRRawPointerTypeBase>(reqParamType) &&
+ !as<IRRawPointerTypeBase>(funcParamType))
+ {
+ if (as<IRPtrTypeBase>(funcParamType))
+ {
+ // The implementation function expects a pointer to the
+ // concrete type. This is the case for inout/out parameters.
+ auto bitCast = builder->emitBitCast(funcParamType, wrapperParam);
+ args.add(bitCast);
+ }
+ else
+ {
+ // The implementation function expects just a value of the
+ // concrete type. We need to insert a load in this case.
+ auto bitCast = builder->emitBitCast(
+ builder->getPtrType(funcParamType),
+ wrapperParam);
+ auto load = builder->emitLoad(bitCast);
+ args.add(load);
+ }
+ }
+ else
+ {
+ args.add(wrapperParam);
+ }
+ }
+ auto call = builder->emitCallInst(func->getResultType(), func, args);
+ if (retVar)
+ {
+ // If the caller of the wrapper function allocates space,
+ // we need to store the result of the call into a local varaible,
+ // and then copy the local variable into the caller-provided
+ // buffer (params[0]).
+ builder->emitStore(retVar, call);
+ // The result type of the inner function can only be a concrete type
+ // if we reach here. If it is a generic type or generic associated type,
+ // it would have already been lowered out during interface lowering and
+ // lowerGenericFunction.
+ // This means that we can just grab the rtti object from the type directly.
+ auto rttiObject = sharedContext->maybeEmitRTTIObject(func->getResultType());
+ auto rttiPtr = builder->emitGetAddress(
+ builder->getPtrType(builder->getRTTIType()),
+ rttiObject);
+ builder->emitCopy(params[0], retVar, rttiPtr);
+ builder->emitReturn();
+ }
+ else
+ {
+ if (call->getDataType()->op == kIROp_VoidType)
+ builder->emitReturn();
+ else
+ builder->emitReturn(call);
+ }
+ return wrapperFunc;
+ }
+
+ void lowerWitnessTable(IRWitnessTable* witnessTable)
+ {
+ auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
+ for (auto child : witnessTable->getChildren())
+ {
+ auto entry = as<IRWitnessTableEntry>(child);
+ if (!entry)
+ continue;
+ auto interfaceRequirementVal = sharedContext->findInterfaceRequirementVal(interfaceType, entry->getRequirementKey());
+ if (auto ordinaryFunc = as<IRFunc>(entry->getSatisfyingVal()))
+ {
+ auto wrapper = emitWitnessTableWrapper(ordinaryFunc, interfaceRequirementVal);
+ entry->satisfyingVal.set(wrapper);
+ sharedContext->addToWorkList(wrapper);
+ }
+ }
+ }
+
+ void processInst(IRInst* inst)
+ {
+ if (auto witnessTable = as<IRWitnessTable>(inst))
+ {
+ lowerWitnessTable(witnessTable);
+ }
+ }
+
+ void processModule()
+ {
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage;
+ sharedBuilder->module = sharedContext->module;
+ sharedBuilder->session = sharedContext->module->session;
+
+ sharedContext->addToWorkList(sharedContext->module->getModuleInst());
+
+ while (sharedContext->workList.getCount() != 0)
+ {
+ // We will then iterate until our work list goes dry.
+ //
+ while (sharedContext->workList.getCount() != 0)
+ {
+ IRInst* inst = sharedContext->workList.getLast();
+
+ sharedContext->workList.removeLast();
+ sharedContext->workListSet.Remove(inst);
+
+ processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ sharedContext->addToWorkList(child);
+ }
+ }
+ }
+ }
+ };
+
+ void generateWitnessTableWrapperFunctions(SharedGenericsLoweringContext* sharedContext)
+ {
+ GenerateWitnessTableWrapperContext context;
+ context.sharedContext = sharedContext;
+ context.processModule();
+ }
+
+}
diff --git a/source/slang/slang-ir-witness-table-wrapper.h b/source/slang/slang-ir-witness-table-wrapper.h
new file mode 100644
index 000000000..62b8ffa2c
--- /dev/null
+++ b/source/slang/slang-ir-witness-table-wrapper.h
@@ -0,0 +1,23 @@
+// slang-ir-witness-table-wrapper.h
+#pragma once
+
+namespace Slang
+{
+ struct SharedGenericsLoweringContext;
+
+ /// This pass generates wrapper functions for witness table function entries.
+ ///
+ /// Enabled for generation of dynamic dispatch code only.
+ ///
+ /// Functions that are used to satisfy interface requirement have concrete
+ /// type signatures for `this` and `associatedtype` parameters/return values.
+ /// However, when they are called from a witness table, the callee only have a
+ /// raw pointer for this arguments, since the conrete type is not known to the
+ /// callee. Therefore, we need to generate wrappers for each member function
+ /// callable through a witness table, so that the wrapper functions take general void*
+ /// pointer for arguments whose type is unknown at call sites, and convert them
+ /// to concrete types and calls the actual implementation.
+ void generateWitnessTableWrapperFunctions(
+ SharedGenericsLoweringContext* sharedContext);
+
+}
diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj
index 53d4681b3..b92fb5967 100644
--- a/source/slang/slang.vcxproj
+++ b/source/slang/slang.vcxproj
@@ -259,6 +259,7 @@
<ClInclude Include="slang-ir-type-set.h" />
<ClInclude Include="slang-ir-union.h" />
<ClInclude Include="slang-ir-validate.h" />
+ <ClInclude Include="slang-ir-witness-table-wrapper.h" />
<ClInclude Include="slang-ir-wrap-structured-buffers.h" />
<ClInclude Include="slang-ir.h" />
<ClInclude Include="slang-legalize-types.h" />
@@ -357,6 +358,7 @@
<ClCompile Include="slang-ir-type-set.cpp" />
<ClCompile Include="slang-ir-union.cpp" />
<ClCompile Include="slang-ir-validate.cpp" />
+ <ClCompile Include="slang-ir-witness-table-wrapper.cpp" />
<ClCompile Include="slang-ir-wrap-structured-buffers.cpp" />
<ClCompile Include="slang-ir.cpp" />
<ClCompile Include="slang-legalize-types.cpp" />
diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters
index 5ae8d77ff..855b3fb7b 100644
--- a/source/slang/slang.vcxproj.filters
+++ b/source/slang/slang.vcxproj.filters
@@ -228,6 +228,9 @@
<ClInclude Include="slang-ir-validate.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="slang-ir-witness-table-wrapper.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="slang-ir-wrap-structured-buffers.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -518,6 +521,9 @@
<ClCompile Include="slang-ir-validate.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="slang-ir-witness-table-wrapper.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="slang-ir-wrap-structured-buffers.cpp">
<Filter>Source Files</Filter>
</ClCompile>