summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang2
-rw-r--r--source/slang/slang-diagnostic-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h6
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp83
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h61
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp1969
-rw-r--r--source/slang/slang-ir.cpp4
-rw-r--r--source/slang/slang-parameter-binding.cpp10
-rw-r--r--source/slang/slang-parameter-binding.h2
9 files changed, 1495 insertions, 646 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 17e2822a3..59a64a192 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -723,7 +723,6 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format>
[__readNone]
[ForceInline]
- [require(hlsl, texture_sm_4_0_fragment)]
T Sample(vector<float, Shape.dimensions+isArray> location, vector<int, Shape.planeDimensions> offset, float clamp, out uint status)
{
__target_switch
@@ -1328,7 +1327,6 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,0,format>
[__readNone]
[ForceInline]
- [require(hlsl, texture_sm_4_0_fragment)]
T Sample(SamplerState s, vector<float, Shape.dimensions+isArray> location, constexpr vector<int, Shape.planeDimensions> offset, float clamp, out uint status)
{
__target_switch
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 98af8a228..3e7d71369 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -708,6 +708,7 @@ DIAGNOSTIC(40002, Error, invalidBindingValue, "binding location '$0' is out of v
DIAGNOSTIC(40003, Error, bindingExceedsLimit, "binding location '$0' assigned to component '$1' exceeds maximum limit.")
DIAGNOSTIC(40004, Error, bindingAlreadyOccupiedByModule, "DescriptorSet ID '$0' is already occupied by module instance '$1'.")
DIAGNOSTIC(40005, Error, topLevelModuleUsedWithoutSpecifyingBinding, "top level module '$0' is being used without specifying binding location. Use [Binding: \"index\"] attribute to provide a binding location.")
+DIAGNOSTIC(40006, Error, unimplementedSystemValueSemantic, "unknown system-value semantic '$0'")
DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value semantic '$0'")
@@ -849,10 +850,9 @@ DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid param
DIAGNOSTIC(55200, Error, unsupportedBuiltinType, "'$0' is not a supported builtin type for the target.")
DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$0', but the current code generation target does not allow recursion.")
DIAGNOSTIC(55202, Error, systemValueAttributeNotSupported, "system value semantic '$0' is not supported for the current target.")
-DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or convertible to type '$1'.")
+DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or be convertible to type '$1'.")
DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'")
-
DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.")
DIAGNOSTIC(57003, Error, unknownTessPartitioning, "unknown tessellation partitioning '$0'.")
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fb9a9613a..83b38b3b6 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4405,7 +4405,7 @@ public:
}
// void addLayoutDecoration(IRInst* value, Layout* layout);
- void addLayoutDecoration(IRInst* value, IRLayout* layout);
+ IRLayoutDecoration* addLayoutDecoration(IRInst* value, IRLayout* layout);
// IRLayout* getLayout(Layout* astLayout);
@@ -4525,9 +4525,9 @@ public:
addDecoration(value, kIROp_ForceUnrollDecoration, getIntValue(getIntType(), iters));
}
- void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)
+ IRSemanticDecoration* addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)
{
- addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index));
+ return as<IRSemanticDecoration>(addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index)));
}
void addRequireSPIRVDescriptorIndexingExtensionDecoration(IRInst* value)
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index 582af4ac8..7d3e3c533 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -2,9 +2,34 @@
#include "slang-ir-legalize-varying-params.h"
#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+#include "slang-parameter-binding.h"
namespace Slang
{
+ // Convert semantic name (ignores case) into equivlent `SystemValueSemanticName`
+ SystemValueSemanticName convertSystemValueSemanticNameToEnum(String rawSemanticName)
+ {
+ auto semanticName = rawSemanticName.toLower();
+
+ SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None;
+
+#define CASE(ID, NAME) \
+ if(semanticName == String(#NAME).toLower()) \
+ { \
+ systemValueSemanticName = SystemValueSemanticName::ID; \
+ } \
+ else
+
+ SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
+#undef CASE
+ {
+ systemValueSemanticName = SystemValueSemanticName::Unknown;
+ // no match
+ }
+ return systemValueSemanticName;
+ }
// This pass implements logic to "legalize" the varying parameter
// signature of an entry point.
@@ -51,36 +76,6 @@ namespace Slang
// * Slang allows for `inout` varying parameters, which need to desugar into
// distinct `in` and `out` parameters for targets like GLSL.
-
-#define SYSTEM_VALUE_SEMANTIC_NAMES(M) \
- M(DispatchThreadID, SV_DispatchThreadID) \
- M(GroupID, SV_GroupID) \
- M(GroupThreadID, SV_GroupThreadID) \
- M(GroupThreadIndex, SV_GroupIndex) \
- /* end */
-
- /// A known system-value semantic name that can be applied to a parameter
- ///
-enum class SystemValueSemanticName
-{
- None = 0,
-
- // TODO: Should this enumeration be responsible for differentiating
- // cases where the same semantic name string is allowed in multiple stages,
- // or as both input/output in a single stage, and those different uses
- // might result in different meanings? The alternative is to always
- // pass around the semantic name, stage, and direction together so
- // that code can tell those special cases apart.
-
-#define CASE(ID, NAME) ID,
-SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
-#undef CASE
-
- // TODO: There are many more system-value semantic names that we
- // can/should handle here, but for now I've restricted this list
- // to those that are necessary for translating compute shaders.
-};
-
/// A placeholder that represents the value of a legalized varying
/// parameter, for the purposes of substituting it into IR code.
///
@@ -249,7 +244,7 @@ IRInst* emitCalcDispatchThreadID(
}
/// Emit code to calculate `SV_GroupIndex`
-IRInst* emitCalcGroupThreadIndex(
+IRInst* emitCalcGroupIndex(
IRBuilder& builder,
IRInst* groupThreadID,
IRInst* groupExtents)
@@ -935,23 +930,7 @@ protected:
// avoid all the `String`s we crete and thren throw
// away here.
//
- String semanticNameSpelling = semanticInst->getName();
- auto semanticName = semanticNameSpelling.toLower();
-
- SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None;
-
- #define CASE(ID, NAME) \
- if(semanticName == String(#NAME).toLower()) \
- { \
- systemValueSemanticName = SystemValueSemanticName::ID; \
- } \
- else
-
- SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
- #undef CASE
- {
- // no match
- }
+ auto systemValueSemanticName = convertSystemValueSemanticNameToEnum(String(semanticInst->getName()));
if( systemValueSemanticName != SystemValueSemanticName::None )
{
@@ -1223,7 +1202,7 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz
threadIdxGlobalParam,
blockDimGlobalParam);
- groupThreadIndex = emitCalcGroupThreadIndex(
+ groupThreadIndex = emitCalcGroupIndex(
builder,
threadIdxGlobalParam,
blockDimGlobalParam);
@@ -1254,7 +1233,7 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz
{
case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(blockIdxGlobalParam);
case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(threadIdxGlobalParam);
- case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
+ case SystemValueSemanticName::GroupIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID);
default:
return diagnoseUnsupportedSystemVal(info);
@@ -1397,7 +1376,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents);
- groupThreadIndex = emitCalcGroupThreadIndex(builder, groupThreadID, groupExtents);
+ groupThreadIndex = emitCalcGroupIndex(builder, groupThreadID, groupExtents);
}
LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info) SLANG_OVERRIDE
@@ -1414,7 +1393,7 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
{
case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(groupID);
case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(groupThreadID);
- case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
+ case SystemValueSemanticName::GroupIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex);
case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID);
default:
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
index 952192def..58efa39a2 100644
--- a/source/slang/slang-ir-legalize-varying-params.h
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -1,17 +1,18 @@
// slang-ir-legalize-varying-params.h
#pragma once
+#include "slang-ir-insts.h"
namespace Slang
{
class DiagnosticSink;
-struct IRFunc;
struct IRModule;
struct IRInst;
struct IRFunc;
struct IRVectorType;
struct IRBuilder;
+struct IREntryPointDecoration;
void legalizeEntryPointVaryingParamsForCPU(
IRModule* module,
@@ -21,13 +22,63 @@ void legalizeEntryPointVaryingParamsForCUDA(
IRModule* module,
DiagnosticSink* sink);
-IRInst* emitCalcGroupThreadIndex(
- IRBuilder& builder,
- IRInst* groupThreadID,
- IRInst* groupExtents);
+
+// (#4375) Once `slang-ir-metal-legalize.cpp` is merged with
+// `slang-ir-legalize-varying-params.cpp`, move the following
+// below into `slang-ir-legalize-varying-params.cpp` as well
IRInst* emitCalcGroupExtents(
IRBuilder& builder,
IRFunc* entryPoint,
IRVectorType* type);
+
+IRInst* emitCalcGroupIndex(
+ IRBuilder& builder,
+ IRInst* groupThreadID,
+ IRInst* groupExtents);
+
+// SystemValueSemanticName member definition macro
+#define SYSTEM_VALUE_SEMANTIC_NAMES(M) \
+ M(Position, SV_Position) \
+ M(ClipDistance, SV_ClipDistance) \
+ M(CullDistance, SV_CullDistance) \
+ M(Coverage, SV_Coverage) \
+ M(InnerCoverage, SV_InnerCoverage) \
+ M(Depth, SV_Depth) \
+ M(DepthGreaterEqual, SV_DepthGreaterEqual) \
+ M(DepthLessEqual, SV_DepthLessEqual) \
+ M(DispatchThreadID, SV_DispatchThreadID) \
+ M(DomainLsocation, SV_DomainLsocation) \
+ M(GroupID, SV_GroupID) \
+ M(GroupIndex, SV_GroupIndex) \
+ M(GroupThreadID, SV_GroupThreadID) \
+ M(GSInstanceID, SV_GSInstanceID) \
+ M(InstanceID, SV_InstanceID) \
+ M(IsFrontFace, SV_IsFrontFace) \
+ M(OutputControlPointID, SV_OutputControlPointID)\
+ M(PointSize, SV_PointSize) \
+ M(PrimitiveID, SV_PrimitiveID) \
+ M(RenderTargetArrayIndex, SV_RenderTargetArrayIndex) \
+ M(SampleIndex, SV_SampleIndex) \
+ M(StencilRef, SV_StencilRef) \
+ M(TessFactor, SV_TessFactor) \
+ M(VertexID, SV_VertexID) \
+ M(ViewID, SV_ViewID) \
+ M(ViewportArrayIndex, SV_ViewportArrayIndex) \
+ M(Target, SV_Target) \
+ /* end */
+
+/// A known system-value semantic name that can be applied to a parameter
+///
+enum class SystemValueSemanticName
+{
+ None = 0,
+ Unknown = 0,
+#define CASE(ID, NAME) ID,
+ SYSTEM_VALUE_SEMANTIC_NAMES(CASE)
+#undef CASE
+};
+
+SystemValueSemanticName convertSystemValueSemanticNameToEnum(String rawSemanticName);
+
}
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 01bcb0295..6b333f892 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -1,5 +1,6 @@
#include "slang-ir-metal-legalize.h"
+#include <set>
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
@@ -10,7 +11,6 @@
namespace Slang
{
- const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid");
struct EntryPointInfo
{
@@ -18,701 +18,1472 @@ namespace Slang
IREntryPointDecoration* entryPointDecor;
};
- void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint)
+ const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid");
+ struct LegalizeMetalEntryPointContext
{
- // If an entry point has a input parameter with a struct type, we want to hoist out
- // all the fields of the struct type to be individual parameters of the entry point.
- // This will canonicalize the entry point signature, so we can handle all cases uniformly.
-
- // For example, given an entry point:
- // ```
- // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID};
- // void main(VertexInput vin) { ... }
- // ```
- // We will transform it to:
- // ```
- // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
- // VertexInput vin = {pos,uv,vertexId};
- // ...
- // }
- // ```
-
- auto func = entryPoint.entryPointFunc;
- List<IRParam*> paramsToProcess;
- for (auto param : func->getParams())
- {
- if (as<IRStructType>(param->getDataType()))
- {
- paramsToProcess.add(param);
+ ShortList<IRType*> permittedTypes_sv_target;
+ Dictionary<IRFunc*, IRInst*> entryPointToGroupThreadId;
+ HashSet<IRStructField*> semanticInfoToRemove;
+
+ DiagnosticSink* m_sink;
+ IRModule* m_module;
+
+ LegalizeMetalEntryPointContext(DiagnosticSink* sink, IRModule* module) : m_sink(sink), m_module(module)
+ {
+ }
+
+ void removeSemanticLayoutsFromLegalizedStructs()
+ {
+ // Metal does not allow duplicate attributes to appear in the same shader.
+ // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]`
+ // must be removed.
+ for (auto field : semanticInfoToRemove)
+ {
+ auto key = field->getKey();
+ // Some decorations appear twice, destroy all found
+ for (;;)
+ {
+ if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ {
+ semanticDecor->removeAndDeallocate();
+ continue;
+ }
+ else if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
+ {
+ layoutDecor->removeAndDeallocate();
+ continue;
+ }
+ break;
+ }
}
}
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- for (auto param : paramsToProcess)
+ void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint)
{
- auto structType = as<IRStructType>(param->getDataType());
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto varLayout = findVarLayout(param);
-
- // If `param` already has a semantic, we don't want to hoist its fields out.
- if (varLayout->findSystemValueSemanticAttr() != nullptr ||
- param->findDecoration<IRSemanticDecoration>())
- continue;
-
- IRStructTypeLayout* structTypeLayout = nullptr;
- if (varLayout)
- structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
- Index fieldIndex = 0;
- List<IRInst*> fieldParams;
- for (auto field : structType->getFields())
+ // If an entry point has a input parameter with a struct type, we want to hoist out
+ // all the fields of the struct type to be individual parameters of the entry point.
+ // This will canonicalize the entry point signature, so we can handle all cases uniformly.
+
+ // For example, given an entry point:
+ // ```
+ // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID};
+ // void main(VertexInput vin) { ... }
+ // ```
+ // We will transform it to:
+ // ```
+ // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
+ // VertexInput vin = {pos,uv,vertexId};
+ // ...
+ // }
+ // ```
+
+ auto func = entryPoint.entryPointFunc;
+ List<IRParam*> paramsToProcess;
+ for (auto param : func->getParams())
{
- auto fieldParam = builder.emitParam(field->getFieldType());
-
- IRCloneEnv cloneEnv;
- cloneInstDecorationsAndChildren(&cloneEnv, builder.getModule(), field->getKey(), fieldParam);
+ if (as<IRStructType>(param->getDataType()))
+ {
+ paramsToProcess.add(param);
+ }
+ }
+
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ for (auto param : paramsToProcess)
+ {
+ auto structType = as<IRStructType>(param->getDataType());
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto varLayout = findVarLayout(param);
- IRVarLayout* fieldLayout = structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr;
+ // If `param` already has a semantic, we don't want to hoist its fields out.
+ if (varLayout->findSystemValueSemanticAttr() != nullptr ||
+ param->findDecoration<IRSemanticDecoration>())
+ continue;
+
+ IRStructTypeLayout* structTypeLayout = nullptr;
if (varLayout)
+ structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
+ Index fieldIndex = 0;
+ List<IRInst*> fieldParams;
+ for (auto field : structType->getFields())
{
- IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout());
- varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout);
- for (auto offsetAttr : fieldLayout->getOffsetAttrs())
+ auto fieldParam = builder.emitParam(field->getFieldType());
+ IRCloneEnv cloneEnv;
+ cloneInstDecorationsAndChildren(&cloneEnv, builder.getModule(), field->getKey(), fieldParam);
+
+ IRVarLayout* fieldLayout = structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr;
+ if (varLayout)
{
- auto parentOffsetAttr = varLayout->findOffsetAttr(offsetAttr->getResourceKind());
- UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0;
- UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0;
- auto resInfo = varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
- resInfo->offset = parentOffset + offsetAttr->getOffset();
- resInfo->space = parentSpace + offsetAttr->getSpace();
+ IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout());
+ varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout);
+ for (auto offsetAttr : fieldLayout->getOffsetAttrs())
+ {
+ auto parentOffsetAttr = varLayout->findOffsetAttr(offsetAttr->getResourceKind());
+ UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0;
+ UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0;
+ auto resInfo = varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
+ resInfo->offset = parentOffset + offsetAttr->getOffset();
+ resInfo->space = parentSpace + offsetAttr->getSpace();
+ }
+ builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build());
}
- builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build());
+ param->insertBefore(fieldParam);
+ fieldParams.add(fieldParam);
+ fieldIndex++;
}
- param->insertBefore(fieldParam);
- fieldParams.add(fieldParam);
- fieldIndex++;
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto reconstructedParam = builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer());
+ param->replaceUsesWith(reconstructedParam);
+ param->removeFromParent();
}
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto reconstructedParam = builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer());
- param->replaceUsesWith(reconstructedParam);
- param->removeFromParent();
+ fixUpFuncType(func);
}
- fixUpFuncType(func);
- }
- void packStageInParameters(EntryPointInfo entryPoint)
- {
- // If the entry point has any parameters whose layout contains VaryingInput,
- // we need to pack those parameters into a single `struct` type, and decorate
- // the fields with the appropriate `[[attribute]]` decorations.
- // For other parameters that are not `VaryingInput`, we need to leave them as is.
- //
- // For example, given this code after `hoistEntryPointParameterFromStruct`:
- // ```
- // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
- // VertexInput vin = {pos,uv,vertexId};
- // ...
- // }
- // ```
- // We are going to transform it into:
- // ```
- // struct VertexInput {
- // float3 pos [[attribute(0)]];
- // float2 uv [[attribute(1)]];
- // };
- // void main(VertexInput vin, int vertexId : SV_VertexID) {
- // let pos = vin.pos;
- // let uv = vin.uv;
- // ...
- // }
-
- auto func = entryPoint.entryPointFunc;
-
- bool isGeometryStage = false;
- switch (entryPoint.entryPointDecor->getProfile().getStage())
- {
- case Stage::Vertex:
- case Stage::Amplification:
- case Stage::Mesh:
- case Stage::Geometry:
- case Stage::Domain:
- case Stage::Hull:
- isGeometryStage = true;
- break;
- }
+ // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct
+ void flattenInputParameters(EntryPointInfo entryPoint)
+ {
+ // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members).
+ /*
+ // Assume the following code
+ struct NestedFragment
+ {
+ float2 p3;
+ };
+ struct Fragment
+ {
+ float4 p1;
+ float3 p2;
+ NestedFragment p3_nested;
+ };
- List<IRParam*> paramsToPack;
- for (auto param : func->getParams())
- {
- auto layout = findVarLayout(param);
- if (!layout)
- continue;
- if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
- continue;
- if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- continue;
- paramsToPack.add(param);
- }
+ // Fragment flattens into
+ struct Fragment
+ {
+ float4 p1;
+ float3 p2;
+ float2 p3;
+ };
+ */
+
+ // This is important since Metal does not allow semantic's on a struct
+ /*
+ // Assume the following code
+ struct NestedFragment1
+ {
+ float2 p3;
+ };
+ struct Fragment1
+ {
+ float4 p1 : SV_TARGET0;
+ float3 p2 : SV_TARGET1;
+ NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct
+ };
+
+ */
+
+ // Metal does allow semantics on members of a nested struct but we are avoiding this
+ // approach since there are senarios where legalization (and verification) is
+ // hard/expensive without creating a flat struct:
+ // 1. Entry points may share structs, semantics may be inconsistent across entry points
+ // 2. Multiple of the same struct may be used in a param list
+ /*
+ // Assume the following code
+ struct NestedFragment
+ {
+ float2 p3;
+ };
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ NestedFragment p2 : SV_TARGET1;
+ NestedFragment p3 : SV_TARGET2;
+ };
- if (paramsToPack.getCount() == 0)
- return;
+ // Legalized without flattening -- abandoned
+ struct NestedFragment1
+ {
+ float2 p3 : SV_TARGET1;
+ };
+ struct NestedFragment2
+ {
+ float2 p3 : SV_TARGET2;
+ };
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ NestedFragment1 p2;
+ NestedFragment2 p3;
+ };
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- IRStructType* structType = builder.createStructType();
- auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
- builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Input")).getUnownedSlice());
- List<IRStructKey*> keys;
- IRStructTypeLayout::Builder layoutBuilder(&builder);
- for (auto param : paramsToPack)
- {
- auto paramVarLayout = findVarLayout(param);
- auto key = builder.createStructKey();
- param->transferDecorationsTo(key);
- builder.createStructField(structType, key, param->getDataType());
- if (auto varyingInOffsetAttr = paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput))
- {
- if (!key->findDecoration<IRSemanticDecoration>() && !paramVarLayout->findAttr<IRSemanticAttr>())
+ // Legalized with flattening -- current approach
+ struct Fragment
{
- // If the parameter doesn't have a semantic, we need to add one for semantic matching.
- builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varyingInOffsetAttr->getOffset());
- }
- }
- if (isGeometryStage)
+ float4 p1 : SV_TARGET0;
+ float2 p2 : SV_TARGET1;
+ float2 p3 : SV_TARGET2;
+ };
+ */
+
+ auto func = entryPoint.entryPointFunc;
+ bool modified = false;
+ for (auto param : func->getParams())
{
- // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute offsets.
- IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout());
- elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout);
- for (auto offsetAttr : paramVarLayout->getOffsetAttrs())
+ auto layout = findVarLayout(param);
+ if (!layout)
+ continue;
+ if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ continue;
+ // If we find a IRParam with a IRStructType member, we need to flatten the entire IRParam
+ if (auto structType = as<IRStructType>(param->getDataType()))
{
- auto resourceKind = offsetAttr->getResourceKind();
- if (resourceKind == LayoutResourceKind::VaryingInput)
+ IRBuilder builder(func);
+ MapStructToFlatStruct mapOldFieldToNewField;
+
+ // Flatten struct if we have nested IRStructType
+ auto flattenedStruct = maybeFlattenNestedStructs(builder, structType, mapOldFieldToNewField, semanticInfoToRemove);
+ if (flattenedStruct != structType)
{
- resourceKind = LayoutResourceKind::MetalAttribute;
+ // Validate/rearange all semantics which overlap in our flat struct
+ fixFieldSemanticsOfFlatStruct(flattenedStruct);
+
+ // Replace the 'old IRParam type' with a 'new IRParam type'
+ param->setFullType(flattenedStruct);
+
+ // Emit a new variable at EntryPoint of 'old IRParam type'
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto dstVal = builder.emitVar(structType);
+ auto dstLoad = builder.emitLoad(dstVal);
+ param->replaceUsesWith(dstLoad);
+ builder.setInsertBefore(dstLoad);
+ // Copy the 'new IRParam type' to our 'old IRParam type'
+ mapOldFieldToNewField.emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>(builder, dstVal, param);
+
+ modified = true;
}
- auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind);
- resInfo->offset = offsetAttr->getOffset();
- resInfo->space = offsetAttr->getSpace();
}
- paramVarLayout = elementVarLayoutBuilder.build();
}
- layoutBuilder.addField(key, paramVarLayout);
- builder.addLayoutDecoration(key, paramVarLayout);
- keys.add(key);
- }
- builder.setInsertInto(func->getFirstBlock());
- auto packedParam = builder.emitParamAtHead(structType);
- auto typeLayout = layoutBuilder.build();
- IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
-
- // Add a VaryingInput resource info to the packed parameter layout, so that we can emit
- // the needed `[[stage_in]]` attribute in Metal emitter.
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(packedParam, paramVarLayout);
-
- // Replace the original parameters with the packed parameter
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++)
- {
- auto param = paramsToPack[paramIndex];
- auto key = keys[paramIndex];
- auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key);
- param->replaceUsesWith(paramField);
- param->removeFromParent();
+ if (modified)
+ fixUpFuncType(func);
}
- fixUpFuncType(func);
- }
-
- struct MetalSystemValueInfo
- {
- String metalSystemValueName;
- IRType* requiredType;
- IRType* altRequiredType;
- bool isUnsupported;
- bool isSpecial;
- };
- IRType* getGroupThreadIdType(IRBuilder& builder)
- {
- return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
- }
-
- MetalSystemValueInfo getSystemValueInfo(IRBuilder& builder, String inSemanticName)
- {
- inSemanticName = inSemanticName.toLower();
-
- UnownedStringSlice semanticName;
- UnownedStringSlice semanticIndex;
- splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
+ void packStageInParameters(EntryPointInfo entryPoint)
+ {
+ // If the entry point has any parameters whose layout contains VaryingInput,
+ // we need to pack those parameters into a single `struct` type, and decorate
+ // the fields with the appropriate `[[attribute]]` decorations.
+ // For other parameters that are not `VaryingInput`, we need to leave them as is.
+ //
+ // For example, given this code after `hoistEntryPointParameterFromStruct`:
+ // ```
+ // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
+ // VertexInput vin = {pos,uv,vertexId};
+ // ...
+ // }
+ // ```
+ // We are going to transform it into:
+ // ```
+ // struct VertexInput {
+ // float3 pos [[attribute(0)]];
+ // float2 uv [[attribute(1)]];
+ // };
+ // void main(VertexInput vin, int vertexId : SV_VertexID) {
+ // let pos = vin.pos;
+ // let uv = vin.uv;
+ // ...
+ // }
+
+ auto func = entryPoint.entryPointFunc;
+
+ bool isGeometryStage = false;
+ switch (entryPoint.entryPointDecor->getProfile().getStage())
+ {
+ case Stage::Vertex:
+ case Stage::Amplification:
+ case Stage::Mesh:
+ case Stage::Geometry:
+ case Stage::Domain:
+ case Stage::Hull:
+ isGeometryStage = true;
+ break;
+ }
- MetalSystemValueInfo result = {};
+ List<IRParam*> paramsToPack;
+ for (auto param : func->getParams())
+ {
+ auto layout = findVarLayout(param);
+ if (!layout)
+ continue;
+ if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ continue;
+ paramsToPack.add(param);
+ }
- if (semanticName == "sv_position")
- {
- result.metalSystemValueName = toSlice("position");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 4));
- }
- else if (semanticName == "sv_clipdistance")
- {
- result.isSpecial = true;
- }
- else if (semanticName == "sv_culldistance")
- {
- result.isSpecial = true;
- }
- else if (semanticName == "sv_coverage")
- {
- result.metalSystemValueName = toSlice("sample_mask");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ if (paramsToPack.getCount() == 0)
+ return;
+
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ IRStructType* structType = builder.createStructType();
+ auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
+ builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Input")).getUnownedSlice());
+ List<IRStructKey*> keys;
+ IRStructTypeLayout::Builder layoutBuilder(&builder);
+ for (auto param : paramsToPack)
+ {
+ auto paramVarLayout = findVarLayout(param);
+ auto key = builder.createStructKey();
+ param->transferDecorationsTo(key);
+ builder.createStructField(structType, key, param->getDataType());
+ if (auto varyingInOffsetAttr = paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ {
+ if (!key->findDecoration<IRSemanticDecoration>() && !paramVarLayout->findAttr<IRSemanticAttr>())
+ {
+ // If the parameter doesn't have a semantic, we need to add one for semantic matching.
+ builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varyingInOffsetAttr->getOffset());
+ }
+ }
+ if (isGeometryStage)
+ {
+ // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute offsets.
+ IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout());
+ elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout);
+ for (auto offsetAttr : paramVarLayout->getOffsetAttrs())
+ {
+ auto resourceKind = offsetAttr->getResourceKind();
+ if (resourceKind == LayoutResourceKind::VaryingInput)
+ {
+ resourceKind = LayoutResourceKind::MetalAttribute;
+ }
+ auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind);
+ resInfo->offset = offsetAttr->getOffset();
+ resInfo->space = offsetAttr->getSpace();
+ }
+ paramVarLayout = elementVarLayoutBuilder.build();
+ }
+ layoutBuilder.addField(key, paramVarLayout);
+ builder.addLayoutDecoration(key, paramVarLayout);
+ keys.add(key);
+ }
+ builder.setInsertInto(func->getFirstBlock());
+ auto packedParam = builder.emitParamAtHead(structType);
+ auto typeLayout = layoutBuilder.build();
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+
+ // Add a VaryingInput resource info to the packed parameter layout, so that we can emit
+ // the needed `[[stage_in]]` attribute in Metal emitter.
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(packedParam, paramVarLayout);
+
+ // Replace the original parameters with the packed parameter
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++)
+ {
+ auto param = paramsToPack[paramIndex];
+ auto key = keys[paramIndex];
+ auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key);
+ param->replaceUsesWith(paramField);
+ param->removeFromParent();
+ }
+ fixUpFuncType(func);
}
- else if (semanticName == "sv_innercoverage")
- {
- result.isSpecial = true;
- }
- else if (semanticName == "sv_depth")
- {
- result.metalSystemValueName = toSlice("depth(any)");
- result.requiredType = builder.getBasicType(BaseType::Float);
- }
- else if (semanticName == "sv_depthgreaterequal")
+ struct MetalSystemValueInfo
{
- result.metalSystemValueName = toSlice("depth(greater)");
- result.requiredType = builder.getBasicType(BaseType::Float);
- }
- else if (semanticName == "sv_depthlessequal")
- {
- result.metalSystemValueName = toSlice("depth(less)");
- result.requiredType = builder.getBasicType(BaseType::Float);
- }
- else if (semanticName == "sv_dispatchthreadid")
- {
- result.metalSystemValueName = toSlice("thread_position_in_grid");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
- }
- else if (semanticName == "sv_domainlocation")
- {
- result.metalSystemValueName = toSlice("position_in_patch");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 3));
- result.altRequiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 2));
- }
- else if (semanticName == "sv_groupid")
- {
- result.metalSystemValueName = toSlice("threadgroup_position_in_grid");
- result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
- }
- else if (semanticName == "sv_groupindex")
- {
- result.isSpecial = true;
- }
- else if (semanticName == "sv_groupthreadid")
- {
- result.metalSystemValueName = toSlice("thread_position_in_threadgroup");
- result.requiredType = getGroupThreadIdType(builder);
- }
- else if (semanticName == "sv_gsinstanceid")
+ String metalSystemValueName;
+ SystemValueSemanticName metalSystemValueNameEnum;
+ ShortList<IRType*> permittedTypes;
+ bool isUnsupported = false;
+ bool isSpecial = false;
+ MetalSystemValueInfo()
+ {
+ // most commonly need 2
+ permittedTypes.reserveOverflowBuffer(2);
+ }
+ };
+
+ IRType* getGroupThreadIdType(IRBuilder& builder)
{
- // Metal does not have geometry shader, so this is invalid.
- result.isUnsupported = true;
+ return builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3));
}
- else if (semanticName == "sv_instanceid")
+
+ // Get all permitted types of "sv_target" for Metal
+ ShortList<IRType*>& getPermittedTypes_sv_target(IRBuilder& builder)
{
- result.metalSystemValueName = toSlice("instance_id");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ permittedTypes_sv_target.reserveOverflowBuffer(5 * 4);
+ if (permittedTypes_sv_target.getCount() == 0)
+ {
+ for (auto baseType : { BaseType::Float, BaseType::Half, BaseType::Int, BaseType::UInt, BaseType::Int16, BaseType::UInt16 })
+ {
+ for (IRIntegerValue i = 1; i <= 4; i++)
+ {
+ permittedTypes_sv_target.add(builder.getVectorType(builder.getBasicType(baseType), i));
+ }
+ }
+ }
+ return permittedTypes_sv_target;
}
- else if (semanticName == "sv_isfrontface")
+
+ MetalSystemValueInfo getSystemValueInfo(String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar)
{
- result.metalSystemValueName = toSlice("front_facing");
- result.requiredType = builder.getBasicType(BaseType::Bool);
+ IRBuilder builder(m_module);
+ MetalSystemValueInfo result = {};
+ UnownedStringSlice semanticName;
+ UnownedStringSlice semanticIndex;
+
+ auto hasExplicitIndex = splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
+ if (!hasExplicitIndex && optionalSemanticIndex)
+ semanticIndex = optionalSemanticIndex->getUnownedSlice();
+
+ result.metalSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName);
+
+ switch (result.metalSystemValueNameEnum)
+ {
+ case SystemValueSemanticName::Position:
+ {
+ result.metalSystemValueName = toSlice("position");
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 4)));
+ break;
+ }
+ case SystemValueSemanticName::ClipDistance:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::CullDistance:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::Coverage:
+ {
+ result.metalSystemValueName = toSlice("sample_mask");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::InnerCoverage:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::Depth:
+ {
+ result.metalSystemValueName = toSlice("depth(any)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DepthGreaterEqual:
+ {
+ result.metalSystemValueName = toSlice("depth(greater)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DepthLessEqual:
+ {
+ result.metalSystemValueName = toSlice("depth(less)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DispatchThreadID:
+ {
+ result.metalSystemValueName = toSlice("thread_position_in_grid");
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)));
+ break;
+ }
+ case SystemValueSemanticName::DomainLsocation:
+ {
+ result.metalSystemValueName = toSlice("position_in_patch");
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 3)));
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 2)));
+ break;
+ }
+ case SystemValueSemanticName::GroupID:
+ {
+ result.metalSystemValueName = toSlice("threadgroup_position_in_grid");
+ result.permittedTypes.add(builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)));
+ break;
+ }
+ case SystemValueSemanticName::GroupIndex:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::GroupThreadID:
+ {
+ result.metalSystemValueName = toSlice("thread_position_in_threadgroup");
+ result.permittedTypes.add(getGroupThreadIdType(builder));
+ break;
+ }
+ case SystemValueSemanticName::GSInstanceID:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+ case SystemValueSemanticName::InstanceID:
+ {
+ result.metalSystemValueName = toSlice("instance_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::IsFrontFace:
+ {
+ result.metalSystemValueName = toSlice("front_facing");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Bool));
+ break;
+ }
+ case SystemValueSemanticName::OutputControlPointID:
+ {
+ // In metal, a hull shader is just a compute shader.
+ // This needs to be handled separately, by lowering into an ordinary buffer.
+ break;
+ }
+ case SystemValueSemanticName::PointSize:
+ {
+ result.metalSystemValueName = toSlice("point_size");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::PrimitiveID:
+ {
+ result.metalSystemValueName = toSlice("patch_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::RenderTargetArrayIndex:
+ {
+ result.metalSystemValueName = toSlice("render_target_array_index");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::SampleIndex:
+ {
+ result.metalSystemValueName = toSlice("sample_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::StencilRef:
+ {
+ result.metalSystemValueName = toSlice("stencil");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::TessFactor:
+ {
+ // Tessellation factor outputs should be lowered into a write into a normal buffer.
+ break;
+ }
+ case SystemValueSemanticName::VertexID:
+ {
+ result.metalSystemValueName = toSlice("vertex_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::ViewID:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+ case SystemValueSemanticName::ViewportArrayIndex:
+ {
+ result.metalSystemValueName = toSlice("viewport_array_index");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::Target:
+ {
+ result.metalSystemValueName = (StringBuilder() << "color("
+ << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0"))
+ << ")").produceString();
+ result.permittedTypes = getPermittedTypes_sv_target(builder);
+
+ break;
+ }
+ default:
+ m_sink->diagnose(parentVar, Diagnostics::unimplementedSystemValueSemantic, semanticName);
+ return result;
+ }
+ return result;
}
- else if (semanticName == "sv_outputcontrolpointid")
+
+ void reportUnsupportedSystemAttribute(IRInst* param, String semanticName)
{
- // In metal, a hull shader is just a compute shader.
- // This needs to be handled separately, by lowering into an ordinary buffer.
+ m_sink->diagnose(param->sourceLoc, Diagnostics::systemValueAttributeNotSupported, semanticName);
}
- else if (semanticName == "sv_pointsize")
+
+ void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout)
{
- result.metalSystemValueName = toSlice("point_size");
- result.requiredType = builder.getBasicType(BaseType::Float);
+ // Ensure each field in an output struct type has either a system semantic or a user semantic,
+ // so that signature matching can happen correctly.
+ auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
+ Index index = 0;
+ IRBuilder builder(structType);
+ for (auto field : structType->getFields())
+ {
+ auto key = field->getKey();
+ if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ {
+ if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ {
+ auto indexAsString = String(UInt(semanticDecor->getSemanticIndex()));
+ auto sysValInfo = getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field);
+ if (sysValInfo.isUnsupported || sysValInfo.isSpecial)
+ {
+ reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName());
+ }
+ else
+ {
+ builder.addTargetSystemValueDecoration(key, sysValInfo.metalSystemValueName.getUnownedSlice());
+ semanticDecor->removeAndDeallocate();
+ }
+ }
+ index++;
+ continue;
+ }
+ typeLayout->getFieldLayout(index);
+ auto fieldLayout = typeLayout->getFieldLayout(index);
+ if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
+ {
+ UInt varOffset = 0;
+ if (auto varOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
+ varOffset = varOffsetAttr->getOffset();
+ varOffset += offsetAttr->getOffset();
+ builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset);
+ }
+ index++;
+ }
}
- else if (semanticName == "sv_primitiveid")
+
+ // Stores a hicharchy of members and children which map 'oldStruct->member' to 'flatStruct->member'
+ // Note: this map assumes we map to FlatStruct since it is easier/faster to process
+ struct MapStructToFlatStruct
{
- result.metalSystemValueName = toSlice("patch_id");
- result.requiredType = builder.getBasicType(BaseType::UInt);
- result.altRequiredType = builder.getBasicType(BaseType::UInt16);
- }
- else if (semanticName == "sv_rendertargetarrayindex")
+ /*
+ We need a hicharchy map to resolve dependencies for mapping
+ oldStruct to newStruct efficently. Example:
+
+ MyStruct
+ |
+ / | \
+ / | \
+ / | \
+ M0<A> M1<A> M2<B>
+ | | |
+ A_0 A_0 B_0
+
+ Without storing hicharchy information, there will be no way to tell apart
+ `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField
+ only has 1 instance of `A::A0`
+ */
+
+ enum CopyOptions : int
{
- result.metalSystemValueName = toSlice("render_target_array_index");
- result.requiredType = builder.getBasicType(BaseType::UInt);
- result.altRequiredType = builder.getBasicType(BaseType::UInt16);
- }
- else if (semanticName == "sv_sampleindex")
+ // Copy a flattened-struct into a struct
+ FlatStructIntoStruct = 0,
+
+ // Copy a struct into a flattened-struct
+ StructIntoFlatStruct = 1,
+ };
+
+ private:
+ // Children of member if applicable.
+ Dictionary<IRStructField*, MapStructToFlatStruct> members;
+
+ // Field correlating to MapStructToFlatStruct Node.
+ IRInst* node;
+ IRStructKey* getKey()
+ {
+ SLANG_ASSERT(as<IRStructField>(node));
+ return as<IRStructField>(node)->getKey();
+ }
+ IRInst* getNode()
+ {
+ return node;
+ }
+ IRType* getFieldType()
+ {
+ SLANG_ASSERT(as<IRStructField>(node));
+ return as<IRStructField>(node)->getFieldType();
+ }
+
+ // Whom node maps to inside target flatStruct
+ IRStructField* targetMapping;
+
+ auto begin()
+ {
+ return members.begin();
+ }
+ auto end()
+ {
+ return members.end();
+ }
+
+ // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to members in val2
+ // using `MapStructToFlatStruct`
+ template<int copyOptions>
+ static void _emitCopy(IRBuilder& builder, IRInst* val1, IRStructType* type1, IRInst* val2, IRStructType* type2, MapStructToFlatStruct& node)
+ {
+ for (auto& field1Pair : node)
+ {
+ auto& field1 = field1Pair.second;
+
+ // Get member of val1
+ IRInst* fieldAddr1 = nullptr;
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey());
+ }
+ else
+ {
+ if (as<IRPtrTypeBase>(val1))
+ val1 = builder.emitLoad(val1);
+ fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey());
+ }
+
+ // If val1 is a struct, recurse
+ if (auto fieldAsStruct1 = as<IRStructType>(field1.getFieldType()))
+ {
+ _emitCopy<copyOptions>(builder, fieldAddr1, fieldAsStruct1, val2, type2, field1);
+ continue;
+ }
+
+ // Get member of val2 which maps to val1.member
+ auto field2 = field1.getMapping();
+ SLANG_ASSERT(field2);
+ IRInst* fieldAddr2 = nullptr;
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ if (as<IRPtrTypeBase>(val2))
+ val2 = builder.emitLoad(val1);
+ fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey());
+ }
+ else
+ {
+ fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey());
+ }
+
+ // Copy val2/val1 member into val1/val2 member
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ builder.emitStore(fieldAddr1, fieldAddr2);
+ }
+ else
+ {
+ builder.emitStore(fieldAddr2, fieldAddr1);
+ }
+ }
+ }
+
+ public:
+ void setNode(IRInst* newNode)
+ {
+ node = newNode;
+ }
+ // Get 'MapStructToFlatStruct' that is a child of 'parent'.
+ // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'.
+ MapStructToFlatStruct& getMember(IRStructField* member)
+ {
+ return members[member];
+ }
+ MapStructToFlatStruct& operator[](IRStructField* member)
+ {
+ return getMember(member);
+ }
+
+ void setMapping(IRStructField* newTargetMapping)
+ {
+ targetMapping = newTargetMapping;
+ }
+ // Get 'MapStructToFlatStruct' that is a child of 'parent'.
+ // Return nullptr if no member is mapped to 'parent'
+ IRStructField* getMapping()
+ {
+ return targetMapping;
+ }
+
+ // Copies srcVal into dstVal using hicharchy map.
+ template<int copyOptions>
+ void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal)
+ {
+ auto dstType = dstVal->getDataType();
+ if (auto dstPtrType = as<IRPtrTypeBase>(dstType))
+ dstType = dstPtrType->getValueType();
+ auto dstStructType = as<IRStructType>(dstType);
+ SLANG_ASSERT(dstStructType);
+
+ auto srcType = srcVal->getDataType();
+ if (auto srcPtrType = as<IRPtrTypeBase>(srcType))
+ srcType = srcPtrType->getValueType();
+ auto srcStructType = as<IRStructType>(srcType);
+ SLANG_ASSERT(srcStructType);
+
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a struct
+ SLANG_ASSERT(node == dstStructType);
+ _emitCopy<copyOptions>(builder, dstVal, dstStructType, srcVal, srcStructType, *this);
+ }
+ else
+ {
+ // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct
+ SLANG_ASSERT(node == srcStructType);
+ _emitCopy<copyOptions>(builder, srcVal, srcStructType, dstVal, dstStructType, *this);
+ }
+ }
+ };
+
+ IRStructType* _flattenNestedStructs(IRBuilder& builder, IRStructType* dst, IRStructType* src, IRSemanticDecoration* parentSemanticDecoration,
+ IRLayoutDecoration* parentLayout, MapStructToFlatStruct& mapFieldToField, HashSet<IRStructField*>& varsWithSemanticInfo)
{
- result.metalSystemValueName = toSlice("sample_id");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ // For all fields ('oldField') of a struct do the following:
+ // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration, IRLayoutDecoration), store these if found.
+ // * Do not propagate semantic info if the current node has *any* form of semantic information.
+ // Update varsWithSemanticInfo.
+ // 2. If IRStructType:
+ // 2a. Recurse this function with 'decorations that carry semantic info' from parent.
+ // 3. If not IRStructType:
+ // 3a. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic info'.
+ // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is needed to copy between types.
+ for (auto oldField : src->getFields())
+ {
+ auto& fieldMappingNode = mapFieldToField[oldField];
+ fieldMappingNode.setNode(oldField);
+
+ // step 1
+ bool foundSemanticDecor = false;
+ auto oldKey = oldField->getKey();
+ IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration;
+ if (auto oldSemanticDecoration = oldKey->findDecoration<IRSemanticDecoration>())
+ {
+ foundSemanticDecor = true;
+ fieldSemanticDecoration = oldSemanticDecoration;
+ parentLayout = nullptr;
+ }
+
+ IRLayoutDecoration* fieldLayout = parentLayout;
+ if (auto oldLayout = oldKey->findDecoration<IRLayoutDecoration>())
+ {
+ fieldLayout = oldLayout;
+ if (!foundSemanticDecor)
+ fieldSemanticDecoration = nullptr;
+ }
+ if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout)
+ varsWithSemanticInfo.add(oldField);
+
+ // step 2a
+ if (auto structFieldType = as<IRStructType>(oldField->getFieldType()))
+ {
+ _flattenNestedStructs(builder, dst, structFieldType, fieldSemanticDecoration, fieldLayout, fieldMappingNode, varsWithSemanticInfo);
+ continue;
+ }
+
+ // step 3a
+ auto newKey = builder.createStructKey();
+ copyNameHintAndDebugDecorations(newKey, oldKey);
+
+ auto newField = builder.createStructField(dst, newKey, oldField->getFieldType());
+ copyNameHintAndDebugDecorations(newField, oldField);
+
+ if (fieldSemanticDecoration)
+ builder.addSemanticDecoration(newKey, fieldSemanticDecoration->getSemanticName(), fieldSemanticDecoration->getSemanticIndex());
+
+ if (fieldLayout)
+ {
+ IRLayout* oldLayout = fieldLayout->getLayout();
+ List<IRInst*> instToCopy;
+ // Only copy certain decorations needed for resolving system semantics
+ for (UInt i = 0; i < oldLayout->getOperandCount(); i++)
+ {
+ auto operand = oldLayout->getOperand(i);
+ if (as<IRVarOffsetAttr>(operand) || as<IRUserSemanticAttr>(operand) || as<IRSystemValueSemanticAttr>(operand) || as<IRStageAttr>(operand))
+ instToCopy.add(operand);
+ }
+ IRVarLayout* newLayout = builder.getVarLayout(instToCopy);
+ builder.addLayoutDecoration(newKey, newLayout);
+ }
+ // step 3b
+ fieldMappingNode.setMapping(newField);
+ }
+
+ return dst;
}
- else if (semanticName == "sv_stencilref")
+
+ // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there was no struct flattening.
+ // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct `IRStructFields`s
+ IRStructType* maybeFlattenNestedStructs(IRBuilder& builder, IRStructType* src, MapStructToFlatStruct& mapFieldToField, HashSet<IRStructField*>& varsWithSemanticInfo)
{
- result.metalSystemValueName = toSlice("stencil");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ // Find all values inside struct that need flattening and legalization.
+ bool hasStructTypeMembers = false;
+ for (auto field : src->getFields())
+ {
+ if (as<IRStructType>(field->getFieldType()))
+ {
+ hasStructTypeMembers = true;
+ break;
+ }
+ }
+ if (!hasStructTypeMembers)
+ return src;
+
+ // We need to:
+ // 1. Make new struct 1:1 with old struct but without nestested structs (flatten)
+ // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be handled later).
+ // 3. Store the mapping from old to new struct fields to allow copying a old-struct to new-struct.
+ builder.setInsertAfter(src);
+ auto newStruct = builder.createStructType();
+ copyNameHintAndDebugDecorations(newStruct, src);
+ mapFieldToField.setNode(src);
+ return _flattenNestedStructs(builder, newStruct, src, nullptr, nullptr, mapFieldToField, varsWithSemanticInfo);
}
- else if (semanticName == "sv_tessfactor")
+
+ // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'.
+ // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function.
+ template<typename CopyLogicFunc>
+ void _replaceAllReturnInst(IRBuilder& builder, IRFunc* targetFunc, IRStructType* newType, CopyLogicFunc copyLogicFunc)
{
- // Tessellation factor outputs should be lowered into a write into a normal buffer.
+ for (auto block : targetFunc->getBlocks())
+ {
+ if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ {
+ builder.setInsertBefore(returnInst);
+ auto returnVal = returnInst->getVal();
+ returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal));
+ }
+ }
}
- else if (semanticName == "sv_vertexid")
+
+ UInt _returnNonOverlappingAttributeIndex(std::set<UInt>& usedSemanticIndex)
{
- result.metalSystemValueName = toSlice("vertex_id");
- result.requiredType = builder.getBasicType(BaseType::UInt);
+ // Find first unused semantic index of equal semantic type
+ // to fill any gaps in user set semantic bindings
+ UInt prev = 0;
+ for (auto i : usedSemanticIndex)
+ {
+ if (i > prev + 1)
+ {
+ break;
+ }
+ prev = i;
+ }
+ usedSemanticIndex.insert(prev + 1);
+ return prev + 1;
}
- else if (semanticName == "sv_viewid")
+
+ template<typename T>
+ struct AttributeParentPair
{
- result.isUnsupported = true;
- }
- else if (semanticName == "sv_viewportarrayindex")
+ IRLayoutDecoration* layoutDecor;
+ T* attr;
+ };
+
+ IRLayoutDecoration* _replaceAttributeOfLayout(IRBuilder& builder, IRLayoutDecoration* parentLayoutDecor, IRInst* instToReplace, IRInst* instToReplaceWith)
{
- result.metalSystemValueName = toSlice("viewport_array_index");
- result.requiredType = builder.getBasicType(BaseType::UInt);
- result.altRequiredType = builder.getBasicType(BaseType::UInt16);
+ // Replace `instToReplace` with a `instToReplaceWith`
+
+ auto layout = parentLayoutDecor->getLayout();
+ // Find the exact same decoration `instToReplace` in-case multiple of the same type exist
+ List<IRInst*> opList;
+ opList.add(instToReplaceWith);
+ for (UInt i = 0; i < layout->getOperandCount(); i++)
+ {
+ if (layout->getOperand(i) != instToReplace)
+ opList.add(layout->getOperand(i));
+ }
+ auto newLayoutDecor = builder.addLayoutDecoration(parentLayoutDecor->getParent(), builder.getVarLayout(opList));
+ parentLayoutDecor->removeAndDeallocate();
+ return newLayoutDecor;
}
- else if (semanticName.startsWith("sv_target"))
+
+ IRLayoutDecoration* _simplifyUserSemanticNames(IRBuilder& builder, IRLayoutDecoration* layoutDecor)
{
-
- result.metalSystemValueName = (StringBuilder() << "color("
- << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0"))
- << ")").produceString();
+ // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into ("SV_TARGET", 0) using 'IRUserSemanticAttr'
+ // This is done to ensure we can check semantic groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()'
+ SLANG_ASSERT(layoutDecor);
+ auto layout = layoutDecor->getLayout();
+ List<IRInst*> layoutOps;
+ layoutOps.reserve(3);
+ bool changed = false;
+ for(auto attr : layout->getAllAttrs())
+ {
+ if(auto userSemantic = as<IRUserSemanticAttr>(attr))
+ {
+ UnownedStringSlice outName;
+ UnownedStringSlice outIndex;
+ bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex);
+ if (hasStringIndex)
+ {
+ changed = true;
+ auto loweredName = String(outName).toLower();
+ auto loweredNameSlice = loweredName.getUnownedSlice();
+ auto newDecoration = builder.getUserSemanticAttr(loweredNameSlice, stringToInt(outIndex));
+ userSemantic->replaceUsesWith(newDecoration);
+ userSemantic->removeAndDeallocate();
+ userSemantic = newDecoration;
+ }
+ layoutOps.add(userSemantic);
+ continue;
+ }
+ layoutOps.add(attr);
+ }
+ if(changed)
+ {
+ auto parent = layoutDecor->parent;
+ layoutDecor->removeAndDeallocate();
+ builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps));
+ }
+ return layoutDecor;
}
- else
+ // Find overlapping field semantics and legalize them
+ void fixFieldSemanticsOfFlatStruct(IRStructType* structType)
{
- result.isUnsupported = true;
- }
- return result;
- }
+ // Goal is to ensure we do not have overlapping semantics:
+ /*
+ // Assume the following code
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET;
+ float3 p2 : SV_TARGET;
+ float2 p3 : SV_TARGET;
+ float2 p4 : SV_TARGET;
+ };
+
+ // Translates into
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ float3 p2 : SV_TARGET1;
+ float2 p3 : SV_TARGET2;
+ float2 p4 : SV_TARGET3;
+ };
+ */
- void reportUnsupportedSystemAttribute(DiagnosticSink* sink, IRInst* param, String semanticName)
- {
- sink->diagnose(param->sourceLoc, Diagnostics::systemValueAttributeNotSupported, semanticName);
- }
+ IRBuilder builder(this->m_module);
- void ensureResultStructHasUserSemantic(DiagnosticSink* sink, IRStructType* structType, IRVarLayout* varLayout)
- {
- // Ensure each field in an output struct type has either a system semantic or a user semantic,
- // so that signature matching can happen correctly.
- auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
- Index index = 0;
- IRBuilder builder(structType);
- for (auto field : structType->getFields())
- {
- auto key = field->getKey();
- if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ List<IRSemanticDecoration*> overlappingSemanticsDecor;
+ Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>> usedSemanticIndexSemanticDecor;
+
+ List<AttributeParentPair<IRVarOffsetAttr>> overlappingVarOffset;
+ Dictionary<UInt, std::set<UInt, std::less<UInt>>> usedSemanticIndexVarOffset;
+
+ List<AttributeParentPair<IRUserSemanticAttr>> overlappingUserSemantic;
+ Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt> >> usedSemanticIndexUserSemantic;
+
+ // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when legalizing
+ // we may destroy and remake a `IRLayoutDecoration*`
+ Dictionary<IRLayoutDecoration*, IRLayoutDecoration*> oldLayoutDecorToNew;
+
+ // Collect all "semantic info carrying decorations". Any collected decoration will
+ // fill up their respective 'Dictionary<SEMANTIC_TYPE, OrderedHashSet<UInt>>'
+ // to keep track of in-use offsets for a semantic type.
+ // Example: IRSemanticDecoration with name of "SV_TARGET1".
+ // * This will have SEMANTIC_TYPE of "sv_target".
+ // * This will use up index '1'
+ //
+ // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to
+ // a list of 'overlapping semantic info decorations' so we can legalize this
+ // 'semantic info decoration' later.
+ //
+ // NOTE: this is a flat struct, all members are children of the initial
+ // IRStructType.
+ for (auto field : structType->getFields())
{
- if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ auto key = field->getKey();
+ if (auto semanticDecoration = key->findDecoration<IRSemanticDecoration>())
{
- auto sysValInfo = getSystemValueInfo(builder, semanticDecor->getSemanticName());
- if (sysValInfo.isUnsupported || sysValInfo.isSpecial)
+ // Ensure names are in a uniform lowercase format so we can bunch together simmilar semantics
+ UnownedStringSlice outName;
+ UnownedStringSlice outIndex;
+ bool hasStringIndex = splitNameAndIndex(semanticDecoration->getSemanticName(), outName, outIndex);
+ if (hasStringIndex)
{
- reportUnsupportedSystemAttribute(sink, field, semanticDecor->getSemanticName());
+ auto loweredName = String(outName).toLower();
+ auto loweredNameSlice = loweredName.getUnownedSlice();
+ auto newDecoration = builder.addSemanticDecoration(key, loweredNameSlice, stringToInt(outIndex));
+ semanticDecoration->replaceUsesWith(newDecoration);
+ semanticDecoration->removeAndDeallocate();
+ semanticDecoration = newDecoration;
}
+ auto& semanticUse = usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()];
+ if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end())
+ overlappingSemanticsDecor.add(semanticDecoration);
else
+ semanticUse.insert(semanticDecoration->getSemanticIndex());
+ }
+ if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
+ {
+ // Ensure names are in a uniform lowercase format so we can bunch together simmilar semantics
+ layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor);
+ oldLayoutDecorToNew[layoutDecor] = layoutDecor;
+ auto layout = layoutDecor->getLayout();
+ for (auto attr : layout->getAllAttrs())
{
- builder.addTargetSystemValueDecoration(key, sysValInfo.metalSystemValueName.getUnownedSlice());
- semanticDecor->removeAndDeallocate();
+ if (auto offset = as<IRVarOffsetAttr>(attr))
+ {
+ auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()];
+ if (semanticUse.find(offset->getOffset()) != semanticUse.end())
+ overlappingVarOffset.add({ layoutDecor, offset });
+ else
+ semanticUse.insert(offset->getOffset());
+ }
+ else if (auto userSemantic = as<IRUserSemanticAttr>(attr))
+ {
+ auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()];
+ if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end())
+ overlappingUserSemantic.add({ layoutDecor, userSemantic });
+ else
+ semanticUse.insert(userSemantic->getIndex());
+ }
}
}
- index++;
- continue;
}
- typeLayout->getFieldLayout(index);
- auto fieldLayout = typeLayout->getFieldLayout(index);
- if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
+
+ // Legalize all overlapping 'semantic info decorations'
+ for (auto decor : overlappingSemanticsDecor)
{
- UInt varOffset = 0;
- if (auto varOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
- varOffset = varOffsetAttr->getOffset();
- varOffset += offsetAttr->getOffset();
- builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset);
+ auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexSemanticDecor[decor->getSemanticName()]);
+ builder.addSemanticDecoration(decor->getParent(), decor->getSemanticName(), (int)newOffset);
+ decor->removeAndDeallocate();
+ }
+ for (auto& varOffset : overlappingVarOffset)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]);
+ auto newVarOffset = builder.getVarOffsetAttr(varOffset.attr->getResourceKind(), newOffset, varOffset.attr->getSpace());
+ oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout(builder, oldLayoutDecorToNew[varOffset.layoutDecor], varOffset.attr, newVarOffset);
+ }
+ for (auto& userSemantic : overlappingUserSemantic)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(usedSemanticIndexUserSemantic[userSemantic.attr->getName()]);
+ auto newUserSemantic = builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset);
+ oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout(builder, oldLayoutDecorToNew[userSemantic.layoutDecor], userSemantic.attr, newUserSemantic);
}
- index++;
- }
- }
-
-
- void wrapReturnValueInStruct(DiagnosticSink* sink, EntryPointInfo entryPoint)
- {
- // Wrap return value into a struct if it is not already a struct.
- // For example, given this entry point:
- // ```
- // float4 main() : SV_Target { return float3(1,2,3); }
- // ```
- // We are going to transform it into:
- // ```
- // struct Output {
- // float4 value : SV_Target;
- // };
- // Output main() { return {float3(1,2,3)}; }
-
- auto func = entryPoint.entryPointFunc;
-
- auto returnType = func->getResultType();
- if (as<IRVoidType>(returnType))
- return;
- auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>();
- if (!entryPointLayoutDecor)
- return;
- auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout());
- if (!entryPointLayout)
- return;
- auto resultLayout = entryPointLayout->getResultLayout();
-
- // If return type is already a struct, just make sure every field has a semantic.
- if (auto returnStructType = as<IRStructType>(returnType))
- {
- ensureResultStructHasUserSemantic(sink, returnStructType, resultLayout);
- return;
}
- // If not, we need to wrap the result into a struct type.
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- IRStructType* structType = builder.createStructType();
- auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
- builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Output")).getUnownedSlice());
- auto key = builder.createStructKey();
- builder.addNameHintDecoration(key, toSlice("output"));
- builder.addLayoutDecoration(key, resultLayout);
- builder.addTargetSystemValueDecoration(key, toSlice("color(0)"));
- builder.createStructField(structType, key, returnType);
- IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder);
- structTypeLayoutBuilder.addField(key, resultLayout);
- auto typeLayout = structTypeLayoutBuilder.build();
- IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
- auto varLayout = varLayoutBuilder.build();
- ensureResultStructHasUserSemantic(sink, structType, varLayout);
-
- for (auto block : func->getBlocks())
+ void wrapReturnValueInStruct(EntryPointInfo entryPoint)
{
- if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ // Wrap return value into a struct if it is not already a struct.
+ // For example, given this entry point:
+ // ```
+ // float4 main() : SV_Target { return float3(1,2,3); }
+ // ```
+ // We are going to transform it into:
+ // ```
+ // struct Output {
+ // float4 value : SV_Target;
+ // };
+ // Output main() { return {float3(1,2,3)}; }
+
+ auto func = entryPoint.entryPointFunc;
+
+ auto returnType = func->getResultType();
+ if (as<IRVoidType>(returnType))
+ return;
+ auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>();
+ if (!entryPointLayoutDecor)
+ return;
+ auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout());
+ if (!entryPointLayout)
+ return;
+ auto resultLayout = entryPointLayout->getResultLayout();
+
+ // If return type is already a struct, just make sure every field has a semantic.
+ if (auto returnStructType = as<IRStructType>(returnType))
{
- builder.setInsertBefore(returnInst);
- auto returnVal = returnInst->getVal();
- auto newResult = builder.emitMakeStruct(structType, 1, &returnVal);
- returnInst->setOperand(0, newResult);
+ IRBuilder builder(func);
+ MapStructToFlatStruct mapOldFieldToNewField;
+ // Flatten result struct type to ensure we do not have nested semantics
+ auto flattenedStruct = maybeFlattenNestedStructs(builder, returnStructType, mapOldFieldToNewField, semanticInfoToRemove);
+ if (returnStructType != flattenedStruct)
+ {
+ // Replace all return-values with the flattenedStruct we made.
+ _replaceAllReturnInst(builder, func, flattenedStruct,
+ [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
+ {
+ auto srcStructType = as<IRStructType>(srcVal->getDataType());
+ SLANG_ASSERT(srcStructType);
+ auto dstVal = copyBuilder.emitVar(dstType);
+ mapOldFieldToNewField.emitCopy<(int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>(copyBuilder, dstVal, srcVal);
+ return builder.emitLoad(dstVal);
+ }
+ );
+ fixUpFuncType(func, flattenedStruct);
+ }
+ // Ensure non-overlapping semantics
+ fixFieldSemanticsOfFlatStruct(flattenedStruct);
+ ensureResultStructHasUserSemantic(flattenedStruct, resultLayout);
+ return;
}
- }
- fixUpFuncType(func, structType);
- }
- void legalizeMeshEntryPoint(EntryPointInfo entryPoint)
- {
- auto func = entryPoint.entryPointFunc;
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ IRStructType* structType = builder.createStructType();
+ auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
+ builder.addNameHintDecoration(structType, (String(stageText) + toSlice("Output")).getUnownedSlice());
+ auto key = builder.createStructKey();
+ builder.addNameHintDecoration(key, toSlice("output"));
+ builder.addLayoutDecoration(key, resultLayout);
+ builder.createStructField(structType, key, returnType);
+ IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder);
+ structTypeLayoutBuilder.addField(key, resultLayout);
+ auto typeLayout = structTypeLayoutBuilder.build();
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+ auto varLayout = varLayoutBuilder.build();
+ ensureResultStructHasUserSemantic(structType, varLayout);
+
+ _replaceAllReturnInst(builder, func, structType,
+ [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
+ {
+ return copyBuilder.emitMakeStruct(dstType, 1, &srcVal);
+ }
+ );
- if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh)
- {
- return;
+ // Assign an appropriate system value semantic for stage output
+ auto stage = entryPoint.entryPointDecor->getProfile().getStage();
+ switch (stage)
+ {
+ case Stage::Compute:
+ case Stage::Fragment:
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("color(0)"));
+ break;
+ }
+ case Stage::Vertex:
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("position"));
+ break;
+ }
+ default:
+ SLANG_ASSERT(false);
+ return;
+ }
+
+ fixUpFuncType(func, structType);
}
- IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
- for (auto param : func->getParams())
+ void legalizeMeshEntryPoint(EntryPointInfo entryPoint)
{
- if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- {
- IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+ auto func = entryPoint.entryPointFunc;
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(param, paramVarLayout);
+ if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh)
+ {
+ return;
}
- }
- }
+ IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
+ for (auto param : func->getParams())
+ {
+ if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ {
+ IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(param, paramVarLayout);
+ }
+ }
- void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint)
- {
- if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification)
- {
- return;
}
- // Find out DispatchMesh function
- IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
- for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
+
+ void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint)
{
- if (const auto func = as<IRGlobalValueWithCode>(globalInst))
+ if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification)
+ {
+ return;
+ }
+ // Find out DispatchMesh function
+ IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
+ for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
{
- if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
+ if (const auto func = as<IRGlobalValueWithCode>(globalInst))
{
- if (dec->getName() == "DispatchMesh")
+ if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
{
- SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
- dispatchMeshFunc = func;
+ if (dec->getName() == "DispatchMesh")
+ {
+ SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
+ dispatchMeshFunc = func;
+ }
}
}
}
- }
-
- if (!dispatchMeshFunc)
- return;
- IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
+ if (!dispatchMeshFunc)
+ return;
- // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid
- traverseUses(dispatchMeshFunc, [&](const IRUse* use) {
- if (const auto call = as<IRCall>(use->getUser()))
- {
- SLANG_ASSERT(call->getArgCount() == 4);
- const auto payload = call->getArg(3);
+ IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
- const auto payloadPtrType = composeGetters<IRPtrType>(
- payload,
- &IRInst::getDataType
- );
- SLANG_ASSERT(payloadPtrType);
- const auto payloadType = payloadPtrType->getValueType();
- SLANG_ASSERT(payloadType);
+ // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid
+ traverseUses(dispatchMeshFunc, [&](const IRUse* use) {
+ if (const auto call = as<IRCall>(use->getUser()))
+ {
+ SLANG_ASSERT(call->getArgCount() == 4);
+ const auto payload = call->getArg(3);
- builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
- const auto annotatedPayloadType =
- builder.getPtrType(
- kIROp_RefType,
- payloadPtrType->getValueType(),
- AddressSpace::MetalObjectData
+ const auto payloadPtrType = composeGetters<IRPtrType>(
+ payload,
+ &IRInst::getDataType
);
- auto packedParam = builder.emitParam(annotatedPayloadType);
- builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload"));
- IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
-
- // Add the MetalPayload resource info, so we can emit [[payload]]
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(packedParam, paramVarLayout);
-
- // Now we replace the call to DispatchMesh with a call to the mesh grid properties
- // But first we need to create the parameter
- const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
- auto mgp = builder.emitParam(meshGridPropertiesType);
- builder.addExternCppDecoration(mgp, toSlice("_slang_mgp"));
- }
- });
- }
+ SLANG_ASSERT(payloadPtrType);
+ const auto payloadType = payloadPtrType->getValueType();
+ SLANG_ASSERT(payloadType);
+
+ builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+ const auto annotatedPayloadType =
+ builder.getPtrType(
+ kIROp_RefType,
+ payloadPtrType->getValueType(),
+ AddressSpace::MetalObjectData
+ );
+ auto packedParam = builder.emitParam(annotatedPayloadType);
+ builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload"));
+ IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+
+ // Add the MetalPayload resource info, so we can emit [[payload]]
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(packedParam, paramVarLayout);
+
+ // Now we replace the call to DispatchMesh with a call to the mesh grid properties
+ // But first we need to create the parameter
+ const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
+ auto mgp = builder.emitParam(meshGridPropertiesType);
+ builder.addExternCppDecoration(mgp, toSlice("_slang_mgp"));
+ }
+ });
+ }
- IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
- {
- auto fromType = val->getFullType();
- if (auto fromVector = as<IRVectorType>(fromType))
+ IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
{
- if (auto toVector = as<IRVectorType>(toType))
+ auto fromType = val->getFullType();
+ if (auto fromVector = as<IRVectorType>(fromType))
{
- if (fromVector->getElementCount() != toVector->getElementCount())
+ if (auto toVector = as<IRVectorType>(toType))
+ {
+ if (fromVector->getElementCount() != toVector->getElementCount())
+ {
+ fromType = builder.getVectorType(fromVector->getElementType(), toVector->getElementCount());
+ val = builder.emitVectorReshape(fromType, val);
+ }
+ }
+ else if (as<IRBasicType>(toType))
{
- fromType = builder.getVectorType(fromVector->getElementType(), toVector->getElementCount());
- val = builder.emitVectorReshape(fromType, val);
+ UInt index = 0;
+ val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
+ if (toType->getOp() == kIROp_VoidType)
+ return nullptr;
}
}
- else if (as<IRBasicType>(toType))
+ else if (auto fromBasicType = as<IRBasicType>(fromType))
{
- UInt index = 0;
- val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
+ if (fromBasicType->getOp() == kIROp_VoidType)
+ return nullptr;
+ if (!as<IRBasicType>(toType))
+ return nullptr;
if (toType->getOp() == kIROp_VoidType)
return nullptr;
}
- }
- else if (auto fromBasicType = as<IRBasicType>(fromType))
- {
- if (fromBasicType->getOp() == kIROp_VoidType)
- return nullptr;
- if (!as<IRBasicType>(toType))
- return nullptr;
- if (toType->getOp() == kIROp_VoidType)
+ else
+ {
return nullptr;
+ }
+ return builder.emitCast(toType, val);
}
- else
- {
- return nullptr;
- }
- return builder.emitCast(toType, val);
- }
-
- void legalizeSystemValueParameters(EntryPointInfo entryPoint, DiagnosticSink* sink)
- {
- SLANG_UNUSED(sink);
struct SystemValLegalizationWorkItem
{
- IRParam* param;
+ IRInst* var;
String attrName;
UInt attrIndex;
};
- List<SystemValLegalizationWorkItem> systemValWorkItems;
-
- IRBuilder builder(entryPoint.entryPointFunc);
- for (auto param : entryPoint.entryPointFunc->getParams())
+ std::optional<SystemValLegalizationWorkItem> tryToMakeSystemValWorkItem(IRInst* var)
{
- if (auto semanticDecoration = param->findDecoration<IRSemanticDecoration>())
+ if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>())
{
if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
{
- systemValWorkItems.add({ param, String(semanticDecoration->getSemanticName()).toLower(), (UInt)semanticDecoration->getSemanticIndex() });
- continue;
+ return { { var, String(semanticDecoration->getSemanticName()).toLower(), (UInt)semanticDecoration->getSemanticIndex() } };
}
}
- auto layoutDecor = param->findDecoration<IRLayoutDecoration>();
+ auto layoutDecor = var->findDecoration<IRLayoutDecoration>();
if (!layoutDecor)
- continue;
+ return {};
auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>();
if (!sysValAttr)
- continue;
+ return {};
auto semanticName = String(sysValAttr->getName());
auto sysAttrIndex = sysValAttr->getIndex();
- systemValWorkItems.add({ param, semanticName, sysAttrIndex });
+
+ return { { var, semanticName, sysAttrIndex } };
+ }
+
+
+ List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint)
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems;
+ for (auto param : entryPoint.entryPointFunc->getParams())
+ {
+ auto maybeWorkItem = tryToMakeSystemValWorkItem(param);
+ if (maybeWorkItem.has_value())
+ systemValWorkItems.add(std::move(maybeWorkItem.value()));
+ }
+ return systemValWorkItems;
}
- IRParam* groupThreadId = nullptr;
- for (auto index = 0; index < systemValWorkItems.getCount(); index++)
+ void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem)
{
- auto workItem = systemValWorkItems[index];
+ IRBuilder builder(entryPoint.entryPointFunc);
- auto param = workItem.param;
+ auto var = workItem.var;
auto semanticName = workItem.attrName;
- auto info = getSystemValueInfo(builder, semanticName);
+ auto indexAsString = String(workItem.attrIndex);
+ auto info = getSystemValueInfo(semanticName, &indexAsString, var);
+
if (info.isSpecial)
{
- if (semanticName == "sv_innercoverage")
+ if (info.metalSystemValueNameEnum == SystemValueSemanticName::InnerCoverage)
{
// Metal does not support conservative rasterization, so this is always false.
auto val = builder.getBoolValue(false);
- param->replaceUsesWith(val);
- param->removeAndDeallocate();
+ var->replaceUsesWith(val);
+ var->removeAndDeallocate();
}
- else if (semanticName == "sv_groupindex")
+ else if (info.metalSystemValueNameEnum == SystemValueSemanticName::GroupIndex)
{
- // Ensure we have a cached "sv_groupthreadid"
- if (!groupThreadId)
+ // Ensure we have a cached "sv_groupthreadid" in our entry point
+ if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc))
{
+ auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint);
for (auto i : systemValWorkItems)
{
- if (i.attrName == groupThreadIDString)
+ auto indexAsStringGroupThreadId = String(i.attrIndex);
+ if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var).metalSystemValueNameEnum == SystemValueSemanticName::GroupThreadID)
{
- groupThreadId = i.param;
+ entryPointToGroupThreadId[entryPoint.entryPointFunc] = i.var;
}
}
- if (!groupThreadId)
+ if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc))
{
// Add the missing groupthreadid needed to compute sv_groupindex
IRBuilder groupThreadIdBuilder(builder);
groupThreadIdBuilder.setInsertInto(entryPoint.entryPointFunc->getFirstBlock());
- groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder));
+ auto groupThreadId = groupThreadIdBuilder.emitParamAtHead(getGroupThreadIdType(groupThreadIdBuilder));
+ entryPointToGroupThreadId[entryPoint.entryPointFunc] = groupThreadId;
groupThreadIdBuilder.addNameHintDecoration(groupThreadId, groupThreadIDString);
-
+
// Since "sv_groupindex" will be translated out to a global var and no longer be considered a system value
// we can reuse its layout and semantic info
Index foundRequiredDecorations = 0;
IRLayoutDecoration* layoutDecoration = nullptr;
UInt semanticIndex = 0;
- for (auto decoration : param->getDecorations())
+ for (auto decoration : var->getDecorations())
{
if (auto layoutDecorationTmp = as<IRLayoutDecoration>(decoration))
{
@@ -731,67 +1502,109 @@ namespace Slang
SLANG_ASSERT(layoutDecoration);
layoutDecoration->removeFromParent();
layoutDecoration->insertAtStart(groupThreadId);
- systemValWorkItems.add({ groupThreadId, groupThreadIDString, semanticIndex });
+ SystemValLegalizationWorkItem newWorkItem = { groupThreadId, groupThreadIDString, semanticIndex };
+ legalizeSystemValue(entryPoint, newWorkItem);
}
}
IRBuilder svBuilder(builder.getModule());
svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst());
auto computeExtent = emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, builder.getVectorType(builder.getUIntType(), builder.getIntValue(builder.getIntType(), 3)));
- auto groupIndexCalc = emitCalcGroupThreadIndex(svBuilder, groupThreadId, computeExtent);
+ auto groupIndexCalc = emitCalcGroupIndex(svBuilder, entryPointToGroupThreadId[entryPoint.entryPointFunc], computeExtent);
svBuilder.addNameHintDecoration(groupIndexCalc, UnownedStringSlice("sv_groupindex"));
- param->replaceUsesWith(groupIndexCalc);
- param->removeAndDeallocate();
+ var->replaceUsesWith(groupIndexCalc);
+ var->removeAndDeallocate();
}
}
if (info.isUnsupported)
{
- reportUnsupportedSystemAttribute(sink, param, semanticName);
- continue;
+ reportUnsupportedSystemAttribute(var, semanticName);
+ return;
}
- if (!info.requiredType)
- continue;
+ if (!info.permittedTypes.getCount())
+ return;
- builder.addTargetSystemValueDecoration(param, info.metalSystemValueName.getUnownedSlice());
+ builder.addTargetSystemValueDecoration(var, info.metalSystemValueName.getUnownedSlice());
- // If the required type is different from the actual type, we need to insert a conversion.
- if (info.requiredType != param->getFullType() && info.altRequiredType != param->getFullType())
+ bool varTypeIsPermitted = false;
+ auto varType = var->getFullType();
+ for (auto& permittedType : info.permittedTypes)
{
- auto targetType = param->getFullType();
- builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
- param->setFullType(info.requiredType);
- List<IRUse*> uses;
- for (auto use = param->firstUse; use; use = use->nextUse)
- uses.add(use);
- auto convertedValue = tryConvertValue(builder, param, targetType);
- copyNameHintAndDebugDecorations(convertedValue, param);
- if (!convertedValue)
- {
- // If we can't convert the value, report an error.
- StringBuilder typeNameSB;
- getTypeNameHint(typeNameSB, info.requiredType);
- sink->diagnose(param->sourceLoc, Diagnostics::systemValueTypeIncompatible, semanticName, typeNameSB.produceString());
- }
- else
+ varTypeIsPermitted = varTypeIsPermitted || permittedType == varType;
+ }
+
+ if (!varTypeIsPermitted)
+ {
+ // Note: we do not currently prefer any conversion
+ // example:
+ // * allowed types for semantic: `float4`, `uint4`, `int4`
+ // * user used, `float2`
+ // * Slang will equally prefer `float4` to `uint4` to `int4`.
+ // This means the type may lose data if slang selects `uint4` or `int4`.
+ bool foundAConversion = false;
+ for (auto permittedType : info.permittedTypes)
{
+ var->setFullType(permittedType);
+ builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+
+ // get uses before we `tryConvertValue` since this creates a new use
+ List<IRUse*> uses;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ uses.add(use);
+
+ auto convertedValue = tryConvertValue(builder, var, varType);
+ if (convertedValue == nullptr)
+ continue;
+
+ foundAConversion = true;
+ copyNameHintAndDebugDecorations(convertedValue, var);
+
for (auto use : uses)
builder.replaceOperand(use, convertedValue);
}
+ if (!foundAConversion)
+ {
+ // If we can't convert the value, report an error.
+ for (auto permittedType : info.permittedTypes)
+ {
+ StringBuilder typeNameSB;
+ getTypeNameHint(typeNameSB, permittedType);
+ m_sink->diagnose(var->sourceLoc, Diagnostics::systemValueTypeIncompatible, semanticName, typeNameSB.produceString());
+ }
+ }
}
}
- fixUpFuncType(entryPoint.entryPointFunc);
- }
- void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink)
- {
- hoistEntryPointParameterFromStruct(entryPoint);
- packStageInParameters(entryPoint);
- legalizeSystemValueParameters(entryPoint, sink);
- wrapReturnValueInStruct(sink, entryPoint);
- legalizeMeshEntryPoint(entryPoint);
- legalizeDispatchMeshPayloadForMetal(entryPoint);
- }
+ void legalizeSystemValueParameters(EntryPointInfo entryPoint)
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems = collectSystemValFromEntryPoint(entryPoint);
+
+ for (auto index = 0; index < systemValWorkItems.getCount(); index++)
+ {
+ legalizeSystemValue(entryPoint, systemValWorkItems[index]);
+ }
+ fixUpFuncType(entryPoint.entryPointFunc);
+ }
+
+ void legalizeEntryPointForMetal(EntryPointInfo entryPoint)
+ {
+ //Input Parameter Legalize
+ hoistEntryPointParameterFromStruct(entryPoint);
+ packStageInParameters(entryPoint);
+ flattenInputParameters(entryPoint);
+
+ //System Value Legalize
+ legalizeSystemValueParameters(entryPoint);
+
+ //Output Value Legalize
+ wrapReturnValueInStruct(entryPoint);
+
+ //Other Legalize
+ legalizeMeshEntryPoint(entryPoint);
+ legalizeDispatchMeshPayloadForMetal(entryPoint);
+ }
+ };
void legalizeFuncBody(IRFunc* func)
{
@@ -871,8 +1684,10 @@ namespace Slang
}
}
+ LegalizeMetalEntryPointContext context(sink, module);
for (auto entryPoint : entryPoints)
- legalizeEntryPointForMetal(entryPoint, sink);
+ context.legalizeEntryPointForMetal(entryPoint);
+ context.removeSemanticLayoutsFromLegalizedStructs();
specializeAddressSpace(module);
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index a4ea100dc..c6409a7e1 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -6267,9 +6267,9 @@ namespace Slang
addDecoration(inst, kIROp_HighLevelDeclDecoration, ptrConst);
}
- void IRBuilder::addLayoutDecoration(IRInst* value, IRLayout* layout)
+ IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRInst* value, IRLayout* layout)
{
- addDecoration(value, kIROp_LayoutDecoration, layout);
+ return as<IRLayoutDecoration>(addDecoration(value, kIROp_LayoutDecoration, layout));
}
IRTypeSizeAttr* IRBuilder::getTypeSizeAttr(
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 6a1080b0e..046c35ef9 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -480,8 +480,8 @@ static bool isDigit(char c)
return (c >= '0') && (c <= '9');
}
-void splitNameAndIndex(
- UnownedStringSlice const& text,
+bool splitNameAndIndex(
+ UnownedStringSlice const& text,
UnownedStringSlice& outName,
UnownedStringSlice& outDigits)
{
@@ -489,14 +489,20 @@ void splitNameAndIndex(
char const* digitsEnd = text.end();
char const* nameEnd = digitsEnd;
+ // ExplicitIndex is when a semantic has an index at the end of its name
+ // "SV_TARGET1" has an ExplicitIndex
+ // "SV_TARGET" does not have an ExplicitIndex
+ bool hasExplicitIndex = false;
while( nameEnd != nameBegin && isDigit(*(nameEnd - 1)) )
{
+ hasExplicitIndex = true;
nameEnd--;
}
char const* digitsBegin = nameEnd;
outName = UnownedStringSlice(nameBegin, nameEnd);
outDigits = UnownedStringSlice(digitsBegin, digitsEnd);
+ return hasExplicitIndex;
}
LayoutResourceKind findRegisterClassFromName(UnownedStringSlice const& registerClassName)
diff --git a/source/slang/slang-parameter-binding.h b/source/slang/slang-parameter-binding.h
index 3c0949e37..e2530e34e 100644
--- a/source/slang/slang-parameter-binding.h
+++ b/source/slang/slang-parameter-binding.h
@@ -32,7 +32,7 @@ void generateParameterBindings(
/// Given a string that specifies a name and index (e.g., `COLOR0`),
/// split it into slices for the name part and the index part.
///
-void splitNameAndIndex(
+bool splitNameAndIndex(
UnownedStringSlice const& text,
UnownedStringSlice& outName,
UnownedStringSlice& outDigits);