diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-varying-params.cpp | 83 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-varying-params.h | 61 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 1969 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.h | 2 |
9 files changed, 1495 insertions, 646 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 17e2822a3..59a64a192 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -723,7 +723,6 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format> [__readNone] [ForceInline] - [require(hlsl, texture_sm_4_0_fragment)] T Sample(vector<float, Shape.dimensions+isArray> location, vector<int, Shape.planeDimensions> offset, float clamp, out uint status) { __target_switch @@ -1328,7 +1327,6 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,0,format> [__readNone] [ForceInline] - [require(hlsl, texture_sm_4_0_fragment)] T Sample(SamplerState s, vector<float, Shape.dimensions+isArray> location, constexpr vector<int, Shape.planeDimensions> offset, float clamp, out uint status) { __target_switch diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 98af8a228..3e7d71369 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -708,6 +708,7 @@ DIAGNOSTIC(40002, Error, invalidBindingValue, "binding location '$0' is out of v DIAGNOSTIC(40003, Error, bindingExceedsLimit, "binding location '$0' assigned to component '$1' exceeds maximum limit.") DIAGNOSTIC(40004, Error, bindingAlreadyOccupiedByModule, "DescriptorSet ID '$0' is already occupied by module instance '$1'.") DIAGNOSTIC(40005, Error, topLevelModuleUsedWithoutSpecifyingBinding, "top level module '$0' is being used without specifying binding location. Use [Binding: \"index\"] attribute to provide a binding location.") +DIAGNOSTIC(40006, Error, unimplementedSystemValueSemantic, "unknown system-value semantic '$0'") DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value semantic '$0'") @@ -849,10 +850,9 @@ DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid param DIAGNOSTIC(55200, Error, unsupportedBuiltinType, "'$0' is not a supported builtin type for the target.") DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$0', but the current code generation target does not allow recursion.") DIAGNOSTIC(55202, Error, systemValueAttributeNotSupported, "system value semantic '$0' is not supported for the current target.") -DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or convertible to type '$1'.") +DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or be convertible to type '$1'.") DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'") - DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.") DIAGNOSTIC(57003, Error, unknownTessPartitioning, "unknown tessellation partitioning '$0'.") diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fb9a9613a..83b38b3b6 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4405,7 +4405,7 @@ public: } // void addLayoutDecoration(IRInst* value, Layout* layout); - void addLayoutDecoration(IRInst* value, IRLayout* layout); + IRLayoutDecoration* addLayoutDecoration(IRInst* value, IRLayout* layout); // IRLayout* getLayout(Layout* astLayout); @@ -4525,9 +4525,9 @@ public: addDecoration(value, kIROp_ForceUnrollDecoration, getIntValue(getIntType(), iters)); } - void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0) + IRSemanticDecoration* addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0) { - addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index)); + return as<IRSemanticDecoration>(addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index))); } void addRequireSPIRVDescriptorIndexingExtensionDecoration(IRInst* value) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 582af4ac8..7d3e3c533 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -2,9 +2,34 @@ #include "slang-ir-legalize-varying-params.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" +#include "slang-parameter-binding.h" namespace Slang { + // Convert semantic name (ignores case) into equivlent `SystemValueSemanticName` + SystemValueSemanticName convertSystemValueSemanticNameToEnum(String rawSemanticName) + { + auto semanticName = rawSemanticName.toLower(); + + SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None; + +#define CASE(ID, NAME) \ + if(semanticName == String(#NAME).toLower()) \ + { \ + systemValueSemanticName = SystemValueSemanticName::ID; \ + } \ + else + + SYSTEM_VALUE_SEMANTIC_NAMES(CASE) +#undef CASE + { + systemValueSemanticName = SystemValueSemanticName::Unknown; + // no match + } + return systemValueSemanticName; + } // This pass implements logic to "legalize" the varying parameter // signature of an entry point. @@ -51,36 +76,6 @@ namespace Slang // * Slang allows for `inout` varying parameters, which need to desugar into // distinct `in` and `out` parameters for targets like GLSL. - -#define SYSTEM_VALUE_SEMANTIC_NAMES(M) \ - M(DispatchThreadID, SV_DispatchThreadID) \ - M(GroupID, SV_GroupID) \ - M(GroupThreadID, SV_GroupThreadID) \ - M(GroupThreadIndex, SV_GroupIndex) \ - /* end */ - - /// A known system-value semantic name that can be applied to a parameter - /// -enum class SystemValueSemanticName -{ - None = 0, - - // TODO: Should this enumeration be responsible for differentiating - // cases where the same semantic name string is allowed in multiple stages, - // or as both input/output in a single stage, and those different uses - // might result in different meanings? The alternative is to always - // pass around the semantic name, stage, and direction together so - // that code can tell those special cases apart. - -#define CASE(ID, NAME) ID, -SYSTEM_VALUE_SEMANTIC_NAMES(CASE) -#undef CASE - - // TODO: There are many more system-value semantic names that we - // can/should handle here, but for now I've restricted this list - // to those that are necessary for translating compute shaders. -}; - /// A placeholder that represents the value of a legalized varying /// parameter, for the purposes of substituting it into IR code. /// @@ -249,7 +244,7 @@ IRInst* emitCalcDispatchThreadID( } /// Emit code to calculate `SV_GroupIndex` -IRInst* emitCalcGroupThreadIndex( +IRInst* emitCalcGroupIndex( IRBuilder& builder, IRInst* groupThreadID, IRInst* groupExtents) @@ -935,23 +930,7 @@ protected: // avoid all the `String`s we crete and thren throw // away here. // - String semanticNameSpelling = semanticInst->getName(); - auto semanticName = semanticNameSpelling.toLower(); - - SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None; - - #define CASE(ID, NAME) \ - if(semanticName == String(#NAME).toLower()) \ - { \ - systemValueSemanticName = SystemValueSemanticName::ID; \ - } \ - else - - SYSTEM_VALUE_SEMANTIC_NAMES(CASE) - #undef CASE - { - // no match - } + auto systemValueSemanticName = convertSystemValueSemanticNameToEnum(String(semanticInst->getName())); if( systemValueSemanticName != SystemValueSemanticName::None ) { @@ -1223,7 +1202,7 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz threadIdxGlobalParam, blockDimGlobalParam); - groupThreadIndex = emitCalcGroupThreadIndex( + groupThreadIndex = emitCalcGroupIndex( builder, threadIdxGlobalParam, blockDimGlobalParam); @@ -1254,7 +1233,7 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz { case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(blockIdxGlobalParam); case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(threadIdxGlobalParam); - case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex); + case SystemValueSemanticName::GroupIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex); case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID); default: return diagnoseUnsupportedSystemVal(info); @@ -1397,7 +1376,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); - groupThreadIndex = emitCalcGroupThreadIndex(builder, groupThreadID, groupExtents); + groupThreadIndex = emitCalcGroupIndex(builder, groupThreadID, groupExtents); } LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info) SLANG_OVERRIDE @@ -1414,7 +1393,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize { case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(groupID); case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(groupThreadID); - case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex); + case SystemValueSemanticName::GroupIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex); case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID); default: diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index 952192def..58efa39a2 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -1,17 +1,18 @@ // slang-ir-legalize-varying-params.h #pragma once +#include "slang-ir-insts.h" namespace Slang { class DiagnosticSink; -struct IRFunc; struct IRModule; struct IRInst; struct IRFunc; struct IRVectorType; struct IRBuilder; +struct IREntryPointDecoration; void legalizeEntryPointVaryingParamsForCPU( IRModule* module, @@ -21,13 +22,63 @@ void legalizeEntryPointVaryingParamsForCUDA( IRModule* module, DiagnosticSink* sink); -IRInst* emitCalcGroupThreadIndex( - IRBuilder& builder, - IRInst* groupThreadID, - IRInst* groupExtents); + +// (#4375) Once `slang-ir-metal-legalize.cpp` is merged with +// `slang-ir-legalize-varying-params.cpp`, move the following +// below into `slang-ir-legalize-varying-params.cpp` as well IRInst* emitCalcGroupExtents( IRBuilder& builder, IRFunc* entryPoint, IRVectorType* type); + +IRInst* emitCalcGroupIndex( + IRBuilder& builder, + IRInst* groupThreadID, + IRInst* groupExtents); + +// SystemValueSemanticName member definition macro +#define SYSTEM_VALUE_SEMANTIC_NAMES(M) \ + M(Position, SV_Position) \ + M(ClipDistance, SV_ClipDistance) \ + M(CullDistance, SV_CullDistance) \ + M(Coverage, SV_Coverage) \ + M(InnerCoverage, SV_InnerCoverage) \ + M(Depth, SV_Depth) \ + M(DepthGreaterEqual, SV_DepthGreaterEqual) \ + M(DepthLessEqual, SV_DepthLessEqual) \ + M(DispatchThreadID, SV_DispatchThreadID) \ + M(DomainLsocation, SV_DomainLsocation) \ + M(GroupID, SV_GroupID) \ + M(GroupIndex, SV_GroupIndex) \ + M(GroupThreadID, SV_GroupThreadID) \ + M(GSInstanceID, SV_GSInstanceID) \ + M(InstanceID, SV_InstanceID) \ + M(IsFrontFace, SV_IsFrontFace) \ + M(OutputControlPointID, SV_OutputControlPointID)\ + M(PointSize, SV_PointSize) \ + M(PrimitiveID, SV_PrimitiveID) \ + M(RenderTargetArrayIndex, SV_RenderTargetArrayIndex) \ + M(SampleIndex, SV_SampleIndex) \ + M(StencilRef, SV_StencilRef) \ + M(TessFactor, SV_TessFactor) \ + M(VertexID, SV_VertexID) \ + M(ViewID, SV_ViewID) \ + M(ViewportArrayIndex, SV_ViewportArrayIndex) \ + M(Target, SV_Target) \ + /* end */ + +/// A known system-value semantic name that can be applied to a parameter +/// +enum class SystemValueSemanticName +{ + None = 0, + Unknown = 0, +#define CASE(ID, NAME) ID, + SYSTEM_VALUE_SEMANTIC_NAMES(CASE) +#undef CASE +}; + +SystemValueSemanticName convertSystemValueSemanticNameToEnum(String rawSemanticName); + } diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 01bcb0295..6b333f892 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1,5 +1,6 @@ #include "slang-ir-metal-legalize.h" +#include <set> #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" @@ -10,7 +11,6 @@ namespace Slang { - const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid"); struct EntryPointInfo { @@ -18,701 +18,1472 @@ namespace Slang IREntryPointDecoration* entryPointDecor; }; - void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint) + const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid"); + struct LegalizeMetalEntryPointContext { - // If an entry point has a input parameter with a struct type, we want to hoist out - // all the fields of the struct type to be individual parameters of the entry point. - // This will canonicalize the entry point signature, so we can handle all cases uniformly. - - // For example, given an entry point: - // ``` - // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID}; - // void main(VertexInput vin) { ... } - // ``` - // We will transform it to: - // ``` - // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { - // VertexInput vin = {pos,uv,vertexId}; - // ... - // } - // ``` - - auto func = entryPoint.entryPointFunc; - List<IRParam*> paramsToProcess; - for (auto param : func->getParams()) - { - if (as<IRStructType>(param->getDataType())) - { - paramsToProcess.add(param); + ShortList<IRType*> permittedTypes_sv_target; + Dictionary<IRFunc*, IRInst*> entryPointToGroupThreadId; + HashSet<IRStructField*> semanticInfoToRemove; + + DiagnosticSink* m_sink; + IRModule* m_module; + + LegalizeMetalEntryPointContext(DiagnosticSink* sink, IRModule* module) : m_sink(sink), m_module(module) + { + } + + void removeSemanticLayoutsFromLegalizedStructs() + { + // Metal does not allow duplicate attributes to appear in the same shader. + // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]` + // must be removed. + for (auto field : semanticInfoToRemove) + { + auto key = field->getKey(); + // Some decorations appear twice, destroy all found + for (;;) + { + if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>()) + { + semanticDecor->removeAndDeallocate(); + continue; + } + else if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>()) + { + layoutDecor->removeAndDeallocate(); + continue; + } + break; + } } } - IRBuilder builder(func); - builder.setInsertBefore(func); - for (auto param : paramsToProcess) + void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint) { - auto structType = as<IRStructType>(param->getDataType()); - builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto varLayout = findVarLayout(param); - - // If `param` already has a semantic, we don't want to hoist its fields out. - if (varLayout->findSystemValueSemanticAttr() != nullptr || - param->findDecoration<IRSemanticDecoration>()) - continue; - - IRStructTypeLayout* structTypeLayout = nullptr; - if (varLayout) - structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout()); - Index fieldIndex = 0; - List<IRInst*> fieldParams; - for (auto field : structType->getFields()) + // If an entry point has a input parameter with a struct type, we want to hoist out + // all the fields of the struct type to be individual parameters of the entry point. + // This will canonicalize the entry point signature, so we can handle all cases uniformly. + + // For example, given an entry point: + // ``` + // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID}; + // void main(VertexInput vin) { ... } + // ``` + // We will transform it to: + // ``` + // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { + // VertexInput vin = {pos,uv,vertexId}; + // ... + // } + // ``` + + auto func = entryPoint.entryPointFunc; + List<IRParam*> paramsToProcess; + for (auto param : func->getParams()) { - auto fieldParam = builder.emitParam(field->getFieldType()); - - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren(&cloneEnv, builder.getModule(), field->getKey(), fieldParam); + if (as<IRStructType>(param->getDataType())) + { + paramsToProcess.add(param); + } + } + + IRBuilder builder(func); + builder.setInsertBefore(func); + for (auto param : paramsToProcess) + { + auto structType = as<IRStructType>(param->getDataType()); + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto varLayout = findVarLayout(param); - IRVarLayout* fieldLayout = structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr; + // If `param` already has a semantic, we don't want to hoist its fields out. + if (varLayout->findSystemValueSemanticAttr() != nullptr || + param->findDecoration<IRSemanticDecoration>()) + continue; + + IRStructTypeLayout* structTypeLayout = nullptr; if (varLayout) + structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout()); + Index fieldIndex = 0; + List<IRInst*> fieldParams; + for (auto field : structType->getFields()) { - IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout()); - varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout); - for (auto offsetAttr : fieldLayout->getOffsetAttrs()) + auto fieldParam = builder.emitParam(field->getFieldType()); + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren(&cloneEnv, builder.getModule(), field->getKey(), fieldParam); + + IRVarLayout* fieldLayout = structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr; + if (varLayout) { - auto parentOffsetAttr = varLayout->findOffsetAttr(offsetAttr->getResourceKind()); - UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0; - UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0; - auto resInfo = varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind()); - resInfo->offset = parentOffset + offsetAttr->getOffset(); - resInfo->space = parentSpace + offsetAttr->getSpace(); + IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout()); + varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout); + for (auto offsetAttr : fieldLayout->getOffsetAttrs()) + { + auto parentOffsetAttr = varLayout->findOffsetAttr(offsetAttr->getResourceKind()); + UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0; + UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0; + auto resInfo = varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind()); + resInfo->offset = parentOffset + offsetAttr->getOffset(); + resInfo->space = parentSpace + offsetAttr->getSpace(); + } + builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build()); } - builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build()); + param->insertBefore(fieldParam); + fieldParams.add(fieldParam); + fieldIndex++; } - param->insertBefore(fieldParam); - fieldParams.add(fieldParam); - fieldIndex++; + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto reconstructedParam = builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer()); + param->replaceUsesWith(reconstructedParam); + param->removeFromParent(); } - builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto reconstructedParam = builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer()); - param->replaceUsesWith(reconstructedParam); - param->removeFromParent(); + fixUpFuncType(func); } - fixUpFuncType(func); - } - void packStageInParameters(EntryPointInfo entryPoint) - { - // If the entry point has any parameters whose layout contains VaryingInput, - // we need to pack those parameters into a single `struct` type, and decorate - // the fields with the appropriate `[[attribute]]` decorations. - // For other parameters that are not `VaryingInput`, we need to leave them as is. - // - // For example, given this code after `hoistEntryPointParameterFromStruct`: - // ``` - // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { - // VertexInput vin = {pos,uv,vertexId}; - // ... - // } - // ``` - // We are going to transform it into: - // ``` - // struct VertexInput { - // float3 pos [[attribute(0)]]; - // float2 uv [[attribute(1)]]; - // }; - // void main(VertexInput vin, int vertexId : SV_VertexID) { - // let pos = vin.pos; - // let uv = vin.uv; - // ... - // } - - auto func = entryPoint.entryPointFunc; - - bool isGeometryStage = false; - switch (entryPoint.entryPointDecor->getProfile().getStage()) - { - case Stage::Vertex: - case Stage::Amplification: - case Stage::Mesh: - case Stage::Geometry: - case Stage::Domain: - case Stage::Hull: - isGeometryStage = true; - break; - } + // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct + void flattenInputParameters(EntryPointInfo entryPoint) + { + // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members). + /* + // Assume the following code + struct NestedFragment + { + float2 p3; + }; + struct Fragment + { + float4 p1; + float3 p2; + NestedFragment p3_nested; + }; - List<IRParam*> paramsToPack; - for (auto param : func->getParams()) - { - auto layout = findVarLayout(param); - if (!layout) - continue; - if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) - continue; - if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) - continue; - paramsToPack.add(param); - } + // Fragment flattens into + struct Fragment + { + float4 p1; + float3 p2; + float2 p3; + }; + */ + + // This is important since Metal does not allow semantic's on a struct + /* + // Assume the following code + struct NestedFragment1 + { + float2 p3; + }; + struct Fragment1 + { + float4 p1 : SV_TARGET0; + float3 p2 : SV_TARGET1; + NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct + }; + + */ + + // Metal does allow semantics on members of a nested struct but we are avoiding this + // approach since there are senarios where legalization (and verification) is + // hard/expensive without creating a flat struct: + // 1. Entry points may share structs, semantics may be inconsistent across entry points + // 2. Multiple of the same struct may be used in a param list + /* + // Assume the following code + struct NestedFragment + { + float2 p3; + }; + struct Fragment + { + float4 p1 : SV_TARGET0; + NestedFragment p2 : SV_TARGET1; + NestedFragment p3 : SV_TARGET2; + }; - if (paramsToPack.getCount() == 0) - return; + // Legalized without flattening -- abandoned + struct NestedFragment1 + { + float2 p3 : SV_TARGET1; + }; + struct NestedFragment2 + { + float2 p3 : SV_TARGET2; + }; + struct Fragment + { + float4 p1 : SV_TARGET0; + NestedFragment1 p2; + NestedFragment2 p3; + }; - IRBuilder builder(func); - builder.setInsertBefore(func); - IRStructType* structType = builder.createStructType(); - auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); - builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Input")).getUnownedSlice()); - List<IRStructKey*> keys; - IRStructTypeLayout::Builder layoutBuilder(&builder); - for (auto param : paramsToPack) - { - auto paramVarLayout = findVarLayout(param); - auto key = builder.createStructKey(); - param->transferDecorationsTo(key); - builder.createStructField(structType, key, param->getDataType()); - if (auto varyingInOffsetAttr = paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput)) - { - if (!key->findDecoration<IRSemanticDecoration>() && !paramVarLayout->findAttr<IRSemanticAttr>()) + // Legalized with flattening -- current approach + struct Fragment { - // If the parameter doesn't have a semantic, we need to add one for semantic matching. - builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varyingInOffsetAttr->getOffset()); - } - } - if (isGeometryStage) + float4 p1 : SV_TARGET0; + float2 p2 : SV_TARGET1; + float2 p3 : SV_TARGET2; + }; + */ + + auto func = entryPoint.entryPointFunc; + bool modified = false; + for (auto param : func->getParams()) { - // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute offsets. - IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout()); - elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout); - for (auto offsetAttr : paramVarLayout->getOffsetAttrs()) + auto layout = findVarLayout(param); + if (!layout) + continue; + if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + continue; + if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + continue; + // If we find a IRParam with a IRStructType member, we need to flatten the entire IRParam + if (auto structType = as<IRStructType>(param->getDataType())) { - auto resourceKind = offsetAttr->getResourceKind(); - if (resourceKind == LayoutResourceKind::VaryingInput) + IRBuilder builder(func); + MapStructToFlatStruct mapOldFieldToNewField; + + // Flatten struct if we have nested IRStructType + auto flattenedStruct = maybeFlattenNestedStructs(builder, structType, mapOldFieldToNewField, semanticInfoToRemove); + if (flattenedStruct != structType) { - resourceKind = LayoutResourceKind::MetalAttribute; + // Validate/rearange all semantics which overlap in our flat struct + fixFieldSemanticsOfFlatStruct(flattenedStruct); + + // Replace the 'old IRParam type' with a 'new IRParam type' + param->setFullType(flattenedStruct); + + // Emit a new variable at EntryPoint of 'old IRParam type' + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto dstVal = builder.emitVar(structType); + auto dstLoad = builder.emitLoad(dstVal); + param->replaceUsesWith(dstLoad); + builder.setInsertBefore(dstLoad); + // Copy the 'new IRParam type' to our 'old IRParam type' + mapOldFieldToNewField.emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>(builder, dstVal, param); + + modified = true; } - auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind); - resInfo->offset = offsetAttr->getOffset(); - resInfo->space = offsetAttr->getSpace(); } - paramVarLayout = elementVarLayoutBuilder.build(); } - layoutBuilder.addField(key, paramVarLayout); - builder.addLayoutDecoration(key, paramVarLayout); - keys.add(key); - } - builder.setInsertInto(func->getFirstBlock()); - auto packedParam = builder.emitParamAtHead(structType); - auto typeLayout = layoutBuilder.build(); - IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); - - // Add a VaryingInput resource info to the packed parameter layout, so that we can emit - // the needed `[[stage_in]]` attribute in Metal emitter. - varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput); - auto paramVarLayout = varLayoutBuilder.build(); - builder.addLayoutDecoration(packedParam, paramVarLayout); - - // Replace the original parameters with the packed parameter - builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++) - { - auto param = paramsToPack[paramIndex]; - auto key = keys[paramIndex]; - auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key); - param->replaceUsesWith(paramField); - param->removeFromParent(); + if (modified) + fixUpFuncType(func); } - fixUpFuncType(func); - } - - struct MetalSystemValueInfo - { - String metalSystemValueName; - IRType* requiredType; - IRType* altRequiredType; - bool isUnsupported; - bool isSpecial; - }; - IRType* getGroupThreadIdType(IRBuilder& builder) - { - return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); - } - - MetalSystemValueInfo getSystemValueInfo(IRBuilder& builder, String inSemanticName) - { - inSemanticName = inSemanticName.toLower(); - - UnownedStringSlice semanticName; - UnownedStringSlice semanticIndex; - splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex); + void packStageInParameters(EntryPointInfo entryPoint) + { + // If the entry point has any parameters whose layout contains VaryingInput, + // we need to pack those parameters into a single `struct` type, and decorate + // the fields with the appropriate `[[attribute]]` decorations. + // For other parameters that are not `VaryingInput`, we need to leave them as is. + // + // For example, given this code after `hoistEntryPointParameterFromStruct`: + // ``` + // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) { + // VertexInput vin = {pos,uv,vertexId}; + // ... + // } + // ``` + // We are going to transform it into: + // ``` + // struct VertexInput { + // float3 pos [[attribute(0)]]; + // float2 uv [[attribute(1)]]; + // }; + // void main(VertexInput vin, int vertexId : SV_VertexID) { + // let pos = vin.pos; + // let uv = vin.uv; + // ... + // } + + auto func = entryPoint.entryPointFunc; + + bool isGeometryStage = false; + switch (entryPoint.entryPointDecor->getProfile().getStage()) + { + case Stage::Vertex: + case Stage::Amplification: + case Stage::Mesh: + case Stage::Geometry: + case Stage::Domain: + case Stage::Hull: + isGeometryStage = true; + break; + } - MetalSystemValueInfo result = {}; + List<IRParam*> paramsToPack; + for (auto param : func->getParams()) + { + auto layout = findVarLayout(param); + if (!layout) + continue; + if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + continue; + if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + continue; + paramsToPack.add(param); + } - if (semanticName == "sv_position") - { - result.metalSystemValueName = toSlice("position"); - result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 4)); - } - else if (semanticName == "sv_clipdistance") - { - result.isSpecial = true; - } - else if (semanticName == "sv_culldistance") - { - result.isSpecial = true; - } - else if (semanticName == "sv_coverage") - { - result.metalSystemValueName = toSlice("sample_mask"); - result.requiredType = builder.getBasicType(BaseType::UInt); + if (paramsToPack.getCount() == 0) + return; + + IRBuilder builder(func); + builder.setInsertBefore(func); + IRStructType* structType = builder.createStructType(); + auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); + builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Input")).getUnownedSlice()); + List<IRStructKey*> keys; + IRStructTypeLayout::Builder layoutBuilder(&builder); + for (auto param : paramsToPack) + { + auto paramVarLayout = findVarLayout(param); + auto key = builder.createStructKey(); + param->transferDecorationsTo(key); + builder.createStructField(structType, key, param->getDataType()); + if (auto varyingInOffsetAttr = paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + { + if (!key->findDecoration<IRSemanticDecoration>() && !paramVarLayout->findAttr<IRSemanticAttr>()) + { + // If the parameter doesn't have a semantic, we need to add one for semantic matching. + builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varyingInOffsetAttr->getOffset()); + } + } + if (isGeometryStage) + { + // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute offsets. + IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout()); + elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout); + for (auto offsetAttr : paramVarLayout->getOffsetAttrs()) + { + auto resourceKind = offsetAttr->getResourceKind(); + if (resourceKind == LayoutResourceKind::VaryingInput) + { + resourceKind = LayoutResourceKind::MetalAttribute; + } + auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind); + resInfo->offset = offsetAttr->getOffset(); + resInfo->space = offsetAttr->getSpace(); + } + paramVarLayout = elementVarLayoutBuilder.build(); + } + layoutBuilder.addField(key, paramVarLayout); + builder.addLayoutDecoration(key, paramVarLayout); + keys.add(key); + } + builder.setInsertInto(func->getFirstBlock()); + auto packedParam = builder.emitParamAtHead(structType); + auto typeLayout = layoutBuilder.build(); + IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); + + // Add a VaryingInput resource info to the packed parameter layout, so that we can emit + // the needed `[[stage_in]]` attribute in Metal emitter. + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(packedParam, paramVarLayout); + + // Replace the original parameters with the packed parameter + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++) + { + auto param = paramsToPack[paramIndex]; + auto key = keys[paramIndex]; + auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key); + param->replaceUsesWith(paramField); + param->removeFromParent(); + } + fixUpFuncType(func); } - else if (semanticName == "sv_innercoverage") - { - result.isSpecial = true; - } - else if (semanticName == "sv_depth") - { - result.metalSystemValueName = toSlice("depth(any)"); - result.requiredType = builder.getBasicType(BaseType::Float); - } - else if (semanticName == "sv_depthgreaterequal") + struct MetalSystemValueInfo { - result.metalSystemValueName = toSlice("depth(greater)"); - result.requiredType = builder.getBasicType(BaseType::Float); - } - else if (semanticName == "sv_depthlessequal") - { - result.metalSystemValueName = toSlice("depth(less)"); - result.requiredType = builder.getBasicType(BaseType::Float); - } - else if (semanticName == "sv_dispatchthreadid") - { - result.metalSystemValueName = toSlice("thread_position_in_grid"); - result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); - } - else if (semanticName == "sv_domainlocation") - { - result.metalSystemValueName = toSlice("position_in_patch"); - result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 3)); - result.altRequiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 2)); - } - else if (semanticName == "sv_groupid") - { - result.metalSystemValueName = toSlice("threadgroup_position_in_grid"); - result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); - } - else if (semanticName == "sv_groupindex") - { - result.isSpecial = true; - } - else if (semanticName == "sv_groupthreadid") - { - result.metalSystemValueName = toSlice("thread_position_in_threadgroup"); - result.requiredType = getGroupThreadIdType(builder); - } - else if (semanticName == "sv_gsinstanceid") + String metalSystemValueName; + SystemValueSemanticName metalSystemValueNameEnum; + ShortList<IRType*> permittedTypes; + bool isUnsupported = false; + bool isSpecial = false; + MetalSystemValueInfo() + { + // most commonly need 2 + permittedTypes.reserveOverflowBuffer(2); + } + }; + + IRType* getGroupThreadIdType(IRBuilder& builder) { - // Metal does not have geometry shader, so this is invalid. - result.isUnsupported = true; + return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); } - else if (semanticName == "sv_instanceid") + + // Get all permitted types of "sv_target" for Metal + ShortList<IRType*>& getPermittedTypes_sv_target(IRBuilder& builder) { - result.metalSystemValueName = toSlice("instance_id"); - result.requiredType = builder.getBasicType(BaseType::UInt); + permittedTypes_sv_target.reserveOverflowBuffer(5 * 4); + if (permittedTypes_sv_target.getCount() == 0) + { + for (auto baseType : { BaseType::Float, BaseType::Half, BaseType::Int, BaseType::UInt, BaseType::Int16, BaseType::UInt16 }) + { + for (IRIntegerValue i = 1; i <= 4; i++) + { + permittedTypes_sv_target.add(builder.getVectorType(builder.getBasicType(baseType), i)); + } + } + } + return permittedTypes_sv_target; } - else if (semanticName == "sv_isfrontface") + + MetalSystemValueInfo getSystemValueInfo(String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar) { - result.metalSystemValueName = toSlice("front_facing"); - result.requiredType = builder.getBasicType(BaseType::Bool); + IRBuilder builder(m_module); + MetalSystemValueInfo result = {}; + UnownedStringSlice semanticName; + UnownedStringSlice semanticIndex; + + auto hasExplicitIndex = splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex); + if (!hasExplicitIndex && optionalSemanticIndex) + semanticIndex = optionalSemanticIndex->getUnownedSlice(); + + result.metalSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName); + + switch (result.metalSystemValueNameEnum) + { + case SystemValueSemanticName::Position: + { + result.metalSystemValueName = toSlice("position"); + result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 4))); + break; + } + case SystemValueSemanticName::ClipDistance: + { + result.isSpecial = true; + break; + } + case SystemValueSemanticName::CullDistance: + { + result.isSpecial = true; + break; + } + case SystemValueSemanticName::Coverage: + { + result.metalSystemValueName = toSlice("sample_mask"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::InnerCoverage: + { + result.isSpecial = true; + break; + } + case SystemValueSemanticName::Depth: + { + result.metalSystemValueName = toSlice("depth(any)"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + break; + } + case SystemValueSemanticName::DepthGreaterEqual: + { + result.metalSystemValueName = toSlice("depth(greater)"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + break; + } + case SystemValueSemanticName::DepthLessEqual: + { + result.metalSystemValueName = toSlice("depth(less)"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + break; + } + case SystemValueSemanticName::DispatchThreadID: + { + result.metalSystemValueName = toSlice("thread_position_in_grid"); + result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3))); + break; + } + case SystemValueSemanticName::DomainLsocation: + { + result.metalSystemValueName = toSlice("position_in_patch"); + result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 3))); + result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 2))); + break; + } + case SystemValueSemanticName::GroupID: + { + result.metalSystemValueName = toSlice("threadgroup_position_in_grid"); + result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3))); + break; + } + case SystemValueSemanticName::GroupIndex: + { + result.isSpecial = true; + break; + } + case SystemValueSemanticName::GroupThreadID: + { + result.metalSystemValueName = toSlice("thread_position_in_threadgroup"); + result.permittedTypes.add(getGroupThreadIdType(builder)); + break; + } + case SystemValueSemanticName::GSInstanceID: + { + result.isUnsupported = true; + break; + } + case SystemValueSemanticName::InstanceID: + { + result.metalSystemValueName = toSlice("instance_id"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::IsFrontFace: + { + result.metalSystemValueName = toSlice("front_facing"); + result.permittedTypes.add(builder.getBasicType(BaseType::Bool)); + break; + } + case SystemValueSemanticName::OutputControlPointID: + { + // In metal, a hull shader is just a compute shader. + // This needs to be handled separately, by lowering into an ordinary buffer. + break; + } + case SystemValueSemanticName::PointSize: + { + result.metalSystemValueName = toSlice("point_size"); + result.permittedTypes.add(builder.getBasicType(BaseType::Float)); + break; + } + case SystemValueSemanticName::PrimitiveID: + { + result.metalSystemValueName = toSlice("patch_id"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); + break; + } + case SystemValueSemanticName::RenderTargetArrayIndex: + { + result.metalSystemValueName = toSlice("render_target_array_index"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); + break; + } + case SystemValueSemanticName::SampleIndex: + { + result.metalSystemValueName = toSlice("sample_id"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::StencilRef: + { + result.metalSystemValueName = toSlice("stencil"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::TessFactor: + { + // Tessellation factor outputs should be lowered into a write into a normal buffer. + break; + } + case SystemValueSemanticName::VertexID: + { + result.metalSystemValueName = toSlice("vertex_id"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + break; + } + case SystemValueSemanticName::ViewID: + { + result.isUnsupported = true; + break; + } + case SystemValueSemanticName::ViewportArrayIndex: + { + result.metalSystemValueName = toSlice("viewport_array_index"); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); + result.permittedTypes.add(builder.getBasicType(BaseType::UInt16)); + break; + } + case SystemValueSemanticName::Target: + { + result.metalSystemValueName = (StringBuilder() << "color(" + << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0")) + << ")").produceString(); + result.permittedTypes = getPermittedTypes_sv_target(builder); + + break; + } + default: + m_sink->diagnose(parentVar, Diagnostics::unimplementedSystemValueSemantic, semanticName); + return result; + } + return result; } - else if (semanticName == "sv_outputcontrolpointid") + + void reportUnsupportedSystemAttribute(IRInst* param, String semanticName) { - // In metal, a hull shader is just a compute shader. - // This needs to be handled separately, by lowering into an ordinary buffer. + m_sink->diagnose(param->sourceLoc, Diagnostics::systemValueAttributeNotSupported, semanticName); } - else if (semanticName == "sv_pointsize") + + void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout) { - result.metalSystemValueName = toSlice("point_size"); - result.requiredType = builder.getBasicType(BaseType::Float); + // Ensure each field in an output struct type has either a system semantic or a user semantic, + // so that signature matching can happen correctly. + auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout()); + Index index = 0; + IRBuilder builder(structType); + for (auto field : structType->getFields()) + { + auto key = field->getKey(); + if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>()) + { + if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + { + auto indexAsString = String(UInt(semanticDecor->getSemanticIndex())); + auto sysValInfo = getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field); + if (sysValInfo.isUnsupported || sysValInfo.isSpecial) + { + reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName()); + } + else + { + builder.addTargetSystemValueDecoration(key, sysValInfo.metalSystemValueName.getUnownedSlice()); + semanticDecor->removeAndDeallocate(); + } + } + index++; + continue; + } + typeLayout->getFieldLayout(index); + auto fieldLayout = typeLayout->getFieldLayout(index); + if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) + { + UInt varOffset = 0; + if (auto varOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) + varOffset = varOffsetAttr->getOffset(); + varOffset += offsetAttr->getOffset(); + builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset); + } + index++; + } } - else if (semanticName == "sv_primitiveid") + + // Stores a hicharchy of members and children which map 'oldStruct->member' to 'flatStruct->member' + // Note: this map assumes we map to FlatStruct since it is easier/faster to process + struct MapStructToFlatStruct { - result.metalSystemValueName = toSlice("patch_id"); - result.requiredType = builder.getBasicType(BaseType::UInt); - result.altRequiredType = builder.getBasicType(BaseType::UInt16); - } - else if (semanticName == "sv_rendertargetarrayindex") + /* + We need a hicharchy map to resolve dependencies for mapping + oldStruct to newStruct efficently. Example: + + MyStruct + | + / | \ + / | \ + / | \ + M0<A> M1<A> M2<B> + | | | + A_0 A_0 B_0 + + Without storing hicharchy information, there will be no way to tell apart + `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField + only has 1 instance of `A::A0` + */ + + enum CopyOptions : int { - result.metalSystemValueName = toSlice("render_target_array_index"); - result.requiredType = builder.getBasicType(BaseType::UInt); - result.altRequiredType = builder.getBasicType(BaseType::UInt16); - } - else if (semanticName == "sv_sampleindex") + // Copy a flattened-struct into a struct + FlatStructIntoStruct = 0, + + // Copy a struct into a flattened-struct + StructIntoFlatStruct = 1, + }; + + private: + // Children of member if applicable. + Dictionary<IRStructField*, MapStructToFlatStruct> members; + + // Field correlating to MapStructToFlatStruct Node. + IRInst* node; + IRStructKey* getKey() + { + SLANG_ASSERT(as<IRStructField>(node)); + return as<IRStructField>(node)->getKey(); + } + IRInst* getNode() + { + return node; + } + IRType* getFieldType() + { + SLANG_ASSERT(as<IRStructField>(node)); + return as<IRStructField>(node)->getFieldType(); + } + + // Whom node maps to inside target flatStruct + IRStructField* targetMapping; + + auto begin() + { + return members.begin(); + } + auto end() + { + return members.end(); + } + + // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to members in val2 + // using `MapStructToFlatStruct` + template<int copyOptions> + static void _emitCopy(IRBuilder& builder, IRInst* val1, IRStructType* type1, IRInst* val2, IRStructType* type2, MapStructToFlatStruct& node) + { + for (auto& field1Pair : node) + { + auto& field1 = field1Pair.second; + + // Get member of val1 + IRInst* fieldAddr1 = nullptr; + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey()); + } + else + { + if (as<IRPtrTypeBase>(val1)) + val1 = builder.emitLoad(val1); + fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey()); + } + + // If val1 is a struct, recurse + if (auto fieldAsStruct1 = as<IRStructType>(field1.getFieldType())) + { + _emitCopy<copyOptions>(builder, fieldAddr1, fieldAsStruct1, val2, type2, field1); + continue; + } + + // Get member of val2 which maps to val1.member + auto field2 = field1.getMapping(); + SLANG_ASSERT(field2); + IRInst* fieldAddr2 = nullptr; + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + if (as<IRPtrTypeBase>(val2)) + val2 = builder.emitLoad(val1); + fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey()); + } + else + { + fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey()); + } + + // Copy val2/val1 member into val1/val2 member + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + builder.emitStore(fieldAddr1, fieldAddr2); + } + else + { + builder.emitStore(fieldAddr2, fieldAddr1); + } + } + } + + public: + void setNode(IRInst* newNode) + { + node = newNode; + } + // Get 'MapStructToFlatStruct' that is a child of 'parent'. + // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'. + MapStructToFlatStruct& getMember(IRStructField* member) + { + return members[member]; + } + MapStructToFlatStruct& operator[](IRStructField* member) + { + return getMember(member); + } + + void setMapping(IRStructField* newTargetMapping) + { + targetMapping = newTargetMapping; + } + // Get 'MapStructToFlatStruct' that is a child of 'parent'. + // Return nullptr if no member is mapped to 'parent' + IRStructField* getMapping() + { + return targetMapping; + } + + // Copies srcVal into dstVal using hicharchy map. + template<int copyOptions> + void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal) + { + auto dstType = dstVal->getDataType(); + if (auto dstPtrType = as<IRPtrTypeBase>(dstType)) + dstType = dstPtrType->getValueType(); + auto dstStructType = as<IRStructType>(dstType); + SLANG_ASSERT(dstStructType); + + auto srcType = srcVal->getDataType(); + if (auto srcPtrType = as<IRPtrTypeBase>(srcType)) + srcType = srcPtrType->getValueType(); + auto srcStructType = as<IRStructType>(srcType); + SLANG_ASSERT(srcStructType); + + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a struct + SLANG_ASSERT(node == dstStructType); + _emitCopy<copyOptions>(builder, dstVal, dstStructType, srcVal, srcStructType, *this); + } + else + { + // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct + SLANG_ASSERT(node == srcStructType); + _emitCopy<copyOptions>(builder, srcVal, srcStructType, dstVal, dstStructType, *this); + } + } + }; + + IRStructType* _flattenNestedStructs(IRBuilder& builder, IRStructType* dst, IRStructType* src, IRSemanticDecoration* parentSemanticDecoration, + IRLayoutDecoration* parentLayout, MapStructToFlatStruct& mapFieldToField, HashSet<IRStructField*>& varsWithSemanticInfo) { - result.metalSystemValueName = toSlice("sample_id"); - result.requiredType = builder.getBasicType(BaseType::UInt); + // For all fields ('oldField') of a struct do the following: + // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration, IRLayoutDecoration), store these if found. + // * Do not propagate semantic info if the current node has *any* form of semantic information. + // Update varsWithSemanticInfo. + // 2. If IRStructType: + // 2a. Recurse this function with 'decorations that carry semantic info' from parent. + // 3. If not IRStructType: + // 3a. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic info'. + // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is needed to copy between types. + for (auto oldField : src->getFields()) + { + auto& fieldMappingNode = mapFieldToField[oldField]; + fieldMappingNode.setNode(oldField); + + // step 1 + bool foundSemanticDecor = false; + auto oldKey = oldField->getKey(); + IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration; + if (auto oldSemanticDecoration = oldKey->findDecoration<IRSemanticDecoration>()) + { + foundSemanticDecor = true; + fieldSemanticDecoration = oldSemanticDecoration; + parentLayout = nullptr; + } + + IRLayoutDecoration* fieldLayout = parentLayout; + if (auto oldLayout = oldKey->findDecoration<IRLayoutDecoration>()) + { + fieldLayout = oldLayout; + if (!foundSemanticDecor) + fieldSemanticDecoration = nullptr; + } + if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout) + varsWithSemanticInfo.add(oldField); + + // step 2a + if (auto structFieldType = as<IRStructType>(oldField->getFieldType())) + { + _flattenNestedStructs(builder, dst, structFieldType, fieldSemanticDecoration, fieldLayout, fieldMappingNode, varsWithSemanticInfo); + continue; + } + + // step 3a + auto newKey = builder.createStructKey(); + copyNameHintAndDebugDecorations(newKey, oldKey); + + auto newField = builder.createStructField(dst, newKey, oldField->getFieldType()); + copyNameHintAndDebugDecorations(newField, oldField); + + if (fieldSemanticDecoration) + builder.addSemanticDecoration(newKey, fieldSemanticDecoration->getSemanticName(), fieldSemanticDecoration->getSemanticIndex()); + + if (fieldLayout) + { + IRLayout* oldLayout = fieldLayout->getLayout(); + List<IRInst*> instToCopy; + // Only copy certain decorations needed for resolving system semantics + for (UInt i = 0; i < oldLayout->getOperandCount(); i++) + { + auto operand = oldLayout->getOperand(i); + if (as<IRVarOffsetAttr>(operand) || as<IRUserSemanticAttr>(operand) || as<IRSystemValueSemanticAttr>(operand) || as<IRStageAttr>(operand)) + instToCopy.add(operand); + } + IRVarLayout* newLayout = builder.getVarLayout(instToCopy); + builder.addLayoutDecoration(newKey, newLayout); + } + // step 3b + fieldMappingNode.setMapping(newField); + } + + return dst; } - else if (semanticName == "sv_stencilref") + + // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there was no struct flattening. + // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct `IRStructFields`s + IRStructType* maybeFlattenNestedStructs(IRBuilder& builder, IRStructType* src, MapStructToFlatStruct& mapFieldToField, HashSet<IRStructField*>& varsWithSemanticInfo) { - result.metalSystemValueName = toSlice("stencil"); - result.requiredType = builder.getBasicType(BaseType::UInt); + // Find all values inside struct that need flattening and legalization. + bool hasStructTypeMembers = false; + for (auto field : src->getFields()) + { + if (as<IRStructType>(field->getFieldType())) + { + hasStructTypeMembers = true; + break; + } + } + if (!hasStructTypeMembers) + return src; + + // We need to: + // 1. Make new struct 1:1 with old struct but without nestested structs (flatten) + // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be handled later). + // 3. Store the mapping from old to new struct fields to allow copying a old-struct to new-struct. + builder.setInsertAfter(src); + auto newStruct = builder.createStructType(); + copyNameHintAndDebugDecorations(newStruct, src); + mapFieldToField.setNode(src); + return _flattenNestedStructs(builder, newStruct, src, nullptr, nullptr, mapFieldToField, varsWithSemanticInfo); } - else if (semanticName == "sv_tessfactor") + + // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'. + // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function. + template<typename CopyLogicFunc> + void _replaceAllReturnInst(IRBuilder& builder, IRFunc* targetFunc, IRStructType* newType, CopyLogicFunc copyLogicFunc) { - // Tessellation factor outputs should be lowered into a write into a normal buffer. + for (auto block : targetFunc->getBlocks()) + { + if (auto returnInst = as<IRReturn>(block->getTerminator())) + { + builder.setInsertBefore(returnInst); + auto returnVal = returnInst->getVal(); + returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal)); + } + } } - else if (semanticName == "sv_vertexid") + + UInt _returnNonOverlappingAttributeIndex(std::set<UInt>& usedSemanticIndex) { - result.metalSystemValueName = toSlice("vertex_id"); - result.requiredType = builder.getBasicType(BaseType::UInt); + // Find first unused semantic index of equal semantic type + // to fill any gaps in user set semantic bindings + UInt prev = 0; + for (auto i : usedSemanticIndex) + { + if (i > prev + 1) + { + break; + } + prev = i; + } + usedSemanticIndex.insert(prev + 1); + return prev + 1; } - else if (semanticName == "sv_viewid") + + template<typename T> + struct AttributeParentPair { - result.isUnsupported = true; - } - else if (semanticName == "sv_viewportarrayindex") + IRLayoutDecoration* layoutDecor; + T* attr; + }; + + IRLayoutDecoration* _replaceAttributeOfLayout(IRBuilder& builder, IRLayoutDecoration* parentLayoutDecor, IRInst* instToReplace, IRInst* instToReplaceWith) { - result.metalSystemValueName = toSlice("viewport_array_index"); - result.requiredType = builder.getBasicType(BaseType::UInt); - result.altRequiredType = builder.getBasicType(BaseType::UInt16); + // Replace `instToReplace` with a `instToReplaceWith` + + auto layout = parentLayoutDecor->getLayout(); + // Find the exact same decoration `instToReplace` in-case multiple of the same type exist + List<IRInst*> opList; + opList.add(instToReplaceWith); + for (UInt i = 0; i < layout->getOperandCount(); i++) + { + if (layout->getOperand(i) != instToReplace) + opList.add(layout->getOperand(i)); + } + auto newLayoutDecor = builder.addLayoutDecoration(parentLayoutDecor->getParent(), builder.getVarLayout(opList)); + parentLayoutDecor->removeAndDeallocate(); + return newLayoutDecor; } - else if (semanticName.startsWith("sv_target")) + + IRLayoutDecoration* _simplifyUserSemanticNames(IRBuilder& builder, IRLayoutDecoration* layoutDecor) { - - result.metalSystemValueName = (StringBuilder() << "color(" - << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0")) - << ")").produceString(); + // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into ("SV_TARGET", 0) using 'IRUserSemanticAttr' + // This is done to ensure we can check semantic groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()' + SLANG_ASSERT(layoutDecor); + auto layout = layoutDecor->getLayout(); + List<IRInst*> layoutOps; + layoutOps.reserve(3); + bool changed = false; + for(auto attr : layout->getAllAttrs()) + { + if(auto userSemantic = as<IRUserSemanticAttr>(attr)) + { + UnownedStringSlice outName; + UnownedStringSlice outIndex; + bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex); + if (hasStringIndex) + { + changed = true; + auto loweredName = String(outName).toLower(); + auto loweredNameSlice = loweredName.getUnownedSlice(); + auto newDecoration = builder.getUserSemanticAttr(loweredNameSlice, stringToInt(outIndex)); + userSemantic->replaceUsesWith(newDecoration); + userSemantic->removeAndDeallocate(); + userSemantic = newDecoration; + } + layoutOps.add(userSemantic); + continue; + } + layoutOps.add(attr); + } + if(changed) + { + auto parent = layoutDecor->parent; + layoutDecor->removeAndDeallocate(); + builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps)); + } + return layoutDecor; } - else + // Find overlapping field semantics and legalize them + void fixFieldSemanticsOfFlatStruct(IRStructType* structType) { - result.isUnsupported = true; - } - return result; - } + // Goal is to ensure we do not have overlapping semantics: + /* + // Assume the following code + struct Fragment + { + float4 p1 : SV_TARGET; + float3 p2 : SV_TARGET; + float2 p3 : SV_TARGET; + float2 p4 : SV_TARGET; + }; + + // Translates into + struct Fragment + { + float4 p1 : SV_TARGET0; + float3 p2 : SV_TARGET1; + float2 p3 : SV_TARGET2; + float2 p4 : SV_TARGET3; + }; + */ - void reportUnsupportedSystemAttribute(DiagnosticSink* sink, IRInst* param, String semanticName) - { - sink->diagnose(param->sourceLoc, Diagnostics::systemValueAttributeNotSupported, semanticName); - } + IRBuilder builder(this->m_module); - void ensureResultStructHasUserSemantic(DiagnosticSink* sink, IRStructType* structType, IRVarLayout* varLayout) - { - // Ensure each field in an output struct type has either a system semantic or a user semantic, - // so that signature matching can happen correctly. - auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout()); - Index index = 0; - IRBuilder builder(structType); - for (auto field : structType->getFields()) - { - auto key = field->getKey(); - if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>()) + List<IRSemanticDecoration*> overlappingSemanticsDecor; + Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>> usedSemanticIndexSemanticDecor; + + List<AttributeParentPair<IRVarOffsetAttr>> overlappingVarOffset; + Dictionary<UInt, std::set<UInt, std::less<UInt>>> usedSemanticIndexVarOffset; + + List<AttributeParentPair<IRUserSemanticAttr>> overlappingUserSemantic; + Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt> >> usedSemanticIndexUserSemantic; + + // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when legalizing + // we may destroy and remake a `IRLayoutDecoration*` + Dictionary<IRLayoutDecoration*, IRLayoutDecoration*> oldLayoutDecorToNew; + + // Collect all "semantic info carrying decorations". Any collected decoration will + // fill up their respective 'Dictionary<SEMANTIC_TYPE, OrderedHashSet<UInt>>' + // to keep track of in-use offsets for a semantic type. + // Example: IRSemanticDecoration with name of "SV_TARGET1". + // * This will have SEMANTIC_TYPE of "sv_target". + // * This will use up index '1' + // + // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to + // a list of 'overlapping semantic info decorations' so we can legalize this + // 'semantic info decoration' later. + // + // NOTE: this is a flat struct, all members are children of the initial + // IRStructType. + for (auto field : structType->getFields()) { - if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + auto key = field->getKey(); + if (auto semanticDecoration = key->findDecoration<IRSemanticDecoration>()) { - auto sysValInfo = getSystemValueInfo(builder, semanticDecor->getSemanticName()); - if (sysValInfo.isUnsupported || sysValInfo.isSpecial) + // Ensure names are in a uniform lowercase format so we can bunch together simmilar semantics + UnownedStringSlice outName; + UnownedStringSlice outIndex; + bool hasStringIndex = splitNameAndIndex(semanticDecoration->getSemanticName(), outName, outIndex); + if (hasStringIndex) { - reportUnsupportedSystemAttribute(sink, field, semanticDecor->getSemanticName()); + auto loweredName = String(outName).toLower(); + auto loweredNameSlice = loweredName.getUnownedSlice(); + auto newDecoration = builder.addSemanticDecoration(key, loweredNameSlice, stringToInt(outIndex)); + semanticDecoration->replaceUsesWith(newDecoration); + semanticDecoration->removeAndDeallocate(); + semanticDecoration = newDecoration; } + auto& semanticUse = usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()]; + if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end()) + overlappingSemanticsDecor.add(semanticDecoration); else + semanticUse.insert(semanticDecoration->getSemanticIndex()); + } + if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>()) + { + // Ensure names are in a uniform lowercase format so we can bunch together simmilar semantics + layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor); + oldLayoutDecorToNew[layoutDecor] = layoutDecor; + auto layout = layoutDecor->getLayout(); + for (auto attr : layout->getAllAttrs()) { - builder.addTargetSystemValueDecoration(key, sysValInfo.metalSystemValueName.getUnownedSlice()); - semanticDecor->removeAndDeallocate(); + if (auto offset = as<IRVarOffsetAttr>(attr)) + { + auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()]; + if (semanticUse.find(offset->getOffset()) != semanticUse.end()) + overlappingVarOffset.add({ layoutDecor, offset }); + else + semanticUse.insert(offset->getOffset()); + } + else if (auto userSemantic = as<IRUserSemanticAttr>(attr)) + { + auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()]; + if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end()) + overlappingUserSemantic.add({ layoutDecor, userSemantic }); + else + semanticUse.insert(userSemantic->getIndex()); + } } } - index++; - continue; } - typeLayout->getFieldLayout(index); - auto fieldLayout = typeLayout->getFieldLayout(index); - if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) + + // Legalize all overlapping 'semantic info decorations' + for (auto decor : overlappingSemanticsDecor) { - UInt varOffset = 0; - if (auto varOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) - varOffset = varOffsetAttr->getOffset(); - varOffset += offsetAttr->getOffset(); - builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset); + auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexSemanticDecor[decor->getSemanticName()]); + builder.addSemanticDecoration(decor->getParent(), decor->getSemanticName(), (int)newOffset); + decor->removeAndDeallocate(); + } + for (auto& varOffset : overlappingVarOffset) + { + auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]); + auto newVarOffset = builder.getVarOffsetAttr(varOffset.attr->getResourceKind(), newOffset, varOffset.attr->getSpace()); + oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout(builder, oldLayoutDecorToNew[varOffset.layoutDecor], varOffset.attr, newVarOffset); + } + for (auto& userSemantic : overlappingUserSemantic) + { + auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexUserSemantic[userSemantic.attr->getName()]); + auto newUserSemantic = builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset); + oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout(builder, oldLayoutDecorToNew[userSemantic.layoutDecor], userSemantic.attr, newUserSemantic); } - index++; - } - } - - - void wrapReturnValueInStruct(DiagnosticSink* sink, EntryPointInfo entryPoint) - { - // Wrap return value into a struct if it is not already a struct. - // For example, given this entry point: - // ``` - // float4 main() : SV_Target { return float3(1,2,3); } - // ``` - // We are going to transform it into: - // ``` - // struct Output { - // float4 value : SV_Target; - // }; - // Output main() { return {float3(1,2,3)}; } - - auto func = entryPoint.entryPointFunc; - - auto returnType = func->getResultType(); - if (as<IRVoidType>(returnType)) - return; - auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>(); - if (!entryPointLayoutDecor) - return; - auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout()); - if (!entryPointLayout) - return; - auto resultLayout = entryPointLayout->getResultLayout(); - - // If return type is already a struct, just make sure every field has a semantic. - if (auto returnStructType = as<IRStructType>(returnType)) - { - ensureResultStructHasUserSemantic(sink, returnStructType, resultLayout); - return; } - // If not, we need to wrap the result into a struct type. - IRBuilder builder(func); - builder.setInsertBefore(func); - IRStructType* structType = builder.createStructType(); - auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); - builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Output")).getUnownedSlice()); - auto key = builder.createStructKey(); - builder.addNameHintDecoration(key, toSlice("output")); - builder.addLayoutDecoration(key, resultLayout); - builder.addTargetSystemValueDecoration(key, toSlice("color(0)")); - builder.createStructField(structType, key, returnType); - IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder); - structTypeLayoutBuilder.addField(key, resultLayout); - auto typeLayout = structTypeLayoutBuilder.build(); - IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); - auto varLayout = varLayoutBuilder.build(); - ensureResultStructHasUserSemantic(sink, structType, varLayout); - - for (auto block : func->getBlocks()) + void wrapReturnValueInStruct(EntryPointInfo entryPoint) { - if (auto returnInst = as<IRReturn>(block->getTerminator())) + // Wrap return value into a struct if it is not already a struct. + // For example, given this entry point: + // ``` + // float4 main() : SV_Target { return float3(1,2,3); } + // ``` + // We are going to transform it into: + // ``` + // struct Output { + // float4 value : SV_Target; + // }; + // Output main() { return {float3(1,2,3)}; } + + auto func = entryPoint.entryPointFunc; + + auto returnType = func->getResultType(); + if (as<IRVoidType>(returnType)) + return; + auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>(); + if (!entryPointLayoutDecor) + return; + auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout()); + if (!entryPointLayout) + return; + auto resultLayout = entryPointLayout->getResultLayout(); + + // If return type is already a struct, just make sure every field has a semantic. + if (auto returnStructType = as<IRStructType>(returnType)) { - builder.setInsertBefore(returnInst); - auto returnVal = returnInst->getVal(); - auto newResult = builder.emitMakeStruct(structType, 1, &returnVal); - returnInst->setOperand(0, newResult); + IRBuilder builder(func); + MapStructToFlatStruct mapOldFieldToNewField; + // Flatten result struct type to ensure we do not have nested semantics + auto flattenedStruct = maybeFlattenNestedStructs(builder, returnStructType, mapOldFieldToNewField, semanticInfoToRemove); + if (returnStructType != flattenedStruct) + { + // Replace all return-values with the flattenedStruct we made. + _replaceAllReturnInst(builder, func, flattenedStruct, + [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* + { + auto srcStructType = as<IRStructType>(srcVal->getDataType()); + SLANG_ASSERT(srcStructType); + auto dstVal = copyBuilder.emitVar(dstType); + mapOldFieldToNewField.emitCopy<(int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>(copyBuilder, dstVal, srcVal); + return builder.emitLoad(dstVal); + } + ); + fixUpFuncType(func, flattenedStruct); + } + // Ensure non-overlapping semantics + fixFieldSemanticsOfFlatStruct(flattenedStruct); + ensureResultStructHasUserSemantic(flattenedStruct, resultLayout); + return; } - } - fixUpFuncType(func, structType); - } - void legalizeMeshEntryPoint(EntryPointInfo entryPoint) - { - auto func = entryPoint.entryPointFunc; + IRBuilder builder(func); + builder.setInsertBefore(func); + IRStructType* structType = builder.createStructType(); + auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); + builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Output")).getUnownedSlice()); + auto key = builder.createStructKey(); + builder.addNameHintDecoration(key, toSlice("output")); + builder.addLayoutDecoration(key, resultLayout); + builder.createStructField(structType, key, returnType); + IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder); + structTypeLayoutBuilder.addField(key, resultLayout); + auto typeLayout = structTypeLayoutBuilder.build(); + IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); + auto varLayout = varLayoutBuilder.build(); + ensureResultStructHasUserSemantic(structType, varLayout); + + _replaceAllReturnInst(builder, func, structType, + [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* + { + return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); + } + ); - if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh) - { - return; + // Assign an appropriate system value semantic for stage output + auto stage = entryPoint.entryPointDecor->getProfile().getStage(); + switch (stage) + { + case Stage::Compute: + case Stage::Fragment: + { + builder.addTargetSystemValueDecoration(key, toSlice("color(0)")); + break; + } + case Stage::Vertex: + { + builder.addTargetSystemValueDecoration(key, toSlice("position")); + break; + } + default: + SLANG_ASSERT(false); + return; + } + + fixUpFuncType(func, structType); } - IRBuilder builder{ entryPoint.entryPointFunc->getModule() }; - for (auto param : func->getParams()) + void legalizeMeshEntryPoint(EntryPointInfo entryPoint) { - if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) - { - IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build()); + auto func = entryPoint.entryPointFunc; - varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); - auto paramVarLayout = varLayoutBuilder.build(); - builder.addLayoutDecoration(param, paramVarLayout); + if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh) + { + return; } - } - } + IRBuilder builder{ entryPoint.entryPointFunc->getModule() }; + for (auto param : func->getParams()) + { + if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + { + IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build()); + + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(param, paramVarLayout); + } + } - void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint) - { - if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification) - { - return; } - // Find out DispatchMesh function - IRGlobalValueWithCode* dispatchMeshFunc = nullptr; - for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts()) + + void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint) { - if (const auto func = as<IRGlobalValueWithCode>(globalInst)) + if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification) + { + return; + } + // Find out DispatchMesh function + IRGlobalValueWithCode* dispatchMeshFunc = nullptr; + for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts()) { - if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>()) + if (const auto func = as<IRGlobalValueWithCode>(globalInst)) { - if (dec->getName() == "DispatchMesh") + if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>()) { - SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found"); - dispatchMeshFunc = func; + if (dec->getName() == "DispatchMesh") + { + SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found"); + dispatchMeshFunc = func; + } } } } - } - - if (!dispatchMeshFunc) - return; - IRBuilder builder{ entryPoint.entryPointFunc->getModule() }; + if (!dispatchMeshFunc) + return; - // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid - traverseUses(dispatchMeshFunc, [&](const IRUse* use) { - if (const auto call = as<IRCall>(use->getUser())) - { - SLANG_ASSERT(call->getArgCount() == 4); - const auto payload = call->getArg(3); + IRBuilder builder{ entryPoint.entryPointFunc->getModule() }; - const auto payloadPtrType = composeGetters<IRPtrType>( - payload, - &IRInst::getDataType - ); - SLANG_ASSERT(payloadPtrType); - const auto payloadType = payloadPtrType->getValueType(); - SLANG_ASSERT(payloadType); + // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid + traverseUses(dispatchMeshFunc, [&](const IRUse* use) { + if (const auto call = as<IRCall>(use->getUser())) + { + SLANG_ASSERT(call->getArgCount() == 4); + const auto payload = call->getArg(3); - builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - const auto annotatedPayloadType = - builder.getPtrType( - kIROp_RefType, - payloadPtrType->getValueType(), - AddressSpace::MetalObjectData + const auto payloadPtrType = composeGetters<IRPtrType>( + payload, + &IRInst::getDataType ); - auto packedParam = builder.emitParam(annotatedPayloadType); - builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload")); - IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build()); - - // Add the MetalPayload resource info, so we can emit [[payload]] - varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); - auto paramVarLayout = varLayoutBuilder.build(); - builder.addLayoutDecoration(packedParam, paramVarLayout); - - // Now we replace the call to DispatchMesh with a call to the mesh grid properties - // But first we need to create the parameter - const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType(); - auto mgp = builder.emitParam(meshGridPropertiesType); - builder.addExternCppDecoration(mgp, toSlice("_slang_mgp")); - } - }); - } + SLANG_ASSERT(payloadPtrType); + const auto payloadType = payloadPtrType->getValueType(); + SLANG_ASSERT(payloadType); + + builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + const auto annotatedPayloadType = + builder.getPtrType( + kIROp_RefType, + payloadPtrType->getValueType(), + AddressSpace::MetalObjectData + ); + auto packedParam = builder.emitParam(annotatedPayloadType); + builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload")); + IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build()); + + // Add the MetalPayload resource info, so we can emit [[payload]] + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(packedParam, paramVarLayout); + + // Now we replace the call to DispatchMesh with a call to the mesh grid properties + // But first we need to create the parameter + const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType(); + auto mgp = builder.emitParam(meshGridPropertiesType); + builder.addExternCppDecoration(mgp, toSlice("_slang_mgp")); + } + }); + } - IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType) - { - auto fromType = val->getFullType(); - if (auto fromVector = as<IRVectorType>(fromType)) + IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType) { - if (auto toVector = as<IRVectorType>(toType)) + auto fromType = val->getFullType(); + if (auto fromVector = as<IRVectorType>(fromType)) { - if (fromVector->getElementCount() != toVector->getElementCount()) + if (auto toVector = as<IRVectorType>(toType)) + { + if (fromVector->getElementCount() != toVector->getElementCount()) + { + fromType = builder.getVectorType(fromVector->getElementType(), toVector->getElementCount()); + val = builder.emitVectorReshape(fromType, val); + } + } + else if (as<IRBasicType>(toType)) { - fromType = builder.getVectorType(fromVector->getElementType(), toVector->getElementCount()); - val = builder.emitVectorReshape(fromType, val); + UInt index = 0; + val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); + if (toType->getOp() == kIROp_VoidType) + return nullptr; } } - else if (as<IRBasicType>(toType)) + else if (auto fromBasicType = as<IRBasicType>(fromType)) { - UInt index = 0; - val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); + if (fromBasicType->getOp() == kIROp_VoidType) + return nullptr; + if (!as<IRBasicType>(toType)) + return nullptr; if (toType->getOp() == kIROp_VoidType) return nullptr; } - } - else if (auto fromBasicType = as<IRBasicType>(fromType)) - { - if (fromBasicType->getOp() == kIROp_VoidType) - return nullptr; - if (!as<IRBasicType>(toType)) - return nullptr; - if (toType->getOp() == kIROp_VoidType) + else + { return nullptr; + } + return builder.emitCast(toType, val); } - else - { - return nullptr; - } - return builder.emitCast(toType, val); - } - - void legalizeSystemValueParameters(EntryPointInfo entryPoint, DiagnosticSink* sink) - { - SLANG_UNUSED(sink); struct SystemValLegalizationWorkItem { - IRParam* param; + IRInst* var; String attrName; UInt attrIndex; }; - List<SystemValLegalizationWorkItem> systemValWorkItems; - - IRBuilder builder(entryPoint.entryPointFunc); - for (auto param : entryPoint.entryPointFunc->getParams()) + std::optional<SystemValLegalizationWorkItem> tryToMakeSystemValWorkItem(IRInst* var) { - if (auto semanticDecoration = param->findDecoration<IRSemanticDecoration>()) + if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>()) { if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) { - systemValWorkItems.add({ param, String(semanticDecoration->getSemanticName()).toLower(), (UInt)semanticDecoration->getSemanticIndex() }); - continue; + return { { var, String(semanticDecoration->getSemanticName()).toLower(), (UInt)semanticDecoration->getSemanticIndex() } }; } } - auto layoutDecor = param->findDecoration<IRLayoutDecoration>(); + auto layoutDecor = var->findDecoration<IRLayoutDecoration>(); if (!layoutDecor) - continue; + return {}; auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>(); if (!sysValAttr) - continue; + return {}; auto semanticName = String(sysValAttr->getName()); auto sysAttrIndex = sysValAttr->getIndex(); - systemValWorkItems.add({ param, semanticName, sysAttrIndex }); + + return { { var, semanticName, sysAttrIndex } }; + } + + + List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint) + { + List<SystemValLegalizationWorkItem> systemValWorkItems; + for (auto param : entryPoint.entryPointFunc->getParams()) + { + auto maybeWorkItem = tryToMakeSystemValWorkItem(param); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + return systemValWorkItems; } - IRParam* groupThreadId = nullptr; - for (auto index = 0; index < systemValWorkItems.getCount(); index++) + void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem) { - auto workItem = systemValWorkItems[index]; + IRBuilder builder(entryPoint.entryPointFunc); - auto param = workItem.param; + auto var = workItem.var; auto semanticName = workItem.attrName; - auto info = getSystemValueInfo(builder, semanticName); + auto indexAsString = String(workItem.attrIndex); + auto info = getSystemValueInfo(semanticName, &indexAsString, var); + if (info.isSpecial) { - if (semanticName == "sv_innercoverage") + if (info.metalSystemValueNameEnum == SystemValueSemanticName::InnerCoverage) { // Metal does not support conservative rasterization, so this is always false. auto val = builder.getBoolValue(false); - param->replaceUsesWith(val); - param->removeAndDeallocate(); + var->replaceUsesWith(val); + var->removeAndDeallocate(); } - else if (semanticName == "sv_groupindex") + else if (info.metalSystemValueNameEnum == SystemValueSemanticName::GroupIndex) { - // Ensure we have a cached "sv_groupthreadid" - if (!groupThreadId) + // Ensure we have a cached "sv_groupthreadid" in our entry point + if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc)) { + auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint); for (auto i : systemValWorkItems) { - if (i.attrName == groupThreadIDString) + auto indexAsStringGroupThreadId = String(i.attrIndex); + if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var).metalSystemValueNameEnum == SystemValueSemanticName::GroupThreadID) { - groupThreadId = i.param; + entryPointToGroupThreadId[entryPoint.entryPointFunc] = i.var; } } - if (!groupThreadId) + if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc)) { // Add the missing groupthreadid needed to compute sv_groupindex IRBuilder groupThreadIdBuilder(builder); groupThreadIdBuilder.setInsertInto(entryPoint.entryPointFunc->getFirstBlock()); - groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder)); + auto groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder)); + entryPointToGroupThreadId[entryPoint.entryPointFunc] = groupThreadId; groupThreadIdBuilder.addNameHintDecoration(groupThreadId, groupThreadIDString); - + // Since "sv_groupindex" will be translated out to a global var and no longer be considered a system value // we can reuse its layout and semantic info Index foundRequiredDecorations = 0; IRLayoutDecoration* layoutDecoration = nullptr; UInt semanticIndex = 0; - for (auto decoration : param->getDecorations()) + for (auto decoration : var->getDecorations()) { if (auto layoutDecorationTmp = as<IRLayoutDecoration>(decoration)) { @@ -731,67 +1502,109 @@ namespace Slang SLANG_ASSERT(layoutDecoration); layoutDecoration->removeFromParent(); layoutDecoration->insertAtStart(groupThreadId); - systemValWorkItems.add({ groupThreadId, groupThreadIDString, semanticIndex }); + SystemValLegalizationWorkItem newWorkItem = { groupThreadId, groupThreadIDString, semanticIndex }; + legalizeSystemValue(entryPoint, newWorkItem); } } IRBuilder svBuilder(builder.getModule()); svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst()); auto computeExtent = emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, builder.getVectorType(builder.getUIntType(), builder.getIntValue(builder.getIntType(), 3))); - auto groupIndexCalc = emitCalcGroupThreadIndex(svBuilder, groupThreadId, computeExtent); + auto groupIndexCalc = emitCalcGroupIndex(svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], computeExtent); svBuilder.addNameHintDecoration(groupIndexCalc, UnownedStringSlice("sv_groupindex")); - param->replaceUsesWith(groupIndexCalc); - param->removeAndDeallocate(); + var->replaceUsesWith(groupIndexCalc); + var->removeAndDeallocate(); } } if (info.isUnsupported) { - reportUnsupportedSystemAttribute(sink, param, semanticName); - continue; + reportUnsupportedSystemAttribute(var, semanticName); + return; } - if (!info.requiredType) - continue; + if (!info.permittedTypes.getCount()) + return; - builder.addTargetSystemValueDecoration(param, info.metalSystemValueName.getUnownedSlice()); + builder.addTargetSystemValueDecoration(var, info.metalSystemValueName.getUnownedSlice()); - // If the required type is different from the actual type, we need to insert a conversion. - if (info.requiredType != param->getFullType() && info.altRequiredType != param->getFullType()) + bool varTypeIsPermitted = false; + auto varType = var->getFullType(); + for (auto& permittedType : info.permittedTypes) { - auto targetType = param->getFullType(); - builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - param->setFullType(info.requiredType); - List<IRUse*> uses; - for (auto use = param->firstUse; use; use = use->nextUse) - uses.add(use); - auto convertedValue = tryConvertValue(builder, param, targetType); - copyNameHintAndDebugDecorations(convertedValue, param); - if (!convertedValue) - { - // If we can't convert the value, report an error. - StringBuilder typeNameSB; - getTypeNameHint(typeNameSB, info.requiredType); - sink->diagnose(param->sourceLoc, Diagnostics::systemValueTypeIncompatible, semanticName, typeNameSB.produceString()); - } - else + varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; + } + + if (!varTypeIsPermitted) + { + // Note: we do not currently prefer any conversion + // example: + // * allowed types for semantic: `float4`, `uint4`, `int4` + // * user used, `float2` + // * Slang will equally prefer `float4` to `uint4` to `int4`. + // This means the type may lose data if slang selects `uint4` or `int4`. + bool foundAConversion = false; + for (auto permittedType : info.permittedTypes) { + var->setFullType(permittedType); + builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + + // get uses before we `tryConvertValue` since this creates a new use + List<IRUse*> uses; + for (auto use = var->firstUse; use; use = use->nextUse) + uses.add(use); + + auto convertedValue = tryConvertValue(builder, var, varType); + if (convertedValue == nullptr) + continue; + + foundAConversion = true; + copyNameHintAndDebugDecorations(convertedValue, var); + for (auto use : uses) builder.replaceOperand(use, convertedValue); } + if (!foundAConversion) + { + // If we can't convert the value, report an error. + for (auto permittedType : info.permittedTypes) + { + StringBuilder typeNameSB; + getTypeNameHint(typeNameSB, permittedType); + m_sink->diagnose(var->sourceLoc, Diagnostics::systemValueTypeIncompatible, semanticName, typeNameSB.produceString()); + } + } } } - fixUpFuncType(entryPoint.entryPointFunc); - } - void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink) - { - hoistEntryPointParameterFromStruct(entryPoint); - packStageInParameters(entryPoint); - legalizeSystemValueParameters(entryPoint, sink); - wrapReturnValueInStruct(sink, entryPoint); - legalizeMeshEntryPoint(entryPoint); - legalizeDispatchMeshPayloadForMetal(entryPoint); - } + void legalizeSystemValueParameters(EntryPointInfo entryPoint) + { + List<SystemValLegalizationWorkItem> systemValWorkItems = collectSystemValFromEntryPoint(entryPoint); + + for (auto index = 0; index < systemValWorkItems.getCount(); index++) + { + legalizeSystemValue(entryPoint, systemValWorkItems[index]); + } + fixUpFuncType(entryPoint.entryPointFunc); + } + + void legalizeEntryPointForMetal(EntryPointInfo entryPoint) + { + //Input Parameter Legalize + hoistEntryPointParameterFromStruct(entryPoint); + packStageInParameters(entryPoint); + flattenInputParameters(entryPoint); + + //System Value Legalize + legalizeSystemValueParameters(entryPoint); + + //Output Value Legalize + wrapReturnValueInStruct(entryPoint); + + //Other Legalize + legalizeMeshEntryPoint(entryPoint); + legalizeDispatchMeshPayloadForMetal(entryPoint); + } + }; void legalizeFuncBody(IRFunc* func) { @@ -871,8 +1684,10 @@ namespace Slang } } + LegalizeMetalEntryPointContext context(sink, module); for (auto entryPoint : entryPoints) - legalizeEntryPointForMetal(entryPoint, sink); + context.legalizeEntryPointForMetal(entryPoint); + context.removeSemanticLayoutsFromLegalizedStructs(); specializeAddressSpace(module); } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index a4ea100dc..c6409a7e1 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6267,9 +6267,9 @@ namespace Slang addDecoration(inst, kIROp_HighLevelDeclDecoration, ptrConst); } - void IRBuilder::addLayoutDecoration(IRInst* value, IRLayout* layout) + IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRInst* value, IRLayout* layout) { - addDecoration(value, kIROp_LayoutDecoration, layout); + return as<IRLayoutDecoration>(addDecoration(value, kIROp_LayoutDecoration, layout)); } IRTypeSizeAttr* IRBuilder::getTypeSizeAttr( diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 6a1080b0e..046c35ef9 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -480,8 +480,8 @@ static bool isDigit(char c) return (c >= '0') && (c <= '9'); } -void splitNameAndIndex( - UnownedStringSlice const& text, +bool splitNameAndIndex( + UnownedStringSlice const& text, UnownedStringSlice& outName, UnownedStringSlice& outDigits) { @@ -489,14 +489,20 @@ void splitNameAndIndex( char const* digitsEnd = text.end(); char const* nameEnd = digitsEnd; + // ExplicitIndex is when a semantic has an index at the end of its name + // "SV_TARGET1" has an ExplicitIndex + // "SV_TARGET" does not have an ExplicitIndex + bool hasExplicitIndex = false; while( nameEnd != nameBegin && isDigit(*(nameEnd - 1)) ) { + hasExplicitIndex = true; nameEnd--; } char const* digitsBegin = nameEnd; outName = UnownedStringSlice(nameBegin, nameEnd); outDigits = UnownedStringSlice(digitsBegin, digitsEnd); + return hasExplicitIndex; } LayoutResourceKind findRegisterClassFromName(UnownedStringSlice const& registerClassName) diff --git a/source/slang/slang-parameter-binding.h b/source/slang/slang-parameter-binding.h index 3c0949e37..e2530e34e 100644 --- a/source/slang/slang-parameter-binding.h +++ b/source/slang/slang-parameter-binding.h @@ -32,7 +32,7 @@ void generateParameterBindings( /// Given a string that specifies a name and index (e.g., `COLOR0`), /// split it into slices for the name part and the index part. /// -void splitNameAndIndex( +bool splitNameAndIndex( UnownedStringSlice const& text, UnownedStringSlice& outName, UnownedStringSlice& outDigits); |
