diff options
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 195 |
1 files changed, 178 insertions, 17 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index c31ef3bc7..accb290fa 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1446,6 +1446,24 @@ UnownedStringSlice CPPSourceEmitter::_getFuncName(const HLSLIntrinsic* specOp) return m_slicePool.getSlice(handle); } +UnownedStringSlice CPPSourceEmitter::_getWitnessTableWrapperFuncName(IRFunc* func) +{ + StringSlicePool::Handle handle = StringSlicePool::kNullHandle; + if (m_witnessTableWrapperFuncNameMap.TryGetValue(func, handle)) + { + return m_slicePool.getSlice(handle); + } + + StringBuilder builder; + builder << getName(func) << "_wtwrapper"; + + handle = m_slicePool.add(builder); + m_witnessTableWrapperFuncNameMap.Add(func, handle); + + SLANG_ASSERT(handle != StringSlicePool::kNullHandle); + return m_slicePool.getSlice(handle); +} + SlangResult CPPSourceEmitter::calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& outBuilder) { typedef HLSLIntrinsic::Op Op; @@ -1591,6 +1609,83 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable) pendingWitnessTableDefinitions.add(witnessTable); } +void CPPSourceEmitter::_emitWitnessTableWrappers() +{ + for (auto witnessTable : pendingWitnessTableDefinitions) + { + for (auto child : witnessTable->getChildren()) + { + if (auto entry = as<IRWitnessTableEntry>(child)) + { + if (auto funcVal = as<IRFunc>(entry->getSatisfyingVal())) + { + emitType(funcVal->getResultType()); + m_writer->emit(" "); + m_writer->emit(_getWitnessTableWrapperFuncName(funcVal)); + m_writer->emit("("); + // Emit parameter list. + { + bool isFirst = true; + for (auto param : funcVal->getParams()) + { + if (as<IRTypeType>(param->getFullType())) + continue; + + if (isFirst) + isFirst = false; + else + m_writer->emit(","); + + if (param->findDecoration<IRThisPointerDecoration>()) + { + m_writer->emit("void* "); + m_writer->emit(getName(param)); + continue; + } + emitSimpleFuncParamImpl(param); + } + } + m_writer->emit(")\n{\n"); + m_writer->indent(); + m_writer->emit("return "); + m_writer->emit(getName(funcVal)); + m_writer->emit("("); + // Emit argument list. + { + bool isFirst = true; + for (auto param : funcVal->getParams()) + { + if (as<IRTypeType>(param->getFullType())) + continue; + + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + + if (param->findDecoration<IRThisPointerDecoration>()) + { + m_writer->emit("*static_cast<"); + emitType(param->getFullType()); + m_writer->emit("*>("); + m_writer->emit(getName(param)); + m_writer->emit(")"); + } + else + { + m_writer->emit(getName(param)); + } + } + } + m_writer->emit(");\n"); + m_writer->dedent(); + m_writer->emit("}\n"); + } + } + } + } +} + void CPPSourceEmitter::_emitWitnessTableDefinitions() { for (auto witnessTable : pendingWitnessTableDefinitions) @@ -1612,8 +1707,9 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() m_writer->emit(",\n"); else isFirstEntry = false; - m_writer->emit("&Context::"); - m_writer->emit(getName(funcVal)); + + m_writer->emit("&KernelContext::"); + m_writer->emit(_getWitnessTableWrapperFuncName(funcVal)); } else { @@ -1660,7 +1756,7 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( else isFirstEntry = false; emitType(funcVal->getResultType()); - m_writer->emit(" (Context::*"); + m_writer->emit(" (KernelContext::*"); m_writer->emit(getName(entry->requirementKey.get())); m_writer->emit(")"); m_writer->emit("("); @@ -1671,7 +1767,13 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( m_writer->emit(", "); else isFirstParam = false; - emitParamType(param->getFullType(), getName(param)); + if (param->findDecoration<IRThisPointerDecoration>()) + { + m_writer->emit("void* "); + m_writer->emit(getName(param)); + continue; + } + emitSimpleFuncParamImpl(param); } m_writer->emit(");\n"); } @@ -1681,7 +1783,7 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( } } m_writer->dedent(); - m_writer->emit("\n};\n"); + m_writer->emit("};\n"); } bool CPPSourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) @@ -1742,7 +1844,7 @@ void CPPSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoint SLANG_UNUSED(entryPointDecor); auto profile = m_effectiveProfile; - auto stage = profile.GetStage(); + auto stage = profile.getStage(); switch (stage) { @@ -1877,6 +1979,31 @@ void CPPSourceEmitter::emitSimpleValueImpl(IRInst* inst) } } +static bool isVoidPtrType(IRType* type) +{ + auto ptrType = as<IRPtrType>(type); + if (!ptrType) return false; + return ptrType->getValueType()->op == kIROp_VoidType; +} + +void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) +{ + // Polymorphic types are already translated to void* type in + // lower-generics pass. However, the current emitting logic will + // emit "void&" instead of "void*" for pointer types. + // In the future, we will handle pointer types more properly, + // and this override logic will not be necessary. + // For now we special-case this scenario. + if (param->findDecoration<IRPolymorphicDecoration>() && + isVoidPtrType(param->getDataType())) + { + m_writer->emit("void* "); + m_writer->emit(getName(param)); + return; + } + CLikeSourceEmitter::emitSimpleFuncParamImpl(param); +} + void CPPSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) { emitSimpleType(m_typeSet.addVectorType(elementType, int(elementCount))); @@ -2102,6 +2229,16 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut m_writer->emit(")"); return true; } + case kIROp_getAddr: + { + // Once we clean up the pointer emitting logic, we can + // just use GetElementAddress instruction in place of + // getAddr instruction, and this case can be removed. + m_writer->emit("(&("); + emitInstExpr(inst->getOperand(0), EmitOpInfo::get(EmitOp::General)); + m_writer->emit("))"); + return true; + } } } @@ -2134,6 +2271,20 @@ void CPPSourceEmitter::emitPreprocessorDirectivesImpl() writer->emit("\n"); + + if (m_target == CodeGenTarget::CPPSource) + { + // Put all into an anonymous namespace + // This includes any generated types, and generated intrinsics + + m_writer->emit("namespace { // anonymous \n\n"); + m_writer->emit("#ifdef SLANG_PRELUDE_NAMESPACE\n"); + m_writer->emit("using namespace SLANG_PRELUDE_NAMESPACE;\n"); + m_writer->emit("#endif\n\n"); + + m_writer->emit("struct KernelContext;\n\n"); + } + if (m_target == CodeGenTarget::CSource) { // For C output we need to emit type definitions. @@ -2169,7 +2320,6 @@ void CPPSourceEmitter::emitPreprocessorDirectivesImpl() { _maybeEmitSpecializedOperationDefinition(intrinsic); } - } } @@ -2293,14 +2443,14 @@ void CPPSourceEmitter::_emitEntryPointDefinitionStart(IRFunc* func, IRGlobalPara m_writer->emit("("); m_writer->emit(varyingTypeName); - m_writer->emit("* varyingInput, UniformEntryPointParams* params, UniformState* uniformState)"); + m_writer->emit("* varyingInput, void* params, void* uniformState)"); emitSemantics(func); m_writer->emit("\n{\n"); m_writer->indent(); // Initialize when constructing so that globals are zeroed - m_writer->emit("Context context = {};\n"); - m_writer->emit("context.uniformState = uniformState;\n"); + m_writer->emit("KernelContext context = {};\n"); + m_writer->emit("context.uniformState = (UniformState*)uniformState;\n"); if (entryPointGlobalParams) { @@ -2590,11 +2740,11 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) List<EmitAction> actions; computeEmitActions(module, actions); - + _emitForwardDeclarations(actions); IRGlobalParam* entryPointGlobalParams = nullptr; - + // Output the global parameters in a 'UniformState' structure { m_writer->emit("struct UniformState\n{\n"); @@ -2605,15 +2755,14 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) m_writer->dedent(); m_writer->emit("\n};\n\n"); } - + // Output the 'Context' which will be used for execution { - m_writer->emit("struct Context\n{\n"); + m_writer->emit("struct KernelContext\n{\n"); m_writer->indent(); m_writer->emit("UniformState* uniformState;\n"); - m_writer->emit("uint3 dispatchThreadID;\n"); //if (m_semanticUsedFlags & SemanticUsedFlag::GroupID) @@ -2658,13 +2807,25 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) } } + // Emit wrapper functions for each witness table entry. + // These wrapper functions takes an abstract type parameter (void*) + // in the place of `this` parameter. + _emitWitnessTableWrappers(); + m_writer->dedent(); - m_writer->emit("};\n\n"); + m_writer->emit("};\n\n"); } // Emit all witness table definitions. _emitWitnessTableDefinitions(); + if (m_target == CodeGenTarget::CPPSource) + { + // Need to close the anonymous namespace when outputting for C++ + + m_writer->emit("} // anonymous\n\n"); + } + // Finally we need to output dll entry points for (auto action : actions) @@ -2675,7 +2836,7 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module) IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>(); - if (entryPointDecor && entryPointDecor->getProfile().GetStage() == Stage::Compute) + if (entryPointDecor && entryPointDecor->getProfile().getStage() == Stage::Compute) { // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-dispatchthreadid // SV_DispatchThreadID is the sum of SV_GroupID * numthreads and GroupThreadID. |
