diff options
| author | Yong He <yonghe@outlook.com> | 2024-04-30 09:57:54 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-30 09:57:54 -0700 |
| commit | f1221b80c3c5f59ed533147825ea414bef5e9df2 (patch) | |
| tree | 2b737438f2fe82d40035118a34b6d7074991f5a6 /source | |
| parent | 019d68fc14dd006c179417ffdb06827abe089a53 (diff) | |
Metal: Vertex/Fragment builtin and layouts. (#4044)
* Metal: Vertex/Fragment builtin and layouts.
* Fix.
* Fix test.
* Emit user semantic on vertex/fragment attributes.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 128 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 337 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-options.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 3 |
7 files changed, 486 insertions, 4 deletions
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index 3773b7c2a..7580ed74d 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -128,8 +128,13 @@ void MetalSourceEmitter::emitFuncParamLayoutImpl(IRInst* param) m_writer->emit(rr->getOffset()); m_writer->emit(")]]"); break; + case LayoutResourceKind::VaryingInput: + m_writer->emit(" [[stage_in]]"); + break; } } + if (auto sysSemanticAttr = layout->findSystemValueSemanticAttr()) + _emitSystemSemantic(sysSemanticAttr->getName(), sysSemanticAttr->getIndex()); } void MetalSourceEmitter::emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) @@ -610,11 +615,130 @@ void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[m } } +void MetalSourceEmitter::_emitSystemSemantic(UnownedStringSlice semanticName, IRIntegerValue semanticIndex) +{ + if (semanticName.caseInsensitiveEquals(toSlice("SV_POSITION"))) + { + m_writer->emit(" [[position]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_VERTEXID"))) + { + m_writer->emit(" [[vertex_id]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_INSTANCEID"))) + { + m_writer->emit(" [[instance_id]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_Target"))) + { + m_writer->emit(" [[color("); + m_writer->emit(semanticIndex); + m_writer->emit(")]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_PRIMITIVEID"))) + { + m_writer->emit(" [[primitive_id]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_GROUPID"))) + { + // TODO: not supported by metal. + // We need to implement the transformation logic in slang-ir-metal-legalize.cpp + // to convert SV_GroupID to something like METAL_threadgroup_position_in_grid. + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_GROUPINDEX"))) + { + // TODO: not supported by metal. + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_DISPATCHTHREADID"))) + { + m_writer->emit(" [[thread_position_in_grid]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_GROUPTHREADID"))) + { + m_writer->emit(" [[thread_position_in_threadgroup]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_CLIPDISTANCE"))) + { + m_writer->emit(" [[clip_distance]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_RENDERTARGETARRAYINDEX"))) + { + m_writer->emit(" [[render_target_array_index]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_VIEWPORTARRAYINDEX"))) + { + m_writer->emit(" [[viewport_array_index]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_Depth"))) + { + m_writer->emit(" [[depth(any)]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_DepthGreaterEqual"))) + { + m_writer->emit(" [[depth(greater)]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_DepthLessEqual"))) + { + m_writer->emit(" [[depth(less)]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_Coverage"))) + { + m_writer->emit(" [[sample_mask]]"); + } + else if (semanticName.caseInsensitiveEquals(toSlice("SV_StencilRef"))) + { + m_writer->emit(" [[stencil]]"); + } + else + { + m_writer->emit(" [[user("); + m_writer->emit(semanticName); + if (semanticIndex != 0) + { + m_writer->emit("_"); + m_writer->emit(semanticIndex); + } + m_writer->emit(")]]"); + } +} + void MetalSourceEmitter::emitSemanticsImpl(IRInst* inst, bool allowOffsets) { - // Metal does not use semantics. - SLANG_UNUSED(inst); SLANG_UNUSED(allowOffsets); + if (inst->getOp() == kIROp_StructKey) + { + // Only emit [[attribute(n)]] on struct keys. + bool hasSemanticFromLayout = false; + if (auto varLayout = findVarLayout(inst)) + { + for (auto attr : varLayout->getAllAttrs()) + { + if (auto offsetAttr = as<IRVarOffsetAttr>(attr)) + { + if (offsetAttr->getResourceKind() == LayoutResourceKind::MetalAttribute) + { + m_writer->emit(" [[attribute("); + m_writer->emit(offsetAttr->getOffset()); + m_writer->emit(")]]"); + } + } + else if (auto semanticAttr = as<IRSemanticAttr>(attr)) + { + auto semanticName = String(semanticAttr->getName()).toUpper(); + _emitSystemSemantic(semanticAttr->getName(), semanticAttr->getIndex()); + hasSemanticFromLayout = true; + } + } + + } + if (!hasSemanticFromLayout) + { + if (auto semanticDecor = inst->findDecoration<IRSemanticDecoration>()) + { + _emitSystemSemantic(semanticDecor->getSemanticName(), semanticDecor->getSemanticIndex()); + } + } + } } void MetalSourceEmitter::_emitStageAccessSemantic(IRStageAccessDecoration* decoration, const char* name) diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index 38d4c3a2c..d925365da 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -53,7 +53,6 @@ protected: virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE; - virtual bool doesTargetSupportPtrTypes() SLANG_OVERRIDE { return true; } void emitFuncParamLayoutImpl(IRInst* param); @@ -68,6 +67,7 @@ protected: void _emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val); void _emitStageAccessSemantic(IRStageAccessDecoration* decoration, const char* name); + void _emitSystemSemantic(UnownedStringSlice semanticName, IRIntegerValue semanticIndex); }; } diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 9369afbc5..87f0911e7 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -29,6 +29,7 @@ #include "slang-ir-fuse-satcoop.h" #include "slang-ir-glsl-legalize.h" #include "slang-ir-hlsl-legalize.h" +#include "slang-ir-metal-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-inline.h" #include "slang-ir-legalize-array-return-type.h" @@ -834,7 +835,11 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); } break; - + case CodeGenTarget::Metal: + { + legalizeIRForMetal(irModule, sink); + } + break; case CodeGenTarget::CSource: case CodeGenTarget::CPPSource: { diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp new file mode 100644 index 000000000..822a1e2f1 --- /dev/null +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -0,0 +1,337 @@ +#include "slang-ir-metal-legalize.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" + +namespace Slang +{ + struct EntryPointInfo + { + IRFunc* entryPointFunc; + IREntryPointDecoration* entryPointDecor; + }; + + void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint) + { + // If an entry point has a input parameter with a struct type, we want to hoist out + // all the fields of the struct type to be individual parameters of the entry point. + // This will canonicalize the entry point signature, so we can handle all cases uniformly. + + // For example, given an entry point: + // ``` + // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID}; + // void main(VertexInput vin) { ... } + // ``` + // We will transform it to: + // ``` + // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { + // VertexInput vin = {pos,uv,vertexId}; + // ... + // } + // ``` + + auto func = entryPoint.entryPointFunc; + List<IRParam*> paramsToProcess; + for (auto param : func->getParams()) + { + if (auto structType = as<IRStructType>(param->getDataType())) + { + paramsToProcess.add(param); + } + } + + IRBuilder builder(func); + builder.setInsertBefore(func); + for (auto param : paramsToProcess) + { + auto structType = as<IRStructType>(param->getDataType()); + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto varLayout = findVarLayout(param); + IRStructTypeLayout* structTypeLayout = nullptr; + if (varLayout) + structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout()); + Index fieldIndex = 0; + List<IRInst*> fieldParams; + for (auto field : structType->getFields()) + { + auto fieldParam = builder.emitParam(field->getFieldType()); + + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren(&cloneEnv, builder.getModule(), field->getKey(), fieldParam); + + IRVarLayout* fieldLayout = structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr; + if (varLayout) + { + IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout()); + varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout); + for (auto offsetAttr : fieldLayout->getOffsetAttrs()) + { + auto parentOffsetAttr = varLayout->findOffsetAttr(offsetAttr->getResourceKind()); + UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0; + UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0; + auto resInfo = varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind()); + resInfo->offset = parentOffset + offsetAttr->getOffset(); + resInfo->space = parentSpace + offsetAttr->getSpace(); + } + builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build()); + } + param->insertBefore(fieldParam); + fieldParams.add(fieldParam); + fieldIndex++; + } + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto reconstructedParam = builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer()); + param->replaceUsesWith(reconstructedParam); + param->removeFromParent(); + } + fixUpFuncType(func); + } + + void packStageInParameters(EntryPointInfo entryPoint) + { + // If the entry point has any parameters whose layout contains VaryingInput, + // we need to pack those parameters into a single `struct` type, and decorate + // the fields with the appropriate `[[attribute]]` decorations. + // For other parameters that are not `VaryingInput`, we need to leave them as is. + // + // For example, given this code after `hoistEntryPointParameterFromStruct`: + // ``` + // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { + // VertexInput vin = {pos,uv,vertexId}; + // ... + // } + // ``` + // We are going to transform it into: + // ``` + // struct VertexInput { + // float3 pos [[attribute(0)]]; + // float2 uv [[attribute(1)]]; + // }; + // void main(VertexInput vin, int vertexId : SV_VertexID) { + // let pos = vin.pos; + // let uv = vin.uv; + // ... + // } + + auto func = entryPoint.entryPointFunc; + + bool isGeometryStage = false; + switch (entryPoint.entryPointDecor->getProfile().getStage()) + { + case Stage::Vertex: + case Stage::Amplification: + case Stage::Mesh: + case Stage::Geometry: + case Stage::Domain: + case Stage::Hull: + isGeometryStage = true; + break; + } + + List<IRParam*> paramsToPack; + for (auto param : func->getParams()) + { + auto layout = findVarLayout(param); + if (!layout) + continue; + if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + continue; + paramsToPack.add(param); + } + + if (paramsToPack.getCount() == 0) + return; + + IRBuilder builder(func); + builder.setInsertBefore(func); + IRStructType* structType = builder.createStructType(); + auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); + builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Input")).getUnownedSlice()); + List<IRStructKey*> keys; + IRStructTypeLayout::Builder layoutBuilder(&builder); + for (auto param : paramsToPack) + { + auto paramVarLayout = findVarLayout(param); + auto key = builder.createStructKey(); + param->transferDecorationsTo(key); + builder.createStructField(structType, key, param->getDataType()); + if (auto varyingInOffsetAttr = paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + { + if (!key->findDecoration<IRSemanticDecoration>() && !paramVarLayout->findAttr<IRSemanticAttr>()) + { + // If the parameter doesn't have a semantic, we need to add one for semantic matching. + builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varyingInOffsetAttr->getOffset()); + } + } + if (isGeometryStage) + { + // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute offsets. + IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout()); + elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout); + for (auto offsetAttr : paramVarLayout->getOffsetAttrs()) + { + auto resourceKind = offsetAttr->getResourceKind(); + if (resourceKind == LayoutResourceKind::VaryingInput) + { + resourceKind = LayoutResourceKind::MetalAttribute; + } + auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind); + resInfo->offset = offsetAttr->getOffset(); + resInfo->space = offsetAttr->getSpace(); + } + paramVarLayout = elementVarLayoutBuilder.build(); + } + layoutBuilder.addField(key, paramVarLayout); + builder.addLayoutDecoration(key, paramVarLayout); + keys.add(key); + } + builder.setInsertInto(func->getFirstBlock()); + auto packedParam = builder.emitParamAtHead(structType); + auto typeLayout = layoutBuilder.build(); + IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); + + // Add a VaryingInput resource info to the packed parameter layout, so that we can emit + // the needed `[[stage_in]]` attribute in Metal emitter. + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(packedParam, paramVarLayout); + + // Replace the original parameters with the packed parameter + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++) + { + auto param = paramsToPack[paramIndex]; + auto key = keys[paramIndex]; + auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key); + param->replaceUsesWith(paramField); + param->removeFromParent(); + } + fixUpFuncType(func); + } + + + void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout) + { + // Ensure each field in an output struct type has either a system semantic or a user semantic, + // so that signature matching can happen correctly. + auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout()); + Index index = 0; + IRBuilder builder(structType); + for (auto field : structType->getFields()) + { + auto key = field->getKey(); + if (key->findDecoration<IRSemanticDecoration>()) + { + index++; + continue; + } + typeLayout->getFieldLayout(index); + auto fieldLayout = typeLayout->getFieldLayout(index); + if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) + { + UInt varOffset = 0; + if (auto varOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) + varOffset = varOffsetAttr->getOffset(); + varOffset += offsetAttr->getOffset(); + builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset); + } + index++; + } + } + + + void wrapReturnValueInStruct(EntryPointInfo entryPoint) + { + // Wrap return value into a struct if it is not already a struct. + // For example, given this entry point: + // ``` + // float4 main() : SV_Target { return float3(1,2,3); } + // ``` + // We are going to transform it into: + // ``` + // struct Output { + // float4 value : SV_Target; + // }; + // Output main() { return {float3(1,2,3)}; } + + auto func = entryPoint.entryPointFunc; + + auto returnType = func->getResultType(); + if (as<IRVoidType>(returnType)) + return; + auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>(); + if (!entryPointLayoutDecor) + return; + auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout()); + if (!entryPointLayout) + return; + auto resultLayout = entryPointLayout->getResultLayout(); + + // If return type is already a struct, just make sure every field has a semantic. + if (auto returnStructType = as<IRStructType>(returnType)) + { + ensureResultStructHasUserSemantic(returnStructType, resultLayout); + return; + } + + // If not, we need to wrap the result into a struct type. + IRBuilder builder(func); + builder.setInsertBefore(func); + IRStructType* structType = builder.createStructType(); + auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); + builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Output")).getUnownedSlice()); + auto key = builder.createStructKey(); + builder.addNameHintDecoration(key, toSlice("output")); + builder.addLayoutDecoration(key, resultLayout); + builder.createStructField(structType, key, returnType); + IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder); + structTypeLayoutBuilder.addField(key, resultLayout); + auto typeLayout = structTypeLayoutBuilder.build(); + IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); + auto varLayout = varLayoutBuilder.build(); + ensureResultStructHasUserSemantic(structType, varLayout); + + for (auto block : func->getBlocks()) + { + if (auto returnInst = as<IRReturn>(block->getTerminator())) + { + builder.setInsertBefore(returnInst); + auto returnVal = returnInst->getVal(); + auto newResult = builder.emitMakeStruct(structType, 1, &returnVal); + returnInst->setOperand(0, newResult); + } + } + fixUpFuncType(func, structType); + } + + void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink) + { + SLANG_UNUSED(sink); + + hoistEntryPointParameterFromStruct(entryPoint); + packStageInParameters(entryPoint); + wrapReturnValueInStruct(entryPoint); + } + + void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) + { + List<EntryPointInfo> entryPoints; + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(inst)) + { + if (auto entryPointDecor = func->findDecoration<IREntryPointDecoration>()) + { + EntryPointInfo info; + info.entryPointDecor = entryPointDecor; + info.entryPointFunc = func; + entryPoints.add(info); + } + } + } + + for (auto entryPoint : entryPoints) + legalizeEntryPointForMetal(entryPoint, sink); + } +} diff --git a/source/slang/slang-ir-metal-legalize.h b/source/slang/slang-ir-metal-legalize.h new file mode 100644 index 000000000..3eb64438f --- /dev/null +++ b/source/slang/slang-ir-metal-legalize.h @@ -0,0 +1,10 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + class DiagnosticSink; + + void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink); +} diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 9d9e8039b..dad69350e 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -2746,6 +2746,9 @@ SlangResult OptionsParser::_parse( m_rawTargets[0].format == CodeGenTarget::CUDASource || m_rawTargets[0].format == CodeGenTarget::SPIRV || m_rawTargets[0].format == CodeGenTarget::SPIRVAssembly || + m_rawTargets[0].format == CodeGenTarget::Metal || + m_rawTargets[0].format == CodeGenTarget::MetalLib || + m_rawTargets[0].format == CodeGenTarget::MetalLibAssembly || ArtifactDescUtil::makeDescForCompileTarget(asExternal(m_rawTargets[0].format)).kind == ArtifactKind::HostCallable)) { RawOutput rawOutput; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 971d6056f..4a446a351 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -6445,6 +6445,9 @@ SlangResult EndToEndCompileRequest::isParameterLocationUsed(Int entryPointIndex, if (SLANG_FAILED(_getEntryPointResult(this, static_cast<int>(entryPointIndex), static_cast<int>(targetIndex), artifact))) return SLANG_E_INVALID_ARG; + if (!artifact) + return SLANG_E_NOT_AVAILABLE; + // Find a rep auto metadata = findAssociatedRepresentation<IArtifactPostEmitMetadata>(artifact); if (!metadata) |
