diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 88 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-wrap-global-context.cpp | 287 | ||||
| -rw-r--r-- | source/slang/slang-ir-wrap-global-context.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 9 |
7 files changed, 398 insertions, 14 deletions
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index 1ce25a2da..17d074e75 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -100,7 +100,7 @@ void MetalSourceEmitter::_emitHLSLTextureType(IRTextureTypeBase* texType) case SLANG_TEXTURE_2D: m_writer->emit("2d"); break; case SLANG_TEXTURE_3D: m_writer->emit("3d"); break; case SLANG_TEXTURE_CUBE: m_writer->emit("cube"); break; - case SLANG_TEXTURE_BUFFER: m_writer->emit("1d"); break; + case SLANG_TEXTURE_BUFFER: m_writer->emit("_buffer"); break; default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled resource shape"); break; @@ -274,7 +274,39 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO return true; } break; - + case kIROp_RWStructuredBufferGetElementPtr: + { + EmitOpInfo outerPrec = inOuterPrec; + bool needClose = false; + + auto prec = getInfo(EmitOp::Add); + needClose = maybeEmitParens(outerPrec, prec); + emitOperand(inst->getOperand(0), leftSide(outerPrec, prec)); + m_writer->emit("+"); + emitOperand(inst->getOperand(1), rightSide(prec, outerPrec)); + maybeCloseParens(needClose); + return true; + } + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + auto prec = getInfo(EmitOp::Postfix); + emitOperand(inst->getOperand(0), leftSide(inOuterPrec, prec)); + m_writer->emit("["); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("]"); + return true; + } + case kIROp_RWStructuredBufferStore: + { + auto prec = getInfo(EmitOp::Postfix); + emitOperand(inst->getOperand(0), leftSide(inOuterPrec, prec)); + m_writer->emit("["); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("] = "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + return true; + } default: break; } // Not handled @@ -479,12 +511,47 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_ParameterBlockType: case kIROp_ConstantBufferType: { - m_writer->emit("constant "); emitType((IRType*)type->getOperand(0)); + m_writer->emit(" constant*"); + return; + } + case kIROp_PtrType: + case kIROp_InOutType: + case kIROp_OutType: + case kIROp_RefType: + case kIROp_ConstRefType: + { + auto ptrType = cast<IRPtrTypeBase>(type); + emitType((IRType*)ptrType->getValueType()); + switch ((AddressSpace)ptrType->getAddressSpace()) + { + case AddressSpace::Global: + m_writer->emit(" device"); + break; + case AddressSpace::Uniform: + m_writer->emit(" constant"); + break; + case AddressSpace::ThreadLocal: + m_writer->emit(" thread"); + break; + case AddressSpace::GroupShared: + m_writer->emit(" threadgroup"); + break; + } m_writer->emit("*"); return; } - default: break; + case kIROp_ArrayType: + { + m_writer->emit("array<"); + emitType((IRType*)type->getOperand(0)); + m_writer->emit(", "); + emitVal(type->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(">"); + return; + } + default: + break; } if (auto texType = as<IRTextureType>(type)) @@ -499,9 +566,8 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) } else if (auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type)) { - m_writer->emit("device "); emitType(structuredBufferType->getElementType()); - m_writer->emit("*"); + m_writer->emit(" device*"); return; } else if (const auto untypedBufferType = as<IRUntypedBufferResourceType>(type)) @@ -511,10 +577,11 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_HLSLByteAddressBufferType: case kIROp_HLSLRWByteAddressBufferType: case kIROp_HLSLRasterizerOrderedByteAddressBufferType: - m_writer->emit("device "); - m_writer->emit("uint32_t *"); + m_writer->emit("uint32_t device*"); + break; + case kIROp_RaytracingAccelerationStructureType: + m_writer->emit("acceleration_structure<instancing>"); break; - case kIROp_RaytracingAccelerationStructureType: m_writer->emit("acceleration_structure<instancing>"); break; default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type"); break; @@ -650,7 +717,8 @@ void MetalSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) void MetalSourceEmitter::emitFrontMatterImpl(TargetRequest*) { - + m_writer->emit("#include <metal_stdlib>\n"); + m_writer->emit("using namespace metal;\n"); } void MetalSourceEmitter::emitGlobalInstImpl(IRInst* inst) diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index 4c4f27be3..6c3de04c4 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -54,6 +54,8 @@ protected: virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE; + virtual bool doesTargetSupportPtrTypes() SLANG_OVERRIDE { return true; } + // Emit a single `register` semantic, as appropriate for a given resource-type-specific layout info // Keyword to use in the uniform case (`register` for globals, `packoffset` inside a `cbuffer`) void _emitHLSLRegisterSemantic(LayoutResourceKind kind, EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling = "register"); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2b85be8c7..b5a94cf0d 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -520,10 +520,7 @@ struct SPIRVEmitContext // > Version nuumber // - // TODO(JS): - // Was previously set to SpvVersion, but that doesn't work since we - // upgraded to SPIR-V headers 1.6. (It would lead to validation errors during vk tests) - // For now mark as version 1.5.0 + // We are targeting SPIRV 1.5 for now. static const uint32_t spvVersion1_5_0 = 0x00010500; m_words.add(spvVersion1_5_0); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c00c7bfc3..b7227a4f0 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -69,6 +69,7 @@ #include "slang-ir-synthesize-active-mask.h" #include "slang-ir-validate.h" #include "slang-ir-wrap-structured-buffers.h" +#include "slang-ir-wrap-global-context.h" #include "slang-ir-liveness.h" #include "slang-ir-glsl-liveness.h" #include "slang-ir-translate-glsl-global-var.h" @@ -1082,6 +1083,12 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); } + // Metal does not allow global variables and global parameters, so + // we need to convert them into an explicit global context parameter + // passed around through a function parameter. + if (target == CodeGenTarget::Metal) + wrapGlobalScopeInContextType(irModule); + auto metadata = new ArtifactPostEmitMetadata; outLinkedIR.metadata = metadata; diff --git a/source/slang/slang-ir-wrap-global-context.cpp b/source/slang/slang-ir-wrap-global-context.cpp new file mode 100644 index 000000000..01a343a02 --- /dev/null +++ b/source/slang/slang-ir-wrap-global-context.cpp @@ -0,0 +1,287 @@ +#include "slang-ir-wrap-global-context.h" + +#include "slang-ir-util.h" + +namespace Slang +{ + struct WrapGlobalScopeContext + { + List<IRFunc*> entryPoints; + IRStructType* contextType; + struct GlobalVarInfo + { + IRStructKey* key; + }; + Dictionary<IRInst*, GlobalVarInfo> mapGlobalVarToInfo; + struct FuncInfo + { + IRInst* contextArg; + }; + Dictionary<IRFunc*, FuncInfo> mapFuncToInfo; + IRStringLit* findNameHint(IRInst* inst) + { + if (auto nameDecor = inst->findDecoration<IRNameHintDecoration>()) + return nameDecor->getNameOperand(); + if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) + return linkageDecor->getMangledNameOperand(); + return nullptr; + } + + // Move all global parameters to the entry point parameters, + // and replace them with global variables that are initialized with + // the entry point parameters. + void moveGlobalParametersToEntryPoint(IRModule* module) + { + Dictionary<IRInst*, IRInst*> mapGlobalParamToGlobalVar; + + IRBuilder builder(module); + + for (auto globalInst : module->getGlobalInsts()) + { + if (auto globalParam = as<IRGlobalParam>(globalInst)) + { + builder.setInsertBefore(globalParam); + auto globalVar = builder.createGlobalVar( + globalParam->getFullType(), + (int)AddressSpace::ThreadLocal); + if (auto name = findNameHint(globalParam)) + builder.addNameHintDecoration(globalVar, name); + mapGlobalParamToGlobalVar[globalParam] = globalVar; + } + } + + // For every entry point, we need to add a new parameter for each global parameter. + for (auto entryPoint : entryPoints) + { + auto firstBlock = entryPoint->getFirstBlock(); + auto paramInsertPoint = firstBlock->getFirstInst(); + struct ParamInfo + { + IRInst* newParam; + IRInst* globalVar; + }; + List<ParamInfo> newParams; + for (auto globalParam : mapGlobalParamToGlobalVar) + { + auto newParam = builder.createParam(globalParam.first->getFullType()); + newParam->insertBefore(paramInsertPoint); + if (auto name = findNameHint(globalParam.first)) + builder.addNameHintDecoration(newParam, name); + newParams.add({newParam, globalParam.second}); + } + + // Insert assignments to the global variables at the start of the entry point. + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + for (auto& paramInfo : newParams) + { + auto globalVar = paramInfo.globalVar; + auto newParam = paramInfo.newParam; + builder.emitStore(globalVar, newParam); + } + } + + // Replace all uses of global parameters with a load from the global variable. + for (auto globalParam : mapGlobalParamToGlobalVar) + { + auto globalVar = globalParam.second; + traverseUses(globalParam.first, [&](IRUse* use) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + auto load = builder.emitLoad(globalParam.first->getFullType(), globalVar); + builder.replaceOperand(use, load); + }); + globalParam.first->removeAndDeallocate(); + } + } + + void processModule(IRModule* module) + { + IRBuilder builder(module); + List<IRInst*> instsToRemove; + + List<IRFunc*> functions; + + // Collect all entry points and functions. + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration<IREntryPointDecoration>()) + entryPoints.add(as<IRFunc>(globalInst)); + if (auto func = as<IRFunc>(globalInst)) + functions.add(func); + } + + // Before everything, we need to move all global parameters to the entry point parameters. + // For each global parameter, e.g. `uniform float4 g;`, we will replace it with a global + // variable, e.g. `float4 _g;`, and add a new parameter to the each entry point, and copy + // the value from the entry point parameter to the global variable. + moveGlobalParametersToEntryPoint(module); + + // The next step is to wrap all global variables in a context type, and pass them around + // with explicit function parameters. + + // Collect all global variables. + for (auto globalInst : module->getGlobalInsts()) + { + if (auto globalVar = as<IRGlobalVar>(globalInst)) + { + auto key = builder.createStructKey(); + + if (auto name = findNameHint(globalVar)) + builder.addNameHintDecoration(key, name); + + GlobalVarInfo info; + info.key = key; + mapGlobalVarToInfo[globalVar] = info; + } + } + if (mapGlobalVarToInfo.getCount() == 0) + return; + + // Create the context type for the global scope. + contextType = builder.createStructType(); + builder.addNameHintDecoration(contextType, toSlice("_SlangGlobalContext")); + for (auto& fieldKV : mapGlobalVarToInfo) + { + auto ptrType = as<IRPtrTypeBase>(fieldKV.first->getFullType()); + if (!ptrType) + continue; + builder.createStructField( + contextType, fieldKV.second.key, ptrType->getValueType()); + } + + // Identify all functions that requires the global scope context. + + // First, add all functions to the work list if it directly uses a global variable. + List<IRFunc*> funcWorkList; + HashSet<IRFunc*> funcWorkListSet; + for (auto& fieldKV : mapGlobalVarToInfo) + { + auto globalVar = fieldKV.first; + for (auto use = globalVar->firstUse; use; use = use->nextUse) + { + if (auto userFunc = getParentFunc(use->getUser())) + { + if (funcWorkListSet.add(userFunc)) + funcWorkList.add(userFunc); + } + } + } + + // Next, propagate the call graph and add all functions that transitively uses a global variable. + for (Index i = 0; i < funcWorkList.getCount(); i++) + { + auto func = funcWorkList[i]; + for (auto use = func->firstUse; use; use = use->nextUse) + { + if (auto call = as<IRCall>(use->getUser())) + { + if (call->getCallee() != func) + continue; + if (auto callerFunc = as<IRFunc>(getParentFunc(call))) + { + if (funcWorkListSet.add(callerFunc)) + funcWorkList.add(callerFunc); + } + } + } + } + + // Now, everything in funcWorkListSet is a function that requires the global scope context. + // We go ahead and add the context type as the first parameter to these functions. + List<IRInst*> newCallArgs; + + auto threadPtrType = builder.getPtrType(kIROp_PtrType, contextType, (int)AddressSpace::ThreadLocal); + for (auto func : funcWorkListSet) + { + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + continue; + bool isEntryPoint = func->findDecoration<IREntryPointDecoration>() != nullptr; + FuncInfo funcInfo = {}; + if (isEntryPoint) + { + // If the function is an entry point, we need to declare a local variable to hold the context. + setInsertBeforeOrdinaryInst(&builder, firstBlock->getFirstOrdinaryInst()); + funcInfo.contextArg = builder.emitVar(contextType, (int)AddressSpace::ThreadLocal); + } + else + { + // For other functions, we just add the context as the first parameter. + builder.setInsertBefore(firstBlock->getFirstInst()); + funcInfo.contextArg = builder.emitParamAtHead(threadPtrType); + } + builder.addNameHintDecoration(funcInfo.contextArg, toSlice("_globalCtx")); + + mapFuncToInfo[func] = funcInfo; + + // Now go through the body of the function and insert the context as the first argument to all calls. + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto call = as<IRCall>(inst)) + { + if (funcWorkListSet.contains((IRFunc*)getResolvedInstForDecorations(call->getCallee()))) + { + builder.setInsertBefore(call); + newCallArgs.clear(); + newCallArgs.add(funcInfo.contextArg); + for (auto arg : call->getArgsList()) + newCallArgs.add(arg); + auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), newCallArgs); + call->replaceUsesWith(newCall); + instsToRemove.add(call); + } + } + } + } + } + + // Next, we need to replace all accesses to global variables with accesses to the context. + for (auto globalVarKV : mapGlobalVarToInfo) + { + auto globalVar = globalVarKV.first; + auto key = globalVarKV.second.key; + traverseUses(globalVar, [&](IRUse* use) + { + auto user = use->getUser(); + auto parentFunc = getParentFunc(user); + if (!parentFunc) + return; + auto funcInfo = mapFuncToInfo.tryGetValue(parentFunc); + SLANG_ASSERT(funcInfo); + + auto contextArg = funcInfo->contextArg; + builder.setInsertBefore(user); + auto replacement = builder.emitFieldAddress( + builder.getPtrType( + kIROp_PtrType, + tryGetPointedToType(&builder, globalVar->getFullType()), + (int)AddressSpace::ThreadLocal), + contextArg, + key); + builder.replaceOperand(use, replacement); + }); + SLANG_ASSERT(!globalVar->hasUses()); + instsToRemove.add(globalVar); + } + + // Fix up all function types. + for (auto func : functions) + { + fixUpFuncType(func); + } + + // Finally, cleanup the IR by removing all the insts scheduled for removal. + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); + } + }; + + void wrapGlobalScopeInContextType(IRModule* module) + { + WrapGlobalScopeContext context; + context.processModule(module); + } +} diff --git a/source/slang/slang-ir-wrap-global-context.h b/source/slang/slang-ir-wrap-global-context.h new file mode 100644 index 000000000..1cd411e0a --- /dev/null +++ b/source/slang/slang-ir-wrap-global-context.h @@ -0,0 +1,14 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + // The metal backend does not support global variables or parameters. + // To workaround this restriction, we use this pass to wrap all the + // global scope variables in a context type, and pass that context + // type as the first parameter to all functions. + + void wrapGlobalScopeInContextType(IRModule* module); + +} diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index fb4ce117e..50ed0096d 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -39,6 +39,15 @@ struct IRModule; struct IRStructField; struct IRStructKey; +enum class AddressSpace +{ + Generic = 0, + ThreadLocal = 1, + Global = 2, + GroupShared = 3, + Uniform = 4, +}; + typedef unsigned int IROpFlags; enum : IROpFlags { |
