summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2024-07-10 16:24:12 -0400
committerGitHub <noreply@github.com>2024-07-10 13:24:12 -0700
commita08ccfa50a06797dab60918b788570a520c45454 (patch)
tree79e3b90b68c72229efdc6ed51f3f8e42788843bd
parent667e50498a226103278d0997528cc76979b2c4ef (diff)
Fixes to Metal Input parameters and Output value input/output semantics (#4536)
* initial change to test with CI for CPU/CUDA errors * Fixes to Metal Input parameters and Output values Note: 1. Flattening a struct is the process of making a struct have 0 struct/class members. Changes: 1. Separated `legalizeSystemValueParameters`. This was done to make it easier to run `legalizeSystemValue` 1 system-value at a time to simplify logic. This change is optional and can be undone if not preferred. 2. Wrap everything inside a Metal legalization context. This was done since it simplifies a lot of logic and will be required for #4375 3. Created `convertSystemValueSemanticNameToEnum` and expanded the existing System-Value Enum system. This allows (sometimes) faster comparisons and helps prepare code for porting into `slang-ir-legalize-varying-params.cpp` (#4375) 4. Added a more dynamic `legalizeSystemValue` system so more than 2 types can be targeted for legalization. This is required to legalize `output`. There is still no preference for any converted type, the first valid type will be converted to. 5. Flatten all input(`flattenInputParameters`)/output(part of `wrapReturnValueInStruct`) structs and assign semantics accordingly. 6. Semantics when legalized have no specific logic other than to: 1. avoid overlapping semantics 2. Prefer assigning explicit semantics specified by a user. 7. Fixed some issue with incorrect output semantics if not a fragment stage (when there are not any assigned semantics) * change metallib test to the correct metal test * comment code & cleanup -- Did not address all review Added comments for clarity + cleaned up some odd areas which were messy * Add comment to `fixFieldSemanticsOfFlatStruct` I found `fixFieldSemanticsOfFlatStruct` to still be confusing at a cursory glance. Added comments to make the function be more understandable. * white space * Address review comments 1. Fix semantic propegation. 2. Fix how we map struct fields of the flat struct to struct. This is specifically important for if reusing the same struct twice since struct member info is not unique per struct instance used. * Fix semantic legalization by adding TreeMap Add TreeMap to allow proper sorted-object data iteration. * Fix some compile issues * try to fix gcc compile error * compile error * fix logic bug in treeMap iterator next-semantic setter * fix vsproject filters * filter file syntax error * remove need of a context to make copies stable * Rename treemap to the more appropriate name of "treeset", adjust code comments accordingly. * remove custom type `TreeSet` and use `std::set` * remove TreeMap fully --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/hlsl.meta.slang2
-rw-r--r--source/slang/slang-diagnostic-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h6
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp83
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h61
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp1969
-rw-r--r--source/slang/slang-ir.cpp4
-rw-r--r--source/slang/slang-parameter-binding.cpp10
-rw-r--r--source/slang/slang-parameter-binding.h2
-rw-r--r--tests/compute/compile-time-loop.slang2
-rw-r--r--tests/compute/constexpr.slang2
-rw-r--r--tests/compute/discard-stmt.slang2
-rw-r--r--tests/compute/texture-sampling.slang4
-rw-r--r--tests/metal/atomic-intrinsics.slang1
-rw-r--r--tests/metal/nested-struct-fragment-input.slang68
-rw-r--r--tests/metal/nested-struct-fragment-output.slang74
-rw-r--r--tests/metal/nested-struct-multi-entry-point-vertex.slang45
-rw-r--r--tests/metal/no-struct-vertex-output.slang12
-rw-r--r--tests/metal/stage-in-2.slang2
-rw-r--r--tests/metal/sv_target-complex-1.slang34
-rw-r--r--tests/metal/sv_target-complex-2.slang27
21 files changed, 1762 insertions, 652 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 17e2822a3..59a64a192 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -723,7 +723,6 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format>
[__readNone]
[ForceInline]
- [require(hlsl, texture_sm_4_0_fragment)]
T Sample(vector<float, Shape.dimensions+isArray> location, vector<int, Shape.planeDimensions> offset, float clamp, out uint status)
{
__target_switch
@@ -1328,7 +1327,6 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,0,format>
[__readNone]
[ForceInline]
- [require(hlsl, texture_sm_4_0_fragment)]
T Sample(SamplerState s, vector<float, Shape.dimensions+isArray> location, constexpr vector<int, Shape.planeDimensions> offset, float clamp, out uint status)
{
__target_switch
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 98af8a228..3e7d71369 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -708,6 +708,7 @@ DIAGNOSTIC(40002, Error, invalidBindingValue, "binding location '$0' is out of v
DIAGNOSTIC(40003, Error, bindingExceedsLimit, "binding location '$0' assigned to component '$1' exceeds maximum limit.")
DIAGNOSTIC(40004, Error, bindingAlreadyOccupiedByModule, "DescriptorSet ID '$0' is already occupied by module instance '$1'.")
DIAGNOSTIC(40005, Error, topLevelModuleUsedWithoutSpecifyingBinding, "top level module '$0' is being used without specifying binding location. Use [Binding: \"index\"] attribute to provide a binding location.")
+DIAGNOSTIC(40006, Error, unimplementedSystemValueSemantic, "unknown system-value semantic '$0'")
DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value semantic '$0'")
@@ -849,10 +850,9 @@ DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid param
DIAGNOSTIC(55200, Error, unsupportedBuiltinType, "'$0' is not a supported builtin type for the target.")
DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$0', but the current code generation target does not allow recursion.")
DIAGNOSTIC(55202, Error, systemValueAttributeNotSupported, "system value semantic '$0' is not supported for the current target.")
-DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or convertible to type '$1'.")
+DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or be convertible to type '$1'.")
DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'")
-
DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.")
DIAGNOSTIC(57003, Error, unknownTessPartitioning, "unknown tessellation partitioning '$0'.")
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fb9a9613a..83b38b3b6 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4405,7 +4405,7 @@ public:
}
// void addLayoutDecoration(IRInst* value, Layout* layout);
- void addLayoutDecoration(IRInst* value, IRLayout* layout);
+ IRLayoutDecoration* addLayoutDecoration(IRInst* value, IRLayout* layout);
// IRLayout* getLayout(Layout* astLayout);
@@ -4525,9 +4525,9 @@ public:
addDecoration(value, kIROp_ForceUnrollDecoration, getIntValue(getIntType(), iters));
}
- void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)
+ IRSemanticDecoration* addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)
{
- addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index));
+ return as<IRSemanticDecoration>(addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index)));
}
void addRequireSPIRVDescriptorIndexingExtensionDecoration(IRInst* value)
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index 582af4ac8..7d3e3c533 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -2,9 +2,34 @@
#include "slang-ir-legalize-varying-params.h"
#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+#include "slang-parameter-binding.h"
namespace Slang
{
+ // Convert semantic name (ignores case) into equivlent `SystemValueSemanticName`
+ SystemValueSemanticName convertSystemValueSemanticNameToEnum(String rawSemanticName)
+ {
+ auto semanticName = rawSemanticName.toLower();
+
+ SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None;
+
+#define CASE(ID, NAME) \
+ if(semanticName == String(#NAME).toLower()) \
+ { \
+ systemValueSemanticName = SystemValueSemanticName::ID; \
+ } \
+ else
+
+ SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
+#undef CASE
+ {
+ systemValueSemanticName = SystemValueSemanticName::Unknown;
+ // no match
+ }
+ return systemValueSemanticName;
+ }
// This pass implements logic to "legalize" the varying parameter
// signature of an entry point.
@@ -51,36 +76,6 @@ namespace Slang
// * Slang allows for `inout` varying parameters, which need to desugar into
// distinct `in` and `out` parameters for targets like GLSL.
-
-#define SYSTEM_VALUE_SEMANTIC_NAMES(M) \
- M(DispatchThreadID, SV_DispatchThreadID) \
- M(GroupID, SV_GroupID) \
- M(GroupThreadID, SV_GroupThreadID) \
- M(GroupThreadIndex, SV_GroupIndex) \
- /* end */
-
- /// A known system-value semantic name that can be applied to a parameter
- ///
-enum class SystemValueSemanticName
-{
- None = 0,
-
- // TODO: Should this enumeration be responsible for differentiating
- // cases where the same semantic name string is allowed in multiple stages,
- // or as both input/output in a single stage, and those different uses
- // might result in different meanings? The alternative is to always
- // pass around the semantic name, stage, and direction together so
- // that code can tell those special cases apart.
-
-#define CASE(ID, NAME) ID,
-SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
-#undef CASE
-
- // TODO: There are many more system-value semantic names that we
- // can/should handle here, but for now I've restricted this list
- // to those that are necessary for translating compute shaders.
-};
-
/// A placeholder that represents the value of a legalized varying
/// parameter, for the purposes of substituting it into IR code.
///
@@ -249,7 +244,7 @@ IRInst* emitCalcDispatchThreadID(
}
/// Emit code to calculate `SV_GroupIndex`
-IRInst* emitCalcGroupThreadIndex(
+IRInst* emitCalcGroupIndex(
IRBuilder& builder,
IRInst* groupThreadID,
IRInst* groupExtents)
@@ -935,23 +930,7 @@ protected:
// avoid all the `String`s we crete and thren throw
// away here.
//
- String semanticNameSpelling = semanticInst->getName();
- auto semanticName = semanticNameSpelling.toLower();
-
- SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None;
-
- #define CASE(ID, NAME) \
- if(semanticName == String(#NAME).toLower()) \
- { \
- systemValueSemanticName = SystemValueSemanticName::ID; \
- } \
- else
-
- SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
- #undef CASE
- {
- // no match
- }
+ auto systemValueSemanticName = convertSystemValueSemanticNameToEnum(String(semanticInst->getName()));
if( systemValueSemanticName != SystemValueSemanticName::None )
{
@@ -1223,7 +1202,7 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz
threadIdxGlobalParam,
blockDimGlobalParam);
- groupThreadIndex = emitCalcGroupThreadIndex(
+ groupThreadIndex = emitCalcGroupIndex(
builder,
threadIdxGlobalParam,
blockDimGlobalParam);
@@ -1254,7 +1233,7 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz
{
case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(blockIdxGlobalParam);
case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(threadIdxGlobalParam);
- case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
+ case SystemValueSemanticName::GroupIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID);
default:
return diagnoseUnsupportedSystemVal(info);
@@ -1397,7 +1376,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents);
- groupThreadIndex = emitCalcGroupThreadIndex(builder, groupThreadID, groupExtents);
+ groupThreadIndex = emitCalcGroupIndex(builder, groupThreadID, groupExtents);
}
LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info) SLANG_OVERRIDE
@@ -1414,7 +1393,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
{
case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(groupID);
case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(groupThreadID);
- case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
+ case SystemValueSemanticName::GroupIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID);
default:
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
index 952192def..58efa39a2 100644
--- a/source/slang/slang-ir-legalize-varying-params.h
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -1,17 +1,18 @@
// slang-ir-legalize-varying-params.h
#pragma once
+#include "slang-ir-insts.h"
namespace Slang
{
class DiagnosticSink;
-struct IRFunc;
struct IRModule;
struct IRInst;
struct IRFunc;
struct IRVectorType;
struct IRBuilder;
+struct IREntryPointDecoration;
void legalizeEntryPointVaryingParamsForCPU(
IRModule* module,
@@ -21,13 +22,63 @@ void legalizeEntryPointVaryingParamsForCUDA(
IRModule* module,
DiagnosticSink* sink);
-IRInst* emitCalcGroupThreadIndex(
- IRBuilder& builder,
- IRInst* groupThreadID,
- IRInst* groupExtents);
+
+// (#4375) Once `slang-ir-metal-legalize.cpp` is merged with
+// `slang-ir-legalize-varying-params.cpp`, move the following
+// below into `slang-ir-legalize-varying-params.cpp` as well
IRInst* emitCalcGroupExtents(
IRBuilder& builder,
IRFunc* entryPoint,
IRVectorType* type);
+
+IRInst* emitCalcGroupIndex(
+ IRBuilder& builder,
+ IRInst* groupThreadID,
+ IRInst* groupExtents);
+
+// SystemValueSemanticName member definition macro
+#define SYSTEM_VALUE_SEMANTIC_NAMES(M) \
+ M(Position, SV_Position) \
+ M(ClipDistance, SV_ClipDistance) \
+ M(CullDistance, SV_CullDistance) \
+ M(Coverage, SV_Coverage) \
+ M(InnerCoverage, SV_InnerCoverage) \
+ M(Depth, SV_Depth) \
+ M(DepthGreaterEqual, SV_DepthGreaterEqual) \
+ M(DepthLessEqual, SV_DepthLessEqual) \
+ M(DispatchThreadID, SV_DispatchThreadID) \
+ M(DomainLsocation, SV_DomainLsocation) \
+ M(GroupID, SV_GroupID) \
+ M(GroupIndex, SV_GroupIndex) \
+ M(GroupThreadID, SV_GroupThreadID) \
+ M(GSInstanceID, SV_GSInstanceID) \
+ M(InstanceID, SV_InstanceID) \
+ M(IsFrontFace, SV_IsFrontFace) \
+ M(OutputControlPointID, SV_OutputControlPointID)\
+ M(PointSize, SV_PointSize) \
+ M(PrimitiveID, SV_PrimitiveID) \
+ M(RenderTargetArrayIndex, SV_RenderTargetArrayIndex) \
+ M(SampleIndex, SV_SampleIndex) \
+ M(StencilRef, SV_StencilRef) \
+ M(TessFactor, SV_TessFactor) \
+ M(VertexID, SV_VertexID) \
+ M(ViewID, SV_ViewID) \
+ M(ViewportArrayIndex, SV_ViewportArrayIndex) \
+ M(Target, SV_Target) \
+ /* end */
+
+/// A known system-value semantic name that can be applied to a parameter
+///
+enum class SystemValueSemanticName
+{
+ None = 0,
+ Unknown = 0,
+#define CASE(ID, NAME) ID,
+ SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
+#undef CASE
+};
+
+SystemValueSemanticName convertSystemValueSemanticNameToEnum(String rawSemanticName);
+
}
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 01bcb0295..6b333f892 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -1,5 +1,6 @@
#include "slang-ir-metal-legalize.h"
+#include <set>
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
@@ -10,7 +11,6 @@
namespace Slang
{
- const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid");
struct EntryPointInfo
{
@@ -18,701 +18,1472 @@ namespace Slang
IREntryPointDecoration* entryPointDecor;
};
- void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint)
+ const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid");
+ struct LegalizeMetalEntryPointContext
{
- // If an entry point has a input parameter with a struct type, we want to hoist out
- // all the fields of the struct type to be individual parameters of the entry point.
- // This will canonicalize the entry point signature, so we can handle all cases uniformly.
-
- // For example, given an entry point:
- // ```
- // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID};
- // void main(VertexInput vin) { ... }
- // ```
- // We will transform it to:
- // ```
- // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
- // VertexInput vin = {pos,uv,vertexId};
- // ...
- // }
- // ```
-
- auto func = entryPoint.entryPointFunc;
- List<IRParam*> paramsToProcess;
- for (auto param : func->getParams())
- {
- if (as<IRStructType>(param->getDataType()))
- {
- paramsToProcess.add(param);
+ ShortList<IRType*> permittedTypes_sv_target;
+ Dictionary<IRFunc*, IRInst*> entryPointToGroupThreadId;
+ HashSet<IRStructField*> semanticInfoToRemove;
+
+ DiagnosticSink* m_sink;
+ IRModule* m_module;
+
+ LegalizeMetalEntryPointContext(DiagnosticSink* sink, IRModule* module) : m_sink(sink), m_module(module)
+ {
+ }
+
+ void removeSemanticLayoutsFromLegalizedStructs()
+ {
+ // Metal does not allow duplicate attributes to appear in the same shader.
+ // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]`
+ // must be removed.
+ for (auto field : semanticInfoToRemove)
+ {
+ auto key = field->getKey();
+ // Some decorations appear twice, destroy all found
+ for (;;)
+ {
+ if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ {
+ semanticDecor->removeAndDeallocate();
+ continue;
+ }
+ else if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
+ {
+ layoutDecor->removeAndDeallocate();
+ continue;
+ }
+ break;
+ }
}
}
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- for (auto param : paramsToProcess)
+ void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint)
{
- auto structType = as<IRStructType>(param->getDataType());
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto varLayout = findVarLayout(param);
-
- // If `param` already has a semantic, we don't want to hoist its fields out.
- if (varLayout->findSystemValueSemanticAttr() != nullptr ||
- param->findDecoration<IRSemanticDecoration>())
- continue;
-
- IRStructTypeLayout* structTypeLayout = nullptr;
- if (varLayout)
- structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
- Index fieldIndex = 0;
- List<IRInst*> fieldParams;
- for (auto field : structType->getFields())
+ // If an entry point has a input parameter with a struct type, we want to hoist out
+ // all the fields of the struct type to be individual parameters of the entry point.
+ // This will canonicalize the entry point signature, so we can handle all cases uniformly.
+
+ // For example, given an entry point:
+ // ```
+ // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID};
+ // void main(VertexInput vin) { ... }
+ // ```
+ // We will transform it to:
+ // ```
+ // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
+ // VertexInput vin = {pos,uv,vertexId};
+ // ...
+ // }
+ // ```
+
+ auto func = entryPoint.entryPointFunc;
+ List<IRParam*> paramsToProcess;
+ for (auto param : func->getParams())
{
- auto fieldParam = builder.emitParam(field->getFieldType());
-
- IRCloneEnv cloneEnv;
- cloneInstDecorationsAndChildren(&cloneEnv, builder.getModule(), field->getKey(), fieldParam);
+ if (as<IRStructType>(param->getDataType()))
+ {
+ paramsToProcess.add(param);
+ }
+ }
+
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ for (auto param : paramsToProcess)
+ {
+ auto structType = as<IRStructType>(param->getDataType());
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto varLayout = findVarLayout(param);
- IRVarLayout* fieldLayout = structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr;
+ // If `param` already has a semantic, we don't want to hoist its fields out.
+ if (varLayout->findSystemValueSemanticAttr() != nullptr ||
+ param->findDecoration<IRSemanticDecoration>())
+ continue;
+
+ IRStructTypeLayout* structTypeLayout = nullptr;
if (varLayout)
+ structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
+ Index fieldIndex = 0;
+ List<IRInst*> fieldParams;
+ for (auto field : structType->getFields())
{
- IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout());
- varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout);
- for (auto offsetAttr : fieldLayout->getOffsetAttrs())
+ auto fieldParam = builder.emitParam(field->getFieldType());
+ IRCloneEnv cloneEnv;
+ cloneInstDecorationsAndChildren(&cloneEnv, builder.getModule(), field->getKey(), fieldParam);
+
+ IRVarLayout* fieldLayout = structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr;
+ if (varLayout)
{
- auto parentOffsetAttr = varLayout->findOffsetAttr(offsetAttr->getResourceKind());
- UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0;
- UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0;
- auto resInfo = varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
- resInfo->offset = parentOffset + offsetAttr->getOffset();
- resInfo->space = parentSpace + offsetAttr->getSpace();
+ IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout());
+ varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout);
+ for (auto offsetAttr : fieldLayout->getOffsetAttrs())
+ {
+ auto parentOffsetAttr = varLayout->findOffsetAttr(offsetAttr->getResourceKind());
+ UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0;
+ UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0;
+ auto resInfo = varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
+ resInfo->offset = parentOffset + offsetAttr->getOffset();
+ resInfo->space = parentSpace + offsetAttr->getSpace();
+ }
+ builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build());
}
- builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build());
+ param->insertBefore(fieldParam);
+ fieldParams.add(fieldParam);
+ fieldIndex++;
}
- param->insertBefore(fieldParam);
- fieldParams.add(fieldParam);
- fieldIndex++;
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto reconstructedParam = builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer());
+ param->replaceUsesWith(reconstructedParam);
+ param->removeFromParent();
}
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto reconstructedParam = builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer());
- param->replaceUsesWith(reconstructedParam);
- param->removeFromParent();
+ fixUpFuncType(func);
}
- fixUpFuncType(func);
- }
- void packStageInParameters(EntryPointInfo entryPoint)
- {
- // If the entry point has any parameters whose layout contains VaryingInput,
- // we need to pack those parameters into a single `struct` type, and decorate
- // the fields with the appropriate `[[attribute]]` decorations.
- // For other parameters that are not `VaryingInput`, we need to leave them as is.
- //
- // For example, given this code after `hoistEntryPointParameterFromStruct`:
- // ```
- // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
- // VertexInput vin = {pos,uv,vertexId};
- // ...
- // }
- // ```
- // We are going to transform it into:
- // ```
- // struct VertexInput {
- // float3 pos [[attribute(0)]];
- // float2 uv [[attribute(1)]];
- // };
- // void main(VertexInput vin, int vertexId : SV_VertexID) {
- // let pos = vin.pos;
- // let uv = vin.uv;
- // ...
- // }
-
- auto func = entryPoint.entryPointFunc;
-
- bool isGeometryStage = false;
- switch (entryPoint.entryPointDecor->getProfile().getStage())
- {
- case Stage::Vertex:
- case Stage::Amplification:
- case Stage::Mesh:
- case Stage::Geometry:
- case Stage::Domain:
- case Stage::Hull:
- isGeometryStage = true;
- break;
- }
+ // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct
+ void flattenInputParameters(EntryPointInfo entryPoint)
+ {
+ // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members).
+ /*
+ // Assume the following code
+ struct NestedFragment
+ {
+ float2 p3;
+ };
+ struct Fragment
+ {
+ float4 p1;
+ float3 p2;
+ NestedFragment p3_nested;
+ };
- List<IRParam*> paramsToPack;
- for (auto param : func->getParams())
- {
- auto layout = findVarLayout(param);
- if (!layout)
- continue;
- if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
- continue;
- if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- continue;
- paramsToPack.add(param);
- }
+ // Fragment flattens into
+ struct Fragment
+ {
+ float4 p1;
+ float3 p2;
+ float2 p3;
+ };
+ */
+
+ // This is important since Metal does not allow semantic's on a struct
+ /*
+ // Assume the following code
+ struct NestedFragment1
+ {
+ float2 p3;
+ };
+ struct Fragment1
+ {
+ float4 p1 : SV_TARGET0;
+ float3 p2 : SV_TARGET1;
+ NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct
+ };
+
+ */
+
+ // Metal does allow semantics on members of a nested struct but we are avoiding this
+ // approach since there are senarios where legalization (and verification) is
+ // hard/expensive without creating a flat struct:
+ // 1. Entry points may share structs, semantics may be inconsistent across entry points
+ // 2. Multiple of the same struct may be used in a param list
+ /*
+ // Assume the following code
+ struct NestedFragment
+ {
+ float2 p3;
+ };
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ NestedFragment p2 : SV_TARGET1;
+ NestedFragment p3 : SV_TARGET2;
+ };
- if (paramsToPack.getCount() == 0)
- return;
+ // Legalized without flattening -- abandoned
+ struct NestedFragment1
+ {
+ float2 p3 : SV_TARGET1;
+ };
+ struct NestedFragment2
+ {
+ float2 p3 : SV_TARGET2;
+ };
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ NestedFragment1 p2;
+ NestedFragment2 p3;
+ };
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- IRStructType* structType = builder.createStructType();
- auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
- builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Input")).getUnownedSlice());
- List<IRStructKey*> keys;
- IRStructTypeLayout::Builder layoutBuilder(&builder);
- for (auto param : paramsToPack)
- {
- auto paramVarLayout = findVarLayout(param);
- auto key = builder.createStructKey();
- param->transferDecorationsTo(key);
- builder.createStructField(structType, key, param->getDataType());
- if (auto varyingInOffsetAttr = paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput))
- {
- if (!key->findDecoration<IRSemanticDecoration>() && !paramVarLayout->findAttr<IRSemanticAttr>())
+ // Legalized with flattening -- current approach
+ struct Fragment
{
- // If the parameter doesn't have a semantic, we need to add one for semantic matching.
- builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varyingInOffsetAttr->getOffset());
- }
- }
- if (isGeometryStage)
+ float4 p1 : SV_TARGET0;
+ float2 p2 : SV_TARGET1;
+ float2 p3 : SV_TARGET2;
+ };
+ */
+
+ auto func = entryPoint.entryPointFunc;
+ bool modified = false;
+ for (auto param : func->getParams())
{
- // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute offsets.
- IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout());
- elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout);
- for (auto offsetAttr : paramVarLayout->getOffsetAttrs())
+ auto layout = findVarLayout(param);
+ if (!layout)
+ continue;
+ if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ continue;
+ // If we find a IRParam with a IRStructType member, we need to flatten the entire IRParam
+ if (auto structType = as<IRStructType>(param->getDataType()))
{
- auto resourceKind = offsetAttr->getResourceKind();
- if (resourceKind == LayoutResourceKind::VaryingInput)
+ IRBuilder builder(func);
+ MapStructToFlatStruct mapOldFieldToNewField;
+
+ // Flatten struct if we have nested IRStructType
+ auto flattenedStruct = maybeFlattenNestedStructs(builder, structType, mapOldFieldToNewField, semanticInfoToRemove);
+ if (flattenedStruct != structType)
{
- resourceKind = LayoutResourceKind::MetalAttribute;
+ // Validate/rearange all semantics which overlap in our flat struct
+ fixFieldSemanticsOfFlatStruct(flattenedStruct);
+
+ // Replace the 'old IRParam type' with a 'new IRParam type'
+ param->setFullType(flattenedStruct);
+
+ // Emit a new variable at EntryPoint of 'old IRParam type'
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto dstVal = builder.emitVar(structType);
+ auto dstLoad = builder.emitLoad(dstVal);
+ param->replaceUsesWith(dstLoad);
+ builder.setInsertBefore(dstLoad);
+ // Copy the 'new IRParam type' to our 'old IRParam type'
+ mapOldFieldToNewField.emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>(builder, dstVal, param);
+
+ modified = true;
}
- auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind);
- resInfo->offset = offsetAttr->getOffset();
- resInfo->space = offsetAttr->getSpace();
}
- paramVarLayout = elementVarLayoutBuilder.build();
}
- layoutBuilder.addField(key, paramVarLayout);
- builder.addLayoutDecoration(key, paramVarLayout);
- keys.add(key);
- }
- builder.setInsertInto(func->getFirstBlock());
- auto packedParam = builder.emitParamAtHead(structType);
- auto typeLayout = layoutBuilder.build();
- IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
-
- // Add a VaryingInput resource info to the packed parameter layout, so that we can emit
- // the needed `[[stage_in]]` attribute in Metal emitter.
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(packedParam, paramVarLayout);
-
- // Replace the original parameters with the packed parameter
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++)
- {
- auto param = paramsToPack[paramIndex];
- auto key = keys[paramIndex];
- auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key);
- param->replaceUsesWith(paramField);
- param->removeFromParent();
+ if (modified)
+ fixUpFuncType(func);
}
- fixUpFuncType(func);
- }
-
- struct MetalSystemValueInfo
- {
- String metalSystemValueName;
- IRType* requiredType;
- IRType* altRequiredType;
- bool isUnsupported;
- bool isSpecial;
- };
- IRType* getGroupThreadIdType(IRBuilder& builder)
- {
- return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
- }
-
- MetalSystemValueInfo getSystemValueInfo(IRBuilder& builder, String inSemanticName)
- {
- inSemanticName = inSemanticName.toLower();
-
- UnownedStringSlice semanticName;
- UnownedStringSlice semanticIndex;
- splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
+ void packStageInParameters(EntryPointInfo entryPoint)
+ {
+ // If the entry point has any parameters whose layout contains VaryingInput,
+ // we need to pack those parameters into a single `struct` type, and decorate
+ // the fields with the appropriate `[[attribute]]` decorations.
+ // For other parameters that are not `VaryingInput`, we need to leave them as is.
+ //
+ // For example, given this code after `hoistEntryPointParameterFromStruct`:
+ // ```
+ // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
+ // VertexInput vin = {pos,uv,vertexId};
+ // ...
+ // }
+ // ```
+ // We are going to transform it into:
+ // ```
+ // struct VertexInput {
+ // float3 pos [[attribute(0)]];
+ // float2 uv [[attribute(1)]];
+ // };
+ // void main(VertexInput vin, int vertexId : SV_VertexID) {
+ // let pos = vin.pos;
+ // let uv = vin.uv;
+ // ...
+ // }
+
+ auto func = entryPoint.entryPointFunc;
+
+ bool isGeometryStage = false;
+ switch (entryPoint.entryPointDecor->getProfile().getStage())
+ {
+ case Stage::Vertex:
+ case Stage::Amplification:
+ case Stage::Mesh:
+ case Stage::Geometry:
+ case Stage::Domain:
+ case Stage::Hull:
+ isGeometryStage = true;
+ break;
+ }
- MetalSystemValueInfo result = {};
+ List<IRParam*> paramsToPack;
+ for (auto param : func->getParams())
+ {
+ auto layout = findVarLayout(param);
+ if (!layout)
+ continue;
+ if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ continue;
+ paramsToPack.add(param);
+ }
- if (semanticName == "sv_position")
- {
- result.metalSystemValueName = toSlice("position");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 4));
- }
- else if (semanticName == "sv_clipdistance")
- {
- result.isSpecial = true;
- }
- else if (semanticName == "sv_culldistance")
- {
- result.isSpecial = true;
- }
- else if (semanticName == "sv_coverage")
- {
- result.metalSystemValueName = toSlice("sample_mask");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ if (paramsToPack.getCount() == 0)
+ return;
+
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ IRStructType* structType = builder.createStructType();
+ auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
+ builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Input")).getUnownedSlice());
+ List<IRStructKey*> keys;
+ IRStructTypeLayout::Builder layoutBuilder(&builder);
+ for (auto param : paramsToPack)
+ {
+ auto paramVarLayout = findVarLayout(param);
+ auto key = builder.createStructKey();
+ param->transferDecorationsTo(key);
+ builder.createStructField(structType, key, param->getDataType());
+ if (auto varyingInOffsetAttr = paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ {
+ if (!key->findDecoration<IRSemanticDecoration>() && !paramVarLayout->findAttr<IRSemanticAttr>())
+ {
+ // If the parameter doesn't have a semantic, we need to add one for semantic matching.
+ builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varyingInOffsetAttr->getOffset());
+ }
+ }
+ if (isGeometryStage)
+ {
+ // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute offsets.
+ IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout());
+ elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout);
+ for (auto offsetAttr : paramVarLayout->getOffsetAttrs())
+ {
+ auto resourceKind = offsetAttr->getResourceKind();
+ if (resourceKind == LayoutResourceKind::VaryingInput)
+ {
+ resourceKind = LayoutResourceKind::MetalAttribute;
+ }
+ auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind);
+ resInfo->offset = offsetAttr->getOffset();
+ resInfo->space = offsetAttr->getSpace();
+ }
+ paramVarLayout = elementVarLayoutBuilder.build();
+ }
+ layoutBuilder.addField(key, paramVarLayout);
+ builder.addLayoutDecoration(key, paramVarLayout);
+ keys.add(key);
+ }
+ builder.setInsertInto(func->getFirstBlock());
+ auto packedParam = builder.emitParamAtHead(structType);
+ auto typeLayout = layoutBuilder.build();
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+
+ // Add a VaryingInput resource info to the packed parameter layout, so that we can emit
+ // the needed `[[stage_in]]` attribute in Metal emitter.
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(packedParam, paramVarLayout);
+
+ // Replace the original parameters with the packed parameter
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++)
+ {
+ auto param = paramsToPack[paramIndex];
+ auto key = keys[paramIndex];
+ auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key);
+ param->replaceUsesWith(paramField);
+ param->removeFromParent();
+ }
+ fixUpFuncType(func);
}
- else if (semanticName == "sv_innercoverage")
- {
- result.isSpecial = true;
- }
- else if (semanticName == "sv_depth")
- {
- result.metalSystemValueName = toSlice("depth(any)");
- result.requiredType = builder.getBasicType(BaseType::Float);
- }
- else if (semanticName == "sv_depthgreaterequal")
+ struct MetalSystemValueInfo
{
- result.metalSystemValueName = toSlice("depth(greater)");
- result.requiredType = builder.getBasicType(BaseType::Float);
- }
- else if (semanticName == "sv_depthlessequal")
- {
- result.metalSystemValueName = toSlice("depth(less)");
- result.requiredType = builder.getBasicType(BaseType::Float);
- }
- else if (semanticName == "sv_dispatchthreadid")
- {
- result.metalSystemValueName = toSlice("thread_position_in_grid");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
- }
- else if (semanticName == "sv_domainlocation")
- {
- result.metalSystemValueName = toSlice("position_in_patch");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 3));
- result.altRequiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 2));
- }
- else if (semanticName == "sv_groupid")
- {
- result.metalSystemValueName = toSlice("threadgroup_position_in_grid");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
- }
- else if (semanticName == "sv_groupindex")
- {
- result.isSpecial = true;
- }
- else if (semanticName == "sv_groupthreadid")
- {
- result.metalSystemValueName = toSlice("thread_position_in_threadgroup");
- result.requiredType = getGroupThreadIdType(builder);
- }
- else if (semanticName == "sv_gsinstanceid")
+ String metalSystemValueName;
+ SystemValueSemanticName metalSystemValueNameEnum;
+ ShortList<IRType*> permittedTypes;
+ bool isUnsupported = false;
+ bool isSpecial = false;
+ MetalSystemValueInfo()
+ {
+ // most commonly need 2
+ permittedTypes.reserveOverflowBuffer(2);
+ }
+ };
+
+ IRType* getGroupThreadIdType(IRBuilder& builder)
{
- // Metal does not have geometry shader, so this is invalid.
- result.isUnsupported = true;
+ return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
}
- else if (semanticName == "sv_instanceid")
+
+ // Get all permitted types of "sv_target" for Metal
+ ShortList<IRType*>& getPermittedTypes_sv_target(IRBuilder& builder)
{
- result.metalSystemValueName = toSlice("instance_id");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ permittedTypes_sv_target.reserveOverflowBuffer(5 * 4);
+ if (permittedTypes_sv_target.getCount() == 0)
+ {
+ for (auto baseType : { BaseType::Float, BaseType::Half, BaseType::Int, BaseType::UInt, BaseType::Int16, BaseType::UInt16 })
+ {
+ for (IRIntegerValue i = 1; i <= 4; i++)
+ {
+ permittedTypes_sv_target.add(builder.getVectorType(builder.getBasicType(baseType), i));
+ }
+ }
+ }
+ return permittedTypes_sv_target;
}
- else if (semanticName == "sv_isfrontface")
+
+ MetalSystemValueInfo getSystemValueInfo(String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar)
{
- result.metalSystemValueName = toSlice("front_facing");
- result.requiredType = builder.getBasicType(BaseType::Bool);
+ IRBuilder builder(m_module);
+ MetalSystemValueInfo result = {};
+ UnownedStringSlice semanticName;
+ UnownedStringSlice semanticIndex;
+
+ auto hasExplicitIndex = splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
+ if (!hasExplicitIndex && optionalSemanticIndex)
+ semanticIndex = optionalSemanticIndex->getUnownedSlice();
+
+ result.metalSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName);
+
+ switch (result.metalSystemValueNameEnum)
+ {
+ case SystemValueSemanticName::Position:
+ {
+ result.metalSystemValueName = toSlice("position");
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 4)));
+ break;
+ }
+ case SystemValueSemanticName::ClipDistance:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::CullDistance:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::Coverage:
+ {
+ result.metalSystemValueName = toSlice("sample_mask");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::InnerCoverage:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::Depth:
+ {
+ result.metalSystemValueName = toSlice("depth(any)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DepthGreaterEqual:
+ {
+ result.metalSystemValueName = toSlice("depth(greater)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DepthLessEqual:
+ {
+ result.metalSystemValueName = toSlice("depth(less)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DispatchThreadID:
+ {
+ result.metalSystemValueName = toSlice("thread_position_in_grid");
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)));
+ break;
+ }
+ case SystemValueSemanticName::DomainLsocation:
+ {
+ result.metalSystemValueName = toSlice("position_in_patch");
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 3)));
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 2)));
+ break;
+ }
+ case SystemValueSemanticName::GroupID:
+ {
+ result.metalSystemValueName = toSlice("threadgroup_position_in_grid");
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)));
+ break;
+ }
+ case SystemValueSemanticName::GroupIndex:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::GroupThreadID:
+ {
+ result.metalSystemValueName = toSlice("thread_position_in_threadgroup");
+ result.permittedTypes.add(getGroupThreadIdType(builder));
+ break;
+ }
+ case SystemValueSemanticName::GSInstanceID:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+ case SystemValueSemanticName::InstanceID:
+ {
+ result.metalSystemValueName = toSlice("instance_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::IsFrontFace:
+ {
+ result.metalSystemValueName = toSlice("front_facing");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Bool));
+ break;
+ }
+ case SystemValueSemanticName::OutputControlPointID:
+ {
+ // In metal, a hull shader is just a compute shader.
+ // This needs to be handled separately, by lowering into an ordinary buffer.
+ break;
+ }
+ case SystemValueSemanticName::PointSize:
+ {
+ result.metalSystemValueName = toSlice("point_size");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::PrimitiveID:
+ {
+ result.metalSystemValueName = toSlice("patch_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::RenderTargetArrayIndex:
+ {
+ result.metalSystemValueName = toSlice("render_target_array_index");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::SampleIndex:
+ {
+ result.metalSystemValueName = toSlice("sample_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::StencilRef:
+ {
+ result.metalSystemValueName = toSlice("stencil");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::TessFactor:
+ {
+ // Tessellation factor outputs should be lowered into a write into a normal buffer.
+ break;
+ }
+ case SystemValueSemanticName::VertexID:
+ {
+ result.metalSystemValueName = toSlice("vertex_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::ViewID:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+ case SystemValueSemanticName::ViewportArrayIndex:
+ {
+ result.metalSystemValueName = toSlice("viewport_array_index");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::Target:
+ {
+ result.metalSystemValueName = (StringBuilder() << "color("
+ << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0"))
+ << ")").produceString();
+ result.permittedTypes = getPermittedTypes_sv_target(builder);
+
+ break;
+ }
+ default:
+ m_sink->diagnose(parentVar, Diagnostics::unimplementedSystemValueSemantic, semanticName);
+ return result;
+ }
+ return result;
}
- else if (semanticName == "sv_outputcontrolpointid")
+
+ void reportUnsupportedSystemAttribute(IRInst* param, String semanticName)
{
- // In metal, a hull shader is just a compute shader.
- // This needs to be handled separately, by lowering into an ordinary buffer.
+ m_sink->diagnose(param->sourceLoc, Diagnostics::systemValueAttributeNotSupported, semanticName);
}
- else if (semanticName == "sv_pointsize")
+
+ void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout)
{
- result.metalSystemValueName = toSlice("point_size");
- result.requiredType = builder.getBasicType(BaseType::Float);
+ // Ensure each field in an output struct type has either a system semantic or a user semantic,
+ // so that signature matching can happen correctly.
+ auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
+ Index index = 0;
+ IRBuilder builder(structType);
+ for (auto field : structType->getFields())
+ {
+ auto key = field->getKey();
+ if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ {
+ if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ {
+ auto indexAsString = String(UInt(semanticDecor->getSemanticIndex()));
+ auto sysValInfo = getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field);
+ if (sysValInfo.isUnsupported || sysValInfo.isSpecial)
+ {
+ reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName());
+ }
+ else
+ {
+ builder.addTargetSystemValueDecoration(key, sysValInfo.metalSystemValueName.getUnownedSlice());
+ semanticDecor->removeAndDeallocate();
+ }
+ }
+ index++;
+ continue;
+ }
+ typeLayout->getFieldLayout(index);
+ auto fieldLayout = typeLayout->getFieldLayout(index);
+ if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
+ {
+ UInt varOffset = 0;
+ if (auto varOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
+ varOffset = varOffsetAttr->getOffset();
+ varOffset += offsetAttr->getOffset();
+ builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset);
+ }
+ index++;
+ }
}
- else if (semanticName == "sv_primitiveid")
+
+ // Stores a hicharchy of members and children which map 'oldStruct->member' to 'flatStruct->member'
+ // Note: this map assumes we map to FlatStruct since it is easier/faster to process
+ struct MapStructToFlatStruct
{
- result.metalSystemValueName = toSlice("patch_id");
- result.requiredType = builder.getBasicType(BaseType::UInt);
- result.altRequiredType = builder.getBasicType(BaseType::UInt16);
- }
- else if (semanticName == "sv_rendertargetarrayindex")
+ /*
+ We need a hicharchy map to resolve dependencies for mapping
+ oldStruct to newStruct efficently. Example:
+
+ MyStruct
+ |
+ / | \
+ / | \
+ / | \
+ M0<A> M1<A> M2<B>
+ | | |
+ A_0 A_0 B_0
+
+ Without storing hicharchy information, there will be no way to tell apart
+ `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField
+ only has 1 instance of `A::A0`
+ */
+
+ enum CopyOptions : int
{
- result.metalSystemValueName = toSlice("render_target_array_index");
- result.requiredType = builder.getBasicType(BaseType::UInt);
- result.altRequiredType = builder.getBasicType(BaseType::UInt16);
- }
- else if (semanticName == "sv_sampleindex")
+ // Copy a flattened-struct into a struct
+ FlatStructIntoStruct = 0,
+
+ // Copy a struct into a flattened-struct
+ StructIntoFlatStruct = 1,
+ };
+
+ private:
+ // Children of member if applicable.
+ Dictionary<IRStructField*, MapStructToFlatStruct> members;
+
+ // Field correlating to MapStructToFlatStruct Node.
+ IRInst* node;
+ IRStructKey* getKey()
+ {
+ SLANG_ASSERT(as<IRStructField>(node));
+ return as<IRStructField>(node)->getKey();
+ }
+ IRInst* getNode()
+ {
+ return node;
+ }
+ IRType* getFieldType()
+ {
+ SLANG_ASSERT(as<IRStructField>(node));
+ return as<IRStructField>(node)->getFieldType();
+ }
+
+ // Whom node maps to inside target flatStruct
+ IRStructField* targetMapping;
+
+ auto begin()
+ {
+ return members.begin();
+ }
+ auto end()
+ {
+ return members.end();
+ }
+
+ // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to members in val2
+ // using `MapStructToFlatStruct`
+ template<int copyOptions>
+ static void _emitCopy(IRBuilder& builder, IRInst* val1, IRStructType* type1, IRInst* val2, IRStructType* type2, MapStructToFlatStruct& node)
+ {
+ for (auto& field1Pair : node)
+ {
+ auto& field1 = field1Pair.second;
+
+ // Get member of val1
+ IRInst* fieldAddr1 = nullptr;
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey());
+ }
+ else
+ {
+ if (as<IRPtrTypeBase>(val1))
+ val1 = builder.emitLoad(val1);
+ fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey());
+ }
+
+ // If val1 is a struct, recurse
+ if (auto fieldAsStruct1 = as<IRStructType>(field1.getFieldType()))
+ {
+ _emitCopy<copyOptions>(builder, fieldAddr1, fieldAsStruct1, val2, type2, field1);
+ continue;
+ }
+
+ // Get member of val2 which maps to val1.member
+ auto field2 = field1.getMapping();
+ SLANG_ASSERT(field2);
+ IRInst* fieldAddr2 = nullptr;
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ if (as<IRPtrTypeBase>(val2))
+ val2 = builder.emitLoad(val1);
+ fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey());
+ }
+ else
+ {
+ fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey());
+ }
+
+ // Copy val2/val1 member into val1/val2 member
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ builder.emitStore(fieldAddr1, fieldAddr2);
+ }
+ else
+ {
+ builder.emitStore(fieldAddr2, fieldAddr1);
+ }
+ }
+ }
+
+ public:
+ void setNode(IRInst* newNode)
+ {
+ node = newNode;
+ }
+ // Get 'MapStructToFlatStruct' that is a child of 'parent'.
+ // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'.
+ MapStructToFlatStruct& getMember(IRStructField* member)
+ {
+ return members[member];
+ }
+ MapStructToFlatStruct& operator[](IRStructField* member)
+ {
+ return getMember(member);
+ }
+
+ void setMapping(IRStructField* newTargetMapping)
+ {
+ targetMapping = newTargetMapping;
+ }
+ // Get 'MapStructToFlatStruct' that is a child of 'parent'.
+ // Return nullptr if no member is mapped to 'parent'
+ IRStructField* getMapping()
+ {
+ return targetMapping;
+ }
+
+ // Copies srcVal into dstVal using hicharchy map.
+ template<int copyOptions>
+ void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal)
+ {
+ auto dstType = dstVal->getDataType();
+ if (auto dstPtrType = as<IRPtrTypeBase>(dstType))
+ dstType = dstPtrType->getValueType();
+ auto dstStructType = as<IRStructType>(dstType);
+ SLANG_ASSERT(dstStructType);
+
+ auto srcType = srcVal->getDataType();
+ if (auto srcPtrType = as<IRPtrTypeBase>(srcType))
+ srcType = srcPtrType->getValueType();
+ auto srcStructType = as<IRStructType>(srcType);
+ SLANG_ASSERT(srcStructType);
+
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a struct
+ SLANG_ASSERT(node == dstStructType);
+ _emitCopy<copyOptions>(builder, dstVal, dstStructType, srcVal, srcStructType, *this);
+ }
+ else
+ {
+ // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct
+ SLANG_ASSERT(node == srcStructType);
+ _emitCopy<copyOptions>(builder, srcVal, srcStructType, dstVal, dstStructType, *this);
+ }
+ }
+ };
+
+ IRStructType* _flattenNestedStructs(IRBuilder& builder, IRStructType* dst, IRStructType* src, IRSemanticDecoration* parentSemanticDecoration,
+ IRLayoutDecoration* parentLayout, MapStructToFlatStruct& mapFieldToField, HashSet<IRStructField*>& varsWithSemanticInfo)
{
- result.metalSystemValueName = toSlice("sample_id");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ // For all fields ('oldField') of a struct do the following:
+ // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration, IRLayoutDecoration), store these if found.
+ // * Do not propagate semantic info if the current node has *any* form of semantic information.
+ // Update varsWithSemanticInfo.
+ // 2. If IRStructType:
+ // 2a. Recurse this function with 'decorations that carry semantic info' from parent.
+ // 3. If not IRStructType:
+ // 3a. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic info'.
+ // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is needed to copy between types.
+ for (auto oldField : src->getFields())
+ {
+ auto& fieldMappingNode = mapFieldToField[oldField];
+ fieldMappingNode.setNode(oldField);
+
+ // step 1
+ bool foundSemanticDecor = false;
+ auto oldKey = oldField->getKey();
+ IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration;
+ if (auto oldSemanticDecoration = oldKey->findDecoration<IRSemanticDecoration>())
+ {
+ foundSemanticDecor = true;
+ fieldSemanticDecoration = oldSemanticDecoration;
+ parentLayout = nullptr;
+ }
+
+ IRLayoutDecoration* fieldLayout = parentLayout;
+ if (auto oldLayout = oldKey->findDecoration<IRLayoutDecoration>())
+ {
+ fieldLayout = oldLayout;
+ if (!foundSemanticDecor)
+ fieldSemanticDecoration = nullptr;
+ }
+ if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout)
+ varsWithSemanticInfo.add(oldField);
+
+ // step 2a
+ if (auto structFieldType = as<IRStructType>(oldField->getFieldType()))
+ {
+ _flattenNestedStructs(builder, dst, structFieldType, fieldSemanticDecoration, fieldLayout, fieldMappingNode, varsWithSemanticInfo);
+ continue;
+ }
+
+ // step 3a
+ auto newKey = builder.createStructKey();
+ copyNameHintAndDebugDecorations(newKey, oldKey);
+
+ auto newField = builder.createStructField(dst, newKey, oldField->getFieldType());
+ copyNameHintAndDebugDecorations(newField, oldField);
+
+ if (fieldSemanticDecoration)
+ builder.addSemanticDecoration(newKey, fieldSemanticDecoration->getSemanticName(), fieldSemanticDecoration->getSemanticIndex());
+
+ if (fieldLayout)
+ {
+ IRLayout* oldLayout = fieldLayout->getLayout();
+ List<IRInst*> instToCopy;
+ // Only copy certain decorations needed for resolving system semantics
+ for (UInt i = 0; i < oldLayout->getOperandCount(); i++)
+ {
+ auto operand = oldLayout->getOperand(i);
+ if (as<IRVarOffsetAttr>(operand) || as<IRUserSemanticAttr>(operand) || as<IRSystemValueSemanticAttr>(operand) || as<IRStageAttr>(operand))
+ instToCopy.add(operand);
+ }
+ IRVarLayout* newLayout = builder.getVarLayout(instToCopy);
+ builder.addLayoutDecoration(newKey, newLayout);
+ }
+ // step 3b
+ fieldMappingNode.setMapping(newField);
+ }
+
+ return dst;
}
- else if (semanticName == "sv_stencilref")
+
+ // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there was no struct flattening.
+ // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct `IRStructFields`s
+ IRStructType* maybeFlattenNestedStructs(IRBuilder& builder, IRStructType* src, MapStructToFlatStruct& mapFieldToField, HashSet<IRStructField*>& varsWithSemanticInfo)
{
- result.metalSystemValueName = toSlice("stencil");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ // Find all values inside struct that need flattening and legalization.
+ bool hasStructTypeMembers = false;
+ for (auto field : src->getFields())
+ {
+ if (as<IRStructType>(field->getFieldType()))
+ {
+ hasStructTypeMembers = true;
+ break;
+ }
+ }
+ if (!hasStructTypeMembers)
+ return src;
+
+ // We need to:
+ // 1. Make new struct 1:1 with old struct but without nestested structs (flatten)
+ // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be handled later).
+ // 3. Store the mapping from old to new struct fields to allow copying a old-struct to new-struct.
+ builder.setInsertAfter(src);
+ auto newStruct = builder.createStructType();
+ copyNameHintAndDebugDecorations(newStruct, src);
+ mapFieldToField.setNode(src);
+ return _flattenNestedStructs(builder, newStruct, src, nullptr, nullptr, mapFieldToField, varsWithSemanticInfo);
}
- else if (semanticName == "sv_tessfactor")
+
+ // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'.
+ // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function.
+ template<typename CopyLogicFunc>
+ void _replaceAllReturnInst(IRBuilder& builder, IRFunc* targetFunc, IRStructType* newType, CopyLogicFunc copyLogicFunc)
{
- // Tessellation factor outputs should be lowered into a write into a normal buffer.
+ for (auto block : targetFunc->getBlocks())
+ {
+ if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ {
+ builder.setInsertBefore(returnInst);
+ auto returnVal = returnInst->getVal();
+ returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal));
+ }
+ }
}
- else if (semanticName == "sv_vertexid")
+
+ UInt _returnNonOverlappingAttributeIndex(std::set<UInt>& usedSemanticIndex)
{
- result.metalSystemValueName = toSlice("vertex_id");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ // Find first unused semantic index of equal semantic type
+ // to fill any gaps in user set semantic bindings
+ UInt prev = 0;
+ for (auto i : usedSemanticIndex)
+ {
+ if (i > prev + 1)
+ {
+ break;
+ }
+ prev = i;
+ }
+ usedSemanticIndex.insert(prev + 1);
+ return prev + 1;
}
- else if (semanticName == "sv_viewid")
+
+ template<typename T>
+ struct AttributeParentPair
{
- result.isUnsupported = true;
- }
- else if (semanticName == "sv_viewportarrayindex")
+ IRLayoutDecoration* layoutDecor;
+ T* attr;
+ };
+
+ IRLayoutDecoration* _replaceAttributeOfLayout(IRBuilder& builder, IRLayoutDecoration* parentLayoutDecor, IRInst* instToReplace, IRInst* instToReplaceWith)
{
- result.metalSystemValueName = toSlice("viewport_array_index");
- result.requiredType = builder.getBasicType(BaseType::UInt);
- result.altRequiredType = builder.getBasicType(BaseType::UInt16);
+ // Replace `instToReplace` with a `instToReplaceWith`
+
+ auto layout = parentLayoutDecor->getLayout();
+ // Find the exact same decoration `instToReplace` in-case multiple of the same type exist
+ List<IRInst*> opList;
+ opList.add(instToReplaceWith);
+ for (UInt i = 0; i < layout->getOperandCount(); i++)
+ {
+ if (layout->getOperand(i) != instToReplace)
+ opList.add(layout->getOperand(i));
+ }
+ auto newLayoutDecor = builder.addLayoutDecoration(parentLayoutDecor->getParent(), builder.getVarLayout(opList));
+ parentLayoutDecor->removeAndDeallocate();
+ return newLayoutDecor;
}
- else if (semanticName.startsWith("sv_target"))
+
+ IRLayoutDecoration* _simplifyUserSemanticNames(IRBuilder& builder, IRLayoutDecoration* layoutDecor)
{
-
- result.metalSystemValueName = (StringBuilder() << "color("
- << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0"))
- << ")").produceString();
+ // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into ("SV_TARGET", 0) using 'IRUserSemanticAttr'
+ // This is done to ensure we can check semantic groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()'
+ SLANG_ASSERT(layoutDecor);
+ auto layout = layoutDecor->getLayout();
+ List<IRInst*> layoutOps;
+ layoutOps.reserve(3);
+ bool changed = false;
+ for(auto attr : layout->getAllAttrs())
+ {
+ if(auto userSemantic = as<IRUserSemanticAttr>(attr))
+ {
+ UnownedStringSlice outName;
+ UnownedStringSlice outIndex;
+ bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex);
+ if (hasStringIndex)
+ {
+ changed = true;
+ auto loweredName = String(outName).toLower();
+ auto loweredNameSlice = loweredName.getUnownedSlice();
+ auto newDecoration = builder.getUserSemanticAttr(loweredNameSlice, stringToInt(outIndex));
+ userSemantic->replaceUsesWith(newDecoration);
+ userSemantic->removeAndDeallocate();
+ userSemantic = newDecoration;
+ }
+ layoutOps.add(userSemantic);
+ continue;
+ }
+ layoutOps.add(attr);
+ }
+ if(changed)
+ {
+ auto parent = layoutDecor->parent;
+ layoutDecor->removeAndDeallocate();
+ builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps));
+ }
+ return layoutDecor;
}
- else
+ // Find overlapping field semantics and legalize them
+ void fixFieldSemanticsOfFlatStruct(IRStructType* structType)
{
- result.isUnsupported = true;
- }
- return result;
- }
+ // Goal is to ensure we do not have overlapping semantics:
+ /*
+ // Assume the following code
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET;
+ float3 p2 : SV_TARGET;
+ float2 p3 : SV_TARGET;
+ float2 p4 : SV_TARGET;
+ };
+
+ // Translates into
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ float3 p2 : SV_TARGET1;
+ float2 p3 : SV_TARGET2;
+ float2 p4 : SV_TARGET3;
+ };
+ */
- void reportUnsupportedSystemAttribute(DiagnosticSink* sink, IRInst* param, String semanticName)
- {
- sink->diagnose(param->sourceLoc, Diagnostics::systemValueAttributeNotSupported, semanticName);
- }
+ IRBuilder builder(this->m_module);
- void ensureResultStructHasUserSemantic(DiagnosticSink* sink, IRStructType* structType, IRVarLayout* varLayout)
- {
- // Ensure each field in an output struct type has either a system semantic or a user semantic,
- // so that signature matching can happen correctly.
- auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
- Index index = 0;
- IRBuilder builder(structType);
- for (auto field : structType->getFields())
- {
- auto key = field->getKey();
- if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ List<IRSemanticDecoration*> overlappingSemanticsDecor;
+ Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>> usedSemanticIndexSemanticDecor;
+
+ List<AttributeParentPair<IRVarOffsetAttr>> overlappingVarOffset;
+ Dictionary<UInt, std::set<UInt, std::less<UInt>>> usedSemanticIndexVarOffset;
+
+ List<AttributeParentPair<IRUserSemanticAttr>> overlappingUserSemantic;
+ Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt> >> usedSemanticIndexUserSemantic;
+
+ // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when legalizing
+ // we may destroy and remake a `IRLayoutDecoration*`
+ Dictionary<IRLayoutDecoration*, IRLayoutDecoration*> oldLayoutDecorToNew;
+
+ // Collect all "semantic info carrying decorations". Any collected decoration will
+ // fill up their respective 'Dictionary<SEMANTIC_TYPE, OrderedHashSet<UInt>>'
+ // to keep track of in-use offsets for a semantic type.
+ // Example: IRSemanticDecoration with name of "SV_TARGET1".
+ // * This will have SEMANTIC_TYPE of "sv_target".
+ // * This will use up index '1'
+ //
+ // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to
+ // a list of 'overlapping semantic info decorations' so we can legalize this
+ // 'semantic info decoration' later.
+ //
+ // NOTE: this is a flat struct, all members are children of the initial
+ // IRStructType.
+ for (auto field : structType->getFields())
{
- if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ auto key = field->getKey();
+ if (auto semanticDecoration = key->findDecoration<IRSemanticDecoration>())
{
- auto sysValInfo = getSystemValueInfo(builder, semanticDecor->getSemanticName());
- if (sysValInfo.isUnsupported || sysValInfo.isSpecial)
+ // Ensure names are in a uniform lowercase format so we can bunch together simmilar semantics
+ UnownedStringSlice outName;
+ UnownedStringSlice outIndex;
+ bool hasStringIndex = splitNameAndIndex(semanticDecoration->getSemanticName(), outName, outIndex);
+ if (hasStringIndex)
{
- reportUnsupportedSystemAttribute(sink, field, semanticDecor->getSemanticName());
+ auto loweredName = String(outName).toLower();
+ auto loweredNameSlice = loweredName.getUnownedSlice();
+ auto newDecoration = builder.addSemanticDecoration(key, loweredNameSlice, stringToInt(outIndex));
+ semanticDecoration->replaceUsesWith(newDecoration);
+ semanticDecoration->removeAndDeallocate();
+ semanticDecoration = newDecoration;
}
+ auto& semanticUse = usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()];
+ if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end())
+ overlappingSemanticsDecor.add(semanticDecoration);
else
+ semanticUse.insert(semanticDecoration->getSemanticIndex());
+ }
+ if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
+ {
+ // Ensure names are in a uniform lowercase format so we can bunch together simmilar semantics
+ layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor);
+ oldLayoutDecorToNew[layoutDecor] = layoutDecor;
+ auto layout = layoutDecor->getLayout();
+ for (auto attr : layout->getAllAttrs())
{
- builder.addTargetSystemValueDecoration(key, sysValInfo.metalSystemValueName.getUnownedSlice());
- semanticDecor->removeAndDeallocate();
+ if (auto offset = as<IRVarOffsetAttr>(attr))
+ {
+ auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()];
+ if (semanticUse.find(offset->getOffset()) != semanticUse.end())
+ overlappingVarOffset.add({ layoutDecor, offset });
+ else
+ semanticUse.insert(offset->getOffset());
+ }
+ else if (auto userSemantic = as<IRUserSemanticAttr>(attr))
+ {
+ auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()];
+ if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end())
+ overlappingUserSemantic.add({ layoutDecor, userSemantic });
+ else
+ semanticUse.insert(userSemantic->getIndex());
+ }
}
}
- index++;
- continue;
}
- typeLayout->getFieldLayout(index);
- auto fieldLayout = typeLayout->getFieldLayout(index);
- if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
+
+ // Legalize all overlapping 'semantic info decorations'
+ for (auto decor : overlappingSemanticsDecor)
{
- UInt varOffset = 0;
- if (auto varOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
- varOffset = varOffsetAttr->getOffset();
- varOffset += offsetAttr->getOffset();
- builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset);
+ auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexSemanticDecor[decor->getSemanticName()]);
+ builder.addSemanticDecoration(decor->getParent(), decor->getSemanticName(), (int)newOffset);
+ decor->removeAndDeallocate();
+ }
+ for (auto& varOffset : overlappingVarOffset)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]);
+ auto newVarOffset = builder.getVarOffsetAttr(varOffset.attr->getResourceKind(), newOffset, varOffset.attr->getSpace());
+ oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout(builder, oldLayoutDecorToNew[varOffset.layoutDecor], varOffset.attr, newVarOffset);
+ }
+ for (auto& userSemantic : overlappingUserSemantic)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexUserSemantic[userSemantic.attr->getName()]);
+ auto newUserSemantic = builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset);
+ oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout(builder, oldLayoutDecorToNew[userSemantic.layoutDecor], userSemantic.attr, newUserSemantic);
}
- index++;
- }
- }
-
-
- void wrapReturnValueInStruct(DiagnosticSink* sink, EntryPointInfo entryPoint)
- {
- // Wrap return value into a struct if it is not already a struct.
- // For example, given this entry point:
- // ```
- // float4 main() : SV_Target { return float3(1,2,3); }
- // ```
- // We are going to transform it into:
- // ```
- // struct Output {
- // float4 value : SV_Target;
- // };
- // Output main() { return {float3(1,2,3)}; }
-
- auto func = entryPoint.entryPointFunc;
-
- auto returnType = func->getResultType();
- if (as<IRVoidType>(returnType))
- return;
- auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>();
- if (!entryPointLayoutDecor)
- return;
- auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout());
- if (!entryPointLayout)
- return;
- auto resultLayout = entryPointLayout->getResultLayout();
-
- // If return type is already a struct, just make sure every field has a semantic.
- if (auto returnStructType = as<IRStructType>(returnType))
- {
- ensureResultStructHasUserSemantic(sink, returnStructType, resultLayout);
- return;
}
- // If not, we need to wrap the result into a struct type.
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- IRStructType* structType = builder.createStructType();
- auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
- builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Output")).getUnownedSlice());
- auto key = builder.createStructKey();
- builder.addNameHintDecoration(key, toSlice("output"));
- builder.addLayoutDecoration(key, resultLayout);
- builder.addTargetSystemValueDecoration(key, toSlice("color(0)"));
- builder.createStructField(structType, key, returnType);
- IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder);
- structTypeLayoutBuilder.addField(key, resultLayout);
- auto typeLayout = structTypeLayoutBuilder.build();
- IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
- auto varLayout = varLayoutBuilder.build();
- ensureResultStructHasUserSemantic(sink, structType, varLayout);
-
- for (auto block : func->getBlocks())
+ void wrapReturnValueInStruct(EntryPointInfo entryPoint)
{
- if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ // Wrap return value into a struct if it is not already a struct.
+ // For example, given this entry point:
+ // ```
+ // float4 main() : SV_Target { return float3(1,2,3); }
+ // ```
+ // We are going to transform it into:
+ // ```
+ // struct Output {
+ // float4 value : SV_Target;
+ // };
+ // Output main() { return {float3(1,2,3)}; }
+
+ auto func = entryPoint.entryPointFunc;
+
+ auto returnType = func->getResultType();
+ if (as<IRVoidType>(returnType))
+ return;
+ auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>();
+ if (!entryPointLayoutDecor)
+ return;
+ auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout());
+ if (!entryPointLayout)
+ return;
+ auto resultLayout = entryPointLayout->getResultLayout();
+
+ // If return type is already a struct, just make sure every field has a semantic.
+ if (auto returnStructType = as<IRStructType>(returnType))
{
- builder.setInsertBefore(returnInst);
- auto returnVal = returnInst->getVal();
- auto newResult = builder.emitMakeStruct(structType, 1, &returnVal);
- returnInst->setOperand(0, newResult);
+ IRBuilder builder(func);
+ MapStructToFlatStruct mapOldFieldToNewField;
+ // Flatten result struct type to ensure we do not have nested semantics
+ auto flattenedStruct = maybeFlattenNestedStructs(builder, returnStructType, mapOldFieldToNewField, semanticInfoToRemove);
+ if (returnStructType != flattenedStruct)
+ {
+ // Replace all return-values with the flattenedStruct we made.
+ _replaceAllReturnInst(builder, func, flattenedStruct,
+ [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
+ {
+ auto srcStructType = as<IRStructType>(srcVal->getDataType());
+ SLANG_ASSERT(srcStructType);
+ auto dstVal = copyBuilder.emitVar(dstType);
+ mapOldFieldToNewField.emitCopy<(int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>(copyBuilder, dstVal, srcVal);
+ return builder.emitLoad(dstVal);
+ }
+ );
+ fixUpFuncType(func, flattenedStruct);
+ }
+ // Ensure non-overlapping semantics
+ fixFieldSemanticsOfFlatStruct(flattenedStruct);
+ ensureResultStructHasUserSemantic(flattenedStruct, resultLayout);
+ return;
}
- }
- fixUpFuncType(func, structType);
- }
- void legalizeMeshEntryPoint(EntryPointInfo entryPoint)
- {
- auto func = entryPoint.entryPointFunc;
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ IRStructType* structType = builder.createStructType();
+ auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
+ builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Output")).getUnownedSlice());
+ auto key = builder.createStructKey();
+ builder.addNameHintDecoration(key, toSlice("output"));
+ builder.addLayoutDecoration(key, resultLayout);
+ builder.createStructField(structType, key, returnType);
+ IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder);
+ structTypeLayoutBuilder.addField(key, resultLayout);
+ auto typeLayout = structTypeLayoutBuilder.build();
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+ auto varLayout = varLayoutBuilder.build();
+ ensureResultStructHasUserSemantic(structType, varLayout);
+
+ _replaceAllReturnInst(builder, func, structType,
+ [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
+ {
+ return copyBuilder.emitMakeStruct(dstType, 1, &srcVal);
+ }
+ );
- if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh)
- {
- return;
+ // Assign an appropriate system value semantic for stage output
+ auto stage = entryPoint.entryPointDecor->getProfile().getStage();
+ switch (stage)
+ {
+ case Stage::Compute:
+ case Stage::Fragment:
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("color(0)"));
+ break;
+ }
+ case Stage::Vertex:
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("position"));
+ break;
+ }
+ default:
+ SLANG_ASSERT(false);
+ return;
+ }
+
+ fixUpFuncType(func, structType);
}
- IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
- for (auto param : func->getParams())
+ void legalizeMeshEntryPoint(EntryPointInfo entryPoint)
{
- if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- {
- IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+ auto func = entryPoint.entryPointFunc;
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(param, paramVarLayout);
+ if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh)
+ {
+ return;
}
- }
- }
+ IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
+ for (auto param : func->getParams())
+ {
+ if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ {
+ IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(param, paramVarLayout);
+ }
+ }
- void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint)
- {
- if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification)
- {
- return;
}
- // Find out DispatchMesh function
- IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
- for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
+
+ void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint)
{
- if (const auto func = as<IRGlobalValueWithCode>(globalInst))
+ if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification)
+ {
+ return;
+ }
+ // Find out DispatchMesh function
+ IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
+ for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
{
- if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
+ if (const auto func = as<IRGlobalValueWithCode>(globalInst))
{
- if (dec->getName() == "DispatchMesh")
+ if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
{
- SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
- dispatchMeshFunc = func;
+ if (dec->getName() == "DispatchMesh")
+ {
+ SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
+ dispatchMeshFunc = func;
+ }
}
}
}
- }
-
- if (!dispatchMeshFunc)
- return;
- IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
+ if (!dispatchMeshFunc)
+ return;
- // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid
- traverseUses(dispatchMeshFunc, [&](const IRUse* use) {
- if (const auto call = as<IRCall>(use->getUser()))
- {
- SLANG_ASSERT(call->getArgCount() == 4);
- const auto payload = call->getArg(3);
+ IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
- const auto payloadPtrType = composeGetters<IRPtrType>(
- payload,
- &IRInst::getDataType
- );
- SLANG_ASSERT(payloadPtrType);
- const auto payloadType = payloadPtrType->getValueType();
- SLANG_ASSERT(payloadType);
+ // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid
+ traverseUses(dispatchMeshFunc, [&](const IRUse* use) {
+ if (const auto call = as<IRCall>(use->getUser()))
+ {
+ SLANG_ASSERT(call->getArgCount() == 4);
+ const auto payload = call->getArg(3);
- builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
- const auto annotatedPayloadType =
- builder.getPtrType(
- kIROp_RefType,
- payloadPtrType->getValueType(),
- AddressSpace::MetalObjectData
+ const auto payloadPtrType = composeGetters<IRPtrType>(
+ payload,
+ &IRInst::getDataType
);
- auto packedParam = builder.emitParam(annotatedPayloadType);
- builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload"));
- IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
-
- // Add the MetalPayload resource info, so we can emit [[payload]]
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(packedParam, paramVarLayout);
-
- // Now we replace the call to DispatchMesh with a call to the mesh grid properties
- // But first we need to create the parameter
- const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
- auto mgp = builder.emitParam(meshGridPropertiesType);
- builder.addExternCppDecoration(mgp, toSlice("_slang_mgp"));
- }
- });
- }
+ SLANG_ASSERT(payloadPtrType);
+ const auto payloadType = payloadPtrType->getValueType();
+ SLANG_ASSERT(payloadType);
+
+ builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+ const auto annotatedPayloadType =
+ builder.getPtrType(
+ kIROp_RefType,
+ payloadPtrType->getValueType(),
+ AddressSpace::MetalObjectData
+ );
+ auto packedParam = builder.emitParam(annotatedPayloadType);
+ builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload"));
+ IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+
+ // Add the MetalPayload resource info, so we can emit [[payload]]
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(packedParam, paramVarLayout);
+
+ // Now we replace the call to DispatchMesh with a call to the mesh grid properties
+ // But first we need to create the parameter
+ const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
+ auto mgp = builder.emitParam(meshGridPropertiesType);
+ builder.addExternCppDecoration(mgp, toSlice("_slang_mgp"));
+ }
+ });
+ }
- IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
- {
- auto fromType = val->getFullType();
- if (auto fromVector = as<IRVectorType>(fromType))
+ IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
{
- if (auto toVector = as<IRVectorType>(toType))
+ auto fromType = val->getFullType();
+ if (auto fromVector = as<IRVectorType>(fromType))
{
- if (fromVector->getElementCount() != toVector->getElementCount())
+ if (auto toVector = as<IRVectorType>(toType))
+ {
+ if (fromVector->getElementCount() != toVector->getElementCount())
+ {
+ fromType = builder.getVectorType(fromVector->getElementType(), toVector->getElementCount());
+ val = builder.emitVectorReshape(fromType, val);
+ }
+ }
+ else if (as<IRBasicType>(toType))
{
- fromType = builder.getVectorType(fromVector->getElementType(), toVector->getElementCount());
- val = builder.emitVectorReshape(fromType, val);
+ UInt index = 0;
+ val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
+ if (toType->getOp() == kIROp_VoidType)
+ return nullptr;
}
}
- else if (as<IRBasicType>(toType))
+ else if (auto fromBasicType = as<IRBasicType>(fromType))
{
- UInt index = 0;
- val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
+ if (fromBasicType->getOp() == kIROp_VoidType)
+ return nullptr;
+ if (!as<IRBasicType>(toType))
+ return nullptr;
if (toType->getOp() == kIROp_VoidType)
return nullptr;
}
- }
- else if (auto fromBasicType = as<IRBasicType>(fromType))
- {
- if (fromBasicType->getOp() == kIROp_VoidType)
- return nullptr;
- if (!as<IRBasicType>(toType))
- return nullptr;
- if (toType->getOp() == kIROp_VoidType)
+ else
+ {
return nullptr;
+ }
+ return builder.emitCast(toType, val);
}
- else
- {
- return nullptr;
- }
- return builder.emitCast(toType, val);
- }
-
- void legalizeSystemValueParameters(EntryPointInfo entryPoint, DiagnosticSink* sink)
- {
- SLANG_UNUSED(sink);
struct SystemValLegalizationWorkItem
{
- IRParam* param;
+ IRInst* var;
String attrName;
UInt attrIndex;
};
- List<SystemValLegalizationWorkItem> systemValWorkItems;
-
- IRBuilder builder(entryPoint.entryPointFunc);
- for (auto param : entryPoint.entryPointFunc->getParams())
+ std::optional<SystemValLegalizationWorkItem> tryToMakeSystemValWorkItem(IRInst* var)
{
- if (auto semanticDecoration = param->findDecoration<IRSemanticDecoration>())
+ if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>())
{
if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
{
- systemValWorkItems.add({ param, String(semanticDecoration->getSemanticName()).toLower(), (UInt)semanticDecoration->getSemanticIndex() });
- continue;
+ return { { var, String(semanticDecoration->getSemanticName()).toLower(), (UInt)semanticDecoration->getSemanticIndex() } };
}
}
- auto layoutDecor = param->findDecoration<IRLayoutDecoration>();
+ auto layoutDecor = var->findDecoration<IRLayoutDecoration>();
if (!layoutDecor)
- continue;
+ return {};
auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>();
if (!sysValAttr)
- continue;
+ return {};
auto semanticName = String(sysValAttr->getName());
auto sysAttrIndex = sysValAttr->getIndex();
- systemValWorkItems.add({ param, semanticName, sysAttrIndex });
+
+ return { { var, semanticName, sysAttrIndex } };
+ }
+
+
+ List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint)
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems;
+ for (auto param : entryPoint.entryPointFunc->getParams())
+ {
+ auto maybeWorkItem = tryToMakeSystemValWorkItem(param);
+ if (maybeWorkItem.has_value())
+ systemValWorkItems.add(std::move(maybeWorkItem.value()));
+ }
+ return systemValWorkItems;
}
- IRParam* groupThreadId = nullptr;
- for (auto index = 0; index < systemValWorkItems.getCount(); index++)
+ void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem)
{
- auto workItem = systemValWorkItems[index];
+ IRBuilder builder(entryPoint.entryPointFunc);
- auto param = workItem.param;
+ auto var = workItem.var;
auto semanticName = workItem.attrName;
- auto info = getSystemValueInfo(builder, semanticName);
+ auto indexAsString = String(workItem.attrIndex);
+ auto info = getSystemValueInfo(semanticName, &indexAsString, var);
+
if (info.isSpecial)
{
- if (semanticName == "sv_innercoverage")
+ if (info.metalSystemValueNameEnum == SystemValueSemanticName::InnerCoverage)
{
// Metal does not support conservative rasterization, so this is always false.
auto val = builder.getBoolValue(false);
- param->replaceUsesWith(val);
- param->removeAndDeallocate();
+ var->replaceUsesWith(val);
+ var->removeAndDeallocate();
}
- else if (semanticName == "sv_groupindex")
+ else if (info.metalSystemValueNameEnum == SystemValueSemanticName::GroupIndex)
{
- // Ensure we have a cached "sv_groupthreadid"
- if (!groupThreadId)
+ // Ensure we have a cached "sv_groupthreadid" in our entry point
+ if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc))
{
+ auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint);
for (auto i : systemValWorkItems)
{
- if (i.attrName == groupThreadIDString)
+ auto indexAsStringGroupThreadId = String(i.attrIndex);
+ if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var).metalSystemValueNameEnum == SystemValueSemanticName::GroupThreadID)
{
- groupThreadId = i.param;
+ entryPointToGroupThreadId[entryPoint.entryPointFunc] = i.var;
}
}
- if (!groupThreadId)
+ if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc))
{
// Add the missing groupthreadid needed to compute sv_groupindex
IRBuilder groupThreadIdBuilder(builder);
groupThreadIdBuilder.setInsertInto(entryPoint.entryPointFunc->getFirstBlock());
- groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder));
+ auto groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder));
+ entryPointToGroupThreadId[entryPoint.entryPointFunc] = groupThreadId;
groupThreadIdBuilder.addNameHintDecoration(groupThreadId, groupThreadIDString);
-
+
// Since "sv_groupindex" will be translated out to a global var and no longer be considered a system value
// we can reuse its layout and semantic info
Index foundRequiredDecorations = 0;
IRLayoutDecoration* layoutDecoration = nullptr;
UInt semanticIndex = 0;
- for (auto decoration : param->getDecorations())
+ for (auto decoration : var->getDecorations())
{
if (auto layoutDecorationTmp = as<IRLayoutDecoration>(decoration))
{
@@ -731,67 +1502,109 @@ namespace Slang
SLANG_ASSERT(layoutDecoration);
layoutDecoration->removeFromParent();
layoutDecoration->insertAtStart(groupThreadId);
- systemValWorkItems.add({ groupThreadId, groupThreadIDString, semanticIndex });
+ SystemValLegalizationWorkItem newWorkItem = { groupThreadId, groupThreadIDString, semanticIndex };
+ legalizeSystemValue(entryPoint, newWorkItem);
}
}
IRBuilder svBuilder(builder.getModule());
svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst());
auto computeExtent = emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, builder.getVectorType(builder.getUIntType(), builder.getIntValue(builder.getIntType(), 3)));
- auto groupIndexCalc = emitCalcGroupThreadIndex(svBuilder, groupThreadId, computeExtent);
+ auto groupIndexCalc = emitCalcGroupIndex(svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], computeExtent);
svBuilder.addNameHintDecoration(groupIndexCalc, UnownedStringSlice("sv_groupindex"));
- param->replaceUsesWith(groupIndexCalc);
- param->removeAndDeallocate();
+ var->replaceUsesWith(groupIndexCalc);
+ var->removeAndDeallocate();
}
}
if (info.isUnsupported)
{
- reportUnsupportedSystemAttribute(sink, param, semanticName);
- continue;
+ reportUnsupportedSystemAttribute(var, semanticName);
+ return;
}
- if (!info.requiredType)
- continue;
+ if (!info.permittedTypes.getCount())
+ return;
- builder.addTargetSystemValueDecoration(param, info.metalSystemValueName.getUnownedSlice());
+ builder.addTargetSystemValueDecoration(var, info.metalSystemValueName.getUnownedSlice());
- // If the required type is different from the actual type, we need to insert a conversion.
- if (info.requiredType != param->getFullType() && info.altRequiredType != param->getFullType())
+ bool varTypeIsPermitted = false;
+ auto varType = var->getFullType();
+ for (auto& permittedType : info.permittedTypes)
{
- auto targetType = param->getFullType();
- builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
- param->setFullType(info.requiredType);
- List<IRUse*> uses;
- for (auto use = param->firstUse; use; use = use->nextUse)
- uses.add(use);
- auto convertedValue = tryConvertValue(builder, param, targetType);
- copyNameHintAndDebugDecorations(convertedValue, param);
- if (!convertedValue)
- {
- // If we can't convert the value, report an error.
- StringBuilder typeNameSB;
- getTypeNameHint(typeNameSB, info.requiredType);
- sink->diagnose(param->sourceLoc, Diagnostics::systemValueTypeIncompatible, semanticName, typeNameSB.produceString());
- }
- else
+ varTypeIsPermitted = varTypeIsPermitted || permittedType == varType;
+ }
+
+ if (!varTypeIsPermitted)
+ {
+ // Note: we do not currently prefer any conversion
+ // example:
+ // * allowed types for semantic: `float4`, `uint4`, `int4`
+ // * user used, `float2`
+ // * Slang will equally prefer `float4` to `uint4` to `int4`.
+ // This means the type may lose data if slang selects `uint4` or `int4`.
+ bool foundAConversion = false;
+ for (auto permittedType : info.permittedTypes)
{
+ var->setFullType(permittedType);
+ builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+
+ // get uses before we `tryConvertValue` since this creates a new use
+ List<IRUse*> uses;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ uses.add(use);
+
+ auto convertedValue = tryConvertValue(builder, var, varType);
+ if (convertedValue == nullptr)
+ continue;
+
+ foundAConversion = true;
+ copyNameHintAndDebugDecorations(convertedValue, var);
+
for (auto use : uses)
builder.replaceOperand(use, convertedValue);
}
+ if (!foundAConversion)
+ {
+ // If we can't convert the value, report an error.
+ for (auto permittedType : info.permittedTypes)
+ {
+ StringBuilder typeNameSB;
+ getTypeNameHint(typeNameSB, permittedType);
+ m_sink->diagnose(var->sourceLoc, Diagnostics::systemValueTypeIncompatible, semanticName, typeNameSB.produceString());
+ }
+ }
}
}
- fixUpFuncType(entryPoint.entryPointFunc);
- }
- void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink)
- {
- hoistEntryPointParameterFromStruct(entryPoint);
- packStageInParameters(entryPoint);
- legalizeSystemValueParameters(entryPoint, sink);
- wrapReturnValueInStruct(sink, entryPoint);
- legalizeMeshEntryPoint(entryPoint);
- legalizeDispatchMeshPayloadForMetal(entryPoint);
- }
+ void legalizeSystemValueParameters(EntryPointInfo entryPoint)
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems = collectSystemValFromEntryPoint(entryPoint);
+
+ for (auto index = 0; index < systemValWorkItems.getCount(); index++)
+ {
+ legalizeSystemValue(entryPoint, systemValWorkItems[index]);
+ }
+ fixUpFuncType(entryPoint.entryPointFunc);
+ }
+
+ void legalizeEntryPointForMetal(EntryPointInfo entryPoint)
+ {
+ //Input Parameter Legalize
+ hoistEntryPointParameterFromStruct(entryPoint);
+ packStageInParameters(entryPoint);
+ flattenInputParameters(entryPoint);
+
+ //System Value Legalize
+ legalizeSystemValueParameters(entryPoint);
+
+ //Output Value Legalize
+ wrapReturnValueInStruct(entryPoint);
+
+ //Other Legalize
+ legalizeMeshEntryPoint(entryPoint);
+ legalizeDispatchMeshPayloadForMetal(entryPoint);
+ }
+ };
void legalizeFuncBody(IRFunc* func)
{
@@ -871,8 +1684,10 @@ namespace Slang
}
}
+ LegalizeMetalEntryPointContext context(sink, module);
for (auto entryPoint : entryPoints)
- legalizeEntryPointForMetal(entryPoint, sink);
+ context.legalizeEntryPointForMetal(entryPoint);
+ context.removeSemanticLayoutsFromLegalizedStructs();
specializeAddressSpace(module);
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index a4ea100dc..c6409a7e1 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -6267,9 +6267,9 @@ namespace Slang
addDecoration(inst, kIROp_HighLevelDeclDecoration, ptrConst);
}
- void IRBuilder::addLayoutDecoration(IRInst* value, IRLayout* layout)
+ IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRInst* value, IRLayout* layout)
{
- addDecoration(value, kIROp_LayoutDecoration, layout);
+ return as<IRLayoutDecoration>(addDecoration(value, kIROp_LayoutDecoration, layout));
}
IRTypeSizeAttr* IRBuilder::getTypeSizeAttr(
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 6a1080b0e..046c35ef9 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -480,8 +480,8 @@ static bool isDigit(char c)
return (c >= '0') && (c <= '9');
}
-void splitNameAndIndex(
- UnownedStringSlice const& text,
+bool splitNameAndIndex(
+ UnownedStringSlice const& text,
UnownedStringSlice& outName,
UnownedStringSlice& outDigits)
{
@@ -489,14 +489,20 @@ void splitNameAndIndex(
char const* digitsEnd = text.end();
char const* nameEnd = digitsEnd;
+ // ExplicitIndex is when a semantic has an index at the end of its name
+ // "SV_TARGET1" has an ExplicitIndex
+ // "SV_TARGET" does not have an ExplicitIndex
+ bool hasExplicitIndex = false;
while( nameEnd != nameBegin && isDigit(*(nameEnd - 1)) )
{
+ hasExplicitIndex = true;
nameEnd--;
}
char const* digitsBegin = nameEnd;
outName = UnownedStringSlice(nameBegin, nameEnd);
outDigits = UnownedStringSlice(digitsBegin, digitsEnd);
+ return hasExplicitIndex;
}
LayoutResourceKind findRegisterClassFromName(UnownedStringSlice const& registerClassName)
diff --git a/source/slang/slang-parameter-binding.h b/source/slang/slang-parameter-binding.h
index 3c0949e37..e2530e34e 100644
--- a/source/slang/slang-parameter-binding.h
+++ b/source/slang/slang-parameter-binding.h
@@ -32,7 +32,7 @@ void generateParameterBindings(
/// Given a string that specifies a name and index (e.g., `COLOR0`),
/// split it into slices for the name part and the index part.
///
-void splitNameAndIndex(
+bool splitNameAndIndex(
UnownedStringSlice const& text,
UnownedStringSlice& outName,
UnownedStringSlice& outDigits);
diff --git a/tests/compute/compile-time-loop.slang b/tests/compute/compile-time-loop.slang
index f69708e0c..9035bde2a 100644
--- a/tests/compute/compile-time-loop.slang
+++ b/tests/compute/compile-time-loop.slang
@@ -1,5 +1,5 @@
//TEST(compute):COMPARE_RENDER_COMPUTE: -shaderobj
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_RENDER_COMPUTE: -mtl -shaderobj
//TEST_INPUT: Texture2D(size=4, content = one):name t
//TEST_INPUT: Sampler:name s
diff --git a/tests/compute/constexpr.slang b/tests/compute/constexpr.slang
index 9aa5c1d56..9c7c9d131 100644
--- a/tests/compute/constexpr.slang
+++ b/tests/compute/constexpr.slang
@@ -1,7 +1,7 @@
// constexpr.slang
//TEST(compute):COMPARE_COMPUTE_EX:-slang -gcompute -shaderobj
//DISABLED://TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -gcompute -shaderobj
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_COMPUTE_EX:-mtl -gcompute -shaderobj
//TEST_INPUT: Texture2D(size=4, content = one):name tex
//TEST_INPUT: Sampler:name samp
diff --git a/tests/compute/discard-stmt.slang b/tests/compute/discard-stmt.slang
index 90a81c0ff..fa00c9ec3 100644
--- a/tests/compute/discard-stmt.slang
+++ b/tests/compute/discard-stmt.slang
@@ -1,5 +1,5 @@
//TEST(compute):COMPARE_RENDER_COMPUTE: -shaderobj
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_RENDER_COMPUTE: -mtl -shaderobj
//TEST_INPUT: Texture2D(size=4, content = one):name tex
//TEST_INPUT: Sampler:name samp
//TEST_INPUT: ubuffer(data=[0 0], stride=4):out,name outputBuffer
diff --git a/tests/compute/texture-sampling.slang b/tests/compute/texture-sampling.slang
index 89041dafb..0e319680a 100644
--- a/tests/compute/texture-sampling.slang
+++ b/tests/compute/texture-sampling.slang
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_RENDER_COMPUTE: -shaderobj -output-using-type
//TEST(compute):COMPARE_RENDER_COMPUTE: -shaderobj -output-using-type -vk
-//DISABLE_TEST(compute):COMPARE_COMPUTE:-slang -shaderobj -mtl
+//TEST(compute):COMPARE_RENDER_COMPUTE: -shaderobj -output-using-type -mtl
//TEST_INPUT: Texture1D(size=4, content = one):name=t1D
@@ -104,7 +104,7 @@ FragmentStageOutput fragmentMain(FragmentStageInput input)
val += tCubeArray.Sample(samplerState, float4(uv, 0.5, 0.0));
val += tCube.Sample(samplerState, float3(uv, 0.5));
- val += t2D.Load(int3(0), int2(0));
+ val += t2D.Load(int3(0));
val += t2dArray.Load(int4(0));
val += t3D[int3(0)];
diff --git a/tests/metal/atomic-intrinsics.slang b/tests/metal/atomic-intrinsics.slang
index 3533ea2aa..5d47db913 100644
--- a/tests/metal/atomic-intrinsics.slang
+++ b/tests/metal/atomic-intrinsics.slang
@@ -2,6 +2,7 @@
//TEST:SIMPLE(filecheck=LIB):-target metallib -entry computeMain -stage compute -DMETAL
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -dx12 -profile cs_6_0 -use-dxil -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-vk -emit-spirv-directly -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-vk -emit-spirv-via-glsl -compute -shaderobj -output-using-type
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -output-using-type
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
diff --git a/tests/metal/nested-struct-fragment-input.slang b/tests/metal/nested-struct-fragment-input.slang
new file mode 100644
index 000000000..727b5b1e5
--- /dev/null
+++ b/tests/metal/nested-struct-fragment-input.slang
@@ -0,0 +1,68 @@
+//TEST:SIMPLE(filecheck=METAL): -target metal -stage fragment -entry fragmentMain
+//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage fragment -entry fragmentMain
+
+// METAL: COARSEVERTEX_7
+// METAL: COARSEVERTEX_6
+
+// Ensure each attribute which may vary only appears once.
+// Ensure 1, 2, 3, 4 all appear
+
+// METAL-DAG: [[ATTR1:COARSEVERTEX_(1|2|3|4)]]
+
+// METAL-NOT: [[ATTR1]]
+// METAL-DAG: [[ATTR2:COARSEVERTEX_(1|2|3|4)]]
+
+// METAL: COARSEVERTEX{{(_0|())}}
+
+// METAL-NOT: [[ATTR2]]
+// METAL-DAG: [[ATTR3:COARSEVERTEX_(1|2|3|4)]]
+
+// METAL-NOT: [[ATTR3]]
+// METAL-DAG: [[ATTR4:COARSEVERTEX_(1|2|3|4)]]
+
+// METALLIB: @fragmentMain
+
+RWStructuredBuffer<float> outputBuffer;
+
+struct BottomFragment1
+{
+ float p1;
+};
+struct BottomFragment2
+{
+ float p1;
+};
+
+struct MiddleFragment1
+{
+ float p1;
+ BottomFragment1 p2;
+ BottomFragment2 p3;
+};
+struct TopFragment
+{
+ float p1 : CoarseVertex7;
+ MiddleFragment1 p2 : CoarseVertex6;
+ MiddleFragment1 p3 : CoarseVertex0;
+};
+
+struct FragmentStageInput
+{
+ TopFragment coarseVertex : CoarseVertex;
+};
+
+float4 fragmentMain(FragmentStageInput input)
+{
+ // METAL-DAG: {{.*}}->p1{{.*}}=
+
+ // METAL-DAG: {{.*}}->p2{{.*}}->p1{{.*}}=
+ // METAL-DAG: {{.*}}->p2{{.*}}->p2{{.*}}->p1{{.*}}=
+ // METAL-DAG: {{.*}}->p2{{.*}}->p3{{.*}}->p1{{.*}}=
+
+ // METAL-DAG: {{.*}}->p3{{.*}}->p1{{.*}}=
+ // METAL-DAG: {{.*}}->p3{{.*}}->p2{{.*}}->p1{{.*}}=
+ // METAL-DAG: {{.*}}->p3{{.*}}->p3{{.*}}->p1{{.*}}=
+
+ outputBuffer[0] = input.coarseVertex.p1 + input.coarseVertex.p2.p1 + +input.coarseVertex.p3.p1;
+ return float4(0, 0, 0, 0);
+}
diff --git a/tests/metal/nested-struct-fragment-output.slang b/tests/metal/nested-struct-fragment-output.slang
new file mode 100644
index 000000000..1d002c124
--- /dev/null
+++ b/tests/metal/nested-struct-fragment-output.slang
@@ -0,0 +1,74 @@
+//TEST:SIMPLE(filecheck=METAL): -target metal -stage fragment -entry fragmentMain
+//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage fragment -entry fragmentMain
+
+//METAL-DAG: color(0)
+//METAL-DAG: color(1)
+//METAL-DAG: color(2)
+//METAL-DAG: color(3)
+//METAL-DAG: color(4)
+//METAL-DAG: color(5)
+//METAL-DAG: color(6)
+//METAL-NOT: color(7)
+
+//METALLIB: @fragmentMain
+
+RWStructuredBuffer<float> outputBuffer;
+
+struct BottomFragment1
+{
+ float p1;
+};
+struct BottomFragment2
+{
+ float p1;
+};
+
+struct MiddleFragment1
+{
+ float p1;
+ BottomFragment1 p2;
+ BottomFragment2 p3;
+};
+struct TopFragment
+{
+ float p1;
+ MiddleFragment1 p2;
+ MiddleFragment1 p3;
+};
+
+struct FragmentStageInput
+{
+ float4 coarseVertex : CoarseVertex;
+};
+
+struct FragmentStageOutput
+{
+ TopFragment fragment : SV_Target;
+};
+
+FragmentStageOutput fragmentMain(FragmentStageInput input)
+{
+ FragmentStageOutput output;
+ output.fragment.p1 = 1;
+
+ output.fragment.p2.p1 = 3;
+ output.fragment.p2.p2.p1 = 4;
+ output.fragment.p2.p3.p1 = 5;
+
+ output.fragment.p3.p1 = 8;
+ output.fragment.p3.p2.p1 = 9;
+ output.fragment.p3.p3.p1 = 10;
+
+ // METAL-DAG: ={{.*}}.p1
+
+ // METAL-DAG: ={{.*}}.p2{{.*}}.p1
+ // METAL-DAG: ={{.*}}.p2{{.*}}.p2{{.*}}.p1
+ // METAL-DAG: ={{.*}}.p2{{.*}}.p3{{.*}}.p1
+
+ // METAL-DAG: ={{.*}}.p3{{.*}}.p1
+ // METAL-DAG: ={{.*}}.p3{{.*}}.p2{{.*}}.p1
+ // METAL-DAG: ={{.*}}.p3{{.*}}.p3{{.*}}.p1
+
+ outputBuffer[0] = 1;
+ return output;
+}
diff --git a/tests/metal/nested-struct-multi-entry-point-vertex.slang b/tests/metal/nested-struct-multi-entry-point-vertex.slang
new file mode 100644
index 000000000..779b66704
--- /dev/null
+++ b/tests/metal/nested-struct-multi-entry-point-vertex.slang
@@ -0,0 +1,45 @@
+//TEST:SIMPLE(filecheck=METAL1): -target metal -stage vertex -entry vertexMain1
+//TEST:SIMPLE(filecheck=METALLIB1): -target metallib -stage vertex -entry vertexMain1
+//TEST:SIMPLE(filecheck=METAL2): -target metal -stage vertex -entry vertexMain2
+//TEST:SIMPLE(filecheck=METALLIB2): -target metallib -stage vertex -entry vertexMain2
+
+//METALLIB1: @vertexMain1
+//METAL1-DAG: attribute(0)
+//METAL1-DAG: attribute(1)
+//METAL1-NOT: attribute(2)
+
+//METALLIB2: @vertexMain2
+//METAL2-DAG: attribute(0)
+//METAL2-DAG: attribute(1)
+//METAL2-DAG: attribute(2)
+
+struct SharedStruct
+{
+ float4 position;
+ float4 color;
+};
+
+struct VertexStageInput
+{
+ SharedStruct assembledVertex : CoarseVertex;
+};
+
+float4 vertexMain1(VertexStageInput vertex)
+{
+ return vertex.assembledVertex.position;
+}
+
+struct sharedStructWrapper
+{
+ float2 uv;
+ SharedStruct sharedData;
+};
+struct VertexStageInput2
+{
+ sharedStructWrapper assembledVertex : CoarseVertex;
+};
+
+float4 vertexMain2(VertexStageInput2 vertex)
+{
+ return vertex.assembledVertex.sharedData.position;
+} \ No newline at end of file
diff --git a/tests/metal/no-struct-vertex-output.slang b/tests/metal/no-struct-vertex-output.slang
new file mode 100644
index 000000000..f4988b685
--- /dev/null
+++ b/tests/metal/no-struct-vertex-output.slang
@@ -0,0 +1,12 @@
+//TEST:SIMPLE(filecheck=METAL): -target metallib -stage vertex -entry vertexMain
+//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage vertex -entry vertexMain
+
+//METAL-DAG: position
+//METALLIB: @vertexMain
+
+// Vertex Shader
+
+float4 vertexMain()
+{
+ return float4(1,1,1,1);
+} \ No newline at end of file
diff --git a/tests/metal/stage-in-2.slang b/tests/metal/stage-in-2.slang
index 2b1e61306..5bdf3f3eb 100644
--- a/tests/metal/stage-in-2.slang
+++ b/tests/metal/stage-in-2.slang
@@ -8,7 +8,7 @@
// CHECK: struct pixelInput
// CHECK-NEXT: {
-// CHECK-NEXT: CoarseVertex{{.*}} coarseVertex{{.*}} {{\[\[}}user(COARSEVERTEX){{\]\]}};
+// CHECK-NEXT: {{\[\[}}user(COARSEVERTEX){{\]\]}};
// Uniform data to be passed from application -> shader.
cbuffer Uniforms
diff --git a/tests/metal/sv_target-complex-1.slang b/tests/metal/sv_target-complex-1.slang
new file mode 100644
index 000000000..a830ff3d2
--- /dev/null
+++ b/tests/metal/sv_target-complex-1.slang
@@ -0,0 +1,34 @@
+//TEST:SIMPLE(filecheck=CHECK): -target metal
+//TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib
+
+struct NestedReturn
+{
+ float4 debug;
+};
+struct NestedReturn2
+{
+ float4 debugAux1;
+ float4 debugAux2;
+};
+
+// Semantics are supposed to ignore uppercase/lowercase differences
+struct Output
+{
+ float4 Diffuse : SV_TarGet0;
+ NestedReturn debug1 : SV_Target1;
+ float4 Material : SV_TArgeT2;
+ NestedReturn2 debug2 : SV_TaRget3;
+}
+
+// CHECK-ASM: define {{.*}} @fragmentMain
+// CHECK: color(0)
+// CHECK-DAG: color(1)
+// CHECK-DAG: color(2)
+// CHECK-DAG: color(3)
+// CHECK-DAG: color(4)
+
+[shader("fragment")]
+Output fragmentMain()
+{
+ return { float4(1), {float4(2)}, float4(3) };
+} \ No newline at end of file
diff --git a/tests/metal/sv_target-complex-2.slang b/tests/metal/sv_target-complex-2.slang
new file mode 100644
index 000000000..9cc59cc5f
--- /dev/null
+++ b/tests/metal/sv_target-complex-2.slang
@@ -0,0 +1,27 @@
+//TEST:SIMPLE(filecheck=CHECK): -target metal
+//TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib
+
+struct NestedReturn
+{
+ float4 debug1;
+ float4 debug2;
+};
+
+struct Output
+{
+ float4 Diffuse : SV_Target0;
+ NestedReturn val : SV_Target1;
+ float4 Material : SV_Target2;
+}
+
+// CHECK-ASM: define {{.*}} @fragmentMain
+// CHECK: color(0)
+// CHECK: color(1)
+// CHECK-DAG: color(3)
+// CHECK-DAG: color(2)
+
+[shader("fragment")]
+Output fragmentMain()
+{
+ return { float4(1), {float4(2), float4(2)}, float4(3) };
+} \ No newline at end of file