diff options
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj.filters | 6 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 88 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-wrap-global-context.cpp | 287 | ||||
| -rw-r--r-- | source/slang/slang-ir-wrap-global-context.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 9 | ||||
| -rw-r--r-- | tests/metal/simple-compute.slang | 15 |
10 files changed, 417 insertions, 18 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index cfb7d474f..acf6df922 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -484,6 +484,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-variable-scope-correction.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-vk-invert-y.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-witness-table-wrapper.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-wrap-global-context.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-wrap-structured-buffers.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir.h" />
<ClInclude Include="..\..\..\source\slang\slang-language-server-ast-lookup.h" />
@@ -713,6 +714,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-ir-variable-scope-correction.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-vk-invert-y.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-witness-table-wrapper.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-wrap-global-context.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-wrap-structured-buffers.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-language-server-ast-lookup.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 14559ce5c..4b60068b3 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -540,6 +540,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-witness-table-wrapper.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-wrap-global-context.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-wrap-structured-buffers.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -1223,6 +1226,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-witness-table-wrapper.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-wrap-global-context.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-wrap-structured-buffers.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index 1ce25a2da..17d074e75 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -100,7 +100,7 @@ void MetalSourceEmitter::_emitHLSLTextureType(IRTextureTypeBase* texType) case SLANG_TEXTURE_2D: m_writer->emit("2d"); break; case SLANG_TEXTURE_3D: m_writer->emit("3d"); break; case SLANG_TEXTURE_CUBE: m_writer->emit("cube"); break; - case SLANG_TEXTURE_BUFFER: m_writer->emit("1d"); break; + case SLANG_TEXTURE_BUFFER: m_writer->emit("_buffer"); break; default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled resource shape"); break; @@ -274,7 +274,39 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO return true; } break; - + case kIROp_RWStructuredBufferGetElementPtr: + { + EmitOpInfo outerPrec = inOuterPrec; + bool needClose = false; + + auto prec = getInfo(EmitOp::Add); + needClose = maybeEmitParens(outerPrec, prec); + emitOperand(inst->getOperand(0), leftSide(outerPrec, prec)); + m_writer->emit("+"); + emitOperand(inst->getOperand(1), rightSide(prec, outerPrec)); + maybeCloseParens(needClose); + return true; + } + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + auto prec = getInfo(EmitOp::Postfix); + emitOperand(inst->getOperand(0), leftSide(inOuterPrec, prec)); + m_writer->emit("["); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("]"); + return true; + } + case kIROp_RWStructuredBufferStore: + { + auto prec = getInfo(EmitOp::Postfix); + emitOperand(inst->getOperand(0), leftSide(inOuterPrec, prec)); + m_writer->emit("["); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("] = "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + return true; + } default: break; } // Not handled @@ -479,12 +511,47 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_ParameterBlockType: case kIROp_ConstantBufferType: { - m_writer->emit("constant "); emitType((IRType*)type->getOperand(0)); + m_writer->emit(" constant*"); + return; + } + case kIROp_PtrType: + case kIROp_InOutType: + case kIROp_OutType: + case kIROp_RefType: + case kIROp_ConstRefType: + { + auto ptrType = cast<IRPtrTypeBase>(type); + emitType((IRType*)ptrType->getValueType()); + switch ((AddressSpace)ptrType->getAddressSpace()) + { + case AddressSpace::Global: + m_writer->emit(" device"); + break; + case AddressSpace::Uniform: + m_writer->emit(" constant"); + break; + case AddressSpace::ThreadLocal: + m_writer->emit(" thread"); + break; + case AddressSpace::GroupShared: + m_writer->emit(" threadgroup"); + break; + } m_writer->emit("*"); return; } - default: break; + case kIROp_ArrayType: + { + m_writer->emit("array<"); + emitType((IRType*)type->getOperand(0)); + m_writer->emit(", "); + emitVal(type->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(">"); + return; + } + default: + break; } if (auto texType = as<IRTextureType>(type)) @@ -499,9 +566,8 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) } else if (auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type)) { - m_writer->emit("device "); emitType(structuredBufferType->getElementType()); - m_writer->emit("*"); + m_writer->emit(" device*"); return; } else if (const auto untypedBufferType = as<IRUntypedBufferResourceType>(type)) @@ -511,10 +577,11 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_HLSLByteAddressBufferType: case kIROp_HLSLRWByteAddressBufferType: case kIROp_HLSLRasterizerOrderedByteAddressBufferType: - m_writer->emit("device "); - m_writer->emit("uint32_t *"); + m_writer->emit("uint32_t device*"); + break; + case kIROp_RaytracingAccelerationStructureType: + m_writer->emit("acceleration_structure<instancing>"); break; - case kIROp_RaytracingAccelerationStructureType: m_writer->emit("acceleration_structure<instancing>"); break; default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type"); break; @@ -650,7 +717,8 @@ void MetalSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) void MetalSourceEmitter::emitFrontMatterImpl(TargetRequest*) { - + m_writer->emit("#include <metal_stdlib>\n"); + m_writer->emit("using namespace metal;\n"); } void MetalSourceEmitter::emitGlobalInstImpl(IRInst* inst) diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index 4c4f27be3..6c3de04c4 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -54,6 +54,8 @@ protected: virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE; + virtual bool doesTargetSupportPtrTypes() SLANG_OVERRIDE { return true; } + // Emit a single `register` semantic, as appropriate for a given resource-type-specific layout info // Keyword to use in the uniform case (`register` for globals, `packoffset` inside a `cbuffer`) void _emitHLSLRegisterSemantic(LayoutResourceKind kind, EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling = "register"); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2b85be8c7..b5a94cf0d 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -520,10 +520,7 @@ struct SPIRVEmitContext // > Version nuumber // - // TODO(JS): - // Was previously set to SpvVersion, but that doesn't work since we - // upgraded to SPIR-V headers 1.6. (It would lead to validation errors during vk tests) - // For now mark as version 1.5.0 + // We are targeting SPIRV 1.5 for now. static const uint32_t spvVersion1_5_0 = 0x00010500; m_words.add(spvVersion1_5_0); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c00c7bfc3..b7227a4f0 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -69,6 +69,7 @@ #include "slang-ir-synthesize-active-mask.h" #include "slang-ir-validate.h" #include "slang-ir-wrap-structured-buffers.h" +#include "slang-ir-wrap-global-context.h" #include "slang-ir-liveness.h" #include "slang-ir-glsl-liveness.h" #include "slang-ir-translate-glsl-global-var.h" @@ -1082,6 +1083,12 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); } + // Metal does not allow global variables and global parameters, so + // we need to convert them into an explicit global context parameter + // passed around through a function parameter. + if (target == CodeGenTarget::Metal) + wrapGlobalScopeInContextType(irModule); + auto metadata = new ArtifactPostEmitMetadata; outLinkedIR.metadata = metadata; diff --git a/source/slang/slang-ir-wrap-global-context.cpp b/source/slang/slang-ir-wrap-global-context.cpp new file mode 100644 index 000000000..01a343a02 --- /dev/null +++ b/source/slang/slang-ir-wrap-global-context.cpp @@ -0,0 +1,287 @@ +#include "slang-ir-wrap-global-context.h" + +#include "slang-ir-util.h" + +namespace Slang +{ + struct WrapGlobalScopeContext + { + List<IRFunc*> entryPoints; + IRStructType* contextType; + struct GlobalVarInfo + { + IRStructKey* key; + }; + Dictionary<IRInst*, GlobalVarInfo> mapGlobalVarToInfo; + struct FuncInfo + { + IRInst* contextArg; + }; + Dictionary<IRFunc*, FuncInfo> mapFuncToInfo; + IRStringLit* findNameHint(IRInst* inst) + { + if (auto nameDecor = inst->findDecoration<IRNameHintDecoration>()) + return nameDecor->getNameOperand(); + if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) + return linkageDecor->getMangledNameOperand(); + return nullptr; + } + + // Move all global parameters to the entry point parameters, + // and replace them with global variables that are initialized with + // the entry point parameters. + void moveGlobalParametersToEntryPoint(IRModule* module) + { + Dictionary<IRInst*, IRInst*> mapGlobalParamToGlobalVar; + + IRBuilder builder(module); + + for (auto globalInst : module->getGlobalInsts()) + { + if (auto globalParam = as<IRGlobalParam>(globalInst)) + { + builder.setInsertBefore(globalParam); + auto globalVar = builder.createGlobalVar( + globalParam->getFullType(), + (int)AddressSpace::ThreadLocal); + if (auto name = findNameHint(globalParam)) + builder.addNameHintDecoration(globalVar, name); + mapGlobalParamToGlobalVar[globalParam] = globalVar; + } + } + + // For every entry point, we need to add a new parameter for each global parameter. + for (auto entryPoint : entryPoints) + { + auto firstBlock = entryPoint->getFirstBlock(); + auto paramInsertPoint = firstBlock->getFirstInst(); + struct ParamInfo + { + IRInst* newParam; + IRInst* globalVar; + }; + List<ParamInfo> newParams; + for (auto globalParam : mapGlobalParamToGlobalVar) + { + auto newParam = builder.createParam(globalParam.first->getFullType()); + newParam->insertBefore(paramInsertPoint); + if (auto name = findNameHint(globalParam.first)) + builder.addNameHintDecoration(newParam, name); + newParams.add({newParam, globalParam.second}); + } + + // Insert assignments to the global variables at the start of the entry point. + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + for (auto& paramInfo : newParams) + { + auto globalVar = paramInfo.globalVar; + auto newParam = paramInfo.newParam; + builder.emitStore(globalVar, newParam); + } + } + + // Replace all uses of global parameters with a load from the global variable. + for (auto globalParam : mapGlobalParamToGlobalVar) + { + auto globalVar = globalParam.second; + traverseUses(globalParam.first, [&](IRUse* use) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + auto load = builder.emitLoad(globalParam.first->getFullType(), globalVar); + builder.replaceOperand(use, load); + }); + globalParam.first->removeAndDeallocate(); + } + } + + void processModule(IRModule* module) + { + IRBuilder builder(module); + List<IRInst*> instsToRemove; + + List<IRFunc*> functions; + + // Collect all entry points and functions. + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration<IREntryPointDecoration>()) + entryPoints.add(as<IRFunc>(globalInst)); + if (auto func = as<IRFunc>(globalInst)) + functions.add(func); + } + + // Before everything, we need to move all global parameters to the entry point parameters. + // For each global parameter, e.g. `uniform float4 g;`, we will replace it with a global + // variable, e.g. `float4 _g;`, and add a new parameter to the each entry point, and copy + // the value from the entry point parameter to the global variable. + moveGlobalParametersToEntryPoint(module); + + // The next step is to wrap all global variables in a context type, and pass them around + // with explicit function parameters. + + // Collect all global variables. + for (auto globalInst : module->getGlobalInsts()) + { + if (auto globalVar = as<IRGlobalVar>(globalInst)) + { + auto key = builder.createStructKey(); + + if (auto name = findNameHint(globalVar)) + builder.addNameHintDecoration(key, name); + + GlobalVarInfo info; + info.key = key; + mapGlobalVarToInfo[globalVar] = info; + } + } + if (mapGlobalVarToInfo.getCount() == 0) + return; + + // Create the context type for the global scope. + contextType = builder.createStructType(); + builder.addNameHintDecoration(contextType, toSlice("_SlangGlobalContext")); + for (auto& fieldKV : mapGlobalVarToInfo) + { + auto ptrType = as<IRPtrTypeBase>(fieldKV.first->getFullType()); + if (!ptrType) + continue; + builder.createStructField( + contextType, fieldKV.second.key, ptrType->getValueType()); + } + + // Identify all functions that requires the global scope context. + + // First, add all functions to the work list if it directly uses a global variable. + List<IRFunc*> funcWorkList; + HashSet<IRFunc*> funcWorkListSet; + for (auto& fieldKV : mapGlobalVarToInfo) + { + auto globalVar = fieldKV.first; + for (auto use = globalVar->firstUse; use; use = use->nextUse) + { + if (auto userFunc = getParentFunc(use->getUser())) + { + if (funcWorkListSet.add(userFunc)) + funcWorkList.add(userFunc); + } + } + } + + // Next, propagate the call graph and add all functions that transitively uses a global variable. + for (Index i = 0; i < funcWorkList.getCount(); i++) + { + auto func = funcWorkList[i]; + for (auto use = func->firstUse; use; use = use->nextUse) + { + if (auto call = as<IRCall>(use->getUser())) + { + if (call->getCallee() != func) + continue; + if (auto callerFunc = as<IRFunc>(getParentFunc(call))) + { + if (funcWorkListSet.add(callerFunc)) + funcWorkList.add(callerFunc); + } + } + } + } + + // Now, everything in funcWorkListSet is a function that requires the global scope context. + // We go ahead and add the context type as the first parameter to these functions. + List<IRInst*> newCallArgs; + + auto threadPtrType = builder.getPtrType(kIROp_PtrType, contextType, (int)AddressSpace::ThreadLocal); + for (auto func : funcWorkListSet) + { + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + continue; + bool isEntryPoint = func->findDecoration<IREntryPointDecoration>() != nullptr; + FuncInfo funcInfo = {}; + if (isEntryPoint) + { + // If the function is an entry point, we need to declare a local variable to hold the context. + setInsertBeforeOrdinaryInst(&builder, firstBlock->getFirstOrdinaryInst()); + funcInfo.contextArg = builder.emitVar(contextType, (int)AddressSpace::ThreadLocal); + } + else + { + // For other functions, we just add the context as the first parameter. + builder.setInsertBefore(firstBlock->getFirstInst()); + funcInfo.contextArg = builder.emitParamAtHead(threadPtrType); + } + builder.addNameHintDecoration(funcInfo.contextArg, toSlice("_globalCtx")); + + mapFuncToInfo[func] = funcInfo; + + // Now go through the body of the function and insert the context as the first argument to all calls. + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto call = as<IRCall>(inst)) + { + if (funcWorkListSet.contains((IRFunc*)getResolvedInstForDecorations(call->getCallee()))) + { + builder.setInsertBefore(call); + newCallArgs.clear(); + newCallArgs.add(funcInfo.contextArg); + for (auto arg : call->getArgsList()) + newCallArgs.add(arg); + auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), newCallArgs); + call->replaceUsesWith(newCall); + instsToRemove.add(call); + } + } + } + } + } + + // Next, we need to replace all accesses to global variables with accesses to the context. + for (auto globalVarKV : mapGlobalVarToInfo) + { + auto globalVar = globalVarKV.first; + auto key = globalVarKV.second.key; + traverseUses(globalVar, [&](IRUse* use) + { + auto user = use->getUser(); + auto parentFunc = getParentFunc(user); + if (!parentFunc) + return; + auto funcInfo = mapFuncToInfo.tryGetValue(parentFunc); + SLANG_ASSERT(funcInfo); + + auto contextArg = funcInfo->contextArg; + builder.setInsertBefore(user); + auto replacement = builder.emitFieldAddress( + builder.getPtrType( + kIROp_PtrType, + tryGetPointedToType(&builder, globalVar->getFullType()), + (int)AddressSpace::ThreadLocal), + contextArg, + key); + builder.replaceOperand(use, replacement); + }); + SLANG_ASSERT(!globalVar->hasUses()); + instsToRemove.add(globalVar); + } + + // Fix up all function types. + for (auto func : functions) + { + fixUpFuncType(func); + } + + // Finally, cleanup the IR by removing all the insts scheduled for removal. + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); + } + }; + + void wrapGlobalScopeInContextType(IRModule* module) + { + WrapGlobalScopeContext context; + context.processModule(module); + } +} diff --git a/source/slang/slang-ir-wrap-global-context.h b/source/slang/slang-ir-wrap-global-context.h new file mode 100644 index 000000000..1cd411e0a --- /dev/null +++ b/source/slang/slang-ir-wrap-global-context.h @@ -0,0 +1,14 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + // The metal backend does not support global variables or parameters. + // To workaround this restriction, we use this pass to wrap all the + // global scope variables in a context type, and pass that context + // type as the first parameter to all functions. + + void wrapGlobalScopeInContextType(IRModule* module); + +} diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index fb4ce117e..50ed0096d 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -39,6 +39,15 @@ struct IRModule; struct IRStructField; struct IRStructKey; +enum class AddressSpace +{ + Generic = 0, + ThreadLocal = 1, + Global = 2, + GroupShared = 3, + Uniform = 4, +}; + typedef unsigned int IROpFlags; enum : IROpFlags { diff --git a/tests/metal/simple-compute.slang b/tests/metal/simple-compute.slang index e099704be..fe797dc2c 100644 --- a/tests/metal/simple-compute.slang +++ b/tests/metal/simple-compute.slang @@ -1,10 +1,17 @@ //TEST:SIMPLE(filecheck=CHECK): -target metal -RWStructuredBuffer<float> outputBuffer; +uniform RWStructuredBuffer<float> outputBuffer; + +// CHECK: {{.*}}kernel{{.*}} void main_kernel(float device* {{.*}}) + +void func(float v) +{ + outputBuffer[0] = v; + outputBuffer[1] = outputBuffer.Load(0); +} -// CHECK: {{.*}}kernel{{.*}} void main() [numthreads(1,1,1)] -void main() +void main_kernel() { - outputBuffer[0] = 1.0f; + func(3.0f); }
\ No newline at end of file |
